Files
drb-c2-core/app/c2_main.py
2025-12-29 19:02:20 -05:00

275 lines
9.3 KiB
Python

import json
import os
import asyncio
from functools import partial
import traceback
from fastapi import FastAPI, HTTPException
import paho.mqtt.client as mqtt
from datetime import datetime
import firebase_admin
from firebase_admin import credentials, firestore
from pydantic import BaseModel
from typing import Any, Dict
app = FastAPI(title="Radio C2 Brain")
# Configuration
MQTT_BROKER = os.getenv("MQTT_BROKER", "mqtt-broker")
FIREBASE_CRED_JSON = os.getenv("FIREBASE_CRED_JSON")
FIRESTORE_DB_ID = os.getenv("FIRESTORE_DB_ID", "c2-server")
C2_ID = "central-brain-01"
# Database Init
if FIREBASE_CRED_JSON:
print("Initializing Firebase with provided JSON credentials...")
cred = credentials.Certificate(json.loads(FIREBASE_CRED_JSON))
firebase_admin.initialize_app(cred)
else:
print("Initializing Firebase with Application Default Credentials...")
firebase_admin.initialize_app()
print(f"Connecting to Firestore Database: {FIRESTORE_DB_ID}")
db = firestore.client(database_id=FIRESTORE_DB_ID)
# Local cache for quick lookups
ACTIVE_NODES_CACHE = {}
MAIN_LOOP = None
# Pydantic Models
class NodeCommand(BaseModel):
command: str
payload: Dict[str, Any]
# Helper for async execution of blocking firestore calls
async def async_firestore(func, *args, **kwargs):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, partial(func, *args, **kwargs))
def on_connect(client, userdata, flags, rc):
print(f"Brain connected to MQTT Broker with result code {rc}")
client.subscribe("nodes/+/checkin")
client.subscribe("nodes/+/status")
client.subscribe("nodes/+/metadata")
def on_message(client, userdata, msg):
if MAIN_LOOP:
asyncio.run_coroutine_threadsafe(handle_message(msg), MAIN_LOOP)
async def update_last_seen(node_id):
"""Generic helper to update the timestamp on any contact."""
try:
doc_ref = db.collection("nodes").document(node_id)
await async_firestore(doc_ref.set, {"last_seen": datetime.utcnow()}, merge=True)
except Exception as e:
print(f"Failed to update heartbeat for {node_id}: {e}")
async def handle_message(msg):
topic_parts = msg.topic.split('/')
if len(topic_parts) < 3: return
node_id = topic_parts[1]
event_type = topic_parts[2]
try:
payload = json.loads(msg.payload.decode())
timestamp = datetime.utcnow()
# 1. ALWAYS update last_seen if we hear from the node
await update_last_seen(node_id)
if event_type == "checkin":
# This now receives the periodic heartbeat
print(f"Heartbeat/Checkin from {node_id}")
data = {
"node_id": node_id,
"last_seen": timestamp,
"status": payload.get("status", "online"),
"active_system": payload.get("active_system"),
"available_systems": payload.get("available_systems", []),
"radio_state": "active" if payload.get("is_listening") else "idle",
"location": payload.get("location")
}
doc_ref = db.collection("nodes").document(node_id)
await async_firestore(doc_ref.set, data, merge=True)
ACTIVE_NODES_CACHE[node_id] = data
elif event_type == "status":
# Handle explicit Offline messages (LWT or clean shutdown)
print(f"Status update for {node_id}: {payload.get('status')} (Reason: {payload.get('reason', 'unknown')})")
status = payload.get("status")
data = {"status": status, "last_seen": timestamp}
# If offline, maybe clear active system?
if status == "offline":
data["radio_state"] = "unknown"
doc_ref = db.collection("nodes").document(node_id)
await async_firestore(doc_ref.set, data, merge=True)
if node_id in ACTIVE_NODES_CACHE:
ACTIVE_NODES_CACHE[node_id].update(data)
elif event_type == "metadata":
# Handle call start/end metadata events
print(f"Metadata received from {node_id}: {payload.get('event')}")
doc_data = {
"node_id": node_id,
"received_at": timestamp,
"event_type": payload.get("event"),
"node_timestamp": payload.get("timestamp"),
**payload.get("metadata", {})
}
await async_firestore(db.collection("metadata").add, doc_data)
except Exception as e:
print(f"Error processing MQTT message from {node_id}: {e}")
traceback.print_exc()
# MQTT Setup
mqtt_client = mqtt.Client(client_id=C2_ID)
mqtt_client.on_connect = on_connect
mqtt_client.on_message = on_message
async def initialize_node_states():
"""
On startup:
1. Mark all known nodes as 'unknown' until they check in.
2. Publish a discovery request to trigger immediate check-ins.
"""
print("Initializing node states...")
try:
nodes_ref = db.collection("nodes")
# Fetch all nodes (blocking call wrapped)
def get_all_nodes():
return list(nodes_ref.stream())
docs = await async_firestore(get_all_nodes)
batch = db.batch()
count = 0
for doc in docs:
doc_ref = nodes_ref.document(doc.id)
batch.update(doc_ref, {"status": "unknown"})
count += 1
# Update local cache if present
if doc.id in ACTIVE_NODES_CACHE:
ACTIVE_NODES_CACHE[doc.id]["status"] = "unknown"
if count > 0:
await async_firestore(batch.commit)
print(f"Reset {count} nodes to 'unknown' status.")
# Publish discovery request
print("Publishing discovery request...")
mqtt_client.publish("nodes/discovery/request", json.dumps({"ts": datetime.utcnow().isoformat()}), qos=1)
except Exception as e:
print(f"Error initializing nodes: {e}")
traceback.print_exc()
async def node_sweeper():
"""
Background task to check for stale nodes.
Runs every 60 seconds.
Marks nodes as 'offline' if last_seen > 90 seconds ago.
"""
print("Starting Node Sweeper...")
while True:
await asyncio.sleep(60)
print("Sweeping nodes...")
try:
nodes_ref = db.collection("nodes")
def get_all_nodes():
return list(nodes_ref.stream())
docs = await async_firestore(get_all_nodes)
batch = db.batch()
updates_count = 0
now = datetime.utcnow()
for doc in docs:
data = doc.to_dict()
node_id = doc.id
status = data.get("status")
last_seen = data.get("last_seen")
# Skip if already offline
if status == "offline":
continue
is_stale = False
if last_seen:
# Handle timezone awareness (Firestore returns aware, utcnow is naive)
if last_seen.tzinfo:
last_seen = last_seen.replace(tzinfo=None)
delta = (now - last_seen).total_seconds()
if delta > 90:
is_stale = True
else:
# No timestamp? Treat as stale if not offline
is_stale = True
if is_stale:
print(f"Node {node_id} is stale. Marking offline.")
doc_ref = nodes_ref.document(node_id)
batch.update(doc_ref, {"status": "offline", "radio_state": "unknown"})
updates_count += 1
if node_id in ACTIVE_NODES_CACHE:
ACTIVE_NODES_CACHE[node_id]["status"] = "offline"
if updates_count > 0:
await async_firestore(batch.commit)
print(f"Sweeper marked {updates_count} nodes as offline.")
except Exception as e:
print(f"Error in node sweeper: {e}")
traceback.print_exc()
@app.on_event("startup")
async def startup_event():
global MAIN_LOOP
MAIN_LOOP = asyncio.get_running_loop()
mqtt_client.connect_async(MQTT_BROKER, 1883, 60)
mqtt_client.loop_start()
# Start background tasks
asyncio.create_task(initialize_node_states())
asyncio.create_task(node_sweeper())
@app.get("/nodes")
async def get_nodes():
def get_docs():
return [
{**doc.to_dict(), "_id": doc.id}
for doc in db.collection("nodes").stream()
]
nodes = await async_firestore(get_docs)
return nodes
@app.post("/nodes/{node_id}/command")
async def send_command_to_node(node_id: str, command: NodeCommand):
if node_id not in ACTIVE_NODES_CACHE:
raise HTTPException(status_code=404, detail="Node not found or is offline")
topic = f"nodes/{node_id}/commands"
message_payload = {
"command": command.command,
**command.payload
}
mqtt_client.publish(topic, json.dumps(message_payload), qos=1)
return {"status": "command_sent", "node_id": node_id, "command": command.command}