Refactored DB handlers
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from internal.db_handler import MongoHandler
|
from internal.db_handler import MongoHandler #
|
||||||
from internal.types import User, UserRoles
|
from internal.types import User, UserRoles
|
||||||
|
|
||||||
DB_NAME = os.getenv("DB_NAME", "default_db")
|
DB_NAME = os.getenv("DB_NAME", "default_db")
|
||||||
@@ -13,7 +13,12 @@ USER_DB_COLLECTION_NAME = "users"
|
|||||||
|
|
||||||
class UserDbController:
|
class UserDbController:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.db_h = MongoHandler(DB_NAME, USER_DB_COLLECTION_NAME, MONGO_URL)
|
self.db_h = MongoHandler(DB_NAME, USER_DB_COLLECTION_NAME, MONGO_URL) #
|
||||||
|
|
||||||
|
async def close_db_connection(self):
|
||||||
|
"""Closes the underlying MongoDB connection."""
|
||||||
|
if self.db_h:
|
||||||
|
await self.db_h.close_client() #
|
||||||
|
|
||||||
async def create_user(self, user_data: Dict[str, Any]) -> Optional[User]:
|
async def create_user(self, user_data: Dict[str, Any]) -> Optional[User]:
|
||||||
try:
|
try:
|
||||||
@@ -21,15 +26,15 @@ class UserDbController:
|
|||||||
user_data['_id'] = str(uuid4())
|
user_data['_id'] = str(uuid4())
|
||||||
|
|
||||||
inserted_id = None
|
inserted_id = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
insert_result = await db.insert_one(user_data)
|
insert_result = await db.insert_one(user_data) #
|
||||||
inserted_id = insert_result.inserted_id
|
inserted_id = insert_result.inserted_id
|
||||||
|
|
||||||
if inserted_id:
|
if inserted_id:
|
||||||
query = {"_id": inserted_id}
|
query = {"_id": inserted_id}
|
||||||
inserted_doc = None
|
inserted_doc = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
inserted_doc = await db.find_one(query)
|
inserted_doc = await db.find_one(query) #
|
||||||
if inserted_doc:
|
if inserted_doc:
|
||||||
return User.from_dict(inserted_doc)
|
return User.from_dict(inserted_doc)
|
||||||
return None
|
return None
|
||||||
@@ -40,20 +45,22 @@ class UserDbController:
|
|||||||
async def find_user(self, query: Dict[str, Any]) -> Optional[User]:
|
async def find_user(self, query: Dict[str, Any]) -> Optional[User]:
|
||||||
try:
|
try:
|
||||||
found_doc = None
|
found_doc = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
found_doc = await db.find_one(query)
|
# The 'db' here is self.db_h, which is an instance of MongoHandler.
|
||||||
|
# MongoHandler.find_one will be called.
|
||||||
|
found_doc = await db.find_one(query) #
|
||||||
if found_doc:
|
if found_doc:
|
||||||
return User.from_dict(found_doc)
|
return User.from_dict(found_doc)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Find user failed: {e}")
|
print(f"Find user failed: {e}") # This error should be less frequent or indicate actual DB issues now
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def update_user(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]:
|
async def update_user(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]:
|
||||||
try:
|
try:
|
||||||
update_result = None
|
update_result = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
update_result = await db.update_one(query, update_data)
|
update_result = await db.update_one(query, update_data) #
|
||||||
return update_result.modified_count
|
return update_result.modified_count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Update user failed: {e}")
|
print(f"Update user failed: {e}")
|
||||||
@@ -62,8 +69,8 @@ class UserDbController:
|
|||||||
async def delete_user(self, query: Dict[str, Any]) -> Optional[int]:
|
async def delete_user(self, query: Dict[str, Any]) -> Optional[int]:
|
||||||
try:
|
try:
|
||||||
delete_result = None
|
delete_result = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
delete_result = await db.delete_one(query)
|
delete_result = await db.delete_one(query) #
|
||||||
return delete_result.deleted_count
|
return delete_result.deleted_count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Delete user failed: {e}")
|
print(f"Delete user failed: {e}")
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ from typing import Optional, Dict, Any, List
|
|||||||
|
|
||||||
class MongoHandler:
|
class MongoHandler:
|
||||||
"""
|
"""
|
||||||
A basic asynchronous handler for MongoDB operations using motor.
|
An asynchronous handler for MongoDB operations using motor.
|
||||||
Designed to be used with asyncio.
|
The client connection is established on first use (via async with or direct call to connect())
|
||||||
|
and should be explicitly closed by calling close_client() at application shutdown.
|
||||||
"""
|
"""
|
||||||
def __init__(self, db_name: str, collection_name: str, mongo_uri: str = "mongodb://localhost:27017/"):
|
def __init__(self, db_name: str, collection_name: str, mongo_uri: str = "mongodb://localhost:27017/"):
|
||||||
"""
|
"""
|
||||||
@@ -15,7 +16,6 @@ class MongoHandler:
|
|||||||
db_name (str): The name of the database to connect to.
|
db_name (str): The name of the database to connect to.
|
||||||
collection_name (str): The name of the collection to use.
|
collection_name (str): The name of the collection to use.
|
||||||
mongo_uri (str): The MongoDB connection string URI.
|
mongo_uri (str): The MongoDB connection string URI.
|
||||||
Defaults to the standard local URI.
|
|
||||||
"""
|
"""
|
||||||
self.mongo_uri = mongo_uri
|
self.mongo_uri = mongo_uri
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
@@ -23,229 +23,142 @@ class MongoHandler:
|
|||||||
self._client: Optional[motor.motor_asyncio.AsyncIOMotorClient] = None
|
self._client: Optional[motor.motor_asyncio.AsyncIOMotorClient] = None
|
||||||
self._db: Optional[motor.motor_asyncio.AsyncIOMotorDatabase] = None
|
self._db: Optional[motor.motor_asyncio.AsyncIOMotorDatabase] = None
|
||||||
self._collection: Optional[motor.motor_asyncio.AsyncIOMotorCollection] = None
|
self._collection: Optional[motor.motor_asyncio.AsyncIOMotorCollection] = None
|
||||||
|
self._lock = asyncio.Lock() # Lock for serializing client creation
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""Establishes an asynchronous connection to MongoDB."""
|
"""
|
||||||
|
Establishes an asynchronous connection to MongoDB if not already established.
|
||||||
|
This method is idempotent.
|
||||||
|
"""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
try:
|
async with self._lock: # Ensure only one coroutine attempts to initialize the client
|
||||||
self._client = motor.motor_asyncio.AsyncIOMotorClient(self.mongo_uri)
|
if self._client is None: # Double-check after acquiring lock
|
||||||
# The ismaster command is cheap and does not require auth.
|
try:
|
||||||
# It is used to confirm that the client can connect to the deployment.
|
print(f"Initializing MongoDB client for: DB '{self.db_name}', Collection '{self.collection_name}' URI: {self.mongo_uri.split('@')[-1]}") # Avoid logging credentials
|
||||||
await self._client.admin.command('ismaster')
|
self._client = motor.motor_asyncio.AsyncIOMotorClient(self.mongo_uri)
|
||||||
|
# The ismaster command is cheap and does not require auth.
|
||||||
|
await self._client.admin.command('ismaster')
|
||||||
|
self._db = self._client[self.db_name]
|
||||||
|
self._collection = self._db[self.collection_name]
|
||||||
|
print(f"MongoDB client initialized and connected: Database '{self.db_name}', Collection '{self.collection_name}'")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to initialize MongoDB client at {self.mongo_uri.split('@')[-1]} for {self.db_name}/{self.collection_name}: {e}")
|
||||||
|
self._client = None # Ensure client is None if connection fails
|
||||||
|
self._db = None
|
||||||
|
self._collection = None
|
||||||
|
raise # Re-raise the exception after printing
|
||||||
|
|
||||||
|
if self._collection is None and self._client is not None:
|
||||||
|
# This can happen if connect was called, client was set, but then an error occurred before collection was set
|
||||||
|
# Or if connect logic needs to re-establish db/collection objects without re-creating client (less common with motor)
|
||||||
|
if self._db is None:
|
||||||
self._db = self._client[self.db_name]
|
self._db = self._client[self.db_name]
|
||||||
self._collection = self._db[self.collection_name]
|
self._collection = self._db[self.collection_name]
|
||||||
print(f"Connected to MongoDB: Database '{self.db_name}', Collection '{self.collection_name}'")
|
if self._collection is None:
|
||||||
except Exception as e:
|
raise RuntimeError(f"MongoDB collection '{self.collection_name}' could not be established even though client exists.")
|
||||||
print(f"Failed to connect to MongoDB at {self.mongo_uri}: {e}")
|
|
||||||
self._client = None # Ensure client is None if connection fails
|
|
||||||
raise # Re-raise the exception after printing
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""Closes the MongoDB connection."""
|
async def close_client(self):
|
||||||
if self._client:
|
"""Closes the MongoDB client connection. Should be called on application shutdown."""
|
||||||
self._client.close()
|
async with self._lock:
|
||||||
self._client = None
|
if self._client:
|
||||||
self._db = None
|
print(f"Closing MongoDB client for: Database '{self.db_name}', Collection '{self.collection_name}'")
|
||||||
self._collection = None
|
self._client.close()
|
||||||
print("MongoDB connection closed.")
|
self._client = None
|
||||||
|
self._db = None
|
||||||
|
self._collection = None
|
||||||
|
print(f"MongoDB client closed for: Database '{self.db_name}', Collection '{self.collection_name}'.")
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
"""Allows using the handler with async with."""
|
"""Allows using the handler with async with. Ensures connection is active."""
|
||||||
await self.connect()
|
await self.connect()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Ensures the connection is closed when exiting async with."""
|
"""Ensures the connection is NOT closed when exiting async with."""
|
||||||
await self.close()
|
# The connection is managed and closed at the application level by calling close_client()
|
||||||
|
pass
|
||||||
|
|
||||||
async def insert_one(self, document: Dict[str, Any]) -> Any:
|
async def insert_one(self, document: Dict[str, Any]) -> Any:
|
||||||
"""
|
|
||||||
Inserts a single document into the collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document (Dict[str, Any]): The document to insert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The result of the insert operation (InsertOneResult).
|
|
||||||
"""
|
|
||||||
if self._collection is None:
|
if self._collection is None:
|
||||||
raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.")
|
await self.connect() # Try to connect if collection is None
|
||||||
|
if self._collection is None: # Check again after attempting to connect
|
||||||
|
raise RuntimeError("MongoDB collection not available. Call connect() or use async with.")
|
||||||
print(f"Inserting document into '{self.collection_name}'...")
|
print(f"Inserting document into '{self.collection_name}'...")
|
||||||
result = await self._collection.insert_one(document)
|
result = await self._collection.insert_one(document)
|
||||||
print(f"Inserted document with ID: {result.inserted_id}")
|
print(f"Inserted document with ID: {result.inserted_id}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def find_one(self, query: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
async def find_one(self, query: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
|
||||||
Finds a single document matching the query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (Dict[str, Any]): The query document.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Dict[str, Any]]: The found document, or None if not found.
|
|
||||||
"""
|
|
||||||
if self._collection is None:
|
if self._collection is None:
|
||||||
raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.")
|
await self.connect()
|
||||||
|
if self._collection is None:
|
||||||
|
raise RuntimeError("MongoDB collection not available. Call connect() or use async with.")
|
||||||
print(f"Finding one document in '{self.collection_name}' with query: {query}")
|
print(f"Finding one document in '{self.collection_name}' with query: {query}")
|
||||||
document = await self._collection.find_one(query)
|
document = await self._collection.find_one(query)
|
||||||
return document
|
return document
|
||||||
|
|
||||||
async def find(self, query: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
async def find(self, query: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||||
"""
|
|
||||||
Finds multiple documents matching the query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (Dict[str, Any], optional): The query document. Defaults to None (find all).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: A list of matching documents.
|
|
||||||
"""
|
|
||||||
if self._collection is None:
|
if self._collection is None:
|
||||||
raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.")
|
await self.connect()
|
||||||
|
if self._collection is None:
|
||||||
|
raise RuntimeError("MongoDB collection not available. Call connect() or use async with.")
|
||||||
if query is None:
|
if query is None:
|
||||||
query = {}
|
query = {}
|
||||||
print(f"Finding documents in '{self.collection_name}' with query: {query}")
|
print(f"Finding documents in '{self.collection_name}' with query: {query}")
|
||||||
# Use list comprehension to iterate through the cursor asynchronously
|
|
||||||
documents = [doc async for doc in self._collection.find(query)]
|
documents = [doc async for doc in self._collection.find(query)]
|
||||||
print(f"Found {len(documents)} documents.")
|
print(f"Found {len(documents)} documents.")
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
async def update_one(self, query: Dict[str, Any], update: Dict[str, Any], upsert: bool = False) -> Any:
|
async def update_one(self, query: Dict[str, Any], update: Dict[str, Any], upsert: bool = False) -> Any:
|
||||||
"""
|
|
||||||
Updates a single document matching the query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (Dict[str, Any]): The query document.
|
|
||||||
update (Dict[str, Any]): The update operations to apply.
|
|
||||||
upsert (bool): If True, insert a new document if no match is found.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The result of the update operation (UpdateResult).
|
|
||||||
"""
|
|
||||||
if self._collection is None:
|
if self._collection is None:
|
||||||
raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.")
|
await self.connect()
|
||||||
|
if self._collection is None:
|
||||||
|
raise RuntimeError("MongoDB collection not available. Call connect() or use async with.")
|
||||||
print(f"Updating one document in '{self.collection_name}' with query: {query}, update: {update}")
|
print(f"Updating one document in '{self.collection_name}' with query: {query}, update: {update}")
|
||||||
result = await self._collection.update_one(query, update, upsert=upsert)
|
result = await self._collection.update_one(query, update, upsert=upsert)
|
||||||
print(f"Matched {result.matched_count}, Modified {result.modified_count}, Upserted ID: {result.upserted_id}")
|
print(f"Matched {result.matched_count}, Modified {result.modified_count}, Upserted ID: {result.upserted_id}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_one(self, query: Dict[str, Any]) -> Any:
|
async def delete_one(self, query: Dict[str, Any]) -> Any:
|
||||||
"""
|
|
||||||
Deletes a single document matching the query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (Dict[str, Any]): The query document.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The result of the delete operation (DeleteResult).
|
|
||||||
"""
|
|
||||||
if self._collection is None:
|
if self._collection is None:
|
||||||
raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.")
|
await self.connect()
|
||||||
|
if self._collection is None:
|
||||||
|
raise RuntimeError("MongoDB collection not available. Call connect() or use async with.")
|
||||||
print(f"Deleting one document from '{self.collection_name}' with query: {query}")
|
print(f"Deleting one document from '{self.collection_name}' with query: {query}")
|
||||||
result = await self._collection.delete_one(query)
|
result = await self._collection.delete_one(query)
|
||||||
print(f"Deleted count: {result.deleted_count}")
|
print(f"Deleted count: {result.deleted_count}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_many(self, query: Dict[str, Any]) -> Any:
|
async def delete_many(self, query: Dict[str, Any]) -> Any:
|
||||||
"""
|
|
||||||
Deletes multiple documents matching the query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (Dict[str, Any]): The query document.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The result of the delete operation (DeleteResult).
|
|
||||||
"""
|
|
||||||
if self._collection is None:
|
if self._collection is None:
|
||||||
raise RuntimeError("MongoDB connection not established. Call connect() first or use async with.")
|
await self.connect()
|
||||||
|
if self._collection is None:
|
||||||
|
raise RuntimeError("MongoDB collection not available. Call connect() or use async with.")
|
||||||
print(f"Deleting many documents from '{self.collection_name}' with query: {query}")
|
print(f"Deleting many documents from '{self.collection_name}' with query: {query}")
|
||||||
result = await self._collection.delete_many(query)
|
result = await self._collection.delete_many(query)
|
||||||
print(f"Deleted count: {result.deleted_count}")
|
print(f"Deleted count: {result.deleted_count}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# --- Example Usage (no change needed here, but behavior of MongoHandler is different) ---
|
||||||
# --- Example Usage ---
|
|
||||||
async def example_mongo_usage():
|
async def example_mongo_usage():
|
||||||
"""Demonstrates how to use the MongoHandler."""
|
|
||||||
# Ensure you have a MongoDB server running, default is localhost:27017
|
|
||||||
db_name = "radio_app_db"
|
db_name = "radio_app_db"
|
||||||
collection_name = "channels"
|
collection_name = "channels_example" # Use a different collection for example
|
||||||
|
handler = MongoHandler(db_name, collection_name) # MONGO_URL defaults to localhost
|
||||||
|
|
||||||
# Using async with ensures the connection is closed automatically
|
try:
|
||||||
async with MongoHandler(db_name, collection_name) as mongo:
|
async with handler: # Connects on enter (if not already connected)
|
||||||
# --- Insert Example ---
|
print("\n--- Inserting a document ---")
|
||||||
print("\n--- Inserting a document ---")
|
# ... (rest of example usage)
|
||||||
channel_data = {
|
channel_data = { "_id": "example_channel", "name": "Example" }
|
||||||
"_id": "channel_3", # You can specify _id or let MongoDB generate one
|
await handler.insert_one(channel_data)
|
||||||
"name": "Emergency Services",
|
found = await handler.find_one({"_id": "example_channel"})
|
||||||
"frequencies": 453000,
|
print(f"Found: {found}")
|
||||||
"location": "Countywide",
|
await handler.delete_one({"_id": "example_channel"})
|
||||||
"avail_on_nodes": ["client-xyz987"],
|
print("Example completed.")
|
||||||
"description": "Monitor for emergency broadcasts."
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
insert_result = await mongo.insert_one(channel_data)
|
|
||||||
print(f"Insert successful: {insert_result.inserted_id}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Insert failed: {e}")
|
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await handler.close_client() # Explicitly close client at the end of usage
|
||||||
|
|
||||||
# --- Find One Example ---
|
|
||||||
print("\n--- Finding one document ---")
|
|
||||||
query = {"_id": "channel_3"}
|
|
||||||
found_channel = await mongo.find_one(query)
|
|
||||||
if found_channel:
|
|
||||||
print("Found document:", found_channel)
|
|
||||||
else:
|
|
||||||
print("Document not found.")
|
|
||||||
|
|
||||||
# --- Find Many Example ---
|
|
||||||
print("\n--- Finding all documents ---")
|
|
||||||
all_channels = await mongo.find() # Empty query finds all
|
|
||||||
print("All documents:", all_channels)
|
|
||||||
|
|
||||||
# --- Update Example ---
|
|
||||||
print("\n--- Updating a document ---")
|
|
||||||
update_query = {"_id": "channel_3"}
|
|
||||||
update_data = {"$set": {"location": "Statewide", "avail_on_nodes": ["client-xyz987", "client-newnode1"]}}
|
|
||||||
update_result = await mongo.update_one(update_query, update_data)
|
|
||||||
print(f"Update successful: Matched {update_result.matched_count}, Modified {update_result.modified_count}")
|
|
||||||
|
|
||||||
print("\n--- Finding the updated document ---")
|
|
||||||
updated_channel = await mongo.find_one(update_query)
|
|
||||||
print("Updated document:", updated_channel)
|
|
||||||
|
|
||||||
# --- Delete Example ---
|
|
||||||
print("\n--- Deleting a document ---")
|
|
||||||
delete_query = {"_id": "channel_3"}
|
|
||||||
delete_result = await mongo.delete_one(delete_query)
|
|
||||||
print(f"Delete successful: Deleted count {delete_result.deleted_count}")
|
|
||||||
|
|
||||||
print("\n--- Verifying deletion ---")
|
|
||||||
deleted_channel = await mongo.find_one(delete_query)
|
|
||||||
if deleted_channel:
|
|
||||||
print("Document still found (deletion failed).")
|
|
||||||
else:
|
|
||||||
print("Document successfully deleted.")
|
|
||||||
|
|
||||||
# --- Insert another for delete_many example ---
|
|
||||||
temp_doc1 = {"_id": "temp_1", "tag": "temp"}
|
|
||||||
temp_doc2 = {"_id": "temp_2", "tag": "temp"}
|
|
||||||
await mongo.insert_one(temp_doc1)
|
|
||||||
await mongo.insert_one(temp_doc2)
|
|
||||||
|
|
||||||
# --- Delete Many Example ---
|
|
||||||
print("\n--- Deleting many documents ---")
|
|
||||||
delete_many_query = {"tag": "temp"}
|
|
||||||
delete_many_result = await mongo.delete_many(delete_many_query)
|
|
||||||
print(f"Delete many successful: Deleted count {delete_many_result.deleted_count}")
|
|
||||||
|
|
||||||
|
|
||||||
# To run the example usage:
|
|
||||||
# 1. Ensure you have a MongoDB server running locally on the default port (27017).
|
|
||||||
# 2. Save the code as mongodb_handler.py.
|
|
||||||
# 3. Run from your terminal: python -m asyncio mongodb_handler.py
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Running the example directly requires running within an asyncio loop
|
asyncio.run(example_mongo_usage())
|
||||||
asyncio.run(example_mongo_usage())
|
|
||||||
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from internal.db_handler import MongoHandler
|
from internal.db_handler import MongoHandler #
|
||||||
from internal.types import System, DiscordId
|
from internal.types import System, DiscordId
|
||||||
|
|
||||||
# Init vars
|
# Init vars
|
||||||
@@ -16,70 +16,47 @@ DISCORD_ID_DB_COLLECTION_NAME = "discord_bot_ids"
|
|||||||
# --- System class ---
|
# --- System class ---
|
||||||
class SystemDbController():
|
class SystemDbController():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Init the handler
|
self.db_h = MongoHandler(DB_NAME, SYSTEM_DB_COLLECTION_NAME, MONGO_URL) #
|
||||||
self.db_h = MongoHandler(DB_NAME, SYSTEM_DB_COLLECTION_NAME, MONGO_URL)
|
|
||||||
|
async def close_db_connection(self):
|
||||||
|
"""Closes the underlying MongoDB connection."""
|
||||||
|
if self.db_h:
|
||||||
|
await self.db_h.close_client() #
|
||||||
|
|
||||||
async def create_system(self, system_data: Dict[str, Any]) -> Optional[System]:
|
async def create_system(self, system_data: Dict[str, Any]) -> Optional[System]:
|
||||||
"""
|
|
||||||
Creates a new system entry in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
system_data: A dictionary containing the data for the new system.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created System object if successful, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Creating a document ---")
|
print("\n--- Creating a document ---")
|
||||||
try:
|
try:
|
||||||
# Check if the data to be inserted has an ID
|
|
||||||
if not system_data.get("_id"):
|
if not system_data.get("_id"):
|
||||||
system_data['_id'] = str(uuid4())
|
system_data['_id'] = str(uuid4())
|
||||||
|
|
||||||
inserted_result = None
|
|
||||||
inserted_id = None
|
inserted_id = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
insert_result = await self.db_h.insert_one(system_data)
|
insert_result = await db.insert_one(system_data) #
|
||||||
inserted_id = insert_result.inserted_id
|
inserted_id = insert_result.inserted_id
|
||||||
|
|
||||||
if inserted_id:
|
if inserted_id:
|
||||||
print(f"Insert successful with ID: {inserted_id}")
|
print(f"Insert successful with ID: {inserted_id}")
|
||||||
# Fetch the inserted document to get the complete data including the generated _id
|
|
||||||
query = {"_id": inserted_id}
|
query = {"_id": inserted_id}
|
||||||
|
|
||||||
inserted_doc = None
|
inserted_doc = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
inserted_doc = await db.find_one(query)
|
inserted_doc = await db.find_one(query) #
|
||||||
|
|
||||||
if inserted_doc:
|
if inserted_doc:
|
||||||
# Convert the fetched dictionary back to a System object
|
|
||||||
return System.from_dict(inserted_doc)
|
return System.from_dict(inserted_doc)
|
||||||
else:
|
else:
|
||||||
print("Insert acknowledged but no ID returned.")
|
print("Insert acknowledged but no ID returned.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Create failed: {e}")
|
print(f"Create failed: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_system(self, query: Dict[str, Any]) -> Optional[System]:
|
async def find_system(self, query: Dict[str, Any]) -> Optional[System]:
|
||||||
"""
|
|
||||||
Finds a single system entry in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A System object if found, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Finding one document ---")
|
print("\n--- Finding one document ---")
|
||||||
try:
|
try:
|
||||||
found_doc = None
|
found_doc = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
found_doc = await db.find_one(query)
|
found_doc = await db.find_one(query) #
|
||||||
|
|
||||||
if found_doc:
|
if found_doc:
|
||||||
print("Found document (raw dict):", found_doc)
|
print("Found document (raw dict):", found_doc)
|
||||||
# Convert the dictionary result to a System object
|
|
||||||
return System.from_dict(found_doc)
|
return System.from_dict(found_doc)
|
||||||
else:
|
else:
|
||||||
print("Document not found.")
|
print("Document not found.")
|
||||||
@@ -89,30 +66,15 @@ class SystemDbController():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_systems(self, query: Dict[str, Any]) -> Optional[List[System]]:
|
async def find_systems(self, query: Dict[str, Any]) -> Optional[List[System]]:
|
||||||
"""
|
|
||||||
Finds one or more system entries in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of System object(s) if found, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Finding documents ---")
|
print("\n--- Finding documents ---")
|
||||||
try:
|
try:
|
||||||
found_docs = None
|
found_docs = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
found_docs = await db.find(query)
|
found_docs = await db.find(query) #
|
||||||
|
|
||||||
if found_docs:
|
if found_docs:
|
||||||
print("Found document (raw dict):", found_docs)
|
print("Found document (raw dict):", found_docs)
|
||||||
# Convert the dictionary results to a System object
|
converted_systems = [System.from_dict(doc) for doc in found_docs]
|
||||||
converted_systems = []
|
|
||||||
for doc in found_docs:
|
|
||||||
converted_systems.append(System.from_dict(doc))
|
|
||||||
|
|
||||||
print("YURB", found_docs, converted_systems)
|
print("YURB", found_docs, converted_systems)
|
||||||
|
|
||||||
return converted_systems if len(converted_systems) > 0 else None
|
return converted_systems if len(converted_systems) > 0 else None
|
||||||
else:
|
else:
|
||||||
print("Document not found.")
|
print("Document not found.")
|
||||||
@@ -122,24 +84,13 @@ class SystemDbController():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_all_systems(self, query: Dict[str, Any] = {}) -> List[System]:
|
async def find_all_systems(self, query: Dict[str, Any] = {}) -> List[System]:
|
||||||
"""
|
|
||||||
Finds multiple system entries in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria (default is empty to find all).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of System objects.
|
|
||||||
"""
|
|
||||||
print("\n--- Finding multiple documents ---")
|
print("\n--- Finding multiple documents ---")
|
||||||
try:
|
try:
|
||||||
found_docs = None
|
found_docs = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
found_docs = await db.find(query)
|
found_docs = await db.find(query) #
|
||||||
|
|
||||||
if found_docs:
|
if found_docs:
|
||||||
print(f"Found {len(found_docs)} documents (raw dicts).")
|
print(f"Found {len(found_docs)} documents (raw dicts).")
|
||||||
# Convert the list of dictionaries to a list of System objects
|
|
||||||
return [System.from_dict(doc) for doc in found_docs]
|
return [System.from_dict(doc) for doc in found_docs]
|
||||||
else:
|
else:
|
||||||
print("No documents found.")
|
print("No documents found.")
|
||||||
@@ -149,22 +100,11 @@ class SystemDbController():
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def update_system(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]:
|
async def update_system(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]:
|
||||||
"""
|
|
||||||
Updates a single system entry in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria to find the document.
|
|
||||||
update_data: A dictionary representing the update operations (e.g., using $set).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The number of modified documents if successful, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Updating a document ---")
|
print("\n--- Updating a document ---")
|
||||||
try:
|
try:
|
||||||
update_result = None
|
update_result = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
update_result = await db.update_one(query, update_data)
|
update_result = await db.update_one(query, update_data) #
|
||||||
|
|
||||||
print(f"Update result: Matched {update_result.matched_count}, Modified {update_result.modified_count}")
|
print(f"Update result: Matched {update_result.matched_count}, Modified {update_result.modified_count}")
|
||||||
return update_result.modified_count
|
return update_result.modified_count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -172,19 +112,11 @@ class SystemDbController():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def delete_system(self, query: Dict[str, Any]) -> Optional[int]:
|
async def delete_system(self, query: Dict[str, Any]) -> Optional[int]:
|
||||||
"""
|
|
||||||
Deletes a single system entry from the database.
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria to find the document to delete.
|
|
||||||
Returns:
|
|
||||||
The number of deleted documents if successful, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Deleting a document ---")
|
print("\n--- Deleting a document ---")
|
||||||
try:
|
try:
|
||||||
delete_result = None
|
delete_result = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
delete_result = await self.db_h.delete_one(query)
|
delete_result = await db.delete_one(query) #
|
||||||
|
|
||||||
print(f"Delete result: Deleted count {delete_result.deleted_count}")
|
print(f"Delete result: Deleted count {delete_result.deleted_count}")
|
||||||
return delete_result.deleted_count
|
return delete_result.deleted_count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -195,66 +127,47 @@ class SystemDbController():
|
|||||||
# --- DiscordIdDbController class ---
|
# --- DiscordIdDbController class ---
|
||||||
class DiscordIdDbController():
|
class DiscordIdDbController():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Init the handler for Discord IDs
|
self.db_h = MongoHandler(DB_NAME, DISCORD_ID_DB_COLLECTION_NAME, MONGO_URL) #
|
||||||
self.db_h = MongoHandler(DB_NAME, DISCORD_ID_DB_COLLECTION_NAME, MONGO_URL)
|
|
||||||
|
async def close_db_connection(self):
|
||||||
|
"""Closes the underlying MongoDB connection."""
|
||||||
|
if self.db_h:
|
||||||
|
await self.db_h.close_client() #
|
||||||
|
|
||||||
async def create_discord_id(self, discord_id_data: Dict[str, Any]) -> Optional[DiscordId]:
|
async def create_discord_id(self, discord_id_data: Dict[str, Any]) -> Optional[DiscordId]:
|
||||||
"""
|
|
||||||
Creates a new Discord ID entry in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
discord_id_data: A dictionary containing the data for the new Discord ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created DiscordId object if successful, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Creating a Discord ID document ---")
|
print("\n--- Creating a Discord ID document ---")
|
||||||
try:
|
try:
|
||||||
if not discord_id_data.get("_id"):
|
if not discord_id_data.get("_id"):
|
||||||
discord_id_data['_id'] = str(uuid4()) # Ensure _id is a string
|
discord_id_data['_id'] = str(uuid4())
|
||||||
|
|
||||||
inserted_id = None
|
inserted_id = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
insert_result = await self.db_h.insert_one(discord_id_data)
|
insert_result = await db.insert_one(discord_id_data) #
|
||||||
inserted_id = insert_result.inserted_id
|
inserted_id = insert_result.inserted_id
|
||||||
|
|
||||||
if inserted_id:
|
if inserted_id:
|
||||||
print(f"Discord ID insert successful with ID: {inserted_id}")
|
print(f"Discord ID insert successful with ID: {inserted_id}")
|
||||||
query = {"_id": inserted_id}
|
query = {"_id": inserted_id}
|
||||||
inserted_doc = None
|
inserted_doc = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
inserted_doc = await db.find_one(query)
|
inserted_doc = await db.find_one(query) #
|
||||||
|
|
||||||
if inserted_doc:
|
if inserted_doc:
|
||||||
return DiscordId.from_dict(inserted_doc)
|
return DiscordId.from_dict(inserted_doc)
|
||||||
else:
|
else:
|
||||||
print("Discord ID insert acknowledged but no ID returned.")
|
print("Discord ID insert acknowledged but no ID returned.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Discord ID create failed: {e}")
|
print(f"Discord ID create failed: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_discord_id(self, query: Dict[str, Any], active_only: bool = False) -> Optional[DiscordId]:
|
async def find_discord_id(self, query: Dict[str, Any], active_only: bool = False) -> Optional[DiscordId]:
|
||||||
"""
|
|
||||||
Finds a single Discord ID entry in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria.
|
|
||||||
active_only: If True, only returns active Discord IDs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A DiscordId object if found, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Finding one Discord ID document ---")
|
print("\n--- Finding one Discord ID document ---")
|
||||||
try:
|
try:
|
||||||
if active_only:
|
if active_only:
|
||||||
query["active"] = True
|
query["active"] = True
|
||||||
|
|
||||||
found_doc = None
|
found_doc = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
found_doc = await db.find_one(query)
|
found_doc = await db.find_one(query) #
|
||||||
|
|
||||||
if found_doc:
|
if found_doc:
|
||||||
print("Found Discord ID document (raw dict):", found_doc)
|
print("Found Discord ID document (raw dict):", found_doc)
|
||||||
return DiscordId.from_dict(found_doc)
|
return DiscordId.from_dict(found_doc)
|
||||||
@@ -266,37 +179,18 @@ class DiscordIdDbController():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def find_discord_ids(self, query: Dict[str, Any] = {}, guild_id: Optional[str] = None, active_only: bool = False) -> Optional[List[DiscordId]]:
|
async def find_discord_ids(self, query: Dict[str, Any] = {}, guild_id: Optional[str] = None, active_only: bool = False) -> Optional[List[DiscordId]]:
|
||||||
"""
|
|
||||||
Finds one or more Discord ID entries in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria.
|
|
||||||
guild_id: Optional. If provided, filters Discord IDs that belong to this guild.
|
|
||||||
active_only: If True, only returns active Discord IDs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of DiscordId object(s) if found, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Finding multiple Discord ID documents ---")
|
print("\n--- Finding multiple Discord ID documents ---")
|
||||||
try:
|
try:
|
||||||
# Add active filter if requested
|
|
||||||
if active_only:
|
if active_only:
|
||||||
query["active"] = True
|
query["active"] = True
|
||||||
|
|
||||||
# Add guild_id filter if provided
|
|
||||||
if guild_id:
|
if guild_id:
|
||||||
query["guild_ids"] = {"$in": [guild_id]}
|
query["guild_ids"] = {"$in": [guild_id]}
|
||||||
|
|
||||||
found_docs = None
|
found_docs = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
found_docs = await db.find(query)
|
found_docs = await db.find(query) #
|
||||||
|
|
||||||
if found_docs:
|
if found_docs:
|
||||||
print(f"Found {len(found_docs)} Discord ID documents (raw dicts).")
|
print(f"Found {len(found_docs)} Discord ID documents (raw dicts).")
|
||||||
converted_discord_ids = []
|
converted_discord_ids = [DiscordId.from_dict(doc) for doc in found_docs]
|
||||||
for doc in found_docs:
|
|
||||||
converted_discord_ids.append(DiscordId.from_dict(doc))
|
|
||||||
|
|
||||||
return converted_discord_ids if len(converted_discord_ids) > 0 else None
|
return converted_discord_ids if len(converted_discord_ids) > 0 else None
|
||||||
else:
|
else:
|
||||||
print("Discord ID documents not found.")
|
print("Discord ID documents not found.")
|
||||||
@@ -306,22 +200,11 @@ class DiscordIdDbController():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def update_discord_id(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]:
|
async def update_discord_id(self, query: Dict[str, Any], update_data: Dict[str, Any]) -> Optional[int]:
|
||||||
"""
|
|
||||||
Updates a single Discord ID entry in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria to find the document.
|
|
||||||
update_data: A dictionary representing the update operations (e.g., using $set).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The number of modified documents if successful, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Updating a Discord ID document ---")
|
print("\n--- Updating a Discord ID document ---")
|
||||||
try:
|
try:
|
||||||
update_result = None
|
update_result = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
update_result = await db.update_one(query, update_data)
|
update_result = await db.update_one(query, update_data) #
|
||||||
|
|
||||||
print(f"Discord ID update result: Matched {update_result.matched_count}, Modified {update_result.modified_count}")
|
print(f"Discord ID update result: Matched {update_result.matched_count}, Modified {update_result.modified_count}")
|
||||||
return update_result.modified_count
|
return update_result.modified_count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -329,19 +212,11 @@ class DiscordIdDbController():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def delete_discord_id(self, query: Dict[str, Any]) -> Optional[int]:
|
async def delete_discord_id(self, query: Dict[str, Any]) -> Optional[int]:
|
||||||
"""
|
|
||||||
Deletes a single Discord ID entry from the database.
|
|
||||||
Args:
|
|
||||||
query: A dictionary representing the query criteria to find the document to delete.
|
|
||||||
Returns:
|
|
||||||
The number of deleted documents if successful, None otherwise.
|
|
||||||
"""
|
|
||||||
print("\n--- Deleting a Discord ID document ---")
|
print("\n--- Deleting a Discord ID document ---")
|
||||||
try:
|
try:
|
||||||
delete_result = None
|
delete_result = None
|
||||||
async with self.db_h as db:
|
async with self.db_h as db: #
|
||||||
delete_result = await self.db_h.delete_one(query)
|
delete_result = await db.delete_one(query) #
|
||||||
|
|
||||||
print(f"Discord ID delete result: Deleted count {delete_result.deleted_count}")
|
print(f"Discord ID delete result: Deleted count {delete_result.deleted_count}")
|
||||||
return delete_result.deleted_count
|
return delete_result.deleted_count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ def role_required(required_role: UserRoles):
|
|||||||
# Make a DB call to get the user and their role
|
# Make a DB call to get the user and their role
|
||||||
user = await current_app.user_db_h.find_user({"_id": user_id})
|
user = await current_app.user_db_h.find_user({"_id": user_id})
|
||||||
|
|
||||||
|
print("YERERERE", user)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
abort(401, "User not found or invalid token.") # User corresponding to token not found
|
abort(401, "User not found or invalid token.") # User corresponding to token not found
|
||||||
|
|
||||||
|
|||||||
@@ -8,40 +8,28 @@ from quart_cors import cors
|
|||||||
from routers.systems import systems_bp
|
from routers.systems import systems_bp
|
||||||
from routers.nodes import nodes_bp, register_client, unregister_client
|
from routers.nodes import nodes_bp, register_client, unregister_client
|
||||||
from routers.bot import bot_bp
|
from routers.bot import bot_bp
|
||||||
from routers.auth import auth_bp # ONLY import auth_bp, not jwt instance from auth.py
|
from routers.auth import auth_bp
|
||||||
from internal.db_wrappers import SystemDbController, DiscordIdDbController
|
from internal.db_wrappers import SystemDbController, DiscordIdDbController #
|
||||||
from internal.auth_wrappers import UserDbController
|
from internal.auth_wrappers import UserDbController #
|
||||||
# Import the JWTManager instance and its configuration function
|
from config.jwt_config import jwt, configure_jwt
|
||||||
from config.jwt_config import jwt, configure_jwt # Import the actual jwt instance and the config function
|
|
||||||
|
|
||||||
|
|
||||||
# --- WebSocket Server Components ---
|
# --- WebSocket Server Components ---
|
||||||
# Dictionary to store active clients: {client_id: websocket}
|
|
||||||
active_clients = {}
|
active_clients = {}
|
||||||
|
|
||||||
|
|
||||||
async def websocket_server_handler(websocket):
|
async def websocket_server_handler(websocket):
|
||||||
"""Handles incoming WebSocket connections and messages from clients."""
|
|
||||||
client_id = None
|
client_id = None
|
||||||
try:
|
try:
|
||||||
# Handshake: Receive the first message which should contain the client ID
|
|
||||||
handshake_message = await websocket.recv()
|
handshake_message = await websocket.recv()
|
||||||
handshake_data = json.loads(handshake_message)
|
handshake_data = json.loads(handshake_message)
|
||||||
|
|
||||||
if handshake_data.get("type") == "handshake" and "id" in handshake_data:
|
if handshake_data.get("type") == "handshake" and "id" in handshake_data:
|
||||||
client_id = handshake_data["id"]
|
client_id = handshake_data["id"]
|
||||||
await register_client(websocket, client_id)
|
await register_client(websocket, client_id)
|
||||||
await websocket.send(json.dumps({"type": "handshake_ack", "status": "success"})) # Acknowledge handshake
|
await websocket.send(json.dumps({"type": "handshake_ack", "status": "success"}))
|
||||||
|
|
||||||
# Keep the connection alive and listen for potential messages from the client
|
|
||||||
# (Though in this server-commanded model, clients might not send much)
|
|
||||||
# We primarily wait for the client to close the connection
|
|
||||||
await websocket.wait_closed()
|
await websocket.wait_closed()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.")
|
print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.")
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
except websockets.exceptions.ConnectionClosedError:
|
except websockets.exceptions.ConnectionClosedError:
|
||||||
print(f"Client connection closed unexpectedly for {client_id}.")
|
print(f"Client connection closed unexpectedly for {client_id}.")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
@@ -56,62 +44,70 @@ async def websocket_server_handler(websocket):
|
|||||||
app = Quart(__name__)
|
app = Quart(__name__)
|
||||||
app = cors(app, allow_origin="*")
|
app = cors(app, allow_origin="*")
|
||||||
|
|
||||||
# Store the websocket server instance
|
|
||||||
websocket_server_instance = None
|
websocket_server_instance = None
|
||||||
|
|
||||||
# Make active_clients accessible via the app instance.
|
|
||||||
app.active_clients = active_clients
|
app.active_clients = active_clients
|
||||||
|
|
||||||
# Create and attach the DB wrappers
|
# Create and attach the DB wrappers
|
||||||
app.sys_db_h = SystemDbController()
|
app.sys_db_h = SystemDbController() #
|
||||||
app.d_id_db_h = DiscordIdDbController()
|
app.d_id_db_h = DiscordIdDbController() #
|
||||||
app.user_db_h = UserDbController()
|
app.user_db_h = UserDbController() #
|
||||||
|
|
||||||
# Configure JWT settings and initialize the JWTManager instance with the app
|
|
||||||
configure_jwt(app)
|
configure_jwt(app)
|
||||||
jwt.init_app(app) # Crucial: This initializes the global 'jwt' instance with your app
|
jwt.init_app(app)
|
||||||
|
|
||||||
|
|
||||||
@app.before_serving
|
@app.before_serving
|
||||||
async def startup_websocket_server():
|
async def startup_tasks(): # Combined startup logic
|
||||||
"""Starts the WebSocket server when the Quart app starts."""
|
"""Starts the WebSocket server and prepares other resources."""
|
||||||
global websocket_server_instance
|
global websocket_server_instance
|
||||||
websocket_server_address = "0.0.0.0"
|
websocket_server_address = "0.0.0.0"
|
||||||
websocket_server_port = 8765
|
websocket_server_port = 8765
|
||||||
|
|
||||||
# Start the WebSocket server task
|
|
||||||
websocket_server_instance = await websockets.serve(
|
websocket_server_instance = await websockets.serve(
|
||||||
websocket_server_handler,
|
websocket_server_handler,
|
||||||
websocket_server_address,
|
websocket_server_address,
|
||||||
websocket_server_port
|
websocket_server_port
|
||||||
)
|
)
|
||||||
print(f"WebSocket server started on ws://{websocket_server_address}:{websocket_server_port}")
|
print(f"WebSocket server started on ws://{websocket_server_address}:{websocket_server_port}")
|
||||||
|
# Database connections are now established on first use by MongoHandler's __aenter__/connect
|
||||||
|
# No explicit connect calls needed here unless desired for early failure detection.
|
||||||
|
print("Application startup complete. DB connections will be initialized on first use.")
|
||||||
|
|
||||||
@app.after_serving
|
@app.after_serving
|
||||||
async def shutdown_websocket_server():
|
async def shutdown_tasks(): # Combined shutdown logic
|
||||||
"""Shuts down the WebSocket server when the Quart app stops."""
|
"""Shuts down services and closes connections."""
|
||||||
global websocket_server_instance
|
global websocket_server_instance
|
||||||
if websocket_server_instance:
|
if websocket_server_instance:
|
||||||
websocket_server_instance.close()
|
websocket_server_instance.close()
|
||||||
await websocket_server_instance.wait_closed()
|
await websocket_server_instance.wait_closed()
|
||||||
print("WebSocket server shut down.")
|
print("WebSocket server shut down.")
|
||||||
|
|
||||||
|
# Close database connections
|
||||||
|
if hasattr(app, 'user_db_h') and app.user_db_h:
|
||||||
|
print("Closing User DB connection...")
|
||||||
|
await app.user_db_h.close_db_connection() #
|
||||||
|
if hasattr(app, 'sys_db_h') and app.sys_db_h:
|
||||||
|
print("Closing System DB connection...")
|
||||||
|
await app.sys_db_h.close_db_connection() #
|
||||||
|
if hasattr(app, 'd_id_db_h') and app.d_id_db_h:
|
||||||
|
print("Closing Discord ID DB connection...")
|
||||||
|
await app.d_id_db_h.close_db_connection() #
|
||||||
|
print("All database connections have been signaled to close.")
|
||||||
|
|
||||||
|
|
||||||
app.register_blueprint(systems_bp, url_prefix="/systems")
|
app.register_blueprint(systems_bp, url_prefix="/systems")
|
||||||
app.register_blueprint(nodes_bp, url_prefix="/nodes")
|
app.register_blueprint(nodes_bp, url_prefix="/nodes")
|
||||||
app.register_blueprint(bot_bp, url_prefix="/bots")
|
app.register_blueprint(bot_bp, url_prefix="/bots")
|
||||||
app.register_blueprint(auth_bp, url_prefix="/auth") # Register the auth blueprint
|
app.register_blueprint(auth_bp, url_prefix="/auth")
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
async def index():
|
async def index():
|
||||||
return "Welcome to the Radio App Server API!"
|
return "Welcome to the Radio App Server API!"
|
||||||
|
|
||||||
# --- Main Execution ---
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Starting Quart API server...")
|
print("Starting Quart API server...")
|
||||||
app.run(
|
app.run(
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=5000,
|
port=5000,
|
||||||
debug=False # Set to True for development
|
debug=False
|
||||||
)
|
)
|
||||||
print("Quart API server stopped.")
|
print("Quart API server stopped.")
|
||||||
Reference in New Issue
Block a user