diff --git a/label_studio/core/middleware.py b/label_studio/core/middleware.py index dca9182488eb..d988582d8868 100644 --- a/label_studio/core/middleware.py +++ b/label_studio/core/middleware.py @@ -205,6 +205,7 @@ def process_request(self, request) -> None: or # scim assign request.user implicitly, check CustomSCIMAuthCheckMiddleware (hasattr(request, 'is_scim') and request.is_scim) + or (hasattr(request, 'is_jwt') and request.is_jwt) ): return @@ -248,3 +249,4 @@ def process_response(self, request, response): del response['Content-Security-Policy-Report-Only'] delattr(response, '_override_report_only_csp') return response + diff --git a/label_studio/core/settings/base.py b/label_studio/core/settings/base.py index 41ef0f8eeb41..96383d81724e 100644 --- a/label_studio/core/settings/base.py +++ b/label_studio/core/settings/base.py @@ -214,6 +214,7 @@ 'annoying', 'rest_framework', 'rest_framework.authtoken', + 'rest_framework_simplejwt.token_blacklist', 'drf_generators', 'core', 'users', @@ -229,6 +230,7 @@ 'labels_manager', 'ml_models', 'ml_model_providers', + 'jwt_auth', ] MIDDLEWARE = [ @@ -247,12 +249,13 @@ 'core.middleware.ContextLogMiddleware', 'core.middleware.DatabaseIsLockedRetryMiddleware', 'core.current_request.ThreadLocalMiddleware', + 'jwt_auth.middleware.JWTAuthenticationMiddleware', ] REST_FRAMEWORK = { 'DEFAULT_FILTER_BACKENDS': ['django_filters.rest_framework.DjangoFilterBackend'], 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework.authentication.TokenAuthentication', + 'jwt_auth.auth.TokenAuthenticationPhaseout', 'rest_framework.authentication.SessionAuthentication', ), 'DEFAULT_PERMISSION_CLASSES': [ diff --git a/label_studio/jwt_auth/apps.py b/label_studio/jwt_auth/apps.py new file mode 100644 index 000000000000..d63fb5be34bc --- /dev/null +++ b/label_studio/jwt_auth/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class JWTAuthConfig(AppConfig): + name = 'jwt_auth' diff --git a/label_studio/jwt_auth/auth.py b/label_studio/jwt_auth/auth.py new file mode 100644 index 000000000000..ea0842b4fbd8 --- /dev/null +++ b/label_studio/jwt_auth/auth.py @@ -0,0 +1,33 @@ +import logging + +from rest_framework.authentication import TokenAuthentication +from rest_framework.exceptions import AuthenticationFailed + +logger = logging.getLogger(__name__) + + +class TokenAuthenticationPhaseout(TokenAuthentication): + """TokenAuthentication that logs usage to help track basic token authentication usage.""" + + def authenticate(self, request): + """Authenticate the request and log if successful.""" + from core.feature_flags import flag_set + + auth_result = super().authenticate(request) + JWT_ACCESS_TOKEN_ENABLED = flag_set('fflag__feature_develop__prompts__dia_1829_jwt_token_auth') + if JWT_ACCESS_TOKEN_ENABLED and (auth_result is not None): + user, _ = auth_result + org = user.active_organization + org_id = org.id if org else None + + # raise 401 if legacy API token auth disabled (i.e. this token is no longer valid) + if org and (not org.jwt.legacy_api_tokens_enabled): + raise AuthenticationFailed( + 'Authentication token no longer valid: JWT authentication is required for this organization' + ) + + logger.info( + 'Basic token authentication used', + extra={'user_id': user.id, 'organization_id': org_id, 'endpoint': request.path} + ) + return auth_result diff --git a/label_studio/jwt_auth/middleware.py b/label_studio/jwt_auth/middleware.py new file mode 100644 index 000000000000..c0781cb2ee81 --- /dev/null +++ b/label_studio/jwt_auth/middleware.py @@ -0,0 +1,25 @@ +import logging + +logger = logging.getLogger(__name__) + + +class JWTAuthenticationMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + from core.feature_flags import flag_set + from rest_framework_simplejwt.authentication import JWTAuthentication + + JWT_ACCESS_TOKEN_ENABLED = flag_set('fflag__feature_develop__prompts__dia_1829_jwt_token_auth') + if JWT_ACCESS_TOKEN_ENABLED: + try: + # annoyingly, this only returns one object on failure so have to unpack awkwardly + user_and_token = JWTAuthentication().authenticate(request) + user = user_and_token[0] if user_and_token else None + if user and user.active_organization.jwt.api_tokens_enabled: + request.user = user + request.is_jwt = True + except: + logger.info('Could not auth using jwt, falling back to other auth methods') + return self.get_response(request) diff --git a/label_studio/jwt_auth/migrations/0001_initial.py b/label_studio/jwt_auth/migrations/0001_initial.py new file mode 100644 index 000000000000..1d3dbb9425e2 --- /dev/null +++ b/label_studio/jwt_auth/migrations/0001_initial.py @@ -0,0 +1,63 @@ +# Generated by Django 5.1.4 on 2025-02-03 15:51 + +import annoying.fields +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("organizations", "0006_alter_organizationmember_deleted_at"), + ] + + operations = [ + migrations.CreateModel( + name="jwtsettings", + fields=[ + ( + "organization", + annoying.fields.AutoOneToOneField( + on_delete=django.db.models.deletion.DO_NOTHING, + primary_key=True, + related_name='jwt', + serialize=False, + to='organizations.organization' + ) + ), + ( + "api_tokens_enabled", + models.BooleanField( + default=False, + help_text="Enable JWT API token authentication for this organization", + verbose_name="JWT API tokens enabled", + ), + ), + ( + 'api_token_ttl_days', + models.IntegerField( + default=30, + help_text='Number of days before JWT API tokens expire', + verbose_name='JWT API token time to live (days)') + ), + ( + "legacy_api_tokens_enabled", + models.BooleanField( + default=True, + help_text="Enable legacy API token authentication for this organization", + verbose_name="legacy API tokens enabled", + ), + ), + ( + "created_at", + models.DateTimeField(auto_now_add=True, verbose_name="created at"), + ), + ( + "updated_at", + models.DateTimeField(auto_now=True, verbose_name="updated at"), + ), + ], + ), + ] diff --git a/label_studio/jwt_auth/migrations/__init__.py b/label_studio/jwt_auth/migrations/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/label_studio/jwt_auth/models.py b/label_studio/jwt_auth/models.py new file mode 100644 index 000000000000..8a746d2bd3a9 --- /dev/null +++ b/label_studio/jwt_auth/models.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +from typing import Any + +from annoying.fields import AutoOneToOneField +from django.db import models +from django.utils.translation import gettext_lazy as _ +from organizations.models import Organization +from rest_framework_simplejwt.backends import TokenBackend +from rest_framework_simplejwt.tokens import RefreshToken, api_settings + + +class JWTSettings(models.Model): + """Organization-specific JWT settings for authentication""" + + organization = AutoOneToOneField(Organization, related_name='jwt', primary_key=True, on_delete=models.DO_NOTHING) + api_tokens_enabled = models.BooleanField( + _('JWT API tokens enabled'), default=False, help_text='Enable JWT API token authentication for this organization' + ) + api_token_ttl_days = models.IntegerField( + _('JWT API token time to live (days)'), default=30, help_text='Number of days before JWT API tokens expire' + ) + legacy_api_tokens_enabled = models.BooleanField( + _('legacy API tokens enabled'), default=True, help_text='Enable legacy API token authentication for this organization' + ) + + created_at = models.DateTimeField(_('created at'), auto_now_add=True) + updated_at = models.DateTimeField(_('updated at'), auto_now=True) + + def has_permission(self, user): + """Check if user has permission to modify JWT settings""" + if not self.organization.has_permission(user): + return False + return user.is_owner or (hasattr(user, 'is_administrator') and user.is_administrator) + + +class LSTokenBackend(TokenBackend): + """A custom JWT token backend that truncates tokens before storing in the database. + + Extends simlpe jwt's TokenBackend to provide methods for generating both + truncated tokens (header + payload only) and full tokens (header + payload + signature). + This preserves privacy of the token by not exposing the signature to the frontend. + """ + + def encode(self, payload: dict[str, Any]) -> str: + """Encode a payload into a truncated JWT token string. + + Args: + payload: Dictionary containing the JWT claims to encode + + Returns: + A truncated JWT string containing only the header and payload portions, + with the signature section removed + """ + header, payload, signature = super().encode(payload).split('.') + return '.'.join([header, payload]) + + def encode_full(self, payload: dict[str, Any]) -> str: + """Encode a payload into a complete JWT token string. + + Args: + payload: Dictionary containing the JWT claims to encode + + Returns: + A complete JWT string containing header, payload and signature portions + """ + return super().encode(payload) + + +class LSAPIToken(RefreshToken): + """API token that utilizes JWT, but stores a truncated version and expires + based on user settings + + This token class extends RefreshToken to provide organization-specific token + lifetimes and support for truncated tokens. It uses the LSTokenBackend to + securely store the token (without the signature). + """ + + _token_backend = LSTokenBackend( + api_settings.ALGORITHM, + api_settings.SIGNING_KEY, + api_settings.VERIFYING_KEY, + api_settings.AUDIENCE, + api_settings.ISSUER, + api_settings.JWK_URL, + api_settings.LEEWAY, + api_settings.JSON_ENCODER, + ) + + def get_full_jwt(self) -> str: + """Get the complete JWT token string (including the signature). + + Returns: + The full JWT token string with header, payload and signature + """ + return self.get_token_backend().encode_full(self.payload) + + +class TruncatedLSAPIToken(LSAPIToken): + """Handles JWT tokens that contain only header and payload (no signature). + Used when frontend has access to truncated refresh tokens only.""" + + def __init__(self, token: str) -> None: + full_token = token + '.' + ('x' * 43) + return super().__init__(full_token, verify=False) diff --git a/label_studio/jwt_auth/serializers.py b/label_studio/jwt_auth/serializers.py new file mode 100644 index 000000000000..68278fbfc23b --- /dev/null +++ b/label_studio/jwt_auth/serializers.py @@ -0,0 +1,53 @@ +from typing import Any + +from jwt_auth.models import JWTSettings, LSAPIToken, TruncatedLSAPIToken +from rest_framework import serializers + + +# Recommended implementation from JWT to support drf-yasg: +# https://django-rest-framework-simplejwt.readthedocs.io/en/latest/drf_yasg_integration.html +class TokenRefreshResponseSerializer(serializers.Serializer): + access = serializers.CharField() + + +class JWTSettingsSerializer(serializers.ModelSerializer): + class Meta: + model = JWTSettings + fields = ('api_tokens_enabled',) + + +class JWTSettingsUpdateSerializer(JWTSettingsSerializer): + pass + + +class LSAPITokenCreateSerializer(serializers.Serializer): + token = serializers.SerializerMethodField() + + def get_token(self, obj): + return obj.get_full_jwt() + + class Meta: + model = LSAPIToken + fields = ['token'] + + +class LSAPITokenListSerializer(LSAPITokenCreateSerializer): + def get_token(self, obj): + return obj.token + + +class LSAPITokenBlacklistSerializer(serializers.Serializer): + refresh = serializers.CharField(write_only=True) + token_class = LSAPIToken + + def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]: + token_str = attrs['refresh'] + if len(token_str.split('.')) == 2: + token = TruncatedLSAPIToken(token_str) + else: + token = LSAPIToken(token_str) + try: + token.blacklist() + except AttributeError: + pass + return {} diff --git a/label_studio/jwt_auth/urls.py b/label_studio/jwt_auth/urls.py new file mode 100644 index 000000000000..35589d8c88c1 --- /dev/null +++ b/label_studio/jwt_auth/urls.py @@ -0,0 +1,12 @@ +from django.urls import path + +from . import views + +app_name = 'jwt_auth' + +urlpatterns = [ + path('api/jwt/settings', views.JWTSettingsAPI.as_view(), name='api-jwt-settings'), + path('api/token/', views.LSAPITokenView.as_view(), name='token_manage'), + path('api/token/refresh/', views.DecoratedTokenRefreshView.as_view(), name='token_refresh'), + path('api/token/blacklist/', views.LSTokenBlacklistView.as_view(), name='token_blacklist'), +] diff --git a/label_studio/jwt_auth/views.py b/label_studio/jwt_auth/views.py new file mode 100644 index 000000000000..a645044931ba --- /dev/null +++ b/label_studio/jwt_auth/views.py @@ -0,0 +1,128 @@ +from core.permissions import all_permissions +from django.utils.decorators import method_decorator +from drf_yasg.utils import swagger_auto_schema +from rest_framework import generics, status +from rest_framework.generics import CreateAPIView +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework_simplejwt.token_blacklist.models import OutstandingToken +from rest_framework_simplejwt.views import TokenRefreshView, TokenViewBase + +from jwt_auth.models import JWTSettings, LSAPIToken, TruncatedLSAPIToken +from jwt_auth.serializers import (JWTSettingsSerializer, + JWTSettingsUpdateSerializer, + LSAPITokenBlacklistSerializer, + LSAPITokenCreateSerializer, + LSAPITokenListSerializer, + TokenRefreshResponseSerializer) + + +@method_decorator( + name='get', + decorator=swagger_auto_schema( + tags=['JWT'], + operation_summary='Retrieve JWT Settings', + operation_description='Retrieve JWT settings for the currently active organization.', + ), +) +@method_decorator( + name='post', + decorator=swagger_auto_schema( + tags=['JWT'], + operation_summary='Update JWT Settings', + operation_description='Update JWT settings for the currently active organization.', + ), +) +class JWTSettingsAPI(CreateAPIView): + queryset = JWTSettings.objects.all() + permission_required = all_permissions.organizations_change + + def get_serializer_class(self): + if self.request.method == 'GET': + return JWTSettingsSerializer + return JWTSettingsUpdateSerializer + + def get(self, request, *args, **kwargs): + jwt = request.user.active_organization.jwt + return Response(self.get_serializer(jwt).data) + + def post(self, request, *args, **kwargs): + jwt = request.user.active_organization.jwt + serializer = self.get_serializer(data=request.data, instance=jwt) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(serializer.data) + + +# Recommended implementation from JWT to support drf-yasg: +# https://django-rest-framework-simplejwt.readthedocs.io/en/latest/drf_yasg_integration.html +class DecoratedTokenRefreshView(TokenRefreshView): + @swagger_auto_schema( + tags=['JWT'], + responses={ + status.HTTP_200_OK: TokenRefreshResponseSerializer, + }, + ) + def post(self, request, *args, **kwargs): + return super().post(request, *args, **kwargs) + + +@method_decorator( + name='get', + decorator=swagger_auto_schema( + tags=['JWT'], + operation_summary='List API tokens', + operation_description='List all API tokens for the current user.', + ), +) +@method_decorator( + name='post', + decorator=swagger_auto_schema( + tags=['JWT'], + operation_summary='Create API token', + operation_description='Create a new API token for the current user.', + ), +) +class LSAPITokenView(generics.ListCreateAPIView): + permission_classes = [IsAuthenticated] + token_class = LSAPIToken + + def get_queryset(self): + return OutstandingToken.objects.filter(user_id=self.request.user.id) + + def list(self, request, *args, **kwargs): + outstanding_tokens = self.get_queryset() + + def _maybe_get_token(token: OutstandingToken): + try: + return TruncatedLSAPIToken(str(token.token)) + except: # expired/invalid token + return None + + token_objects = list(filter(None, [_maybe_get_token(token) for token in outstanding_tokens])) + + serializer = self.get_serializer(token_objects, many=True) + data = serializer.data + return Response(data) + + def get_serializer_class(self): + if self.request.method == 'POST': + return LSAPITokenCreateSerializer + return LSAPITokenListSerializer + + def perform_create(self, serializer): + token = self.token_class.for_user(self.request.user) + serializer.instance = token + + +class LSTokenBlacklistView(TokenViewBase): + _serializer_class = 'jwt_auth.serializers.LSAPITokenBlacklistSerializer' + + @swagger_auto_schema( + tags=['JWT'], + responses={ + status.HTTP_200_OK: LSAPITokenBlacklistSerializer, + }, + ) + def post(self, request, *args, **kwargs): + return super().post(request, *args, **kwargs) diff --git a/label_studio/organizations/functions.py b/label_studio/organizations/functions.py index 78228ccdc71c..005c4ea5fb5b 100644 --- a/label_studio/organizations/functions.py +++ b/label_studio/organizations/functions.py @@ -8,6 +8,8 @@ def create_organization(title, created_by): with transaction.atomic(): org = Organization.objects.create(title=title, created_by=created_by) OrganizationMember.objects.create(user=created_by, organization=org) + org.jwt.enabled = True + org.jwt.save() return org diff --git a/label_studio/tests/jwt_auth/__init__.py b/label_studio/tests/jwt_auth/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/label_studio/tests/jwt_auth/test_auth.py b/label_studio/tests/jwt_auth/test_auth.py new file mode 100644 index 000000000000..ea5f6e622cfe --- /dev/null +++ b/label_studio/tests/jwt_auth/test_auth.py @@ -0,0 +1,99 @@ +import logging + +import pytest +from jwt_auth.models import LSAPIToken +from organizations.models import Organization +from rest_framework import status +from rest_framework.authtoken.models import Token +from rest_framework.test import APIClient +from users.models import User + +from ..utils import mock_feature_flag + + +@pytest.mark.django_db +@pytest.fixture +def jwt_disabled_user(): + user = User.objects.create(email='jwt_disabled@example.com') + org = Organization.objects.create(created_by=user) + user.active_organization = org + user.save() + + jwt_settings = user.active_organization.jwt + jwt_settings.api_tokens_enabled = False + jwt_settings.save() + + return user + + +@pytest.mark.django_db +@pytest.fixture +def jwt_enabled_user(): + user = User.objects.create(email='jwt_enabled@example.com') + org = Organization.objects.create(created_by=user) + user.active_organization = org + user.save() + + jwt_settings = user.active_organization.jwt + jwt_settings.api_tokens_enabled = True + jwt_settings.save() + + return user + +@pytest.mark.django_db +@pytest.fixture +def legacy_disabled_user(): + user = User.objects.create(email='legacy_disabled@example.com') + org = Organization.objects.create(created_by=user) + user.active_organization = org + user.save() + + jwt_settings = user.active_organization.jwt + jwt_settings.legacy_api_tokens_enabled = False + jwt_settings.save() + + return user + + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_logging_when_basic_token_auth_used(jwt_disabled_user, caplog): + token, _ = Token.objects.get_or_create(user=jwt_disabled_user) + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f'Token {token.key}') + caplog.set_level(logging.INFO) + + client.get('/api/projects/') + basic_auth_logs = [record for record in caplog.records if record.message == 'Basic token authentication used'] + + assert len(basic_auth_logs) == 1 + record = basic_auth_logs[0] + assert record.user_id == jwt_disabled_user.id + assert record.organization_id == jwt_disabled_user.active_organization.id + assert record.endpoint == '/api/projects/' + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_no_logging_when_jwt_token_auth_used(jwt_enabled_user, caplog): + refresh = LSAPIToken.for_user(jwt_enabled_user) + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f'Bearer {refresh.access_token}') + caplog.set_level(logging.INFO) + + client.get('/api/projects/') + + basic_auth_logs = [record for record in caplog.records if record.message == 'Basic token authentication used'] + assert len(basic_auth_logs) == 0 + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_legacy_api_token_disabled_user_cannot_use_basic_token(legacy_disabled_user): + token, _ = Token.objects.get_or_create(user=legacy_disabled_user) + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f'Token {token.key}') + + response = client.get('/api/projects/') + + assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/label_studio/tests/jwt_auth/test_middleware.py b/label_studio/tests/jwt_auth/test_middleware.py new file mode 100644 index 000000000000..2bb63b593c8a --- /dev/null +++ b/label_studio/tests/jwt_auth/test_middleware.py @@ -0,0 +1,83 @@ + +import pytest +from jwt_auth.models import LSAPIToken +from organizations.models import Organization +from rest_framework import status +from rest_framework.test import APIClient +from users.models import User + +from ..utils import mock_feature_flag + + +@pytest.mark.django_db +@pytest.fixture +def jwt_disabled_user(): + user = User.objects.create(email='jwt_disabled@example.com') + org = Organization.objects.create(created_by=user) + user.active_organization = org + user.save() + + jwt_settings = user.active_organization.jwt + jwt_settings.api_tokens_enabled = False + jwt_settings.save() + + return user + + +@pytest.mark.django_db +@pytest.fixture +def jwt_enabled_user(): + user = User.objects.create(email='jwt_enabled@example.com') + org = Organization.objects.create(created_by=user) + user.active_organization = org + user.save() + + jwt_settings = user.active_organization.jwt + jwt_settings.api_tokens_enabled = True + jwt_settings.save() + + return user + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_request_without_auth_header_returns_401(): + client = APIClient() + + response = client.get('/api/projects/') + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_request_with_invalid_token_returns_401(): + client = APIClient() + client.credentials(HTTP_AUTHORIZATION='Bearer invalid.token.here') + response = client.get('/api/projects/') + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_request_with_valid_token_returns_authenticated_user(jwt_enabled_user): + refresh = LSAPIToken.for_user(jwt_enabled_user) + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f'Bearer {refresh.access_token}') + + response = client.get('/api/projects/') + + assert response.status_code == status.HTTP_200_OK + assert response.wsgi_request.user == jwt_enabled_user + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_jwt_disabled_user_cannot_use_jwt_token(jwt_disabled_user): + refresh = LSAPIToken.for_user(jwt_disabled_user) + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f'Bearer {refresh.access_token}') + + response = client.get('/api/projects/') + assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/label_studio/tests/jwt_auth/test_models.py b/label_studio/tests/jwt_auth/test_models.py new file mode 100644 index 000000000000..74c0a8102e35 --- /dev/null +++ b/label_studio/tests/jwt_auth/test_models.py @@ -0,0 +1,140 @@ +import pytest +from jwt_auth.models import LSAPIToken, LSTokenBackend, TruncatedLSAPIToken +from organizations.models import Organization, OrganizationMember +from rest_framework_simplejwt.settings import api_settings as simple_jwt_settings +from rest_framework_simplejwt.token_blacklist.models import BlacklistedToken, OutstandingToken +from users.models import User + +from ..utils import mock_feature_flag + + +@pytest.fixture +@pytest.mark.django_db +def test_token_user(): + user = User.objects.create(email='test@example.com') + org = Organization(created_by=user) + org.save() + user.active_organization = org + user.save() + yield user + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_jwt_settings_permissions(): + user = User.objects.create() + org = Organization.objects.create(created_by=user) + OrganizationMember.objects.create( + user=user, + organization=org, + ) + jwt_settings = org.jwt + jwt_settings.api_tokens_enabled = True + + user.is_owner = True + user.save() + assert jwt_settings.has_permission(user) is True + + user.is_owner = False + user.save() + assert jwt_settings.has_permission(user) is False + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.fixture +def token_backend(): + return LSTokenBackend( + algorithm=simple_jwt_settings.ALGORITHM, + signing_key=simple_jwt_settings.SIGNING_KEY, + verifying_key=simple_jwt_settings.VERIFYING_KEY, + audience=simple_jwt_settings.AUDIENCE, + issuer=simple_jwt_settings.ISSUER, + jwk_url=simple_jwt_settings.JWK_URL, + leeway=simple_jwt_settings.LEEWAY, + json_encoder=simple_jwt_settings.JSON_ENCODER, + ) + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +def test_encode_returns_only_header_and_payload(token_backend): + payload = { + 'user_id': 123, + 'exp': 1735689600, # 2025-01-01 + 'iat': 1704153600, # 2024-01-02 + } + token = token_backend.encode(payload) + + parts = token.split('.') + assert len(parts) == 2 + + assert all(part.replace('-', '+').replace('_', '/') for part in parts) + assert all(part.replace('-', '+').replace('_', '/') for part in parts) + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +def test_encode_full_returns_complete_jwt(token_backend): + payload = { + 'user_id': 123, + 'exp': 1735689600, # 2025-01-01 + 'iat': 1704153600, # 2024-01-02 + } + token = token_backend.encode_full(payload) + + parts = token.split('.') + assert len(parts) == 3 + + assert all(part.replace('-', '+').replace('_', '/') for part in parts) + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +def test_encode_vs_encode_full_comparison(token_backend): + payload = { + 'user_id': 123, + 'exp': 1735689600, # 2025-01-01 + 'iat': 1704153600, # 2024-01-02 + } + partial_token = token_backend.encode(payload) + full_token = token_backend.encode_full(payload) + + assert full_token.startswith(partial_token) + + +@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True) +@pytest.mark.django_db +def test_token_lifecycle(test_token_user): + """Test full token lifecycle including creation, access token generation, blacklisting, and validation""" + # 1. Create an api token + refresh_token = LSAPIToken.for_user(test_token_user) + + # 2. Create an access token + access_token = refresh_token.access_token + access_token.verify() # Verify it's valid + + # 3. Get the (truncated) token from the db (like how the FE would get access, before revoking) + jti = refresh_token[simple_jwt_settings.JTI_CLAIM] + outstanding_token = OutstandingToken.objects.get(jti=jti) + truncated_token_str = outstanding_token.token + + # 4. Revoke (blacklist) the token + truncated_token = TruncatedLSAPIToken(truncated_token_str) + truncated_token.blacklist() + + # 5. Verify that the revoked token can no longer be used + assert BlacklistedToken.objects.filter(token__jti=jti).exists() + + +@pytest.mark.django_db +def test_token_creation_and_storage(test_token_user): + """Test that tokens are created and stored correctly with truncated format""" + token = LSAPIToken.for_user(test_token_user) + assert token is not None + + # Token in database shouldn't contain the signature + outstanding_token = OutstandingToken.objects.get(jti=token['jti']) + stored_token_parts = outstanding_token.token.split('.') + assert len(stored_token_parts) == 2 # Only header and payload + + # Full token should have all three JWT parts + full_token = token.get_full_jwt() + full_token_parts = full_token.split('.') + assert len(full_token_parts) == 3 # Header, payload, and signature diff --git a/label_studio/tests/utils.py b/label_studio/tests/utils.py index 8045295c0b2c..f36ce1737452 100644 --- a/label_studio/tests/utils.py +++ b/label_studio/tests/utils.py @@ -4,6 +4,7 @@ import re import tempfile from contextlib import contextmanager +from functools import wraps from pathlib import Path from types import SimpleNamespace from unittest import mock @@ -13,6 +14,7 @@ import requests_mock import ujson as json from box import Box +from core.feature_flags import flag_set from data_export.models import ConvertedFormat, Export from django.apps import apps from django.conf import settings @@ -390,3 +392,29 @@ def file_exists_in_storage(response, exists=True, file_path=None): file_path = export.file.path assert os.path.isfile(file_path) == exists + + +def mock_feature_flag(flag_name: str, value: bool, parent_module: str = 'core.feature_flags'): + """Decorator to mock a feature flag state for a test function. + + Args: + flag_name: Name of the feature flag to mock + value: True or False to set the flag state + parent_module: Module path containing the flag_set function to patch + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + def fake_flag_set(feature_flag, *flag_args, **flag_kwargs): + if feature_flag == flag_name: + return value + return flag_set(feature_flag, *flag_args, **flag_kwargs) + + with mock.patch(f'{parent_module}.flag_set', wraps=fake_flag_set): + return func(*args, **kwargs) + + return wrapper + + return decorator + diff --git a/poetry.lock b/poetry.lock index 4af016648a2e..e88fc3bf6f2c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1012,6 +1012,33 @@ files = [ [package.dependencies] django = ">=4.2" +[[package]] +name = "djangorestframework-simplejwt" +version = "5.4.0" +description = "A minimal JSON Web Token authentication plugin for Django REST Framework" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.12\" or python_version <= \"3.11\"" +files = [ + {file = "djangorestframework_simplejwt-5.4.0-py3-none-any.whl", hash = "sha256:7aec953db9ed4163430c16d086eecb0f028f814ce6bba62b06c25919261e9077"}, + {file = "djangorestframework_simplejwt-5.4.0.tar.gz", hash = "sha256:cccecce1a0e1a4a240fae80da73e5fc23055bababb8b67de88fa47cd36822320"}, +] + +[package.dependencies] +cryptography = {version = ">=3.3.1", optional = true, markers = "extra == \"crypto\""} +django = ">=4.2" +djangorestframework = ">=3.14" +pyjwt = ">=1.7.1,<3" + +[package.extras] +crypto = ["cryptography (>=3.3.1)"] +dev = ["Sphinx (>=1.6.5,<2)", "cryptography", "flake8", "freezegun", "ipython", "isort", "pep8", "pytest", "pytest-cov", "pytest-django", "pytest-watch", "pytest-xdist", "python-jose (==3.3.0)", "sphinx_rtd_theme (>=0.1.9)", "tox", "twine", "wheel"] +doc = ["Sphinx (>=1.6.5,<2)", "sphinx_rtd_theme (>=0.1.9)"] +lint = ["flake8", "isort", "pep8"] +python-jose = ["python-jose (==3.3.0)"] +test = ["cryptography", "freezegun", "pytest", "pytest-cov", "pytest-django", "pytest-xdist", "tox"] + [[package]] name = "dnspython" version = "2.6.1" @@ -3175,7 +3202,6 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, - {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -3420,7 +3446,7 @@ version = "2.8.0" description = "JSON Web Token implementation in Python" optional = false python-versions = ">=3.7" -groups = ["test"] +groups = ["main", "test"] markers = "python_version >= \"3.12\" or python_version <= \"3.11\"" files = [ {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, @@ -5073,4 +5099,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "9d90009c24cbbb441c972940430cafce40db1f9150e50b5353cfff797d5c18bf" +content-hash = "205188751f64814b3d884581d7aa074cb125f797f5365a6d00623abe090fb43b" diff --git a/pyproject.toml b/pyproject.toml index 65b150b8781e..32fe3211f5ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -205,6 +205,7 @@ django-csp = "3.7" openai = "^1.10.0" django-migration-linter = "^5.1.0" setuptools = ">=75.4.0" +djangorestframework-simplejwt = {extras = ["crypto"], version = "^5.4.0"} tldextract = ">=5.1.3" # Humansignal repo dependencies