80 lines
2.8 KiB
Python
80 lines
2.8 KiB
Python
from django.contrib.auth.models import AnonymousUser
|
|
from django.db import close_old_connections
|
|
|
|
from channels.db import database_sync_to_async
|
|
from channels.middleware import BaseMiddleware
|
|
from channels.auth import AuthMiddlewareStack
|
|
from rest_framework_simplejwt.tokens import UntypedToken
|
|
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
|
|
from jwt import decode as jwt_decode
|
|
from django.conf import settings
|
|
from urllib.parse import parse_qs
|
|
from django.contrib.auth import get_user_model
|
|
|
|
|
|
User = get_user_model()
|
|
|
|
|
|
@database_sync_to_async
|
|
def get_user_from_token(token):
|
|
try:
|
|
# Vérifier que le token est valide
|
|
UntypedToken(token)
|
|
|
|
# Décoder le token et obtenir l'ID de l'utilisateur
|
|
decoded_data = jwt_decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
|
user_id = decoded_data.get('user_id')
|
|
|
|
if user_id:
|
|
return User.objects.get(id=user_id)
|
|
return AnonymousUser()
|
|
except (InvalidToken, TokenError, User.DoesNotExist):
|
|
return AnonymousUser()
|
|
|
|
|
|
@database_sync_to_async
|
|
def get_user(scope):
|
|
if "session" in scope:
|
|
session = scope["session"]
|
|
user_id = session.get("_auth_user_id")
|
|
if user_id:
|
|
try:
|
|
return User.objects.get(id=user_id)
|
|
except User.DoesNotExist:
|
|
pass
|
|
return AnonymousUser()
|
|
|
|
|
|
class JwtOrSessionAuthMiddleware(BaseMiddleware):
|
|
async def __call__(self, scope, receive, send):
|
|
# Fermer les connexions DB obsolètes pour éviter les problèmes
|
|
close_old_connections()
|
|
|
|
# Par défaut, définir un utilisateur anonyme
|
|
scope['user'] = AnonymousUser()
|
|
|
|
# Essayer d'abord l'authentification par session
|
|
if "session" in scope:
|
|
scope['user'] = await get_user(scope)
|
|
if not scope['user'].is_anonymous:
|
|
return await super().__call__(scope, receive, send)
|
|
|
|
# Si l'utilisateur est toujours anonyme, essayer JWT
|
|
if scope['user'].is_anonymous and 'query_string' in scope:
|
|
# Extraire token des query parameters
|
|
query_params = parse_qs(scope['query_string'].decode('utf-8'))
|
|
token = query_params.get('token', [None])[0]
|
|
|
|
# Si aucun token dans les query params, chercher dans les headers
|
|
if not token and 'headers' in scope:
|
|
headers = dict(scope['headers'])
|
|
auth_header = headers.get(b'authorization', b'')
|
|
if auth_header.startswith(b'Bearer '):
|
|
token = auth_header.decode('utf-8')[7:]
|
|
|
|
# Authentifier avec le token si présent
|
|
if token:
|
|
scope['user'] = await get_user_from_token(token)
|
|
|
|
return await super().__call__(scope, receive, send)
|