From 6f64a8390ad1bc18915fd6530469c1a20f1d51a3 Mon Sep 17 00:00:00 2001 From: Logan Cusano Date: Sat, 7 Jun 2025 23:08:24 -0400 Subject: [PATCH] Implemented client nickname and access token --- app/internal/types.py | 4 ++++ app/routers/auth.py | 29 ++++++++++++++---------- app/routers/nodes.py | 10 +++++++-- app/server.py | 51 ++++++++++++++++++++++++------------------- 4 files changed, 58 insertions(+), 36 deletions(-) diff --git a/app/internal/types.py b/app/internal/types.py index c1a3faa..9356ab8 100644 --- a/app/internal/types.py +++ b/app/internal/types.py @@ -191,16 +191,20 @@ class ActiveClient: """ websocket = None active_token: DiscordId = None + nickname = None + access_token = None def __init__(self, websocket= None, active_token:DiscordId=None): self.active_token = active_token self.websocket = websocket + class UserRoles(str, Enum): ADMIN = "admin" MOD = "mod" USER = "user" + class User: """ A data model for a User entry. diff --git a/app/routers/auth.py b/app/routers/auth.py index 684a7e0..f6c0557 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -21,22 +21,28 @@ def role_required(required_role: UserRoles): current_user_identity = get_jwt_identity() user_id = current_user_identity['id'] - # Make a DB call to get the user and their role - user = await current_app.user_db_h.find_user({"_id": user_id}) + auth_type = current_user_identity['type'] - print("YERERERE", user) + if auth_type == "node": + node = app.active_clients.get("client_id") - if not user: - abort(401, "User not found or invalid token.") # User corresponding to token not found + if not node: + abort(401, "Node not found or invalid token.") - user_role = user.role # Get the role from the fetched user object + if auth_type == "user": + # Make a DB call to get the user and their role + user = await current_app.user_db_h.find_user({"_id": user_id}) - role_order = {UserRoles.USER: 0, UserRoles.MOD: 1, UserRoles.ADMIN: 2} + if not user: + abort(401, "User not found or invalid token.") # User corresponding to token not found - if role_order[user_role] < role_order[required_role]: - abort(403, "Permission denied: Insufficient role.") + user_role = user.role # Get the role from the fetched user object + + role_order = {UserRoles.USER: 0, UserRoles.MOD: 1, UserRoles.ADMIN: 2} + + if role_order[user_role] < role_order[required_role]: + abort(403, "Permission denied: Insufficient role.") - # REMOVE current_app.ensure_sync() here return await fn(*args, **kwargs) # Directly await the original async function return decorated_view @@ -90,9 +96,10 @@ async def login_user(): if not user or not check_password_hash(user.password_hash, password): abort(401, "Invalid credentials") - access_token = create_access_token(identity={"id": user._id, "username": user.username}) + access_token = create_access_token(identity={"id": user._id, "username": user.username, "type": "user"}) return jsonify({"access_token": access_token, "role": user.role, "username": user.username, "user_id": user._id }), 200 +# DEPRECATED @auth_bp.route('/generate_api_key', methods=['POST']) @jwt_required async def generate_api_key(): diff --git a/app/routers/nodes.py b/app/routers/nodes.py index 0d506f5..e22afc0 100644 --- a/app/routers/nodes.py +++ b/app/routers/nodes.py @@ -6,6 +6,7 @@ from quart import Blueprint, jsonify, request, abort, current_app from werkzeug.exceptions import HTTPException from enum import Enum from internal.types import ActiveClient, NodeCommands, UserRoles +from quart_jwt_extended import create_access_token from quart_jwt_extended import jwt_required from routers.auth import role_required @@ -15,11 +16,16 @@ nodes_bp = Blueprint('nodes', __name__) pending_requests = {} -async def register_client(websocket, client_id): +async def register_client(websocket, client_id, nickname): """Registers a new client connection.""" - current_app.active_clients[client_id] = ActiveClient(websocket) + current_app.active_clients[client_id] = ActiveClient() + current_app.active_clients[client_id].websocket = websocket + current_app.active_clients[client_id].nickname = nickname print(f"Client {client_id} connected.") + # Create a JWT for the client + current_app.active_clients[client_id].access_token = create_access_token(identity={"id": client_id, "username": nickname, "type": "node"}) + # Start a task to listen for messages from this client asyncio.create_task(listen_to_client(websocket, client_id)) diff --git a/app/server.py b/app/server.py index 176bd4a..f0924cc 100644 --- a/app/server.py +++ b/app/server.py @@ -3,7 +3,7 @@ import asyncio import websockets import json import uuid -from quart import Quart, jsonify, request +from quart import Quart, jsonify, request, abort from quart_cors import cors from routers.systems import systems_bp from routers.nodes import nodes_bp, register_client, unregister_client @@ -17,28 +17,6 @@ from config.jwt_config import jwt, configure_jwt # --- WebSocket Server Components --- active_clients = {} -async def websocket_server_handler(websocket): - client_id = None - try: - handshake_message = await websocket.recv() - handshake_data = json.loads(handshake_message) - if handshake_data.get("type") == "handshake" and "id" in handshake_data: - client_id = handshake_data["id"] - await register_client(websocket, client_id) - await websocket.send(json.dumps({"type": "handshake_ack", "status": "success"})) - await websocket.wait_closed() - else: - print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.") - await websocket.close() - except websockets.exceptions.ConnectionClosedError: - print(f"Client connection closed unexpectedly for {client_id}.") - except json.JSONDecodeError: - print(f"Received invalid JSON from {client_id or 'an unknown client'}.") - except Exception as e: - print(f"An error occurred with client {client_id}: {e}") - finally: - if client_id: - await unregister_client(client_id) # --- Quart API Components --- app = Quart(__name__) @@ -56,6 +34,33 @@ configure_jwt(app) jwt.init_app(app) +async def websocket_server_handler(websocket): + client_id = None + try: + handshake_message = await websocket.recv() + handshake_data = json.loads(handshake_message) + if handshake_data.get("type") == "handshake" and "id" in handshake_data: + client_id = handshake_data["id"] + client_nickname = handshake_data.get("nickname") + await register_client(websocket, client_id, nickname) + if not app.active_clients[client_id].access_token: + abort(500, "Error retrieving access token") + await websocket.send(json.dumps({"type": "handshake_ack", "status": "success", "access_token": app.active_clients[client_id].access_token})) + await websocket.wait_closed() + else: + print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.") + await websocket.close() + except websockets.exceptions.ConnectionClosedError: + print(f"Client connection closed unexpectedly for {client_id}.") + except json.JSONDecodeError: + print(f"Received invalid JSON from {client_id or 'an unknown client'}.") + except Exception as e: + print(f"An error occurred with client {client_id}: {e}") + finally: + if client_id: + await unregister_client(client_id) + + @app.before_serving async def startup_tasks(): # Combined startup logic """Starts the WebSocket server and prepares other resources."""