import json import os import asyncio from functools import partial import traceback from fastapi import FastAPI, HTTPException import paho.mqtt.client as mqtt from datetime import datetime import firebase_admin from firebase_admin import credentials, firestore from pydantic import BaseModel from typing import Any, Dict app = FastAPI(title="Radio C2 Brain") # Configuration MQTT_BROKER = os.getenv("MQTT_BROKER", "mqtt-broker") FIREBASE_CRED_JSON = os.getenv("FIREBASE_CRED_JSON") FIRESTORE_DB_ID = os.getenv("FIRESTORE_DB_ID", "c2-server") C2_ID = "central-brain-01" # Database Init if FIREBASE_CRED_JSON: print("Initializing Firebase with provided JSON credentials...") cred = credentials.Certificate(json.loads(FIREBASE_CRED_JSON)) firebase_admin.initialize_app(cred) else: print("Initializing Firebase with Application Default Credentials...") firebase_admin.initialize_app() print(f"Connecting to Firestore Database: {FIRESTORE_DB_ID}") db = firestore.client(database_id=FIRESTORE_DB_ID) # Local cache for quick lookups ACTIVE_NODES_CACHE = {} MAIN_LOOP = None # Pydantic Models class NodeCommand(BaseModel): command: str payload: Dict[str, Any] # Helper for async execution of blocking firestore calls async def async_firestore(func, *args, **kwargs): loop = asyncio.get_running_loop() return await loop.run_in_executor(None, partial(func, *args, **kwargs)) def on_connect(client, userdata, flags, rc): print(f"Brain connected to MQTT Broker with result code {rc}") client.subscribe("nodes/+/checkin") client.subscribe("nodes/+/status") client.subscribe("nodes/+/metadata") def on_message(client, userdata, msg): if MAIN_LOOP: asyncio.run_coroutine_threadsafe(handle_message(msg), MAIN_LOOP) async def update_last_seen(node_id): """Generic helper to update the timestamp on any contact.""" try: doc_ref = db.collection("nodes").document(node_id) await async_firestore(doc_ref.set, {"last_seen": datetime.utcnow()}, merge=True) except Exception as e: print(f"Failed to update heartbeat for {node_id}: {e}") async def handle_message(msg): topic_parts = msg.topic.split('/') if len(topic_parts) < 3: return node_id = topic_parts[1] event_type = topic_parts[2] try: payload = json.loads(msg.payload.decode()) timestamp = datetime.utcnow() # 1. ALWAYS update last_seen if we hear from the node await update_last_seen(node_id) if event_type == "checkin": # This now receives the periodic heartbeat print(f"Heartbeat/Checkin from {node_id}") data = { "node_id": node_id, "last_seen": timestamp, "status": payload.get("status", "online"), "active_system": payload.get("active_system"), "available_systems": payload.get("available_systems", []), "radio_state": "active" if payload.get("is_listening") else "idle", "location": payload.get("location") } doc_ref = db.collection("nodes").document(node_id) await async_firestore(doc_ref.set, data, merge=True) ACTIVE_NODES_CACHE[node_id] = data elif event_type == "status": # Handle explicit Offline messages (LWT or clean shutdown) print(f"Status update for {node_id}: {payload.get('status')} (Reason: {payload.get('reason', 'unknown')})") status = payload.get("status") data = {"status": status, "last_seen": timestamp} # If offline, maybe clear active system? if status == "offline": data["radio_state"] = "unknown" doc_ref = db.collection("nodes").document(node_id) await async_firestore(doc_ref.set, data, merge=True) if node_id in ACTIVE_NODES_CACHE: ACTIVE_NODES_CACHE[node_id].update(data) elif event_type == "metadata": # Handle call start/end metadata events print(f"Metadata received from {node_id}: {payload.get('event')}") doc_data = { "node_id": node_id, "received_at": timestamp, "event_type": payload.get("event"), "node_timestamp": payload.get("timestamp"), **payload.get("metadata", {}) } await async_firestore(db.collection("metadata").add, doc_data) except Exception as e: print(f"Error processing MQTT message from {node_id}: {e}") traceback.print_exc() # MQTT Setup mqtt_client = mqtt.Client(client_id=C2_ID) mqtt_client.on_connect = on_connect mqtt_client.on_message = on_message async def initialize_node_states(): """ On startup: 1. Mark all known nodes as 'unknown' until they check in. 2. Publish a discovery request to trigger immediate check-ins. """ print("Initializing node states...") try: nodes_ref = db.collection("nodes") # Fetch all nodes (blocking call wrapped) def get_all_nodes(): return list(nodes_ref.stream()) docs = await async_firestore(get_all_nodes) batch = db.batch() count = 0 for doc in docs: doc_ref = nodes_ref.document(doc.id) batch.update(doc_ref, {"status": "unknown"}) count += 1 # Update local cache if present if doc.id in ACTIVE_NODES_CACHE: ACTIVE_NODES_CACHE[doc.id]["status"] = "unknown" if count > 0: await async_firestore(batch.commit) print(f"Reset {count} nodes to 'unknown' status.") # Publish discovery request print("Publishing discovery request...") mqtt_client.publish("nodes/discovery/request", json.dumps({"ts": datetime.utcnow().isoformat()}), qos=1) except Exception as e: print(f"Error initializing nodes: {e}") traceback.print_exc() async def node_sweeper(): """ Background task to check for stale nodes. Runs every 60 seconds. Marks nodes as 'offline' if last_seen > 90 seconds ago. """ print("Starting Node Sweeper...") while True: await asyncio.sleep(60) print("Sweeping nodes...") try: nodes_ref = db.collection("nodes") def get_all_nodes(): return list(nodes_ref.stream()) docs = await async_firestore(get_all_nodes) batch = db.batch() updates_count = 0 now = datetime.utcnow() for doc in docs: data = doc.to_dict() node_id = doc.id status = data.get("status") last_seen = data.get("last_seen") # Skip if already offline if status == "offline": continue is_stale = False if last_seen: # Handle timezone awareness (Firestore returns aware, utcnow is naive) if last_seen.tzinfo: last_seen = last_seen.replace(tzinfo=None) delta = (now - last_seen).total_seconds() if delta > 90: is_stale = True else: # No timestamp? Treat as stale if not offline is_stale = True if is_stale: print(f"Node {node_id} is stale. Marking offline.") doc_ref = nodes_ref.document(node_id) batch.update(doc_ref, {"status": "offline", "radio_state": "unknown"}) updates_count += 1 if node_id in ACTIVE_NODES_CACHE: ACTIVE_NODES_CACHE[node_id]["status"] = "offline" if updates_count > 0: await async_firestore(batch.commit) print(f"Sweeper marked {updates_count} nodes as offline.") except Exception as e: print(f"Error in node sweeper: {e}") traceback.print_exc() @app.on_event("startup") async def startup_event(): global MAIN_LOOP MAIN_LOOP = asyncio.get_running_loop() mqtt_client.connect_async(MQTT_BROKER, 1883, 60) mqtt_client.loop_start() # Start background tasks asyncio.create_task(initialize_node_states()) asyncio.create_task(node_sweeper()) @app.get("/nodes") async def get_nodes(): def get_docs(): return [ {**doc.to_dict(), "_id": doc.id} for doc in db.collection("nodes").stream() ] nodes = await async_firestore(get_docs) return nodes @app.post("/nodes/{node_id}/command") async def send_command_to_node(node_id: str, command: NodeCommand): if node_id not in ACTIVE_NODES_CACHE: raise HTTPException(status_code=404, detail="Node not found or is offline") topic = f"nodes/{node_id}/commands" message_payload = { "command": command.command, **command.payload } mqtt_client.publish(topic, json.dumps(message_payload), qos=1) return {"status": "command_sent", "node_id": node_id, "command": command.command}