From af10851002523aebc73b627b44cdd1815a3b57e7 Mon Sep 17 00:00:00 2001 From: Logan Cusano Date: Mon, 26 May 2025 02:36:21 -0400 Subject: [PATCH] Refactored DB handlers --- app/internal/auth_wrappers.py | 33 +++-- app/internal/db_handler.py | 249 +++++++++++----------------------- app/internal/db_wrappers.py | 209 ++++++---------------------- app/routers/auth.py | 2 + app/server.py | 64 ++++----- 5 files changed, 175 insertions(+), 382 deletions(-) diff --git a/app/internal/auth_wrappers.py b/app/internal/auth_wrappers.py index db36d20..468fbf5 100644 --- a/app/internal/auth_wrappers.py +++ b/app/internal/auth_wrappers.py @@ -3,7 +3,7 @@ import os import asyncio from uuid import uuid4 from typing import Optional, List, Dict, Any -from internal.db_handler import MongoHandler +from internal.db_handler import MongoHandler # from internal.types import User, UserRoles DB_NAME = os.getenv("DB_NAME", "default_db") @@ -13,7 +13,12 @@ USER_DB_COLLECTION_NAME = "users" class UserDbController: def __init__(self): - self.db_h = MongoHandler(DB_NAME, USER_DB_COLLECTION_NAME, MONGO_URL) + self.db_h = MongoHandler(DB_NAME, USER_DB_COLLECTION_NAME, MONGO_URL) # + + async def close_db_connection(self): + """Closes the underlying MongoDB connection.""" + if self.db_h: + await self.db_h.close_client() # async def create_user(self, user_data: Dict[str, Any]) -> Optional[User]: try: @@ -21,15 +26,15 @@ class UserDbController: user_data['_id'] = str(uuid4()) inserted_id = None - async with self.db_h as db: - insert_result = await db.insert_one(user_data) + async with self.db_h as db: # + insert_result = await db.insert_one(user_data) # inserted_id = insert_result.inserted_id if inserted_id: query = {"_id": inserted_id} inserted_doc = None - async with self.db_h as db: - inserted_doc = await db.find_one(query) + async with self.db_h as db: # + inserted_doc = await db.find_one(query) # if inserted_doc: return User.from_dict(inserted_doc) return None @@ -40,20 +45,22 @@ class UserDbController: async def find_user(self, query: Dict[str, Any]) -> Optional[User]: try: found_doc = None - async with self.db_h as db: - found_doc = await db.find_one(query) + async with self.db_h as db: # + # The 'db' here is self.db_h, which is an instance of MongoHandler. + # MongoHandler.find_one will be called. + found_doc = await db.find_one(query) # if found_doc: return User.from_dict(found_doc) return None except Exception as e: - print(f"Find user failed: {e}") + print(f"Find user failed: {e}") # This error should be less frequent or indicate actual DB issues now return None async def update_user(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]: try: update_result = None - async with self.db_h as db: - update_result = await db.update_one(query, update_data) + async with self.db_h as db: # + update_result = await db.update_one(query, update_data) # return update_result.modified_count except Exception as e: print(f"Update user failed: {e}") @@ -62,8 +69,8 @@ class UserDbController: async def delete_user(self, query: Dict[str, Any]) -> Optional[int]: try: delete_result = None - async with self.db_h as db: - delete_result = await db.delete_one(query) + async with self.db_h as db: # + delete_result = await db.delete_one(query) # return delete_result.deleted_count except Exception as e: print(f"Delete user failed: {e}") diff --git a/app/internal/db_handler.py b/app/internal/db_handler.py index f89878e..eb54f9d 100644 --- a/app/internal/db_handler.py +++ b/app/internal/db_handler.py @@ -4,8 +4,9 @@ from typing import Optional, Dict, Any, List class MongoHandler: """ - A basic asynchronous handler for MongoDB operations using motor. - Designed to be used with asyncio. + An asynchronous handler for MongoDB operations using motor. + The client connection is established on first use (via async with or direct call to connect()) + and should be explicitly closed by calling close_client() at application shutdown. """ def __init__(self, db_name: str, collection_name: str, mongo_uri: str = "mongodb://localhost:27017/"): """ @@ -15,7 +16,6 @@ class MongoHandler: db_name (str): The name of the database to connect to. collection_name (str): The name of the collection to use. mongo_uri (str): The MongoDB connection string URI. - Defaults to the standard local URI. """ self.mongo_uri = mongo_uri self.db_name = db_name @@ -23,229 +23,142 @@ class MongoHandler: self._client: Optional[motor.motor_asyncio.AsyncIOMotorClient] = None self._db: Optional[motor.motor_asyncio.AsyncIOMotorDatabase] = None self._collection: Optional[motor.motor_asyncio.AsyncIOMotorCollection] = None + self._lock = asyncio.Lock() # Lock for serializing client creation async def connect(self): - """Establishes an asynchronous connection to MongoDB.""" + """ + Establishes an asynchronous connection to MongoDB if not already established. + This method is idempotent. + """ if self._client is None: - try: - self._client = motor.motor_asyncio.AsyncIOMotorClient(self.mongo_uri) - # The ismaster command is cheap and does not require auth. - # It is used to confirm that the client can connect to the deployment. - await self._client.admin.command('ismaster') + async with self._lock: # Ensure only one coroutine attempts to initialize the client + if self._client is None: # Double-check after acquiring lock + try: + print(f"Initializing MongoDB client for: DB '{self.db_name}', Collection '{self.collection_name}' URI: {self.mongo_uri.split('@')[-1]}") # Avoid logging credentials + self._client = motor.motor_asyncio.AsyncIOMotorClient(self.mongo_uri) + # The ismaster command is cheap and does not require auth. + await self._client.admin.command('ismaster') + self._db = self._client[self.db_name] + self._collection = self._db[self.collection_name] + print(f"MongoDB client initialized and connected: Database '{self.db_name}', Collection '{self.collection_name}'") + except Exception as e: + print(f"Failed to initialize MongoDB client at {self.mongo_uri.split('@')[-1]} for {self.db_name}/{self.collection_name}: {e}") + self._client = None # Ensure client is None if connection fails + self._db = None + self._collection = None + raise # Re-raise the exception after printing + + if self._collection is None and self._client is not None: + # This can happen if connect was called, client was set, but then an error occurred before collection was set + # Or if connect logic needs to re-establish db/collection objects without re-creating client (less common with motor) + if self._db is None: self._db = self._client[self.db_name] - self._collection = self._db[self.collection_name] - print(f"Connected to MongoDB: Database '{self.db_name}', Collection '{self.collection_name}'") - except Exception as e: - print(f"Failed to connect to MongoDB at {self.mongo_uri}: {e}") - self._client = None # Ensure client is None if connection fails - raise # Re-raise the exception after printing + self._collection = self._db[self.collection_name] + if self._collection is None: + raise RuntimeError(f"MongoDB collection '{self.collection_name}' could not be established even though client exists.") - async def close(self): - """Closes the MongoDB connection.""" - if self._client: - self._client.close() - self._client = None - self._db = None - self._collection = None - print("MongoDB connection closed.") + + async def close_client(self): + """Closes the MongoDB client connection. Should be called on application shutdown.""" + async with self._lock: + if self._client: + print(f"Closing MongoDB client for: Database '{self.db_name}', Collection '{self.collection_name}'") + self._client.close() + self._client = None + self._db = None + self._collection = None + print(f"MongoDB client closed for: Database '{self.db_name}', Collection '{self.collection_name}'.") async def __aenter__(self): - """Allows using the handler with async with.""" + """Allows using the handler with async with. Ensures connection is active.""" await self.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - """Ensures the connection is closed when exiting async with.""" - await self.close() + """Ensures the connection is NOT closed when exiting async with.""" + # The connection is managed and closed at the application level by calling close_client() + pass async def insert_one(self, document: Dict[str, Any]) -> Any: - """ - Inserts a single document into the collection. - - Args: - document (Dict[str, Any]): The document to insert. - - Returns: - Any: The result of the insert operation (InsertOneResult). - """ if self._collection is None: - raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.") + await self.connect() # Try to connect if collection is None + if self._collection is None: # Check again after attempting to connect + raise RuntimeError("MongoDB collection not available. Call connect() or use async with.") print(f"Inserting document into '{self.collection_name}'...") result = await self._collection.insert_one(document) print(f"Inserted document with ID: {result.inserted_id}") return result async def find_one(self, query: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """ - Finds a single document matching the query. - - Args: - query (Dict[str, Any]): The query document. - - Returns: - Optional[Dict[str, Any]]: The found document, or None if not found. - """ if self._collection is None: - raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.") + await self.connect() + if self._collection is None: + raise RuntimeError("MongoDB collection not available. Call connect() or use async with.") print(f"Finding one document in '{self.collection_name}' with query: {query}") document = await self._collection.find_one(query) return document async def find(self, query: Dict[str, Any] = None) -> List[Dict[str, Any]]: - """ - Finds multiple documents matching the query. - - Args: - query (Dict[str, Any], optional): The query document. Defaults to None (find all). - - Returns: - List[Dict[str, Any]]: A list of matching documents. - """ if self._collection is None: - raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.") + await self.connect() + if self._collection is None: + raise RuntimeError("MongoDB collection not available. Call connect() or use async with.") if query is None: query = {} print(f"Finding documents in '{self.collection_name}' with query: {query}") - # Use list comprehension to iterate through the cursor asynchronously documents = [doc async for doc in self._collection.find(query)] print(f"Found {len(documents)} documents.") return documents async def update_one(self, query: Dict[str, Any], update: Dict[str, Any], upsert: bool = False) -> Any: - """ - Updates a single document matching the query. - - Args: - query (Dict[str, Any]): The query document. - update (Dict[str, Any]): The update operations to apply. - upsert (bool): If True, insert a new document if no match is found. - - Returns: - Any: The result of the update operation (UpdateResult). - """ if self._collection is None: - raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.") + await self.connect() + if self._collection is None: + raise RuntimeError("MongoDB collection not available. Call connect() or use async with.") print(f"Updating one document in '{self.collection_name}' with query: {query}, update: {update}") result = await self._collection.update_one(query, update, upsert=upsert) print(f"Matched {result.matched_count}, Modified {result.modified_count}, Upserted ID: {result.upserted_id}") return result async def delete_one(self, query: Dict[str, Any]) -> Any: - """ - Deletes a single document matching the query. - - Args: - query (Dict[str, Any]): The query document. - - Returns: - Any: The result of the delete operation (DeleteResult). - """ if self._collection is None: - raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.") + await self.connect() + if self._collection is None: + raise RuntimeError("MongoDB collection not available. Call connect() or use async with.") print(f"Deleting one document from '{self.collection_name}' with query: {query}") result = await self._collection.delete_one(query) print(f"Deleted count: {result.deleted_count}") return result async def delete_many(self, query: Dict[str, Any]) -> Any: - """ - Deletes multiple documents matching the query. - - Args: - query (Dict[str, Any]): The query document. - - Returns: - Any: The result of the delete operation (DeleteResult). - """ if self._collection is None: - raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.") + await self.connect() + if self._collection is None: + raise RuntimeError("MongoDB collection not available. Call connect() or use async with.") print(f"Deleting many documents from '{self.collection_name}' with query: {query}") result = await self._collection.delete_many(query) print(f"Deleted count: {result.deleted_count}") return result - -# --- Example Usage --- +# --- Example Usage (no change needed here, but behavior of MongoHandler is different) --- async def example_mongo_usage(): - """Demonstrates how to use the MongoHandler.""" - # Ensure you have a MongoDB server running, default is localhost:27017 db_name = "radio_app_db" - collection_name = "channels" + collection_name = "channels_example" # Use a different collection for example + handler = MongoHandler(db_name, collection_name) # MONGO_URL defaults to localhost - # Using async with ensures the connection is closed automatically - async with MongoHandler(db_name, collection_name) as mongo: - # --- Insert Example --- - print("\n--- Inserting a document ---") - channel_data = { - "_id": "channel_3", # You can specify _id or let MongoDB generate one - "name": "Emergency Services", - "frequencies": 453000, - "location": "Countywide", - "avail_on_nodes": ["client-xyz987"], - "description": "Monitor for emergency broadcasts." - } - try: - insert_result = await mongo.insert_one(channel_data) - print(f"Insert successful: {insert_result.inserted_id}") - except Exception as e: - print(f"Insert failed: {e}") + try: + async with handler: # Connects on enter (if not already connected) + print("\n--- Inserting a document ---") + # ... (rest of example usage) + channel_data = { "_id": "example_channel", "name": "Example" } + await handler.insert_one(channel_data) + found = await handler.find_one({"_id": "example_channel"}) + print(f"Found: {found}") + await handler.delete_one({"_id": "example_channel"}) + print("Example completed.") + finally: + await handler.close_client() # Explicitly close client at the end of usage - # --- Find One Example --- - print("\n--- Finding one document ---") - query = {"_id": "channel_3"} - found_channel = await mongo.find_one(query) - if found_channel: - print("Found document:", found_channel) - else: - print("Document not found.") - - # --- Find Many Example --- - print("\n--- Finding all documents ---") - all_channels = await mongo.find() # Empty query finds all - print("All documents:", all_channels) - - # --- Update Example --- - print("\n--- Updating a document ---") - update_query = {"_id": "channel_3"} - update_data = {"$set": {"location": "Statewide", "avail_on_nodes": ["client-xyz987", "client-newnode1"]}} - update_result = await mongo.update_one(update_query, update_data) - print(f"Update successful: Matched {update_result.matched_count}, Modified {update_result.modified_count}") - - print("\n--- Finding the updated document ---") - updated_channel = await mongo.find_one(update_query) - print("Updated document:", updated_channel) - - # --- Delete Example --- - print("\n--- Deleting a document ---") - delete_query = {"_id": "channel_3"} - delete_result = await mongo.delete_one(delete_query) - print(f"Delete successful: Deleted count {delete_result.deleted_count}") - - print("\n--- Verifying deletion ---") - deleted_channel = await mongo.find_one(delete_query) - if deleted_channel: - print("Document still found (deletion failed).") - else: - print("Document successfully deleted.") - - # --- Insert another for delete_many example --- - temp_doc1 = {"_id": "temp_1", "tag": "temp"} - temp_doc2 = {"_id": "temp_2", "tag": "temp"} - await mongo.insert_one(temp_doc1) - await mongo.insert_one(temp_doc2) - - # --- Delete Many Example --- - print("\n--- Deleting many documents ---") - delete_many_query = {"tag": "temp"} - delete_many_result = await mongo.delete_many(delete_many_query) - print(f"Delete many successful: Deleted count {delete_many_result.deleted_count}") - - -# To run the example usage: -# 1. Ensure you have a MongoDB server running locally on the default port (27017). -# 2. Save the code as mongodb_handler.py. -# 3. Run from your terminal: python -m asyncio mongodb_handler.py if __name__ == "__main__": - # Running the example directly requires running within an asyncio loop - asyncio.run(example_mongo_usage()) + asyncio.run(example_mongo_usage()) \ No newline at end of file diff --git a/app/internal/db_wrappers.py b/app/internal/db_wrappers.py index 2f3e623..2711cf9 100644 --- a/app/internal/db_wrappers.py +++ b/app/internal/db_wrappers.py @@ -3,7 +3,7 @@ import asyncio from uuid import uuid4 from typing import Optional, List, Dict, Any from enum import Enum -from internal.db_handler import MongoHandler +from internal.db_handler import MongoHandler # from internal.types import System, DiscordId # Init vars @@ -16,70 +16,47 @@ DISCORD_ID_DB_COLLECTION_NAME = "discord_bot_ids" # --- System class --- class SystemDbController(): def __init__(self): - # Init the handler - self.db_h = MongoHandler(DB_NAME, SYSTEM_DB_COLLECTION_NAME, MONGO_URL) + self.db_h = MongoHandler(DB_NAME, SYSTEM_DB_COLLECTION_NAME, MONGO_URL) # + + async def close_db_connection(self): + """Closes the underlying MongoDB connection.""" + if self.db_h: + await self.db_h.close_client() # async def create_system(self, system_data: Dict[str, Any]) -> Optional[System]: - """ - Creates a new system entry in the database. - - Args: - system_data: A dictionary containing the data for the new system. - - Returns: - The created System object if successful, None otherwise. - """ print("\n--- Creating a document ---") try: - # Check if the data to be inserted has an ID if not system_data.get("_id"): system_data['_id'] = str(uuid4()) - inserted_result = None inserted_id = None - async with self.db_h as db: - insert_result = await self.db_h.insert_one(system_data) + async with self.db_h as db: # + insert_result = await db.insert_one(system_data) # inserted_id = insert_result.inserted_id if inserted_id: print(f"Insert successful with ID: {inserted_id}") - # Fetch the inserted document to get the complete data including the generated _id query = {"_id": inserted_id} - inserted_doc = None - async with self.db_h as db: - inserted_doc = await db.find_one(query) - + async with self.db_h as db: # + inserted_doc = await db.find_one(query) # if inserted_doc: - # Convert the fetched dictionary back to a System object return System.from_dict(inserted_doc) else: print("Insert acknowledged but no ID returned.") return None - except Exception as e: print(f"Create failed: {e}") return None async def find_system(self, query: Dict[str, Any]) -> Optional[System]: - """ - Finds a single system entry in the database. - - Args: - query: A dictionary representing the query criteria. - - Returns: - A System object if found, None otherwise. - """ print("\n--- Finding one document ---") try: found_doc = None - async with self.db_h as db: - found_doc = await db.find_one(query) - + async with self.db_h as db: # + found_doc = await db.find_one(query) # if found_doc: print("Found document (raw dict):", found_doc) - # Convert the dictionary result to a System object return System.from_dict(found_doc) else: print("Document not found.") @@ -89,30 +66,15 @@ class SystemDbController(): return None async def find_systems(self, query: Dict[str, Any]) -> Optional[List[System]]: - """ - Finds one or more system entries in the database. - - Args: - query: A dictionary representing the query criteria. - - Returns: - A list of System object(s) if found, None otherwise. - """ print("\n--- Finding documents ---") try: found_docs = None - async with self.db_h as db: - found_docs = await db.find(query) - + async with self.db_h as db: # + found_docs = await db.find(query) # if found_docs: print("Found document (raw dict):", found_docs) - # Convert the dictionary results to a System object - converted_systems = [] - for doc in found_docs: - converted_systems.append(System.from_dict(doc)) - + converted_systems = [System.from_dict(doc) for doc in found_docs] print("YURB", found_docs, converted_systems) - return converted_systems if len(converted_systems) > 0 else None else: print("Document not found.") @@ -122,24 +84,13 @@ class SystemDbController(): return None async def find_all_systems(self, query: Dict[str, Any] = {}) -> List[System]: - """ - Finds multiple system entries in the database. - - Args: - query: A dictionary representing the query criteria (default is empty to find all). - - Returns: - A list of System objects. - """ print("\n--- Finding multiple documents ---") try: found_docs = None - async with self.db_h as db: - found_docs = await db.find(query) - + async with self.db_h as db: # + found_docs = await db.find(query) # if found_docs: print(f"Found {len(found_docs)} documents (raw dicts).") - # Convert the list of dictionaries to a list of System objects return [System.from_dict(doc) for doc in found_docs] else: print("No documents found.") @@ -149,22 +100,11 @@ class SystemDbController(): return [] async def update_system(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]: - """ - Updates a single system entry in the database. - - Args: - query: A dictionary representing the query criteria to find the document. - update_data: A dictionary representing the update operations (e.g., using $set). - - Returns: - The number of modified documents if successful, None otherwise. - """ print("\n--- Updating a document ---") try: update_result = None - async with self.db_h as db: - update_result = await db.update_one(query, update_data) - + async with self.db_h as db: # + update_result = await db.update_one(query, update_data) # print(f"Update result: Matched {update_result.matched_count}, Modified {update_result.modified_count}") return update_result.modified_count except Exception as e: @@ -172,19 +112,11 @@ class SystemDbController(): return None async def delete_system(self, query: Dict[str, Any]) -> Optional[int]: - """ - Deletes a single system entry from the database. - Args: - query: A dictionary representing the query criteria to find the document to delete. - Returns: - The number of deleted documents if successful, None otherwise. - """ print("\n--- Deleting a document ---") try: delete_result = None - async with self.db_h as db: - delete_result = await self.db_h.delete_one(query) - + async with self.db_h as db: # + delete_result = await db.delete_one(query) # print(f"Delete result: Deleted count {delete_result.deleted_count}") return delete_result.deleted_count except Exception as e: @@ -195,66 +127,47 @@ class SystemDbController(): # --- DiscordIdDbController class --- class DiscordIdDbController(): def __init__(self): - # Init the handler for Discord IDs - self.db_h = MongoHandler(DB_NAME, DISCORD_ID_DB_COLLECTION_NAME, MONGO_URL) + self.db_h = MongoHandler(DB_NAME, DISCORD_ID_DB_COLLECTION_NAME, MONGO_URL) # + + async def close_db_connection(self): + """Closes the underlying MongoDB connection.""" + if self.db_h: + await self.db_h.close_client() # async def create_discord_id(self, discord_id_data: Dict[str, Any]) -> Optional[DiscordId]: - """ - Creates a new Discord ID entry in the database. - - Args: - discord_id_data: A dictionary containing the data for the new Discord ID. - - Returns: - The created DiscordId object if successful, None otherwise. - """ print("\n--- Creating a Discord ID document ---") try: if not discord_id_data.get("_id"): - discord_id_data['_id'] = str(uuid4()) # Ensure _id is a string + discord_id_data['_id'] = str(uuid4()) inserted_id = None - async with self.db_h as db: - insert_result = await self.db_h.insert_one(discord_id_data) + async with self.db_h as db: # + insert_result = await db.insert_one(discord_id_data) # inserted_id = insert_result.inserted_id if inserted_id: print(f"Discord ID insert successful with ID: {inserted_id}") query = {"_id": inserted_id} inserted_doc = None - async with self.db_h as db: - inserted_doc = await db.find_one(query) - + async with self.db_h as db: # + inserted_doc = await db.find_one(query) # if inserted_doc: return DiscordId.from_dict(inserted_doc) else: print("Discord ID insert acknowledged but no ID returned.") return None - except Exception as e: print(f"Discord ID create failed: {e}") return None async def find_discord_id(self, query: Dict[str, Any], active_only: bool = False) -> Optional[DiscordId]: - """ - Finds a single Discord ID entry in the database. - - Args: - query: A dictionary representing the query criteria. - active_only: If True, only returns active Discord IDs. - - Returns: - A DiscordId object if found, None otherwise. - """ print("\n--- Finding one Discord ID document ---") try: if active_only: query["active"] = True - found_doc = None - async with self.db_h as db: - found_doc = await db.find_one(query) - + async with self.db_h as db: # + found_doc = await db.find_one(query) # if found_doc: print("Found Discord ID document (raw dict):", found_doc) return DiscordId.from_dict(found_doc) @@ -266,37 +179,18 @@ class DiscordIdDbController(): return None async def find_discord_ids(self, query: Dict[str, Any] = {}, guild_id: Optional[str] = None, active_only: bool = False) -> Optional[List[DiscordId]]: - """ - Finds one or more Discord ID entries in the database. - - Args: - query: A dictionary representing the query criteria. - guild_id: Optional. If provided, filters Discord IDs that belong to this guild. - active_only: If True, only returns active Discord IDs. - - Returns: - A list of DiscordId object(s) if found, None otherwise. - """ print("\n--- Finding multiple Discord ID documents ---") try: - # Add active filter if requested if active_only: query["active"] = True - - # Add guild_id filter if provided if guild_id: query["guild_ids"] = {"$in": [guild_id]} - found_docs = None - async with self.db_h as db: - found_docs = await db.find(query) - + async with self.db_h as db: # + found_docs = await db.find(query) # if found_docs: print(f"Found {len(found_docs)} Discord ID documents (raw dicts).") - converted_discord_ids = [] - for doc in found_docs: - converted_discord_ids.append(DiscordId.from_dict(doc)) - + converted_discord_ids = [DiscordId.from_dict(doc) for doc in found_docs] return converted_discord_ids if len(converted_discord_ids) > 0 else None else: print("Discord ID documents not found.") @@ -306,22 +200,11 @@ class DiscordIdDbController(): return None async def update_discord_id(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]: - """ - Updates a single Discord ID entry in the database. - - Args: - query: A dictionary representing the query criteria to find the document. - update_data: A dictionary representing the update operations (e.g., using $set). - - Returns: - The number of modified documents if successful, None otherwise. - """ print("\n--- Updating a Discord ID document ---") try: update_result = None - async with self.db_h as db: - update_result = await db.update_one(query, update_data) - + async with self.db_h as db: # + update_result = await db.update_one(query, update_data) # print(f"Discord ID update result: Matched {update_result.matched_count}, Modified {update_result.modified_count}") return update_result.modified_count except Exception as e: @@ -329,19 +212,11 @@ class DiscordIdDbController(): return None async def delete_discord_id(self, query: Dict[str, Any]) -> Optional[int]: - """ - Deletes a single Discord ID entry from the database. - Args: - query: A dictionary representing the query criteria to find the document to delete. - Returns: - The number of deleted documents if successful, None otherwise. - """ print("\n--- Deleting a Discord ID document ---") try: delete_result = None - async with self.db_h as db: - delete_result = await self.db_h.delete_one(query) - + async with self.db_h as db: # + delete_result = await db.delete_one(query) # print(f"Discord ID delete result: Deleted count {delete_result.deleted_count}") return delete_result.deleted_count except Exception as e: diff --git a/app/routers/auth.py b/app/routers/auth.py index a061386..684a7e0 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -24,6 +24,8 @@ def role_required(required_role: UserRoles): # Make a DB call to get the user and their role user = await current_app.user_db_h.find_user({"_id": user_id}) + print("YERERERE", user) + if not user: abort(401, "User not found or invalid token.") # User corresponding to token not found diff --git a/app/server.py b/app/server.py index 6a73b6f..176bd4a 100644 --- a/app/server.py +++ b/app/server.py @@ -8,40 +8,28 @@ from quart_cors import cors from routers.systems import systems_bp from routers.nodes import nodes_bp, register_client, unregister_client from routers.bot import bot_bp -from routers.auth import auth_bp # ONLY import auth_bp, not jwt instance from auth.py -from internal.db_wrappers import SystemDbController, DiscordIdDbController -from internal.auth_wrappers import UserDbController -# Import the JWTManager instance and its configuration function -from config.jwt_config import jwt, configure_jwt # Import the actual jwt instance and the config function +from routers.auth import auth_bp +from internal.db_wrappers import SystemDbController, DiscordIdDbController # +from internal.auth_wrappers import UserDbController # +from config.jwt_config import jwt, configure_jwt # --- WebSocket Server Components --- -# Dictionary to store active clients: {client_id: websocket} active_clients = {} - async def websocket_server_handler(websocket): - """Handles incoming WebSocket connections and messages from clients.""" client_id = None try: - # Handshake: Receive the first message which should contain the client ID handshake_message = await websocket.recv() handshake_data = json.loads(handshake_message) - if handshake_data.get("type") == "handshake" and "id" in handshake_data: client_id = handshake_data["id"] await register_client(websocket, client_id) - await websocket.send(json.dumps({"type": "handshake_ack", "status": "success"})) # Acknowledge handshake - - # Keep the connection alive and listen for potential messages from the client - # (Though in this server-commanded model, clients might not send much) - # We primarily wait for the client to close the connection + await websocket.send(json.dumps({"type": "handshake_ack", "status": "success"})) await websocket.wait_closed() - else: print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.") await websocket.close() - except websockets.exceptions.ConnectionClosedError: print(f"Client connection closed unexpectedly for {client_id}.") except json.JSONDecodeError: @@ -56,62 +44,70 @@ async def websocket_server_handler(websocket): app = Quart(__name__) app = cors(app, allow_origin="*") -# Store the websocket server instance websocket_server_instance = None - -# Make active_clients accessible via the app instance. app.active_clients = active_clients # Create and attach the DB wrappers -app.sys_db_h = SystemDbController() -app.d_id_db_h = DiscordIdDbController() -app.user_db_h = UserDbController() +app.sys_db_h = SystemDbController() # +app.d_id_db_h = DiscordIdDbController() # +app.user_db_h = UserDbController() # -# Configure JWT settings and initialize the JWTManager instance with the app configure_jwt(app) -jwt.init_app(app) # Crucial: This initializes the global 'jwt' instance with your app +jwt.init_app(app) @app.before_serving -async def startup_websocket_server(): - """Starts the WebSocket server when the Quart app starts.""" +async def startup_tasks(): # Combined startup logic + """Starts the WebSocket server and prepares other resources.""" global websocket_server_instance websocket_server_address = "0.0.0.0" websocket_server_port = 8765 - - # Start the WebSocket server task websocket_server_instance = await websockets.serve( websocket_server_handler, websocket_server_address, websocket_server_port ) print(f"WebSocket server started on ws://{websocket_server_address}:{websocket_server_port}") + # Database connections are now established on first use by MongoHandler's __aenter__/connect + # No explicit connect calls needed here unless desired for early failure detection. + print("Application startup complete. DB connections will be initialized on first use.") @app.after_serving -async def shutdown_websocket_server(): - """Shuts down the WebSocket server when the Quart app stops.""" +async def shutdown_tasks(): # Combined shutdown logic + """Shuts down services and closes connections.""" global websocket_server_instance if websocket_server_instance: websocket_server_instance.close() await websocket_server_instance.wait_closed() print("WebSocket server shut down.") + # Close database connections + if hasattr(app, 'user_db_h') and app.user_db_h: + print("Closing User DB connection...") + await app.user_db_h.close_db_connection() # + if hasattr(app, 'sys_db_h') and app.sys_db_h: + print("Closing System DB connection...") + await app.sys_db_h.close_db_connection() # + if hasattr(app, 'd_id_db_h') and app.d_id_db_h: + print("Closing Discord ID DB connection...") + await app.d_id_db_h.close_db_connection() # + print("All database connections have been signaled to close.") + app.register_blueprint(systems_bp, url_prefix="/systems") app.register_blueprint(nodes_bp, url_prefix="/nodes") app.register_blueprint(bot_bp, url_prefix="/bots") -app.register_blueprint(auth_bp, url_prefix="/auth") # Register the auth blueprint +app.register_blueprint(auth_bp, url_prefix="/auth") @app.route('/') async def index(): return "Welcome to the Radio App Server API!" -# --- Main Execution --- if __name__ == "__main__": print("Starting Quart API server...") app.run( host="0.0.0.0", port=5000, - debug=False # Set to True for development + debug=False ) print("Quart API server stopped.") \ No newline at end of file