259 lines
8.7 KiB
Python
259 lines
8.7 KiB
Python
import json
|
|
import os
|
|
import asyncio
|
|
from functools import partial
|
|
import traceback
|
|
from fastapi import FastAPI, HTTPException
|
|
import paho.mqtt.client as mqtt
|
|
from datetime import datetime
|
|
import firebase_admin
|
|
from firebase_admin import credentials, firestore
|
|
from pydantic import BaseModel
|
|
from typing import Any, Dict
|
|
|
|
app = FastAPI(title="Radio C2 Brain")
|
|
|
|
# Configuration
|
|
MQTT_BROKER = os.getenv("MQTT_BROKER", "mqtt-broker")
|
|
FIREBASE_CRED_JSON = os.getenv("FIREBASE_CRED_JSON")
|
|
FIRESTORE_DB_ID = os.getenv("FIRESTORE_DB_ID", "c2-server")
|
|
C2_ID = "central-brain-01"
|
|
|
|
# Database Init
|
|
if FIREBASE_CRED_JSON:
|
|
print("Initializing Firebase with provided JSON credentials...")
|
|
cred = credentials.Certificate(json.loads(FIREBASE_CRED_JSON))
|
|
firebase_admin.initialize_app(cred)
|
|
else:
|
|
print("Initializing Firebase with Application Default Credentials...")
|
|
firebase_admin.initialize_app()
|
|
|
|
print(f"Connecting to Firestore Database: {FIRESTORE_DB_ID}")
|
|
db = firestore.client(database_id=FIRESTORE_DB_ID)
|
|
|
|
# Local cache for quick lookups
|
|
ACTIVE_NODES_CACHE = {}
|
|
MAIN_LOOP = None
|
|
|
|
# Pydantic Models
|
|
class NodeCommand(BaseModel):
|
|
command: str
|
|
payload: Dict[str, Any]
|
|
|
|
# Helper for async execution of blocking firestore calls
|
|
async def async_firestore(func, *args, **kwargs):
|
|
loop = asyncio.get_running_loop()
|
|
return await loop.run_in_executor(None, partial(func, *args, **kwargs))
|
|
|
|
def on_connect(client, userdata, flags, rc):
|
|
print(f"Brain connected to MQTT Broker with result code {rc}")
|
|
client.subscribe("nodes/+/checkin")
|
|
client.subscribe("nodes/+/status")
|
|
|
|
def on_message(client, userdata, msg):
|
|
if MAIN_LOOP:
|
|
asyncio.run_coroutine_threadsafe(handle_message(msg), MAIN_LOOP)
|
|
|
|
async def update_last_seen(node_id):
|
|
"""Generic helper to update the timestamp on any contact."""
|
|
try:
|
|
doc_ref = db.collection("nodes").document(node_id)
|
|
await async_firestore(doc_ref.set, {"last_seen": datetime.utcnow()}, merge=True)
|
|
except Exception as e:
|
|
print(f"Failed to update heartbeat for {node_id}: {e}")
|
|
|
|
async def handle_message(msg):
|
|
topic_parts = msg.topic.split('/')
|
|
if len(topic_parts) < 3: return
|
|
|
|
node_id = topic_parts[1]
|
|
event_type = topic_parts[2]
|
|
|
|
try:
|
|
payload = json.loads(msg.payload.decode())
|
|
timestamp = datetime.utcnow()
|
|
|
|
# 1. ALWAYS update last_seen if we hear from the node
|
|
await update_last_seen(node_id)
|
|
|
|
if event_type == "checkin":
|
|
# This now receives the periodic heartbeat
|
|
print(f"Heartbeat/Checkin from {node_id}")
|
|
data = {
|
|
"node_id": node_id,
|
|
"last_seen": timestamp,
|
|
"status": payload.get("status", "online"),
|
|
"active_system": payload.get("active_system"),
|
|
"available_systems": payload.get("available_systems", []),
|
|
"radio_state": "active" if payload.get("is_listening") else "idle"
|
|
}
|
|
doc_ref = db.collection("nodes").document(node_id)
|
|
await async_firestore(doc_ref.set, data, merge=True)
|
|
ACTIVE_NODES_CACHE[node_id] = data
|
|
|
|
elif event_type == "status":
|
|
# Handle explicit Offline messages (LWT or clean shutdown)
|
|
print(f"Status update for {node_id}: {payload.get('status')} (Reason: {payload.get('reason', 'unknown')})")
|
|
status = payload.get("status")
|
|
|
|
data = {"status": status, "last_seen": timestamp}
|
|
|
|
# If offline, maybe clear active system?
|
|
if status == "offline":
|
|
data["radio_state"] = "unknown"
|
|
|
|
doc_ref = db.collection("nodes").document(node_id)
|
|
await async_firestore(doc_ref.set, data, merge=True)
|
|
|
|
if node_id in ACTIVE_NODES_CACHE:
|
|
ACTIVE_NODES_CACHE[node_id].update(data)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing MQTT message from {node_id}: {e}")
|
|
traceback.print_exc()
|
|
|
|
|
|
# MQTT Setup
|
|
mqtt_client = mqtt.Client(client_id=C2_ID)
|
|
mqtt_client.on_connect = on_connect
|
|
mqtt_client.on_message = on_message
|
|
|
|
async def initialize_node_states():
|
|
"""
|
|
On startup:
|
|
1. Mark all known nodes as 'unknown' until they check in.
|
|
2. Publish a discovery request to trigger immediate check-ins.
|
|
"""
|
|
print("Initializing node states...")
|
|
try:
|
|
nodes_ref = db.collection("nodes")
|
|
|
|
# Fetch all nodes (blocking call wrapped)
|
|
def get_all_nodes():
|
|
return list(nodes_ref.stream())
|
|
|
|
docs = await async_firestore(get_all_nodes)
|
|
|
|
batch = db.batch()
|
|
count = 0
|
|
|
|
for doc in docs:
|
|
doc_ref = nodes_ref.document(doc.id)
|
|
batch.update(doc_ref, {"status": "unknown"})
|
|
count += 1
|
|
|
|
# Update local cache if present
|
|
if doc.id in ACTIVE_NODES_CACHE:
|
|
ACTIVE_NODES_CACHE[doc.id]["status"] = "unknown"
|
|
|
|
if count > 0:
|
|
await async_firestore(batch.commit)
|
|
print(f"Reset {count} nodes to 'unknown' status.")
|
|
|
|
# Publish discovery request
|
|
print("Publishing discovery request...")
|
|
mqtt_client.publish("nodes/discovery/request", json.dumps({"ts": datetime.utcnow().isoformat()}), qos=1)
|
|
|
|
except Exception as e:
|
|
print(f"Error initializing nodes: {e}")
|
|
traceback.print_exc()
|
|
|
|
async def node_sweeper():
|
|
"""
|
|
Background task to check for stale nodes.
|
|
Runs every 60 seconds.
|
|
Marks nodes as 'offline' if last_seen > 90 seconds ago.
|
|
"""
|
|
print("Starting Node Sweeper...")
|
|
while True:
|
|
await asyncio.sleep(60)
|
|
print("Sweeping nodes...")
|
|
try:
|
|
nodes_ref = db.collection("nodes")
|
|
|
|
def get_all_nodes():
|
|
return list(nodes_ref.stream())
|
|
|
|
docs = await async_firestore(get_all_nodes)
|
|
|
|
batch = db.batch()
|
|
updates_count = 0
|
|
now = datetime.utcnow()
|
|
|
|
for doc in docs:
|
|
data = doc.to_dict()
|
|
node_id = doc.id
|
|
status = data.get("status")
|
|
last_seen = data.get("last_seen")
|
|
|
|
# Skip if already offline
|
|
if status == "offline":
|
|
continue
|
|
|
|
is_stale = False
|
|
if last_seen:
|
|
# Handle timezone awareness (Firestore returns aware, utcnow is naive)
|
|
if last_seen.tzinfo:
|
|
last_seen = last_seen.replace(tzinfo=None)
|
|
|
|
delta = (now - last_seen).total_seconds()
|
|
if delta > 90:
|
|
is_stale = True
|
|
else:
|
|
# No timestamp? Treat as stale if not offline
|
|
is_stale = True
|
|
|
|
if is_stale:
|
|
print(f"Node {node_id} is stale. Marking offline.")
|
|
doc_ref = nodes_ref.document(node_id)
|
|
batch.update(doc_ref, {"status": "offline", "radio_state": "unknown"})
|
|
updates_count += 1
|
|
|
|
if node_id in ACTIVE_NODES_CACHE:
|
|
ACTIVE_NODES_CACHE[node_id]["status"] = "offline"
|
|
|
|
if updates_count > 0:
|
|
await async_firestore(batch.commit)
|
|
print(f"Sweeper marked {updates_count} nodes as offline.")
|
|
|
|
except Exception as e:
|
|
print(f"Error in node sweeper: {e}")
|
|
traceback.print_exc()
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
global MAIN_LOOP
|
|
MAIN_LOOP = asyncio.get_running_loop()
|
|
mqtt_client.connect_async(MQTT_BROKER, 1883, 60)
|
|
mqtt_client.loop_start()
|
|
|
|
# Start background tasks
|
|
asyncio.create_task(initialize_node_states())
|
|
asyncio.create_task(node_sweeper())
|
|
|
|
@app.get("/nodes")
|
|
async def get_nodes():
|
|
def get_docs():
|
|
return [
|
|
{**doc.to_dict(), "_id": doc.id}
|
|
for doc in db.collection("nodes").stream()
|
|
]
|
|
|
|
nodes = await async_firestore(get_docs)
|
|
return nodes
|
|
|
|
@app.post("/nodes/{node_id}/command")
|
|
async def send_command_to_node(node_id: str, command: NodeCommand):
|
|
if node_id not in ACTIVE_NODES_CACHE:
|
|
raise HTTPException(status_code=404, detail="Node not found or is offline")
|
|
|
|
topic = f"nodes/{node_id}/commands"
|
|
|
|
message_payload = {
|
|
"command": command.command,
|
|
**command.payload
|
|
}
|
|
|
|
mqtt_client.publish(topic, json.dumps(message_payload), qos=1)
|
|
|
|
return {"status": "command_sent", "node_id": node_id, "command": command.command} |