Implemented client nickname and access token
All checks were successful
release-image / release-image (push) Successful in 2m9s
All checks were successful
release-image / release-image (push) Successful in 2m9s
This commit is contained in:
@@ -191,16 +191,20 @@ class ActiveClient:
|
||||
"""
|
||||
websocket = None
|
||||
active_token: DiscordId = None
|
||||
nickname = None
|
||||
access_token = None
|
||||
|
||||
def __init__(self, websocket= None, active_token:DiscordId=None):
|
||||
self.active_token = active_token
|
||||
self.websocket = websocket
|
||||
|
||||
|
||||
class UserRoles(str, Enum):
|
||||
ADMIN = "admin"
|
||||
MOD = "mod"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class User:
|
||||
"""
|
||||
A data model for a User entry.
|
||||
|
||||
@@ -21,22 +21,28 @@ def role_required(required_role: UserRoles):
|
||||
current_user_identity = get_jwt_identity()
|
||||
user_id = current_user_identity['id']
|
||||
|
||||
# Make a DB call to get the user and their role
|
||||
user = await current_app.user_db_h.find_user({"_id": user_id})
|
||||
auth_type = current_user_identity['type']
|
||||
|
||||
print("YERERERE", user)
|
||||
if auth_type == "node":
|
||||
node = app.active_clients.get("client_id")
|
||||
|
||||
if not user:
|
||||
abort(401, "User not found or invalid token.") # User corresponding to token not found
|
||||
if not node:
|
||||
abort(401, "Node not found or invalid token.")
|
||||
|
||||
user_role = user.role # Get the role from the fetched user object
|
||||
if auth_type == "user":
|
||||
# Make a DB call to get the user and their role
|
||||
user = await current_app.user_db_h.find_user({"_id": user_id})
|
||||
|
||||
role_order = {UserRoles.USER: 0, UserRoles.MOD: 1, UserRoles.ADMIN: 2}
|
||||
if not user:
|
||||
abort(401, "User not found or invalid token.") # User corresponding to token not found
|
||||
|
||||
if role_order[user_role] < role_order[required_role]:
|
||||
abort(403, "Permission denied: Insufficient role.")
|
||||
user_role = user.role # Get the role from the fetched user object
|
||||
|
||||
role_order = {UserRoles.USER: 0, UserRoles.MOD: 1, UserRoles.ADMIN: 2}
|
||||
|
||||
if role_order[user_role] < role_order[required_role]:
|
||||
abort(403, "Permission denied: Insufficient role.")
|
||||
|
||||
# REMOVE current_app.ensure_sync() here
|
||||
return await fn(*args, **kwargs) # Directly await the original async function
|
||||
|
||||
return decorated_view
|
||||
@@ -90,9 +96,10 @@ async def login_user():
|
||||
if not user or not check_password_hash(user.password_hash, password):
|
||||
abort(401, "Invalid credentials")
|
||||
|
||||
access_token = create_access_token(identity={"id": user._id, "username": user.username})
|
||||
access_token = create_access_token(identity={"id": user._id, "username": user.username, "type": "user"})
|
||||
return jsonify({"access_token": access_token, "role": user.role, "username": user.username, "user_id": user._id }), 200
|
||||
|
||||
# DEPRECATED
|
||||
@auth_bp.route('/generate_api_key', methods=['POST'])
|
||||
@jwt_required
|
||||
async def generate_api_key():
|
||||
|
||||
@@ -6,6 +6,7 @@ from quart import Blueprint, jsonify, request, abort, current_app
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from enum import Enum
|
||||
from internal.types import ActiveClient, NodeCommands, UserRoles
|
||||
from quart_jwt_extended import create_access_token
|
||||
from quart_jwt_extended import jwt_required
|
||||
from routers.auth import role_required
|
||||
|
||||
@@ -15,11 +16,16 @@ nodes_bp = Blueprint('nodes', __name__)
|
||||
pending_requests = {}
|
||||
|
||||
|
||||
async def register_client(websocket, client_id):
|
||||
async def register_client(websocket, client_id, nickname):
|
||||
"""Registers a new client connection."""
|
||||
current_app.active_clients[client_id] = ActiveClient(websocket)
|
||||
current_app.active_clients[client_id] = ActiveClient()
|
||||
current_app.active_clients[client_id].websocket = websocket
|
||||
current_app.active_clients[client_id].nickname = nickname
|
||||
print(f"Client {client_id} connected.")
|
||||
|
||||
# Create a JWT for the client
|
||||
current_app.active_clients[client_id].access_token = create_access_token(identity={"id": client_id, "username": nickname, "type": "node"})
|
||||
|
||||
# Start a task to listen for messages from this client
|
||||
asyncio.create_task(listen_to_client(websocket, client_id))
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import uuid
|
||||
from quart import Quart, jsonify, request
|
||||
from quart import Quart, jsonify, request, abort
|
||||
from quart_cors import cors
|
||||
from routers.systems import systems_bp
|
||||
from routers.nodes import nodes_bp, register_client, unregister_client
|
||||
@@ -17,28 +17,6 @@ from config.jwt_config import jwt, configure_jwt
|
||||
# --- WebSocket Server Components ---
|
||||
active_clients = {}
|
||||
|
||||
async def websocket_server_handler(websocket):
|
||||
client_id = None
|
||||
try:
|
||||
handshake_message = await websocket.recv()
|
||||
handshake_data = json.loads(handshake_message)
|
||||
if handshake_data.get("type") == "handshake" and "id" in handshake_data:
|
||||
client_id = handshake_data["id"]
|
||||
await register_client(websocket, client_id)
|
||||
await websocket.send(json.dumps({"type": "handshake_ack", "status": "success"}))
|
||||
await websocket.wait_closed()
|
||||
else:
|
||||
print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.")
|
||||
await websocket.close()
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
print(f"Client connection closed unexpectedly for {client_id}.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Received invalid JSON from {client_id or 'an unknown client'}.")
|
||||
except Exception as e:
|
||||
print(f"An error occurred with client {client_id}: {e}")
|
||||
finally:
|
||||
if client_id:
|
||||
await unregister_client(client_id)
|
||||
|
||||
# --- Quart API Components ---
|
||||
app = Quart(__name__)
|
||||
@@ -56,6 +34,33 @@ configure_jwt(app)
|
||||
jwt.init_app(app)
|
||||
|
||||
|
||||
async def websocket_server_handler(websocket):
|
||||
client_id = None
|
||||
try:
|
||||
handshake_message = await websocket.recv()
|
||||
handshake_data = json.loads(handshake_message)
|
||||
if handshake_data.get("type") == "handshake" and "id" in handshake_data:
|
||||
client_id = handshake_data["id"]
|
||||
client_nickname = handshake_data.get("nickname")
|
||||
await register_client(websocket, client_id, nickname)
|
||||
if not app.active_clients[client_id].access_token:
|
||||
abort(500, "Error retrieving access token")
|
||||
await websocket.send(json.dumps({"type": "handshake_ack", "status": "success", "access_token": app.active_clients[client_id].access_token}))
|
||||
await websocket.wait_closed()
|
||||
else:
|
||||
print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.")
|
||||
await websocket.close()
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
print(f"Client connection closed unexpectedly for {client_id}.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Received invalid JSON from {client_id or 'an unknown client'}.")
|
||||
except Exception as e:
|
||||
print(f"An error occurred with client {client_id}: {e}")
|
||||
finally:
|
||||
if client_id:
|
||||
await unregister_client(client_id)
|
||||
|
||||
|
||||
@app.before_serving
|
||||
async def startup_tasks(): # Combined startup logic
|
||||
"""Starts the WebSocket server and prepares other resources."""
|
||||
|
||||
Reference in New Issue
Block a user