This commit is contained in:
Logan
2026-04-06 00:22:03 -04:00
parent 2f0597c81b
commit 636a847ee1
9 changed files with 133 additions and 21 deletions
+3
View File
@@ -17,6 +17,9 @@ class Settings(BaseSettings):
# Node health
node_offline_threshold: int = 90 # seconds without checkin before marking offline
# Internal service key — allows server-side services (discord bot) to call C2 without Firebase
service_key: Optional[str] = None
class Config:
env_file = ".env"
+16
View File
@@ -2,6 +2,7 @@ from typing import Optional
from fastapi import HTTPException, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from firebase_admin import auth as firebase_auth
from app.config import settings
_bearer = HTTPBearer(auto_error=False)
@@ -18,6 +19,21 @@ async def require_firebase_token(
raise HTTPException(status_code=401, detail="Invalid or expired token")
async def require_service_or_firebase_token(
credentials: Optional[HTTPAuthorizationCredentials] = Security(_bearer),
) -> dict:
"""Accept either a Firebase ID token or the internal service key."""
if not credentials:
raise HTTPException(status_code=401, detail="Missing authorization token")
token = credentials.credentials
if settings.service_key and token == settings.service_key:
return {"service": True}
try:
return firebase_auth.verify_id_token(token)
except Exception:
raise HTTPException(status_code=401, detail="Invalid or expired token")
async def require_admin_token(
credentials: Optional[HTTPAuthorizationCredentials] = Security(_bearer),
) -> dict:
+5 -5
View File
@@ -5,7 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware
from app.internal.logger import logger
from app.internal.mqtt_handler import mqtt_handler
from app.internal.node_sweeper import sweeper_loop
from app.internal.auth import require_firebase_token
from app.internal.auth import require_firebase_token, require_service_or_firebase_token
from app.routers import nodes, systems, calls, upload, tokens
@@ -32,10 +32,10 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(nodes.router, dependencies=[Depends(require_firebase_token)])
app.include_router(systems.router, dependencies=[Depends(require_firebase_token)])
app.include_router(calls.router, dependencies=[Depends(require_firebase_token)])
app.include_router(tokens.router, dependencies=[Depends(require_firebase_token)])
app.include_router(nodes.router, dependencies=[Depends(require_service_or_firebase_token)])
app.include_router(systems.router, dependencies=[Depends(require_service_or_firebase_token)])
app.include_router(calls.router, dependencies=[Depends(require_service_or_firebase_token)])
app.include_router(tokens.router, dependencies=[Depends(require_service_or_firebase_token)])
app.include_router(upload.router) # auth is per-node, handled inline
+2 -1
View File
@@ -53,7 +53,8 @@ async def send_command(node_id: str, cmd: CommandPayload):
payload = cmd.model_dump(exclude_none=True)
if cmd.action == "discord_join":
token = await assign_token(node_id)
preferred = payload.pop("preferred_token_id", None)
token = await assign_token(node_id, preferred_token_id=preferred)
if not token:
raise HTTPException(503, "No Discord bot tokens available in the pool.")
payload["token"] = token
+8 -3
View File
@@ -56,18 +56,23 @@ async def delete_token(token_id: str):
# Internal helpers — used by the nodes router, not exposed via HTTP
# ---------------------------------------------------------------------------
async def assign_token(node_id: str) -> Optional[str]:
async def assign_token(node_id: str, preferred_token_id: Optional[str] = None) -> Optional[str]:
"""
Find a free token, mark it as in-use, return the token string.
If preferred_token_id is given, try that token first (only if it's free).
Returns None if no tokens are available.
"""
def _find_free():
def _find_free(preferred: Optional[str]):
from app.internal.firestore import db
if preferred:
snap = db.collection("bot_tokens").document(preferred).get()
if snap.exists and not snap.to_dict().get("in_use"):
return [snap]
docs = db.collection("bot_tokens").where("in_use", "==", False).limit(1).stream()
return [d for d in docs]
import asyncio
results = await asyncio.to_thread(_find_free)
results = await asyncio.to_thread(_find_free, preferred_token_id)
if not results:
return None