Implemented client nickname and access token
All checks were successful
release-image / release-image (push) Successful in 2m9s

This commit is contained in:
Logan Cusano
2025-06-07 23:08:24 -04:00
parent 1575d466f2
commit 6f64a8390a
4 changed files with 58 additions and 36 deletions

View File

@@ -191,16 +191,20 @@ class ActiveClient:
""" """
websocket = None websocket = None
active_token: DiscordId = None active_token: DiscordId = None
nickname = None
access_token = None
def __init__(self, websocket= None, active_token:DiscordId=None): def __init__(self, websocket= None, active_token:DiscordId=None):
self.active_token = active_token self.active_token = active_token
self.websocket = websocket self.websocket = websocket
class UserRoles(str, Enum): class UserRoles(str, Enum):
ADMIN = "admin" ADMIN = "admin"
MOD = "mod" MOD = "mod"
USER = "user" USER = "user"
class User: class User:
""" """
A data model for a User entry. A data model for a User entry.

View File

@@ -21,22 +21,28 @@ def role_required(required_role: UserRoles):
current_user_identity = get_jwt_identity() current_user_identity = get_jwt_identity()
user_id = current_user_identity['id'] user_id = current_user_identity['id']
# Make a DB call to get the user and their role auth_type = current_user_identity['type']
user = await current_app.user_db_h.find_user({"_id": user_id})
print("YERERERE", user) if auth_type == "node":
node = app.active_clients.get("client_id")
if not user: if not node:
abort(401, "User not found or invalid token.") # User corresponding to token not found abort(401, "Node not found or invalid token.")
user_role = user.role # Get the role from the fetched user object if auth_type == "user":
# Make a DB call to get the user and their role
user = await current_app.user_db_h.find_user({"_id": user_id})
role_order = {UserRoles.USER: 0, UserRoles.MOD: 1, UserRoles.ADMIN: 2} if not user:
abort(401, "User not found or invalid token.") # User corresponding to token not found
if role_order[user_role] < role_order[required_role]: user_role = user.role # Get the role from the fetched user object
abort(403, "Permission denied: Insufficient role.")
role_order = {UserRoles.USER: 0, UserRoles.MOD: 1, UserRoles.ADMIN: 2}
if role_order[user_role] < role_order[required_role]:
abort(403, "Permission denied: Insufficient role.")
# REMOVE current_app.ensure_sync() here
return await fn(*args, **kwargs) # Directly await the original async function return await fn(*args, **kwargs) # Directly await the original async function
return decorated_view return decorated_view
@@ -90,9 +96,10 @@ async def login_user():
if not user or not check_password_hash(user.password_hash, password): if not user or not check_password_hash(user.password_hash, password):
abort(401, "Invalid credentials") abort(401, "Invalid credentials")
access_token = create_access_token(identity={"id": user._id, "username": user.username}) access_token = create_access_token(identity={"id": user._id, "username": user.username, "type": "user"})
return jsonify({"access_token": access_token, "role": user.role, "username": user.username, "user_id": user._id }), 200 return jsonify({"access_token": access_token, "role": user.role, "username": user.username, "user_id": user._id }), 200
# DEPRECATED
@auth_bp.route('/generate_api_key', methods=['POST']) @auth_bp.route('/generate_api_key', methods=['POST'])
@jwt_required @jwt_required
async def generate_api_key(): async def generate_api_key():

View File

@@ -6,6 +6,7 @@ from quart import Blueprint, jsonify, request, abort, current_app
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from enum import Enum from enum import Enum
from internal.types import ActiveClient, NodeCommands, UserRoles from internal.types import ActiveClient, NodeCommands, UserRoles
from quart_jwt_extended import create_access_token
from quart_jwt_extended import jwt_required from quart_jwt_extended import jwt_required
from routers.auth import role_required from routers.auth import role_required
@@ -15,11 +16,16 @@ nodes_bp = Blueprint('nodes', __name__)
pending_requests = {} pending_requests = {}
async def register_client(websocket, client_id): async def register_client(websocket, client_id, nickname):
"""Registers a new client connection.""" """Registers a new client connection."""
current_app.active_clients[client_id] = ActiveClient(websocket) current_app.active_clients[client_id] = ActiveClient()
current_app.active_clients[client_id].websocket = websocket
current_app.active_clients[client_id].nickname = nickname
print(f"Client {client_id} connected.") print(f"Client {client_id} connected.")
# Create a JWT for the client
current_app.active_clients[client_id].access_token = create_access_token(identity={"id": client_id, "username": nickname, "type": "node"})
# Start a task to listen for messages from this client # Start a task to listen for messages from this client
asyncio.create_task(listen_to_client(websocket, client_id)) asyncio.create_task(listen_to_client(websocket, client_id))

View File

@@ -3,7 +3,7 @@ import asyncio
import websockets import websockets
import json import json
import uuid import uuid
from quart import Quart, jsonify, request from quart import Quart, jsonify, request, abort
from quart_cors import cors from quart_cors import cors
from routers.systems import systems_bp from routers.systems import systems_bp
from routers.nodes import nodes_bp, register_client, unregister_client from routers.nodes import nodes_bp, register_client, unregister_client
@@ -17,28 +17,6 @@ from config.jwt_config import jwt, configure_jwt
# --- WebSocket Server Components --- # --- WebSocket Server Components ---
active_clients = {} active_clients = {}
async def websocket_server_handler(websocket):
client_id = None
try:
handshake_message = await websocket.recv()
handshake_data = json.loads(handshake_message)
if handshake_data.get("type") == "handshake" and "id" in handshake_data:
client_id = handshake_data["id"]
await register_client(websocket, client_id)
await websocket.send(json.dumps({"type": "handshake_ack", "status": "success"}))
await websocket.wait_closed()
else:
print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.")
await websocket.close()
except websockets.exceptions.ConnectionClosedError:
print(f"Client connection closed unexpectedly for {client_id}.")
except json.JSONDecodeError:
print(f"Received invalid JSON from {client_id or 'an unknown client'}.")
except Exception as e:
print(f"An error occurred with client {client_id}: {e}")
finally:
if client_id:
await unregister_client(client_id)
# --- Quart API Components --- # --- Quart API Components ---
app = Quart(__name__) app = Quart(__name__)
@@ -56,6 +34,33 @@ configure_jwt(app)
jwt.init_app(app) jwt.init_app(app)
async def websocket_server_handler(websocket):
client_id = None
try:
handshake_message = await websocket.recv()
handshake_data = json.loads(handshake_message)
if handshake_data.get("type") == "handshake" and "id" in handshake_data:
client_id = handshake_data["id"]
client_nickname = handshake_data.get("nickname")
await register_client(websocket, client_id, nickname)
if not app.active_clients[client_id].access_token:
abort(500, "Error retrieving access token")
await websocket.send(json.dumps({"type": "handshake_ack", "status": "success", "access_token": app.active_clients[client_id].access_token}))
await websocket.wait_closed()
else:
print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.")
await websocket.close()
except websockets.exceptions.ConnectionClosedError:
print(f"Client connection closed unexpectedly for {client_id}.")
except json.JSONDecodeError:
print(f"Received invalid JSON from {client_id or 'an unknown client'}.")
except Exception as e:
print(f"An error occurred with client {client_id}: {e}")
finally:
if client_id:
await unregister_client(client_id)
@app.before_serving @app.before_serving
async def startup_tasks(): # Combined startup logic async def startup_tasks(): # Combined startup logic
"""Starts the WebSocket server and prepares other resources.""" """Starts the WebSocket server and prepares other resources."""