Implemented client nickname and access token
All checks were successful
release-image / release-image (push) Successful in 2m9s
All checks were successful
release-image / release-image (push) Successful in 2m9s
This commit is contained in:
@@ -191,16 +191,20 @@ class ActiveClient:
|
|||||||
"""
|
"""
|
||||||
websocket = None
|
websocket = None
|
||||||
active_token: DiscordId = None
|
active_token: DiscordId = None
|
||||||
|
nickname = None
|
||||||
|
access_token = None
|
||||||
|
|
||||||
def __init__(self, websocket= None, active_token:DiscordId=None):
|
def __init__(self, websocket= None, active_token:DiscordId=None):
|
||||||
self.active_token = active_token
|
self.active_token = active_token
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
|
|
||||||
|
|
||||||
class UserRoles(str, Enum):
|
class UserRoles(str, Enum):
|
||||||
ADMIN = "admin"
|
ADMIN = "admin"
|
||||||
MOD = "mod"
|
MOD = "mod"
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
|
||||||
|
|
||||||
class User:
|
class User:
|
||||||
"""
|
"""
|
||||||
A data model for a User entry.
|
A data model for a User entry.
|
||||||
|
|||||||
@@ -21,22 +21,28 @@ def role_required(required_role: UserRoles):
|
|||||||
current_user_identity = get_jwt_identity()
|
current_user_identity = get_jwt_identity()
|
||||||
user_id = current_user_identity['id']
|
user_id = current_user_identity['id']
|
||||||
|
|
||||||
# Make a DB call to get the user and their role
|
auth_type = current_user_identity['type']
|
||||||
user = await current_app.user_db_h.find_user({"_id": user_id})
|
|
||||||
|
|
||||||
print("YERERERE", user)
|
if auth_type == "node":
|
||||||
|
node = app.active_clients.get("client_id")
|
||||||
|
|
||||||
if not user:
|
if not node:
|
||||||
abort(401, "User not found or invalid token.") # User corresponding to token not found
|
abort(401, "Node not found or invalid token.")
|
||||||
|
|
||||||
user_role = user.role # Get the role from the fetched user object
|
if auth_type == "user":
|
||||||
|
# Make a DB call to get the user and their role
|
||||||
|
user = await current_app.user_db_h.find_user({"_id": user_id})
|
||||||
|
|
||||||
role_order = {UserRoles.USER: 0, UserRoles.MOD: 1, UserRoles.ADMIN: 2}
|
if not user:
|
||||||
|
abort(401, "User not found or invalid token.") # User corresponding to token not found
|
||||||
|
|
||||||
if role_order[user_role] < role_order[required_role]:
|
user_role = user.role # Get the role from the fetched user object
|
||||||
abort(403, "Permission denied: Insufficient role.")
|
|
||||||
|
role_order = {UserRoles.USER: 0, UserRoles.MOD: 1, UserRoles.ADMIN: 2}
|
||||||
|
|
||||||
|
if role_order[user_role] < role_order[required_role]:
|
||||||
|
abort(403, "Permission denied: Insufficient role.")
|
||||||
|
|
||||||
# REMOVE current_app.ensure_sync() here
|
|
||||||
return await fn(*args, **kwargs) # Directly await the original async function
|
return await fn(*args, **kwargs) # Directly await the original async function
|
||||||
|
|
||||||
return decorated_view
|
return decorated_view
|
||||||
@@ -90,9 +96,10 @@ async def login_user():
|
|||||||
if not user or not check_password_hash(user.password_hash, password):
|
if not user or not check_password_hash(user.password_hash, password):
|
||||||
abort(401, "Invalid credentials")
|
abort(401, "Invalid credentials")
|
||||||
|
|
||||||
access_token = create_access_token(identity={"id": user._id, "username": user.username})
|
access_token = create_access_token(identity={"id": user._id, "username": user.username, "type": "user"})
|
||||||
return jsonify({"access_token": access_token, "role": user.role, "username": user.username, "user_id": user._id }), 200
|
return jsonify({"access_token": access_token, "role": user.role, "username": user.username, "user_id": user._id }), 200
|
||||||
|
|
||||||
|
# DEPRECATED
|
||||||
@auth_bp.route('/generate_api_key', methods=['POST'])
|
@auth_bp.route('/generate_api_key', methods=['POST'])
|
||||||
@jwt_required
|
@jwt_required
|
||||||
async def generate_api_key():
|
async def generate_api_key():
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from quart import Blueprint, jsonify, request, abort, current_app
|
|||||||
from werkzeug.exceptions import HTTPException
|
from werkzeug.exceptions import HTTPException
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from internal.types import ActiveClient, NodeCommands, UserRoles
|
from internal.types import ActiveClient, NodeCommands, UserRoles
|
||||||
|
from quart_jwt_extended import create_access_token
|
||||||
from quart_jwt_extended import jwt_required
|
from quart_jwt_extended import jwt_required
|
||||||
from routers.auth import role_required
|
from routers.auth import role_required
|
||||||
|
|
||||||
@@ -15,11 +16,16 @@ nodes_bp = Blueprint('nodes', __name__)
|
|||||||
pending_requests = {}
|
pending_requests = {}
|
||||||
|
|
||||||
|
|
||||||
async def register_client(websocket, client_id):
|
async def register_client(websocket, client_id, nickname):
|
||||||
"""Registers a new client connection."""
|
"""Registers a new client connection."""
|
||||||
current_app.active_clients[client_id] = ActiveClient(websocket)
|
current_app.active_clients[client_id] = ActiveClient()
|
||||||
|
current_app.active_clients[client_id].websocket = websocket
|
||||||
|
current_app.active_clients[client_id].nickname = nickname
|
||||||
print(f"Client {client_id} connected.")
|
print(f"Client {client_id} connected.")
|
||||||
|
|
||||||
|
# Create a JWT for the client
|
||||||
|
current_app.active_clients[client_id].access_token = create_access_token(identity={"id": client_id, "username": nickname, "type": "node"})
|
||||||
|
|
||||||
# Start a task to listen for messages from this client
|
# Start a task to listen for messages from this client
|
||||||
asyncio.create_task(listen_to_client(websocket, client_id))
|
asyncio.create_task(listen_to_client(websocket, client_id))
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
import websockets
|
import websockets
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from quart import Quart, jsonify, request
|
from quart import Quart, jsonify, request, abort
|
||||||
from quart_cors import cors
|
from quart_cors import cors
|
||||||
from routers.systems import systems_bp
|
from routers.systems import systems_bp
|
||||||
from routers.nodes import nodes_bp, register_client, unregister_client
|
from routers.nodes import nodes_bp, register_client, unregister_client
|
||||||
@@ -17,28 +17,6 @@ from config.jwt_config import jwt, configure_jwt
|
|||||||
# --- WebSocket Server Components ---
|
# --- WebSocket Server Components ---
|
||||||
active_clients = {}
|
active_clients = {}
|
||||||
|
|
||||||
async def websocket_server_handler(websocket):
|
|
||||||
client_id = None
|
|
||||||
try:
|
|
||||||
handshake_message = await websocket.recv()
|
|
||||||
handshake_data = json.loads(handshake_message)
|
|
||||||
if handshake_data.get("type") == "handshake" and "id" in handshake_data:
|
|
||||||
client_id = handshake_data["id"]
|
|
||||||
await register_client(websocket, client_id)
|
|
||||||
await websocket.send(json.dumps({"type": "handshake_ack", "status": "success"}))
|
|
||||||
await websocket.wait_closed()
|
|
||||||
else:
|
|
||||||
print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.")
|
|
||||||
await websocket.close()
|
|
||||||
except websockets.exceptions.ConnectionClosedError:
|
|
||||||
print(f"Client connection closed unexpectedly for {client_id}.")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Received invalid JSON from {client_id or 'an unknown client'}.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred with client {client_id}: {e}")
|
|
||||||
finally:
|
|
||||||
if client_id:
|
|
||||||
await unregister_client(client_id)
|
|
||||||
|
|
||||||
# --- Quart API Components ---
|
# --- Quart API Components ---
|
||||||
app = Quart(__name__)
|
app = Quart(__name__)
|
||||||
@@ -56,6 +34,33 @@ configure_jwt(app)
|
|||||||
jwt.init_app(app)
|
jwt.init_app(app)
|
||||||
|
|
||||||
|
|
||||||
|
async def websocket_server_handler(websocket):
|
||||||
|
client_id = None
|
||||||
|
try:
|
||||||
|
handshake_message = await websocket.recv()
|
||||||
|
handshake_data = json.loads(handshake_message)
|
||||||
|
if handshake_data.get("type") == "handshake" and "id" in handshake_data:
|
||||||
|
client_id = handshake_data["id"]
|
||||||
|
client_nickname = handshake_data.get("nickname")
|
||||||
|
await register_client(websocket, client_id, nickname)
|
||||||
|
if not app.active_clients[client_id].access_token:
|
||||||
|
abort(500, "Error retrieving access token")
|
||||||
|
await websocket.send(json.dumps({"type": "handshake_ack", "status": "success", "access_token": app.active_clients[client_id].access_token}))
|
||||||
|
await websocket.wait_closed()
|
||||||
|
else:
|
||||||
|
print(f"Received invalid handshake from {websocket.remote_address}. Closing connection.")
|
||||||
|
await websocket.close()
|
||||||
|
except websockets.exceptions.ConnectionClosedError:
|
||||||
|
print(f"Client connection closed unexpectedly for {client_id}.")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Received invalid JSON from {client_id or 'an unknown client'}.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred with client {client_id}: {e}")
|
||||||
|
finally:
|
||||||
|
if client_id:
|
||||||
|
await unregister_client(client_id)
|
||||||
|
|
||||||
|
|
||||||
@app.before_serving
|
@app.before_serving
|
||||||
async def startup_tasks(): # Combined startup logic
|
async def startup_tasks(): # Combined startup logic
|
||||||
"""Starts the WebSocket server and prepares other resources."""
|
"""Starts the WebSocket server and prepares other resources."""
|
||||||
|
|||||||
Reference in New Issue
Block a user