diff --git a/.env.example b/.env.example index 18087e3..2d8d58a 100644 --- a/.env.example +++ b/.env.example @@ -7,8 +7,9 @@ NODE_LON=0.0 # MQTT — point to your C2 server MQTT_BROKER=localhost MQTT_PORT=1883 -MQTT_USER= -MQTT_PASS= +# Must match MQTT_NODE_USER/MQTT_NODE_PASS in the server's top-level .env +MQTT_USER=drb-node +MQTT_PASS=change-me-node # C2 server for audio upload (leave blank to disable upload) C2_URL=http://localhost:8888 diff --git a/drb-edge-node/app/internal/discord_radio.py b/drb-edge-node/app/internal/discord_radio.py index ab13a41..f4cd9eb 100644 --- a/drb-edge-node/app/internal/discord_radio.py +++ b/drb-edge-node/app/internal/discord_radio.py @@ -5,7 +5,9 @@ from discord.ext import commands from app.config import settings from app.internal.logger import logger -BOT_READY_TIMEOUT = 15 # seconds to wait for Discord bot to become ready +BOT_READY_TIMEOUT = 15 # seconds to wait for Discord bot to become ready +WATCHDOG_INTERVAL = 30 # seconds between voice-connection health checks +REJOIN_DELAY = 5 # seconds to wait before attempting a rejoin class RadioBot: @@ -13,12 +15,18 @@ class RadioBot: self._bot: Optional[commands.Bot] = None self._voice_client: Optional[discord.VoiceClient] = None self._task: Optional[asyncio.Task] = None + self._watchdog_task: Optional[asyncio.Task] = None self._ready_event: Optional[asyncio.Event] = None self._current_token: Optional[str] = None self._icecast_url = ( f"http://{settings.icecast_host}:{settings.icecast_port}{settings.icecast_mount}" ) + # Remembered so we can rejoin after an unexpected disconnect + self._guild_id: Optional[int] = None + self._channel_id: Optional[int] = None + self._was_streaming: bool = False + async def join(self, guild_id: int, channel_id: int, token: str, call_active: bool = False) -> bool: # (Re)start the bot if the token changed or the bot isn't running if self._current_token != token or not self._is_bot_running(): @@ -39,7 +47,9 @@ class RadioBot: if self._voice_client and self._voice_client.is_connected(): await self._voice_client.disconnect(force=True) self._voice_client = await channel.connect() - # Only start playing immediately if a call is currently active + # Remember where we are so the watchdog can rejoin if we drop + self._guild_id = guild_id + self._channel_id = channel_id if call_active: self._play_stream() logger.info(f"Joined #{channel.name} in {guild.name} (streaming={'yes' if call_active else 'waiting for call'})") @@ -49,6 +59,11 @@ class RadioBot: return False async def leave(self) -> bool: + # Clear remembered channel so the watchdog doesn't rejoin + self._guild_id = None + self._channel_id = None + self._was_streaming = False + if self._voice_client and self._voice_client.is_connected(): try: self._stop_stream() @@ -62,6 +77,7 @@ class RadioBot: def start_stream(self): """Called when an OP25 call starts — begin transmitting audio and light the ring.""" + self._was_streaming = True if self._voice_client and self._voice_client.is_connected(): if not self._voice_client.is_playing(): self._play_stream() @@ -69,12 +85,22 @@ class RadioBot: def stop_stream(self): """Called when an OP25 call ends — stop transmitting so the ring goes dark.""" + self._was_streaming = False if self._voice_client and self._voice_client.is_connected(): self._stop_stream() logger.debug("Stream stopped (call ended).") async def stop(self): + self._guild_id = None + self._channel_id = None + self._was_streaming = False + + if self._watchdog_task: + self._watchdog_task.cancel() + self._watchdog_task = None + await self.leave() + if self._task: self._task.cancel() if self._bot: @@ -100,6 +126,37 @@ class RadioBot: if self._voice_client and self._voice_client.is_playing(): self._voice_client.stop() + async def _watchdog_loop(self): + """Periodically verify the voice connection is alive; rejoin silently if not.""" + await asyncio.sleep(WATCHDOG_INTERVAL) # give initial join time to settle + while True: + try: + await asyncio.sleep(WATCHDOG_INTERVAL) + + # Only act if we're supposed to be in a channel + if not self._guild_id or not self._channel_id: + continue + + connected = self._voice_client is not None and self._voice_client.is_connected() + if not connected: + logger.warning("Watchdog: voice connection lost — attempting rejoin.") + await asyncio.sleep(REJOIN_DELAY) + rejoined = await self.join( + self._guild_id, + self._channel_id, + self._current_token, + call_active=self._was_streaming, + ) + if rejoined: + logger.info("Watchdog: successfully rejoined voice channel.") + else: + logger.error("Watchdog: rejoin failed — will retry next cycle.") + + except asyncio.CancelledError: + return + except Exception as e: + logger.error(f"Watchdog error: {e}") + async def _start_bot(self, token: str) -> bool: await self.stop() # clean up any previous instance @@ -116,6 +173,21 @@ class RadioBot: logger.info(f"Discord bot ready: {self._bot.user} ({self._bot.user.id})") self._ready_event.set() + @self._bot.event + async def on_voice_state_update( + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState, + ): + """Detect when our own bot gets disconnected from a voice channel.""" + if self._bot.user and member.id != self._bot.user.id: + return + if before.channel is not None and after.channel is None: + # Bot was disconnected (kicked or server drop) + logger.warning("Bot was disconnected from voice channel — watchdog will rejoin.") + # Nullify the voice client so the watchdog sees it as disconnected + self._voice_client = None + @self._bot.event async def on_message(message: discord.Message): if message.author.bot: @@ -143,6 +215,7 @@ class RadioBot: await self._voice_client.move_to(vc) else: self._voice_client = await vc.connect() + self._channel_id = vc.id await message.reply(f"Joined {vc.name}.") except Exception as e: logger.error(f"joinme failed: {e}") @@ -151,6 +224,7 @@ class RadioBot: try: await asyncio.wait_for(self._ready_event.wait(), timeout=BOT_READY_TIMEOUT) + self._watchdog_task = asyncio.create_task(self._watchdog_loop()) return True except asyncio.TimeoutError: logger.error("Timed out waiting for Discord bot to become ready.") diff --git a/drb-edge-node/app/internal/mqtt_manager.py b/drb-edge-node/app/internal/mqtt_manager.py index 2d4b03f..a3d0a24 100644 --- a/drb-edge-node/app/internal/mqtt_manager.py +++ b/drb-edge-node/app/internal/mqtt_manager.py @@ -5,6 +5,7 @@ from typing import Optional, Callable, Awaitable, Dict, Any import paho.mqtt.client as mqtt from app.config import settings from app.internal.logger import logger +from app.internal import credentials CommandCallback = Callable[[Dict[str, Any]], Awaitable[None]] ConfigCallback = Callable[[Dict[str, Any]], Awaitable[None]] @@ -26,10 +27,11 @@ class MQTTManager: self._t_checkin = f"nodes/{nid}/checkin" self._t_status = f"nodes/{nid}/status" self._t_metadata = f"nodes/{nid}/metadata" - self._t_commands = f"nodes/{nid}/commands" - self._t_config = f"nodes/{nid}/config" - self._t_api_key = f"nodes/{nid}/api_key" - self._t_discovery = "nodes/discovery/request" + self._t_commands = f"nodes/{nid}/commands" + self._t_config = f"nodes/{nid}/config" + self._t_api_key = f"nodes/{nid}/api_key" + self._t_key_request = f"nodes/{nid}/key_request" + self._t_discovery = "nodes/discovery/request" def _build_client(self) -> mqtt.Client: client = mqtt.Client( @@ -61,6 +63,7 @@ class MQTTManager: client.subscribe(self._t_discovery, qos=0) logger.info("MQTT connected.") asyncio.run_coroutine_threadsafe(self._publish_checkin(), self._loop) + asyncio.run_coroutine_threadsafe(self._maybe_request_key(), self._loop) else: logger.error(f"MQTT connect refused: {reason_code}") @@ -129,6 +132,14 @@ class MQTTManager: } self._publish(self._t_metadata, payload, qos=1) + async def _maybe_request_key(self): + """After connecting, wait for any retained api_key message to arrive. + If no key materialises within 5 seconds, ask the server to re-deliver it.""" + await asyncio.sleep(5) + if not credentials.get_api_key(): + logger.info("No API key on disk — requesting re-delivery from C2 server.") + self._publish(self._t_key_request, {}, qos=1) + async def _publish_checkin(self): payload = { "node_id": settings.node_id,