Files
drb-c2-core/app/c2_main.py
2025-12-28 12:04:21 -05:00

136 lines
4.6 KiB
Python

import json
import os
import asyncio
from functools import partial
import traceback
from fastapi import FastAPI, HTTPException
import paho.mqtt.client as mqtt
from datetime import datetime
import firebase_admin
from firebase_admin import credentials, firestore
from pydantic import BaseModel
from typing import Any, Dict
app = FastAPI(title="Radio C2 Brain")
# Configuration
MQTT_BROKER = os.getenv("MQTT_BROKER", "mqtt-broker")
FIREBASE_CRED_JSON = os.getenv("FIREBASE_CRED_JSON")
FIRESTORE_DB_ID = os.getenv("FIRESTORE_DB_ID", "c2-server")
C2_ID = "central-brain-01"
# Database Init
if FIREBASE_CRED_JSON:
print("Initializing Firebase with provided JSON credentials...")
cred = credentials.Certificate(json.loads(FIREBASE_CRED_JSON))
firebase_admin.initialize_app(cred)
else:
print("Initializing Firebase with Application Default Credentials...")
firebase_admin.initialize_app()
print(f"Connecting to Firestore Database: {FIRESTORE_DB_ID}")
db = firestore.client(database_id=FIRESTORE_DB_ID)
# Local cache for quick lookups
ACTIVE_NODES_CACHE = {}
MAIN_LOOP = None
# Pydantic Models
class NodeCommand(BaseModel):
command: str
payload: Dict[str, Any]
# Helper for async execution of blocking firestore calls
async def async_firestore(func, *args, **kwargs):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, partial(func, *args, **kwargs))
def on_connect(client, userdata, flags, rc):
print(f"Brain connected to MQTT Broker with result code {rc}")
client.subscribe("nodes/+/checkin")
client.subscribe("nodes/+/status")
def on_message(client, userdata, msg):
if MAIN_LOOP:
asyncio.run_coroutine_threadsafe(handle_message(msg), MAIN_LOOP)
async def handle_message(msg):
topic_parts = msg.topic.split('/')
if len(topic_parts) < 3: return
node_id = topic_parts[1]
event_type = topic_parts[2]
try:
payload = json.loads(msg.payload.decode())
timestamp = datetime.utcnow()
if event_type == "checkin":
print(f"Processing checkin for {node_id}...")
data = {
"node_id": node_id,
"last_seen": timestamp,
"status": payload.get("status", "online"),
"active_system": payload.get("active_system"),
"available_systems": payload.get("available_systems", []),
"config": payload,
"radio_state": "active" if payload.get("is_listening") else "idle"
}
doc_ref = db.collection("nodes").document(node_id)
print(f"Writing to Firestore: {doc_ref.path} in DB {FIRESTORE_DB_ID}")
await async_firestore(doc_ref.set, data, merge=True)
print(f"Successfully updated checkin for {node_id}")
ACTIVE_NODES_CACHE[node_id] = data
elif event_type == "status":
print(f"Processing status update for {node_id}...")
status = payload.get("status")
doc_ref = db.collection("nodes").document(node_id)
print(f"Writing to Firestore: {doc_ref.path} in DB {FIRESTORE_DB_ID}")
await async_firestore(doc_ref.set, {"status": status, "last_seen": timestamp}, merge=True)
print(f"Successfully updated status for {node_id}")
if node_id in ACTIVE_NODES_CACHE:
ACTIVE_NODES_CACHE[node_id]["status"] = status
except Exception as e:
print(f"Error processing MQTT message from {node_id}: {e}")
traceback.print_exc()
# MQTT Setup
mqtt_client = mqtt.Client(client_id=C2_ID)
mqtt_client.on_connect = on_connect
mqtt_client.on_message = on_message
@app.on_event("startup")
async def startup_event():
global MAIN_LOOP
MAIN_LOOP = asyncio.get_running_loop()
mqtt_client.connect_async(MQTT_BROKER, 1883, 60)
mqtt_client.loop_start()
@app.get("/nodes")
async def get_nodes():
def get_docs():
return [
{**doc.to_dict(), "_id": doc.id}
for doc in db.collection("nodes").stream()
]
nodes = await async_firestore(get_docs)
return nodes
@app.post("/nodes/{node_id}/command")
async def send_command_to_node(node_id: str, command: NodeCommand):
if node_id not in ACTIVE_NODES_CACHE:
raise HTTPException(status_code=404, detail="Node not found or is offline")
topic = f"nodes/{node_id}/commands"
message_payload = {
"command": command.command,
**command.payload
}
mqtt_client.publish(topic, json.dumps(message_payload), qos=1)
return {"status": "command_sent", "node_id": node_id, "command": command.command}