diff --git a/core/common/authentication.py b/core/common/authentication.py index 63cc3ff7a..accad59f5 100644 --- a/core/common/authentication.py +++ b/core/common/authentication.py @@ -10,7 +10,7 @@ class OCLAuthentication(BaseAuthentication): 3. Uses Auth Service to determine auth class Django/OIDC """ def get_auth_class(self, request): - from core.common.services import AuthService + from core.services.auth.core import AuthService if AuthService.is_valid_django_token(request) or get(settings, 'TEST_MODE', False): klass = TokenAuthentication else: diff --git a/core/common/backends.py b/core/common/backends.py index d70d5cf82..f70054303 100644 --- a/core/common/backends.py +++ b/core/common/backends.py @@ -99,7 +99,7 @@ def get_auth_backend(self, request=None): if get(self, '_authentication_backend'): return get(self, '_authentication_backend') - from core.common.services import AuthService + from core.services.auth.core import AuthService if AuthService.is_valid_django_token(request) or get(settings, 'TEST_MODE', False): klass = ModelBackend else: diff --git a/core/common/management/commands/celery_beat_healthcheck.py b/core/common/management/commands/celery_beat_healthcheck.py index 0a7edc337..a697d708b 100644 --- a/core/common/management/commands/celery_beat_healthcheck.py +++ b/core/common/management/commands/celery_beat_healthcheck.py @@ -3,8 +3,7 @@ from django.conf import settings from django.core.management import BaseCommand -from core.common.services import RedisService -from core.users.models import UserProfile +from core.services.storages.redis import RedisService class Command(BaseCommand): diff --git a/core/common/services.py b/core/common/services.py deleted file mode 100644 index e9b14f5be..000000000 --- a/core/common/services.py +++ /dev/null @@ -1,499 +0,0 @@ -import base64 -import json - -import boto3 -import requests -from botocore.client import Config -from botocore.exceptions import NoCredentialsError, ClientError -from django.conf import settings -from django.contrib.auth.backends import ModelBackend -from django.core.files.base import ContentFile -from django.db import connection -from django_redis import get_redis_connection -from mozilla_django_oidc.contrib.drf import OIDCAuthentication -from pydash import get -from rest_framework.authentication import TokenAuthentication -from rest_framework.authtoken.models import Token - -from core.common.backends import OCLOIDCAuthenticationBackend - - -class CloudStorageServiceInterface: - """ - Interface for storage services - """ - - def __init__(self): - pass - - def upload_file(self, key, file_path=None, headers=None, binary=False, metadata=None): # pylint: disable=too-many-arguments - """ - Uploads binary file object to key given file_path - """ - - def upload_base64(self, doc_base64, file_name, append_extension=True, public_read=False, headers=None): # pylint: disable=too-many-arguments - """ - Uploads base64 file content to file_name - """ - - def url_for(self, file_path): - """ - Returns signed url for file_path - """ - - def public_url_for(self, file_path): - """ - Returns public (or unsigned) url for file_path - """ - - def exists(self, key): - """ - Checks if key (object) exists - """ - - def has_path(self, prefix='/', delimiter='/'): - """ - Checks if path exists - """ - - def get_last_key_from_path(self, prefix='/', delimiter='/'): - """ - Returns last key from path - """ - - def delete_objects(self, path): - """ - Deletes all objects in path - """ - - def remove(self, key): - """ - Removes object - """ - - -class S3(CloudStorageServiceInterface): - """ - Configured from settings.EXPORT_SERVICE - """ - GET = 'get_object' - PUT = 'put_object' - - def __init__(self): - super().__init__() - self.conn = self.__get_connection() - - def upload_file( - self, key, file_path=None, headers=None, binary=False, metadata=None - ): # pylint: disable=too-many-arguments - """Uploads file object""" - read_directive = 'rb' if binary else 'r' - file_path = file_path if file_path else key - return self._upload(key, open(file_path, read_directive).read(), headers, metadata) - - def upload_base64( # pylint: disable=too-many-arguments,inconsistent-return-statements - self, doc_base64, file_name, append_extension=True, public_read=False, headers=None - ): - """Uploads via base64 content with file name""" - _format = None - _doc_string = None - try: - _format, _doc_string = doc_base64.split(';base64,') - except: # pylint: disable=bare-except # pragma: no cover - pass - - if not _format or not _doc_string: # pragma: no cover - return - - if append_extension: - file_name_with_ext = file_name + "." + _format.split('/')[-1] - else: - if file_name and file_name.split('.')[-1].lower() not in [ - 'pdf', 'jpg', 'jpeg', 'bmp', 'gif', 'png' - ]: - file_name += '.jpg' - file_name_with_ext = file_name - - doc_data = ContentFile(base64.b64decode(_doc_string)) - if public_read: - self._upload_public(file_name_with_ext, doc_data) - else: - self._upload(file_name_with_ext, doc_data, headers) - - return file_name_with_ext - - def url_for(self, file_path): - return self._generate_signed_url(self.GET, file_path) if file_path else None - - def public_url_for(self, file_path): - url = f"http://{settings.AWS_STORAGE_BUCKET_NAME}.s3.amazonaws.com/{file_path}" - if settings.ENV != 'development': - url = url.replace('http://', 'https://') - return url - - def exists(self, key): - try: - self.__resource().meta.client.head_object(Key=key, Bucket=settings.AWS_STORAGE_BUCKET_NAME) - except (ClientError, NoCredentialsError): - return False - - return True - - def has_path(self, prefix='/', delimiter='/'): - return len(self.__fetch_keys(prefix, delimiter)) > 0 - - def get_last_key_from_path(self, prefix='/', delimiter='/'): - keys = self.__fetch_keys(prefix, delimiter, True) - key = sorted(keys, key=lambda k: k.get('LastModified'), reverse=True)[0] if len(keys) > 1 else get(keys, '0') - return get(key, 'Key') - - def delete_objects(self, path): # pragma: no cover - try: - keys = self.__fetch_keys(prefix=path) - if keys: - self.__resource().meta.client.delete_objects( - Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delete={'Objects': keys}) - except: # pylint: disable=bare-except - pass - - def remove(self, key): - try: - return self.__get_connection().delete_object( - Bucket=settings.AWS_STORAGE_BUCKET_NAME, - Key=key - ) - except NoCredentialsError: # pragma: no cover - pass - - return None - - # private - def _generate_signed_url(self, accessor, key, metadata=None): - params = { - 'Bucket': settings.AWS_STORAGE_BUCKET_NAME, - 'Key': key, - **(metadata or {}) - } - try: - return self.__get_connection().generate_presigned_url( - accessor, - Params=params, - ExpiresIn=60 * 60 * 24 * 7, # a week - ) - except NoCredentialsError: # pragma: no cover - pass - - return None - - def _upload(self, file_path, file_content, headers=None, metadata=None): - """Uploads via file content with file_path as path + name""" - url = self._generate_signed_url(self.PUT, file_path, metadata) - result = None - if url: - res = requests.put( - url, data=file_content, headers=headers - ) if headers else requests.put(url, data=file_content) - result = res.status_code - - return result - - def _upload_public(self, file_path, file_content): - try: - return self.__get_connection().upload_fileobj( - file_content, - settings.AWS_STORAGE_BUCKET_NAME, - file_path, - ExtraArgs={ - 'ACL': 'public-read' - }, - ) - except NoCredentialsError: # pragma: no cover - pass - - return None - - # protected - def __fetch_keys(self, prefix='/', delimiter='/', verbose=False): # pragma: no cover - prefix = prefix[1:] if prefix.startswith(delimiter) else prefix - s3_resource = self.__resource() - objects = s3_resource.meta.client.list_objects(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Prefix=prefix) - content = objects.get('Contents', []) - if verbose: - return content - return [{'Key': k} for k in [obj['Key'] for obj in content]] - - def __resource(self): - return self.__session().resource('s3') - - def __get_connection(self): - session = self.__session() - - return session.client( - 's3', - config=Config(region_name=settings.AWS_REGION_NAME, signature_version='s3v4') - ) - - @staticmethod - def __session(): - return boto3.Session( - aws_access_key_id=settings.AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, - region_name=settings.AWS_REGION_NAME - ) - - -class RedisService: # pragma: no cover - @staticmethod - def get_client(): - return get_redis_connection('default') - - def set(self, key, val, **kwargs): - return self.get_client().set(key, val, **kwargs) - - def set_json(self, key, val): - return self.get_client().set(key, json.dumps(val)) - - def get_formatted(self, key): - val = self.get(key) - if isinstance(val, bytes): - val = val.decode() - - try: - val = json.loads(val) - except: # pylint: disable=bare-except - pass - - return val - - def exists(self, key): - return self.get_client().exists(key) - - def get(self, key): - return self.get_client().get(key) - - def keys(self, pattern): - return self.get_client().keys(pattern) - - def get_int(self, key): - return int(self.get_client().get(key).decode('utf-8')) - - def get_pending_tasks(self, queue, include_task_names, exclude_task_names=None): - # queue = 'bulk_import_root' - # task_name = 'core.common.tasks.bulk_import_parallel_inline' - values = self.get_client().lrange(queue, 0, -1) - tasks = [] - exclude_task_names = exclude_task_names or [] - if values: - for value in values: - val = json.loads(value.decode('utf-8')) - headers = val.get('headers') - task_name = headers.get('task') - if headers.get('id') and task_name in include_task_names and task_name not in exclude_task_names: - tasks.append( - {'task_id': headers['id'], 'task_name': headers['task'], 'state': 'PENDING', 'queue': queue} - ) - return tasks - - -class PostgresQL: - @staticmethod - def create_seq(seq_name, owned_by, min_value=0, start=1): - with connection.cursor() as cursor: - cursor.execute( - f"CREATE SEQUENCE IF NOT EXISTS {seq_name} MINVALUE {min_value} START {start} OWNED BY {owned_by};") - - @staticmethod - def update_seq(seq_name, start): - with connection.cursor() as cursor: - cursor.execute(f"SELECT setval('{seq_name}', {start}, true);") - - @staticmethod - def drop_seq(seq_name): - with connection.cursor() as cursor: - cursor.execute(f"DROP SEQUENCE IF EXISTS {seq_name};") - - @staticmethod - def next_value(seq_name): - with connection.cursor() as cursor: - cursor.execute(f"SELECT nextval('{seq_name}');") - return cursor.fetchone()[0] - - @staticmethod - def last_value(seq_name): - with connection.cursor() as cursor: - cursor.execute(f"SELECT last_value from {seq_name};") - return cursor.fetchone()[0] - - -class AbstractAuthService: - def __init__(self, username=None, password=None, user=None): - self.username = username - self.password = password - self.user = user - if self.user: - self.username = self.user.username - elif self.username: - self.set_user() - - def set_user(self): - from core.users.models import UserProfile - self.user = UserProfile.objects.filter(username=self.username).first() - - def get_token(self): - pass - - def mark_verified(self, **kwargs): - return self.user.mark_verified(**kwargs) - - def update_password(self, password): - return self.user.update_password(password=password) - - -class DjangoAuthService(AbstractAuthService): - token_type = 'Token' - authentication_class = TokenAuthentication - authentication_backend_class = ModelBackend - - def get_token(self, check_password=True): - if check_password: - if not self.user.check_password(self.password): - return False - return self.token_type + ' ' + self.user.get_token() - - @staticmethod - def create_user(_): - return True - - def logout(self, _): - pass - - -class OIDCAuthService(AbstractAuthService): - """ - Service that interacts with OIDP for: - 1. exchanging auth_code with token - 2. migrating user from django to OIDP - """ - token_type = 'Bearer' - authentication_class = OIDCAuthentication - authentication_backend_class = OCLOIDCAuthenticationBackend - USERS_URL = settings.OIDC_SERVER_INTERNAL_URL + f'/admin/realms/{settings.OIDC_REALM}/users' - OIDP_ADMIN_TOKEN_URL = settings.OIDC_SERVER_INTERNAL_URL + '/realms/master/protocol/openid-connect/token' - - @staticmethod - def get_login_redirect_url(client_id, redirect_uri, state, nonce): - return f"{settings.OIDC_OP_AUTHORIZATION_ENDPOINT}?" \ - f"response_type=code id_token&" \ - f"client_id={client_id}&" \ - f"state={state}&" \ - f"nonce={nonce}&" \ - f"redirect_uri={redirect_uri}" - - @staticmethod - def get_registration_redirect_url(client_id, redirect_uri, state, nonce): - return f"{settings.OIDC_OP_REGISTRATION_ENDPOINT}?" \ - f"response_type=code id_token&" \ - f"client_id={client_id}&" \ - f"state={state}&" \ - f"nonce={nonce}&" \ - f"redirect_uri={redirect_uri}" - - @staticmethod - def get_logout_redirect_url(id_token_hint, redirect_uri): - return f"{settings.OIDC_OP_LOGOUT_ENDPOINT}?" \ - f"id_token_hint={id_token_hint}&" \ - f"post_logout_redirect_uri={redirect_uri}" - - @staticmethod - def credential_representation_from_hash(hash_, temporary=False): - algorithm, hashIterations, salt, hashedSaltedValue = hash_.split('$') - - return { - 'type': 'password', - 'hashedSaltedValue': hashedSaltedValue, - 'algorithm': algorithm.replace('_', '-'), - 'hashIterations': int(hashIterations), - 'salt': base64.b64encode(salt.encode()).decode('ascii').strip(), - 'temporary': temporary - } - - @classmethod - def add_user(cls, user, username, password): - response = requests.post( - cls.USERS_URL, - json={ - 'enabled': True, - 'emailVerified': user.verified, - 'firstName': user.first_name, - 'lastName': user.last_name, - 'email': user.email, - 'username': user.username, - 'credentials': [cls.credential_representation_from_hash(hash_=user.password)] - }, - verify=False, - headers=OIDCAuthService.get_admin_headers(username=username, password=password) - ) - if response.status_code == 201: - return True - - return response.json() - - @staticmethod - def get_admin_token(username, password): - response = requests.post( - OIDCAuthService.OIDP_ADMIN_TOKEN_URL, - data={ - 'grant_type': 'password', - 'username': username, - 'password': password, - 'client_id': 'admin-cli' - }, - verify=False, - ) - return response.json().get('access_token') - - @staticmethod - def exchange_code_for_token(code, redirect_uri, client_id, client_secret): - response = requests.post( - settings.OIDC_OP_TOKEN_ENDPOINT, - data={ - 'grant_type': 'authorization_code', - 'client_id': client_id, - 'client_secret': client_secret, - 'code': code, - 'redirect_uri': redirect_uri - } - ) - return response.json() - - @staticmethod - def get_admin_headers(**kwargs): - return {'Authorization': f'Bearer {OIDCAuthService.get_admin_token(**kwargs)}'} - - @staticmethod - def create_user(_): - """In OID auth, user signup needs to happen in OID first""" - pass # pylint: disable=unnecessary-pass - - -class AuthService: - """ - This returns Django or OIDC Auth service based on configured env vars. - """ - @staticmethod - def is_sso_enabled(): - return settings.OIDC_SERVER_URL and not get(settings, 'TEST_MODE', False) - - @staticmethod - def get(**kwargs): - if AuthService.is_sso_enabled(): - return OIDCAuthService(**kwargs) - return DjangoAuthService(**kwargs) - - @staticmethod - def is_valid_django_token(request): - authorization_header = request.META.get('HTTP_AUTHORIZATION') - if authorization_header and authorization_header.startswith('Token '): - token_key = authorization_header.replace('Token ', '') - return Token.objects.filter(key=token_key).exists() - return False diff --git a/core/common/tasks.py b/core/common/tasks.py index 16def409d..c8917b9b4 100644 --- a/core/common/tasks.py +++ b/core/common/tasks.py @@ -615,7 +615,7 @@ def delete_s3_objects(path): @app.task(ignore_result=True) def beat_healthcheck(): # pragma: no cover - from core.common.services import RedisService + from core.services.storages.redis import RedisService redis_service = RedisService() redis_service.set(settings.CELERYBEAT_HEALTHCHECK_KEY, str(datetime.now()), ex=120) diff --git a/core/common/tests.py b/core/common/tests.py index 60319012a..e2c0fe4a5 100644 --- a/core/common/tests.py +++ b/core/common/tests.py @@ -1,21 +1,17 @@ -import base64 import os import uuid from collections import OrderedDict -from unittest.mock import patch, Mock, mock_open, ANY +from unittest.mock import patch, Mock, ANY -import boto3 import django import factory -from botocore.exceptions import ClientError from colour_runner.django_runner import ColourRunnerMixin from django.conf import settings -from django.core.files.base import ContentFile, File +from django.core.files.base import File from django.core.management import call_command from django.test import TestCase from django.test.runner import DiscoverRunner from mock.mock import call -from moto import mock_s3 from requests.auth import HTTPBasicAuth from rest_framework.exceptions import ValidationError from rest_framework.test import APITestCase, APITransactionTestCase @@ -39,7 +35,6 @@ from .checksums import Checksum from .fhir_helpers import translate_fhir_query from .serializers import IdentifierSerializer -from .services import S3, PostgresQL, DjangoAuthService, OIDCAuthService from .validators import URIValidator from ..code_systems.serializers import CodeSystemDetailSerializer @@ -205,6 +200,7 @@ def create_lookup_concept_classes(user=None, org=None): names=[ConceptNameFactory.build(name="English")] ) + class OCLAPITransactionTestCase(APITransactionTestCase, BaseTestCase): @classmethod def setUpClass(cls): @@ -214,6 +210,7 @@ def setUpClass(cls): org = Organization.objects.get(id=1) org.members.add(1) + class OCLAPITestCase(APITestCase, BaseTestCase): @classmethod def setUpClass(cls): @@ -240,181 +237,6 @@ def factory_to_params(factory_klass, **kwargs): } -class S3Test(TestCase): - @mock_s3 - def test_upload(self): - _conn = boto3.resource('s3', region_name='us-east-1') - _conn.create_bucket(Bucket='oclapi2-dev') - - S3()._upload('some/path', 'content') # pylint: disable=protected-access - - self.assertEqual( - _conn.Object( - 'oclapi2-dev', - 'some/path' - ).get()['Body'].read().decode("utf-8"), - 'content' - ) - - @mock_s3 - def test_exists(self): - _conn = boto3.resource('s3', region_name='us-east-1') - _conn.create_bucket(Bucket='oclapi2-dev') - s3 = S3() - self.assertFalse(s3.exists('some/path')) - - s3._upload('some/path', 'content') # pylint: disable=protected-access - - self.assertTrue(s3.exists('some/path')) - - def test_upload_public(self): - conn_mock = Mock(upload_fileobj=Mock(return_value='success')) - - s3 = S3() - s3._S3__get_connection = Mock(return_value=conn_mock) # pylint: disable=protected-access - self.assertEqual(s3._upload_public('some/path', 'content'), 'success') # pylint: disable=protected-access - - conn_mock.upload_fileobj.assert_called_once_with( - 'content', - 'oclapi2-dev', - 'some/path', - ExtraArgs={'ACL': 'public-read'}, - ) - - def test_upload_file(self): - with patch("builtins.open", mock_open(read_data="file-content")) as mock_file: - s3 = S3() - s3._upload = Mock(return_value=200) # pylint: disable=protected-access - file_path = "path/to/file.ext" - res = s3.upload_file(key=file_path, headers={'header1': 'val1'}) - self.assertEqual(res, 200) - s3._upload.assert_called_once_with(file_path, 'file-content', {'header1': 'val1'}, None) # pylint: disable=protected-access - mock_file.assert_called_once_with(file_path, 'r') - - def test_upload_base64(self): - file_content = base64.b64encode(b'file-content') - s3 = S3() - s3_upload_mock = Mock() - s3._upload = s3_upload_mock # pylint: disable=protected-access - uploaded_file_name_with_ext = s3.upload_base64( - doc_base64='extension/ext;base64,' + file_content.decode(), - file_name='some-file-name', - ) - - self.assertEqual( - uploaded_file_name_with_ext, - 'some-file-name.ext' - ) - mock_calls = s3_upload_mock.mock_calls - self.assertEqual(len(mock_calls), 1) - self.assertEqual( - mock_calls[0][1][0], - 'some-file-name.ext' - ) - self.assertTrue( - isinstance(mock_calls[0][1][1], ContentFile) - ) - - def test_upload_base64_public(self): - file_content = base64.b64encode(b'file-content') - s3 = S3() - s3_upload_mock = Mock() - s3._upload_public = s3_upload_mock # pylint: disable=protected-access - uploaded_file_name_with_ext = s3.upload_base64( - doc_base64='extension/ext;base64,' + file_content.decode(), - file_name='some-file-name', - public_read=True, - ) - - self.assertEqual( - uploaded_file_name_with_ext, - 'some-file-name.ext' - ) - mock_calls = s3_upload_mock.mock_calls - self.assertEqual(len(mock_calls), 1) - self.assertEqual( - mock_calls[0][1][0], - 'some-file-name.ext' - ) - self.assertTrue( - isinstance(mock_calls[0][1][1], ContentFile) - ) - - def test_upload_base64_no_ext(self): - s3_upload_mock = Mock() - s3 = S3() - s3._upload = s3_upload_mock # pylint: disable=protected-access - file_content = base64.b64encode(b'file-content') - uploaded_file_name_with_ext = s3.upload_base64( - doc_base64='extension/ext;base64,' + file_content.decode(), - file_name='some-file-name', - append_extension=False, - ) - - self.assertEqual( - uploaded_file_name_with_ext, - 'some-file-name.jpg' - ) - mock_calls = s3_upload_mock.mock_calls - self.assertEqual(len(mock_calls), 1) - self.assertEqual( - mock_calls[0][1][0], - 'some-file-name.jpg' - ) - self.assertTrue( - isinstance(mock_calls[0][1][1], ContentFile) - ) - - @mock_s3 - def test_remove(self): - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket='oclapi2-dev') - - s3 = S3() - s3._upload('some/path', 'content') # pylint: disable=protected-access - - self.assertEqual( - conn.Object( - 'oclapi2-dev', - 'some/path' - ).get()['Body'].read().decode("utf-8"), - 'content' - ) - - s3.remove(key='some/path') - - with self.assertRaises(ClientError): - conn.Object('oclapi2-dev', 'some/path').get() - - @mock_s3 - def test_url_for(self): - _conn = boto3.resource('s3', region_name='us-east-1') - _conn.create_bucket(Bucket='oclapi2-dev') - - s3 = S3() - s3._upload('some/path', 'content') # pylint: disable=protected-access - _url = s3.url_for('some/path') - - self.assertTrue( - 'https://oclapi2-dev.s3.amazonaws.com/some/path' in _url - ) - self.assertTrue( - '&X-Amz-Credential=' in _url - ) - self.assertTrue( - '&X-Amz-Signature=' in _url - ) - self.assertTrue( - 'X-Amz-Expires=' in _url - ) - - def test_public_url_for(self): - self.assertEqual( - S3().public_url_for('some/path').replace('https://', 'http://'), - 'http://oclapi2-dev.s3.amazonaws.com/some/path' - ) - - class FhirHelpersTest(OCLTestCase): def test_language_to_default_locale(self): query_fields = list(CodeSystemDetailSerializer.Meta.fields) @@ -1073,171 +895,6 @@ def test_resources_report(self, email_message_mock): self.assertTrue('for the period of' in call_args['body']) -class PostgresQLTest(OCLTestCase): - @patch('core.common.services.connection') - def test_create_seq(self, db_connection_mock): - cursor_context_mock = Mock(execute=Mock()) - cursor_mock = Mock() - cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) - cursor_mock.__exit__ = Mock(return_value=None) - db_connection_mock.cursor = Mock(return_value=cursor_mock) - - self.assertEqual(PostgresQL.create_seq('foobar_seq', 'sources.uri', 1, 100), None) - - db_connection_mock.cursor.assert_called_once() - cursor_context_mock.execute.assert_called_once_with( - 'CREATE SEQUENCE IF NOT EXISTS foobar_seq MINVALUE 1 START 100 OWNED BY sources.uri;') - - @patch('core.common.services.connection') - def test_update_seq(self, db_connection_mock): - cursor_context_mock = Mock(execute=Mock()) - cursor_mock = Mock() - cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) - cursor_mock.__exit__ = Mock(return_value=None) - db_connection_mock.cursor = Mock(return_value=cursor_mock) - - self.assertEqual(PostgresQL.update_seq('foobar_seq', 1567), None) - - db_connection_mock.cursor.assert_called_once() - cursor_context_mock.execute.assert_called_once_with("SELECT setval('foobar_seq', 1567, true);") - - @patch('core.common.services.connection') - def test_drop_seq(self, db_connection_mock): - cursor_context_mock = Mock(execute=Mock()) - cursor_mock = Mock() - cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) - cursor_mock.__exit__ = Mock(return_value=None) - db_connection_mock.cursor = Mock(return_value=cursor_mock) - - self.assertEqual(PostgresQL.drop_seq('foobar_seq'), None) - - db_connection_mock.cursor.assert_called_once() - cursor_context_mock.execute.assert_called_once_with("DROP SEQUENCE IF EXISTS foobar_seq;") - - @patch('core.common.services.connection') - def test_next_value(self, db_connection_mock): - cursor_context_mock = Mock(execute=Mock(), fetchone=Mock(return_value=[1568])) - cursor_mock = Mock() - cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) - cursor_mock.__exit__ = Mock(return_value=None) - db_connection_mock.cursor = Mock(return_value=cursor_mock) - - self.assertEqual(PostgresQL.next_value('foobar_seq'), 1568) - - db_connection_mock.cursor.assert_called_once() - cursor_context_mock.execute.assert_called_once_with("SELECT nextval('foobar_seq');") - - @patch('core.common.services.connection') - def test_last_value(self, db_connection_mock): - cursor_context_mock = Mock(execute=Mock(), fetchone=Mock(return_value=[1567])) - cursor_mock = Mock() - cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) - cursor_mock.__exit__ = Mock(return_value=None) - db_connection_mock.cursor = Mock(return_value=cursor_mock) - - self.assertEqual(PostgresQL.last_value('foobar_seq'), 1567) - - db_connection_mock.cursor.assert_called_once() - cursor_context_mock.execute.assert_called_once_with("SELECT last_value from foobar_seq;") - - -class DjangoAuthServiceTest(OCLTestCase): - def test_get_token(self): - user = UserProfileFactory(username='foobar') - - token = DjangoAuthService(user=user, password='foobar').get_token(True) - self.assertEqual(token, False) - - user.set_password('foobar') - user.save() - - token = DjangoAuthService(username='foobar', password='foobar').get_token(True) - self.assertTrue('Token ' in token) - self.assertTrue(len(token), 64) - - -class OIDCAuthServiceTest(OCLTestCase): - def test_get_login_redirect_url(self): - self.assertEqual( - OIDCAuthService.get_login_redirect_url('client-id', 'http://localhost:4000', 'state', 'nonce'), - '/realms/ocl/protocol/openid-connect/auth?response_type=code id_token&client_id=client-id&' - 'state=state&nonce=nonce&redirect_uri=http://localhost:4000' - ) - - def test_get_logout_redirect_url(self): - self.assertEqual( - OIDCAuthService.get_logout_redirect_url('id-token-hint', 'http://localhost:4000'), - '/realms/ocl/protocol/openid-connect/logout?id_token_hint=id-token-hint&' - 'post_logout_redirect_uri=http://localhost:4000' - ) - - @patch('requests.post') - def test_exchange_code_for_token(self, post_mock): - post_mock.return_value = Mock(json=Mock(return_value={'token': 'token', 'foo': 'bar'})) - - result = OIDCAuthService.exchange_code_for_token( - 'code', 'http://localhost:4000', 'client-id', 'client-secret' - ) - - self.assertEqual(result, {'token': 'token', 'foo': 'bar'}) - post_mock.assert_called_once_with( - '/realms/ocl/protocol/openid-connect/token', - data={ - 'grant_type': 'authorization_code', - 'client_id': 'client-id', - 'client_secret': 'client-secret', - 'code': 'code', - 'redirect_uri': 'http://localhost:4000' - } - ) - - @patch('requests.post') - def test_get_admin_token(self, post_mock): - post_mock.return_value = Mock(json=Mock(return_value={'access_token': 'token', 'foo': 'bar'})) - - result = OIDCAuthService.get_admin_token('username', 'password') - - self.assertEqual(result, 'token') - post_mock.assert_called_once_with( - '/realms/master/protocol/openid-connect/token', - data={ - 'grant_type': 'password', - 'username': 'username', - 'password': 'password', - 'client_id': 'admin-cli' - }, - verify=False - ) - - @patch('core.common.services.OIDCAuthService.get_admin_token') - @patch('requests.post') - def test_add_user(self, post_mock, get_admin_token_mock): - post_mock.return_value = Mock(status_code=201, json=Mock(return_value={'foo': 'bar'})) - get_admin_token_mock.return_value = 'token' - user = UserProfileFactory(username='username') - user.set_password('password') - user.save() - - result = OIDCAuthService.add_user(user, 'username', 'password') - - self.assertEqual(result, True) - get_admin_token_mock.assert_called_once_with(username='username', password='password') - post_mock.assert_called_once_with( - '/admin/realms/ocl/users', - json={ - 'enabled': True, - 'emailVerified': user.verified, - 'firstName': user.first_name, - 'lastName': user.last_name, - 'email': user.email, - 'username': user.username, - 'credentials': ANY - }, - verify=False, - headers={'Authorization': 'Bearer token'} - ) - - class URIValidatorTest(OCLTestCase): validator = URIValidator() diff --git a/core/concepts/models.py b/core/concepts/models.py index b8dc6a8f6..ff94ac3f3 100644 --- a/core/concepts/models.py +++ b/core/concepts/models.py @@ -10,7 +10,6 @@ from core.common.constants import ISO_639_1, LATEST, HEAD, ALL from core.common.mixins import SourceChildMixin from core.common.models import VersionedModel, ConceptContainerModel -from core.common.services import PostgresQL from core.common.tasks import process_hierarchy_for_new_concept, process_hierarchy_for_concept_version, \ process_hierarchy_for_new_parent_concept_version, update_mappings_concept from core.common.utils import generate_temp_version, drop_version, \ @@ -20,6 +19,7 @@ PERSIST_CLONE_ERROR, PERSIST_CLONE_SPECIFY_USER_ERROR, ALREADY_EXISTS, CONCEPT_REGEX, MAX_LOCALES_LIMIT, \ MAX_NAMES_LIMIT, MAX_DESCRIPTIONS_LIMIT from core.concepts.mixins import ConceptValidationMixin +from core.services.storages.postgres import PostgresQL class AbstractLocalizedText(ChecksumModel): diff --git a/core/importers/models.py b/core/importers/models.py index fbcb261f4..b489d3512 100644 --- a/core/importers/models.py +++ b/core/importers/models.py @@ -14,7 +14,7 @@ from core.celery import app from core.collections.models import Collection from core.common.constants import HEAD -from core.common.services import RedisService +from core.services.storages.redis import RedisService from core.common.tasks import bulk_import_parts_inline, delete_organization, batch_index_resources, \ post_import_update_resource_counts from core.common.utils import drop_version, is_url_encoded_string, encode_string, to_parent_uri, chunks diff --git a/core/importers/views.py b/core/importers/views.py index e9688561a..4537ee31f 100644 --- a/core/importers/views.py +++ b/core/importers/views.py @@ -16,7 +16,7 @@ from core.celery import app from core.common.constants import DEPRECATED_API_HEADER -from core.common.services import RedisService +from core.services.storages.redis import RedisService from core.common.swagger_parameters import update_if_exists_param, task_param, result_param, username_param, \ file_upload_param, file_url_param, parallel_threads_param, verbose_param from core.common.utils import parse_bulk_import_task_id, task_exists, flower_get, queue_bulk_import, \ diff --git a/core/integration_tests/tests_collections.py b/core/integration_tests/tests_collections.py index 29c2c5d2c..aa893cdfc 100644 --- a/core/integration_tests/tests_collections.py +++ b/core/integration_tests/tests_collections.py @@ -1481,7 +1481,7 @@ def test_get_404(self): self.assertEqual(response.status_code, 404) - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_get_204_head(self, s3_exists_mock): s3_exists_mock.return_value = False @@ -1494,7 +1494,7 @@ def test_get_204_head(self, s3_exists_mock): self.assertEqual(response.status_code, 204) s3_exists_mock.assert_called_once_with(f"users/username/username_coll_vHEAD.{self.HEAD_updated_at}.zip") - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_get_204_for_version(self, s3_has_path_mock): s3_has_path_mock.return_value = False @@ -1507,9 +1507,9 @@ def test_get_204_for_version(self, s3_has_path_mock): self.assertEqual(response.status_code, 204) s3_has_path_mock.assert_called_once_with("users/username/username_coll_v1.") - @patch('core.common.services.S3.url_for') - @patch('core.common.services.S3.get_last_key_from_path') - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.url_for') + @patch('core.services.storages.cloud.aws.S3.get_last_key_from_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_get_303_version(self, s3_has_path_mock, s3_get_last_key_from_path_mock, s3_url_for_mock): s3_has_path_mock.return_value = True s3_url = f"https://s3/users/username/username_coll_v1.{self.v1_updated_at}.zip" @@ -1531,8 +1531,8 @@ def test_get_303_version(self, s3_has_path_mock, s3_get_last_key_from_path_mock, s3_get_last_key_from_path_mock.assert_called_once_with("users/username/username_coll_v1.") s3_url_for_mock.assert_called_once_with(f"users/username/username_coll_v1.{self.v1_updated_at}.zip") - @patch('core.common.services.S3.url_for') - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.url_for') + @patch('core.services.storages.cloud.aws.S3.exists') def test_get_303_head(self, s3_exists_mock, s3_url_for_mock): s3_exists_mock.return_value = True s3_url = f"https://s3/users/username/username_coll_vHEAD.{self.HEAD_updated_at}.zip" @@ -1570,7 +1570,7 @@ def test_post_405(self): self.assertEqual(response.status_code, 405) - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_post_303_head(self, s3_exists_mock): s3_exists_mock.return_value = True response = self.client.post( @@ -1583,7 +1583,7 @@ def test_post_303_head(self, s3_exists_mock): self.assertEqual(response['URL'], self.collection.uri + 'export/') s3_exists_mock.assert_called_once_with(f"users/username/username_coll_vHEAD.{self.HEAD_updated_at}.zip") - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_post_303_version(self, s3_has_path_mock): s3_has_path_mock.return_value = True response = self.client.post( @@ -1597,7 +1597,7 @@ def test_post_303_version(self, s3_has_path_mock): s3_has_path_mock.assert_called_once_with("users/username/username_coll_v1.") @patch('core.collections.views.export_collection') - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_post_202_head(self, s3_exists_mock, export_collection_mock): s3_exists_mock.return_value = False response = self.client.post( @@ -1611,7 +1611,7 @@ def test_post_202_head(self, s3_exists_mock, export_collection_mock): export_collection_mock.delay.assert_called_once_with(self.collection.id) @patch('core.collections.views.export_collection') - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_post_202_version(self, s3_has_path_mock, export_collection_mock): s3_has_path_mock.return_value = False response = self.client.post( @@ -1625,7 +1625,7 @@ def test_post_202_version(self, s3_has_path_mock, export_collection_mock): export_collection_mock.delay.assert_called_once_with(self.collection_v1.id) @patch('core.collections.views.export_collection') - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_post_409_head(self, s3_exists_mock, export_collection_mock): s3_exists_mock.return_value = False export_collection_mock.delay.side_effect = AlreadyQueued('already-queued') @@ -1640,7 +1640,7 @@ def test_post_409_head(self, s3_exists_mock, export_collection_mock): export_collection_mock.delay.assert_called_once_with(self.collection.id) @patch('core.collections.views.export_collection') - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_post_409_version(self, s3_has_path_mock, export_collection_mock): s3_has_path_mock.return_value = False export_collection_mock.delay.side_effect = AlreadyQueued('already-queued') @@ -2021,7 +2021,7 @@ def setUp(self): self.token = self.user.get_token() self.collection = UserCollectionFactory(mnemonic='coll1', user=self.user) - @patch('core.common.services.S3.upload_base64') + @patch('core.services.storages.cloud.aws.S3.upload_base64') def test_post_200(self, upload_base64_mock): upload_base64_mock.return_value = 'users/username/collections/coll1/logo.png' self.assertIsNone(self.collection.logo_url) diff --git a/core/integration_tests/tests_orgs.py b/core/integration_tests/tests_orgs.py index 2e5ea5ebe..bbe59e407 100644 --- a/core/integration_tests/tests_orgs.py +++ b/core/integration_tests/tests_orgs.py @@ -521,7 +521,7 @@ def setUp(self): self.user = UserProfileFactory(organizations=[self.organization]) self.token = self.user.get_token() - @patch('core.common.services.S3.upload_base64') + @patch('core.services.storages.cloud.aws.S3.upload_base64') def test_post_200(self, upload_base64_mock): upload_base64_mock.return_value = 'orgs/org-1/logo.png' self.assertIsNone(self.organization.logo_url) diff --git a/core/integration_tests/tests_sources.py b/core/integration_tests/tests_sources.py index 58129a07a..5310b0dad 100644 --- a/core/integration_tests/tests_sources.py +++ b/core/integration_tests/tests_sources.py @@ -958,7 +958,7 @@ def test_get_404(self): self.assertEqual(response.status_code, 404) - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_get_204_head(self, s3_exists_mock): s3_exists_mock.return_value = False @@ -971,7 +971,7 @@ def test_get_204_head(self, s3_exists_mock): self.assertEqual(response.status_code, 204) s3_exists_mock.assert_called_once_with(f"users/username/username_source1_vHEAD.{self.v1_updated_at}.zip") - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_get_204_version(self, s3_has_path_mock): s3_has_path_mock.return_value = False @@ -984,9 +984,9 @@ def test_get_204_version(self, s3_has_path_mock): self.assertEqual(response.status_code, 204) s3_has_path_mock.assert_called_once_with("users/username/username_source1_v1.") - @patch('core.common.services.S3.url_for') - @patch('core.common.services.S3.get_last_key_from_path') - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.url_for') + @patch('core.services.storages.cloud.aws.S3.get_last_key_from_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_get_303_version(self, s3_has_path_mock, s3_get_last_key_from_path_mock, s3_url_for_mock): s3_has_path_mock.return_value = True s3_url = f'https://s3/users/username/username_source1_v1.{self.v1_updated_at}.zip' @@ -1007,8 +1007,8 @@ def test_get_303_version(self, s3_has_path_mock, s3_get_last_key_from_path_mock, s3_has_path_mock.assert_called_once_with("users/username/username_source1_v1.") s3_url_for_mock.assert_called_once_with(f"users/username/username_source1_v1.{self.v1_updated_at}.zip") - @patch('core.common.services.S3.url_for') - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.url_for') + @patch('core.services.storages.cloud.aws.S3.exists') def test_get_303_head(self, s3_exists_mock, s3_url_for_mock): s3_url = f'https://s3/users/username/username_source1_vHEAD.{self.HEAD_updated_at}.zip' s3_url_for_mock.return_value = s3_url @@ -1028,8 +1028,8 @@ def test_get_303_head(self, s3_exists_mock, s3_url_for_mock): s3_exists_mock.assert_called_once_with(f"users/username/username_source1_vHEAD.{self.HEAD_updated_at}.zip") s3_url_for_mock.assert_called_once_with(f"users/username/username_source1_vHEAD.{self.HEAD_updated_at}.zip") - @patch('core.common.services.S3.url_for') - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.url_for') + @patch('core.services.storages.cloud.aws.S3.exists') def test_get_200_head(self, s3_exists_mock, s3_url_for_mock): s3_url = f'https://s3/username/source1_vHEAD.{self.HEAD_updated_at}.zip' s3_url_for_mock.return_value = s3_url @@ -1046,9 +1046,9 @@ def test_get_200_head(self, s3_exists_mock, s3_url_for_mock): s3_exists_mock.assert_called_once_with(f"users/username/username_source1_vHEAD.{self.HEAD_updated_at}.zip") s3_url_for_mock.assert_called_once_with(f"users/username/username_source1_vHEAD.{self.HEAD_updated_at}.zip") - @patch('core.common.services.S3.url_for') - @patch('core.common.services.S3.get_last_key_from_path') - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.url_for') + @patch('core.services.storages.cloud.aws.S3.get_last_key_from_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_get_200_version(self, s3_has_path_mock, s3_get_last_key_from_path_mock, s3_url_for_mock): s3_url = f'https://s3/users/username/username_source1_v1.{self.v1_updated_at}.zip' s3_url_for_mock.return_value = s3_url @@ -1068,7 +1068,7 @@ def test_get_200_version(self, s3_has_path_mock, s3_get_last_key_from_path_mock, s3_url_for_mock.assert_called_once_with(f"users/username/username_source1_v1.{self.v1_updated_at}.zip") @patch('core.sources.models.Source.is_exporting', new_callable=PropertyMock) - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_get_208_HEAD(self, s3_exists_mock, is_exporting_mock): is_exporting_mock.return_value = True @@ -1082,7 +1082,7 @@ def test_get_208_HEAD(self, s3_exists_mock, is_exporting_mock): s3_exists_mock.assert_not_called() @patch('core.sources.models.Source.is_exporting', new_callable=PropertyMock) - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_get_208_version(self, s3_has_path_mock, is_exporting_mock): is_exporting_mock.return_value = True @@ -1113,7 +1113,7 @@ def test_post_405(self): self.assertEqual(response.status_code, 405) - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_post_303_head(self, s3_exists_mock): s3_exists_mock.return_value = True response = self.client.post( @@ -1126,7 +1126,7 @@ def test_post_303_head(self, s3_exists_mock): self.assertEqual(response['URL'], self.source.uri + 'export/') s3_exists_mock.assert_called_once_with(f"users/username/username_source1_vHEAD.{self.HEAD_updated_at}.zip") - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_post_303_version(self, s3_has_path_mock): s3_has_path_mock.return_value = True response = self.client.post( @@ -1140,7 +1140,7 @@ def test_post_303_version(self, s3_has_path_mock): s3_has_path_mock.assert_called_once_with("users/username/username_source1_v1.") @patch('core.sources.views.export_source') - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_post_202_head(self, s3_exists_mock, export_source_mock): s3_exists_mock.return_value = False response = self.client.post( @@ -1154,7 +1154,7 @@ def test_post_202_head(self, s3_exists_mock, export_source_mock): export_source_mock.delay.assert_called_once_with(self.source.id) @patch('core.sources.views.export_source') - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_post_202_version(self, s3_has_path_mock, export_source_mock): s3_has_path_mock.return_value = False response = self.client.post( @@ -1168,7 +1168,7 @@ def test_post_202_version(self, s3_has_path_mock, export_source_mock): export_source_mock.delay.assert_called_once_with(self.source_v1.id) @patch('core.sources.views.export_source') - @patch('core.common.services.S3.exists') + @patch('core.services.storages.cloud.aws.S3.exists') def test_post_409_head(self, s3_exists_mock, export_source_mock): s3_exists_mock.return_value = False export_source_mock.delay.side_effect = AlreadyQueued('already-queued') @@ -1183,7 +1183,7 @@ def test_post_409_head(self, s3_exists_mock, export_source_mock): export_source_mock.delay.assert_called_once_with(self.source.id) @patch('core.sources.views.export_source') - @patch('core.common.services.S3.has_path') + @patch('core.services.storages.cloud.aws.S3.has_path') def test_post_409_version(self, s3_has_path_mock, export_source_mock): s3_has_path_mock.return_value = False export_source_mock.delay.side_effect = AlreadyQueued('already-queued') @@ -1230,7 +1230,7 @@ def test_delete_404_no_export(self, has_export_mock): @patch('core.sources.models.Source.version_export_path', new_callable=PropertyMock) @patch('core.sources.models.Source.has_export') - @patch('core.common.services.S3.remove') + @patch('core.services.storages.cloud.aws.S3.remove') def test_delete_204(self, s3_remove_mock, has_export_mock, export_path_mock): has_export_mock.return_value = True export_path_mock.return_value = 'v1/export/path' @@ -1302,7 +1302,7 @@ def setUp(self): self.token = self.user.get_token() self.source = UserSourceFactory(mnemonic='source1', user=self.user) - @patch('core.common.services.S3.upload_base64') + @patch('core.services.storages.cloud.aws.S3.upload_base64') def test_post_200(self, upload_base64_mock): upload_base64_mock.return_value = 'users/username/sources/source1/logo.png' self.assertIsNone(self.source.logo_url) diff --git a/core/integration_tests/tests_users.py b/core/integration_tests/tests_users.py index 6f6224d54..a582e0dac 100644 --- a/core/integration_tests/tests_users.py +++ b/core/integration_tests/tests_users.py @@ -115,7 +115,7 @@ def test_get_405(self, is_sso_enabled_mock): self.assertEqual(response.status_code, 405) - @patch('core.users.views.OIDCAuthService.get_registration_redirect_url') + @patch('core.users.views.OpenIDAuthService.get_registration_redirect_url') @patch('core.users.views.AuthService.is_sso_enabled') def test_get_200(self, is_sso_enabled_mock, get_registration_url_mock): is_sso_enabled_mock.return_value = True @@ -942,7 +942,7 @@ def test_delete(self): class OIDCodeExchangeViewTest(OCLAPITestCase): - @patch('core.users.views.OIDCAuthService') + @patch('core.users.views.OpenIDAuthService') def test_post_200(self, service_mock): service_mock.exchange_code_for_token = Mock(return_value='response') response = self.client.post( @@ -1027,7 +1027,7 @@ def test_get_405(self, is_sso_enabled_mock): self.assertEqual(response.status_code, 405) - @patch('core.users.views.OIDCAuthService.get_logout_redirect_url') + @patch('core.users.views.OpenIDAuthService.get_logout_redirect_url') @patch('core.users.views.AuthService.is_sso_enabled') def test_get_200(self, is_sso_enabled_mock, get_logout_url_mock): is_sso_enabled_mock.return_value = True diff --git a/core/mappings/models.py b/core/mappings/models.py index 5bf431e0c..309a5803a 100644 --- a/core/mappings/models.py +++ b/core/mappings/models.py @@ -10,7 +10,6 @@ from core.common.constants import NAMESPACE_REGEX, LATEST from core.common.mixins import SourceChildMixin from core.common.models import VersionedModel -from core.common.services import PostgresQL from core.common.tasks import batch_index_resources from core.common.utils import separate_version, to_parent_uri, generate_temp_version, \ encode_string, is_url_encoded_string @@ -18,6 +17,7 @@ MAPPING_IS_ALREADY_NOT_RETIRED, MAPPING_WAS_UNRETIRED, PERSIST_CLONE_ERROR, PERSIST_CLONE_SPECIFY_USER_ERROR, \ ALREADY_EXISTS from core.mappings.mixins import MappingValidationMixin +from core.services.storages.postgres import PostgresQL class Mapping(MappingValidationMixin, SourceChildMixin, VersionedModel): diff --git a/core/middlewares/middlewares.py b/core/middlewares/middlewares.py index 8ad76305c..ced5ec87a 100644 --- a/core/middlewares/middlewares.py +++ b/core/middlewares/middlewares.py @@ -6,8 +6,8 @@ from core.common.constants import VERSION_HEADER, REQUEST_USER_HEADER, RESPONSE_TIME_HEADER, REQUEST_URL_HEADER, \ REQUEST_METHOD_HEADER -from core.common.services import AuthService from core.common.utils import set_current_user, set_request_url +from core.services.auth.core import AuthService request_logger = logging.getLogger('request_logger') MAX_BODY_LENGTH = 50000 diff --git a/core/services/__init__.py b/core/services/__init__.py new file mode 100644 index 000000000..d9e8ae6be --- /dev/null +++ b/core/services/__init__.py @@ -0,0 +1,6 @@ +from django.conf import settings + + +__version__ = settings.VERSION + +from core.common.errbit import * diff --git a/core/services/auth/__init__.py b/core/services/auth/__init__.py new file mode 100644 index 000000000..d9e8ae6be --- /dev/null +++ b/core/services/auth/__init__.py @@ -0,0 +1,6 @@ +from django.conf import settings + + +__version__ = settings.VERSION + +from core.common.errbit import * diff --git a/core/services/auth/core.py b/core/services/auth/core.py new file mode 100644 index 000000000..8c406a7dc --- /dev/null +++ b/core/services/auth/core.py @@ -0,0 +1,53 @@ +from django.conf import settings +from pydash import get +from rest_framework.authtoken.models import Token + + +class AuthService: + """ + This returns Django or OIDC Auth service based on configured env vars. + """ + @staticmethod + def is_sso_enabled(): + return settings.OIDC_SERVER_URL and not get(settings, 'TEST_MODE', False) + + @staticmethod + def get(**kwargs): + from core.services.auth.openid import OpenIDAuthService + from core.services.auth.django import DjangoAuthService + + if AuthService.is_sso_enabled(): + return OpenIDAuthService(**kwargs) + return DjangoAuthService(**kwargs) + + @staticmethod + def is_valid_django_token(request): + authorization_header = request.META.get('HTTP_AUTHORIZATION') + if authorization_header and authorization_header.startswith('Token '): + token_key = authorization_header.replace('Token ', '') + return Token.objects.filter(key=token_key).exists() + return False + + +class AbstractAuthService: + def __init__(self, username=None, password=None, user=None): + self.username = username + self.password = password + self.user = user + if self.user: + self.username = self.user.username + elif self.username: + self.set_user() + + def set_user(self): + from core.users.models import UserProfile + self.user = UserProfile.objects.filter(username=self.username).first() + + def get_token(self): + pass + + def mark_verified(self, **kwargs): + return self.user.mark_verified(**kwargs) + + def update_password(self, password): + return self.user.update_password(password=password) diff --git a/core/services/auth/django.py b/core/services/auth/django.py new file mode 100644 index 000000000..c2f9423f7 --- /dev/null +++ b/core/services/auth/django.py @@ -0,0 +1,23 @@ +from django.contrib.auth.backends import ModelBackend +from rest_framework.authentication import TokenAuthentication + +from core.services.auth.core import AbstractAuthService + + +class DjangoAuthService(AbstractAuthService): + token_type = 'Token' + authentication_class = TokenAuthentication + authentication_backend_class = ModelBackend + + def get_token(self, check_password=True): + if check_password: + if not self.user.check_password(self.password): + return False + return self.token_type + ' ' + self.user.get_token() + + @staticmethod + def create_user(_): + return True + + def logout(self, _): + pass diff --git a/core/services/auth/openid.py b/core/services/auth/openid.py new file mode 100644 index 000000000..607f245b9 --- /dev/null +++ b/core/services/auth/openid.py @@ -0,0 +1,116 @@ +import base64 + +import requests +from django.conf import settings +from mozilla_django_oidc.contrib.drf import OIDCAuthentication + +from core.common.backends import OCLOIDCAuthenticationBackend +from core.services.auth.core import AbstractAuthService + + +class OpenIDAuthService(AbstractAuthService): + """ + Service that interacts with OIDP for: + 1. exchanging auth_code with token + 2. migrating user from django to OIDP + """ + token_type = 'Bearer' + authentication_class = OIDCAuthentication + authentication_backend_class = OCLOIDCAuthenticationBackend + USERS_URL = settings.OIDC_SERVER_INTERNAL_URL + f'/admin/realms/{settings.OIDC_REALM}/users' + OIDP_ADMIN_TOKEN_URL = settings.OIDC_SERVER_INTERNAL_URL + '/realms/master/protocol/openid-connect/token' + + @staticmethod + def get_login_redirect_url(client_id, redirect_uri, state, nonce): + return f"{settings.OIDC_OP_AUTHORIZATION_ENDPOINT}?" \ + f"response_type=code id_token&" \ + f"client_id={client_id}&" \ + f"state={state}&" \ + f"nonce={nonce}&" \ + f"redirect_uri={redirect_uri}" + + @staticmethod + def get_registration_redirect_url(client_id, redirect_uri, state, nonce): + return f"{settings.OIDC_OP_REGISTRATION_ENDPOINT}?" \ + f"response_type=code id_token&" \ + f"client_id={client_id}&" \ + f"state={state}&" \ + f"nonce={nonce}&" \ + f"redirect_uri={redirect_uri}" + + @staticmethod + def get_logout_redirect_url(id_token_hint, redirect_uri): + return f"{settings.OIDC_OP_LOGOUT_ENDPOINT}?" \ + f"id_token_hint={id_token_hint}&" \ + f"post_logout_redirect_uri={redirect_uri}" + + @staticmethod + def credential_representation_from_hash(hash_, temporary=False): + algorithm, hashIterations, salt, hashedSaltedValue = hash_.split('$') + + return { + 'type': 'password', + 'hashedSaltedValue': hashedSaltedValue, + 'algorithm': algorithm.replace('_', '-'), + 'hashIterations': int(hashIterations), + 'salt': base64.b64encode(salt.encode()).decode('ascii').strip(), + 'temporary': temporary + } + + @classmethod + def add_user(cls, user, username, password): + response = requests.post( + cls.USERS_URL, + json={ + 'enabled': True, + 'emailVerified': user.verified, + 'firstName': user.first_name, + 'lastName': user.last_name, + 'email': user.email, + 'username': user.username, + 'credentials': [cls.credential_representation_from_hash(hash_=user.password)] + }, + verify=False, + headers=OpenIDAuthService.get_admin_headers(username=username, password=password) + ) + if response.status_code == 201: + return True + + return response.json() + + @staticmethod + def get_admin_token(username, password): + response = requests.post( + OpenIDAuthService.OIDP_ADMIN_TOKEN_URL, + data={ + 'grant_type': 'password', + 'username': username, + 'password': password, + 'client_id': 'admin-cli' + }, + verify=False, + ) + return response.json().get('access_token') + + @staticmethod + def exchange_code_for_token(code, redirect_uri, client_id, client_secret): + response = requests.post( + settings.OIDC_OP_TOKEN_ENDPOINT, + data={ + 'grant_type': 'authorization_code', + 'client_id': client_id, + 'client_secret': client_secret, + 'code': code, + 'redirect_uri': redirect_uri + } + ) + return response.json() + + @staticmethod + def get_admin_headers(**kwargs): + return {'Authorization': f'Bearer {OpenIDAuthService.get_admin_token(**kwargs)}'} + + @staticmethod + def create_user(_): + """In OID auth, user signup needs to happen in OID first""" + pass # pylint: disable=unnecessary-pass diff --git a/core/services/auth/tests.py b/core/services/auth/tests.py new file mode 100644 index 000000000..da9a2a5f7 --- /dev/null +++ b/core/services/auth/tests.py @@ -0,0 +1,103 @@ +from unittest.mock import patch, Mock, ANY + +from core.services.auth.django import DjangoAuthService +from core.services.auth.openid import OpenIDAuthService +from core.common.tests import OCLTestCase +from core.users.tests.factories import UserProfileFactory + + +class DjangoAuthServiceTest(OCLTestCase): + def test_get_token(self): + user = UserProfileFactory(username='foobar') + + token = DjangoAuthService(user=user, password='foobar').get_token(True) + self.assertEqual(token, False) + + user.set_password('foobar') + user.save() + + token = DjangoAuthService(username='foobar', password='foobar').get_token(True) + self.assertTrue('Token ' in token) + self.assertTrue(len(token), 64) + + +class OpenIDAuthServiceTest(OCLTestCase): + def test_get_login_redirect_url(self): + self.assertEqual( + OpenIDAuthService.get_login_redirect_url('client-id', 'http://localhost:4000', 'state', 'nonce'), + '/realms/ocl/protocol/openid-connect/auth?response_type=code id_token&client_id=client-id&' + 'state=state&nonce=nonce&redirect_uri=http://localhost:4000' + ) + + def test_get_logout_redirect_url(self): + self.assertEqual( + OpenIDAuthService.get_logout_redirect_url('id-token-hint', 'http://localhost:4000'), + '/realms/ocl/protocol/openid-connect/logout?id_token_hint=id-token-hint&' + 'post_logout_redirect_uri=http://localhost:4000' + ) + + @patch('requests.post') + def test_exchange_code_for_token(self, post_mock): + post_mock.return_value = Mock(json=Mock(return_value={'token': 'token', 'foo': 'bar'})) + + result = OpenIDAuthService.exchange_code_for_token( + 'code', 'http://localhost:4000', 'client-id', 'client-secret' + ) + + self.assertEqual(result, {'token': 'token', 'foo': 'bar'}) + post_mock.assert_called_once_with( + '/realms/ocl/protocol/openid-connect/token', + data={ + 'grant_type': 'authorization_code', + 'client_id': 'client-id', + 'client_secret': 'client-secret', + 'code': 'code', + 'redirect_uri': 'http://localhost:4000' + } + ) + + @patch('requests.post') + def test_get_admin_token(self, post_mock): + post_mock.return_value = Mock(json=Mock(return_value={'access_token': 'token', 'foo': 'bar'})) + + result = OpenIDAuthService.get_admin_token('username', 'password') + + self.assertEqual(result, 'token') + post_mock.assert_called_once_with( + '/realms/master/protocol/openid-connect/token', + data={ + 'grant_type': 'password', + 'username': 'username', + 'password': 'password', + 'client_id': 'admin-cli' + }, + verify=False + ) + + @patch('core.services.auth.openid.OpenIDAuthService.get_admin_token') + @patch('requests.post') + def test_add_user(self, post_mock, get_admin_token_mock): + post_mock.return_value = Mock(status_code=201, json=Mock(return_value={'foo': 'bar'})) + get_admin_token_mock.return_value = 'token' + user = UserProfileFactory(username='username') + user.set_password('password') + user.save() + + result = OpenIDAuthService.add_user(user, 'username', 'password') + + self.assertEqual(result, True) + get_admin_token_mock.assert_called_once_with(username='username', password='password') + post_mock.assert_called_once_with( + '/admin/realms/ocl/users', + json={ + 'enabled': True, + 'emailVerified': user.verified, + 'firstName': user.first_name, + 'lastName': user.last_name, + 'email': user.email, + 'username': user.username, + 'credentials': ANY + }, + verify=False, + headers={'Authorization': 'Bearer token'} + ) diff --git a/core/services/storages/__init__.py b/core/services/storages/__init__.py new file mode 100644 index 000000000..d9e8ae6be --- /dev/null +++ b/core/services/storages/__init__.py @@ -0,0 +1,6 @@ +from django.conf import settings + + +__version__ = settings.VERSION + +from core.common.errbit import * diff --git a/core/services/storages/cloud/__init__.py b/core/services/storages/cloud/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/services/storages/cloud/aws.py b/core/services/storages/cloud/aws.py new file mode 100644 index 000000000..f7937a920 --- /dev/null +++ b/core/services/storages/cloud/aws.py @@ -0,0 +1,181 @@ +import base64 + +import boto3 +import requests +from botocore.config import Config +from botocore.exceptions import ClientError, NoCredentialsError +from django.conf import settings +from django.core.files.base import ContentFile +from pydash import get + +from core.services.storages.cloud.core import CloudStorageServiceInterface + + +class S3(CloudStorageServiceInterface): + """ + Configured from settings.EXPORT_SERVICE + """ + GET = 'get_object' + PUT = 'put_object' + + def __init__(self): + super().__init__() + self.conn = self.__get_connection() + + def upload_file( + self, key, file_path=None, headers=None, binary=False, metadata=None + ): # pylint: disable=too-many-arguments + """Uploads file object""" + read_directive = 'rb' if binary else 'r' + file_path = file_path if file_path else key + return self._upload(key, open(file_path, read_directive).read(), headers, metadata) + + def upload_base64( # pylint: disable=too-many-arguments,inconsistent-return-statements + self, doc_base64, file_name, append_extension=True, public_read=False, headers=None + ): + """Uploads via base64 content with file name""" + _format = None + _doc_string = None + try: + _format, _doc_string = doc_base64.split(';base64,') + except: # pylint: disable=bare-except # pragma: no cover + pass + + if not _format or not _doc_string: # pragma: no cover + return + + if append_extension: + file_name_with_ext = file_name + "." + _format.split('/')[-1] + else: + if file_name and file_name.split('.')[-1].lower() not in [ + 'pdf', 'jpg', 'jpeg', 'bmp', 'gif', 'png' + ]: + file_name += '.jpg' + file_name_with_ext = file_name + + doc_data = ContentFile(base64.b64decode(_doc_string)) + if public_read: + self._upload_public(file_name_with_ext, doc_data) + else: + self._upload(file_name_with_ext, doc_data, headers) + + return file_name_with_ext + + def url_for(self, file_path): + return self._generate_signed_url(self.GET, file_path) if file_path else None + + def public_url_for(self, file_path): + url = f"http://{settings.AWS_STORAGE_BUCKET_NAME}.s3.amazonaws.com/{file_path}" + if settings.ENV != 'development': + url = url.replace('http://', 'https://') + return url + + def exists(self, key): + try: + self.__resource().meta.client.head_object(Key=key, Bucket=settings.AWS_STORAGE_BUCKET_NAME) + except (ClientError, NoCredentialsError): + return False + + return True + + def has_path(self, prefix='/', delimiter='/'): + return len(self.__fetch_keys(prefix, delimiter)) > 0 + + def get_last_key_from_path(self, prefix='/', delimiter='/'): + keys = self.__fetch_keys(prefix, delimiter, True) + key = sorted(keys, key=lambda k: k.get('LastModified'), reverse=True)[0] if len(keys) > 1 else get(keys, '0') + return get(key, 'Key') + + def delete_objects(self, path): # pragma: no cover + try: + keys = self.__fetch_keys(prefix=path) + if keys: + self.__resource().meta.client.delete_objects( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delete={'Objects': keys}) + except: # pylint: disable=bare-except + pass + + def remove(self, key): + try: + return self.__get_connection().delete_object( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, + Key=key + ) + except NoCredentialsError: # pragma: no cover + pass + + return None + + # private + def _generate_signed_url(self, accessor, key, metadata=None): + params = { + 'Bucket': settings.AWS_STORAGE_BUCKET_NAME, + 'Key': key, + **(metadata or {}) + } + try: + return self.__get_connection().generate_presigned_url( + accessor, + Params=params, + ExpiresIn=60 * 60 * 24 * 7, # a week + ) + except NoCredentialsError: # pragma: no cover + pass + + return None + + def _upload(self, file_path, file_content, headers=None, metadata=None): + """Uploads via file content with file_path as path + name""" + url = self._generate_signed_url(self.PUT, file_path, metadata) + result = None + if url: + res = requests.put( + url, data=file_content, headers=headers + ) if headers else requests.put(url, data=file_content) + result = res.status_code + + return result + + def _upload_public(self, file_path, file_content): + try: + return self.__get_connection().upload_fileobj( + file_content, + settings.AWS_STORAGE_BUCKET_NAME, + file_path, + ExtraArgs={ + 'ACL': 'public-read' + }, + ) + except NoCredentialsError: # pragma: no cover + pass + + return None + + # protected + def __fetch_keys(self, prefix='/', delimiter='/', verbose=False): # pragma: no cover + prefix = prefix[1:] if prefix.startswith(delimiter) else prefix + s3_resource = self.__resource() + objects = s3_resource.meta.client.list_objects(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Prefix=prefix) + content = objects.get('Contents', []) + if verbose: + return content + return [{'Key': k} for k in [obj['Key'] for obj in content]] + + def __resource(self): + return self.__session().resource('s3') + + def __get_connection(self): + session = self.__session() + + return session.client( + 's3', + config=Config(region_name=settings.AWS_REGION_NAME, signature_version='s3v4') + ) + + @staticmethod + def __session(): + return boto3.Session( + aws_access_key_id=settings.AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, + region_name=settings.AWS_REGION_NAME + ) diff --git a/core/services/storages/cloud/core.py b/core/services/storages/cloud/core.py new file mode 100644 index 000000000..8cc3a1c3b --- /dev/null +++ b/core/services/storages/cloud/core.py @@ -0,0 +1,52 @@ +class CloudStorageServiceInterface: + """ + Interface for storage services + """ + + def __init__(self): + pass + + def upload_file(self, key, file_path=None, headers=None, binary=False, metadata=None): # pylint: disable=too-many-arguments + """ + Uploads binary file object to key given file_path + """ + + def upload_base64(self, doc_base64, file_name, append_extension=True, public_read=False, headers=None): # pylint: disable=too-many-arguments + """ + Uploads base64 file content to file_name + """ + + def url_for(self, file_path): + """ + Returns signed url for file_path + """ + + def public_url_for(self, file_path): + """ + Returns public (or unsigned) url for file_path + """ + + def exists(self, key): + """ + Checks if key (object) exists + """ + + def has_path(self, prefix='/', delimiter='/'): + """ + Checks if path exists + """ + + def get_last_key_from_path(self, prefix='/', delimiter='/'): + """ + Returns last key from path + """ + + def delete_objects(self, path): + """ + Deletes all objects in path + """ + + def remove(self, key): + """ + Removes object + """ diff --git a/core/services/storages/cloud/tests.py b/core/services/storages/cloud/tests.py new file mode 100644 index 000000000..95116055b --- /dev/null +++ b/core/services/storages/cloud/tests.py @@ -0,0 +1,185 @@ +import base64 +from unittest.mock import Mock, patch, mock_open + +import boto3 +from botocore.exceptions import ClientError +from django.core.files.base import ContentFile +from django.test import TestCase +from moto import mock_s3 + +from core.services.storages.cloud.aws import S3 + + +class S3Test(TestCase): + @mock_s3 + def test_upload(self): + _conn = boto3.resource('s3', region_name='us-east-1') + _conn.create_bucket(Bucket='oclapi2-dev') + + S3()._upload('some/path', 'content') # pylint: disable=protected-access + + self.assertEqual( + _conn.Object( + 'oclapi2-dev', + 'some/path' + ).get()['Body'].read().decode("utf-8"), + 'content' + ) + + @mock_s3 + def test_exists(self): + _conn = boto3.resource('s3', region_name='us-east-1') + _conn.create_bucket(Bucket='oclapi2-dev') + s3 = S3() + self.assertFalse(s3.exists('some/path')) + + s3._upload('some/path', 'content') # pylint: disable=protected-access + + self.assertTrue(s3.exists('some/path')) + + def test_upload_public(self): + conn_mock = Mock(upload_fileobj=Mock(return_value='success')) + + s3 = S3() + s3._S3__get_connection = Mock(return_value=conn_mock) # pylint: disable=protected-access + self.assertEqual(s3._upload_public('some/path', 'content'), 'success') # pylint: disable=protected-access + + conn_mock.upload_fileobj.assert_called_once_with( + 'content', + 'oclapi2-dev', + 'some/path', + ExtraArgs={'ACL': 'public-read'}, + ) + + def test_upload_file(self): + with patch("builtins.open", mock_open(read_data="file-content")) as mock_file: + s3 = S3() + s3._upload = Mock(return_value=200) # pylint: disable=protected-access + file_path = "path/to/file.ext" + res = s3.upload_file(key=file_path, headers={'header1': 'val1'}) + self.assertEqual(res, 200) + s3._upload.assert_called_once_with(file_path, 'file-content', {'header1': 'val1'}, None) # pylint: disable=protected-access + mock_file.assert_called_once_with(file_path, 'r') + + def test_upload_base64(self): + file_content = base64.b64encode(b'file-content') + s3 = S3() + s3_upload_mock = Mock() + s3._upload = s3_upload_mock # pylint: disable=protected-access + uploaded_file_name_with_ext = s3.upload_base64( + doc_base64='extension/ext;base64,' + file_content.decode(), + file_name='some-file-name', + ) + + self.assertEqual( + uploaded_file_name_with_ext, + 'some-file-name.ext' + ) + mock_calls = s3_upload_mock.mock_calls + self.assertEqual(len(mock_calls), 1) + self.assertEqual( + mock_calls[0][1][0], + 'some-file-name.ext' + ) + self.assertTrue( + isinstance(mock_calls[0][1][1], ContentFile) + ) + + def test_upload_base64_public(self): + file_content = base64.b64encode(b'file-content') + s3 = S3() + s3_upload_mock = Mock() + s3._upload_public = s3_upload_mock # pylint: disable=protected-access + uploaded_file_name_with_ext = s3.upload_base64( + doc_base64='extension/ext;base64,' + file_content.decode(), + file_name='some-file-name', + public_read=True, + ) + + self.assertEqual( + uploaded_file_name_with_ext, + 'some-file-name.ext' + ) + mock_calls = s3_upload_mock.mock_calls + self.assertEqual(len(mock_calls), 1) + self.assertEqual( + mock_calls[0][1][0], + 'some-file-name.ext' + ) + self.assertTrue( + isinstance(mock_calls[0][1][1], ContentFile) + ) + + def test_upload_base64_no_ext(self): + s3_upload_mock = Mock() + s3 = S3() + s3._upload = s3_upload_mock # pylint: disable=protected-access + file_content = base64.b64encode(b'file-content') + uploaded_file_name_with_ext = s3.upload_base64( + doc_base64='extension/ext;base64,' + file_content.decode(), + file_name='some-file-name', + append_extension=False, + ) + + self.assertEqual( + uploaded_file_name_with_ext, + 'some-file-name.jpg' + ) + mock_calls = s3_upload_mock.mock_calls + self.assertEqual(len(mock_calls), 1) + self.assertEqual( + mock_calls[0][1][0], + 'some-file-name.jpg' + ) + self.assertTrue( + isinstance(mock_calls[0][1][1], ContentFile) + ) + + @mock_s3 + def test_remove(self): + conn = boto3.resource('s3', region_name='us-east-1') + conn.create_bucket(Bucket='oclapi2-dev') + + s3 = S3() + s3._upload('some/path', 'content') # pylint: disable=protected-access + + self.assertEqual( + conn.Object( + 'oclapi2-dev', + 'some/path' + ).get()['Body'].read().decode("utf-8"), + 'content' + ) + + s3.remove(key='some/path') + + with self.assertRaises(ClientError): + conn.Object('oclapi2-dev', 'some/path').get() + + @mock_s3 + def test_url_for(self): + _conn = boto3.resource('s3', region_name='us-east-1') + _conn.create_bucket(Bucket='oclapi2-dev') + + s3 = S3() + s3._upload('some/path', 'content') # pylint: disable=protected-access + _url = s3.url_for('some/path') + + self.assertTrue( + 'https://oclapi2-dev.s3.amazonaws.com/some/path' in _url + ) + self.assertTrue( + '&X-Amz-Credential=' in _url + ) + self.assertTrue( + '&X-Amz-Signature=' in _url + ) + self.assertTrue( + 'X-Amz-Expires=' in _url + ) + + def test_public_url_for(self): + self.assertEqual( + S3().public_url_for('some/path').replace('https://', 'http://'), + 'http://oclapi2-dev.s3.amazonaws.com/some/path' + ) diff --git a/core/services/storages/postgres.py b/core/services/storages/postgres.py new file mode 100644 index 000000000..3f0812aef --- /dev/null +++ b/core/services/storages/postgres.py @@ -0,0 +1,31 @@ +from django.db import connection + + +class PostgresQL: + @staticmethod + def create_seq(seq_name, owned_by, min_value=0, start=1): + with connection.cursor() as cursor: + cursor.execute( + f"CREATE SEQUENCE IF NOT EXISTS {seq_name} MINVALUE {min_value} START {start} OWNED BY {owned_by};") + + @staticmethod + def update_seq(seq_name, start): + with connection.cursor() as cursor: + cursor.execute(f"SELECT setval('{seq_name}', {start}, true);") + + @staticmethod + def drop_seq(seq_name): + with connection.cursor() as cursor: + cursor.execute(f"DROP SEQUENCE IF EXISTS {seq_name};") + + @staticmethod + def next_value(seq_name): + with connection.cursor() as cursor: + cursor.execute(f"SELECT nextval('{seq_name}');") + return cursor.fetchone()[0] + + @staticmethod + def last_value(seq_name): + with connection.cursor() as cursor: + cursor.execute(f"SELECT last_value from {seq_name};") + return cursor.fetchone()[0] diff --git a/core/services/storages/redis.py b/core/services/storages/redis.py new file mode 100644 index 000000000..f3c1b33cd --- /dev/null +++ b/core/services/storages/redis.py @@ -0,0 +1,56 @@ +import json + +from django_redis import get_redis_connection + + +class RedisService: # pragma: no cover + @staticmethod + def get_client(): + return get_redis_connection('default') + + def set(self, key, val, **kwargs): + return self.get_client().set(key, val, **kwargs) + + def set_json(self, key, val): + return self.get_client().set(key, json.dumps(val)) + + def get_formatted(self, key): + val = self.get(key) + if isinstance(val, bytes): + val = val.decode() + + try: + val = json.loads(val) + except: # pylint: disable=bare-except + pass + + return val + + def exists(self, key): + return self.get_client().exists(key) + + def get(self, key): + return self.get_client().get(key) + + def keys(self, pattern): + return self.get_client().keys(pattern) + + def get_int(self, key): + return int(self.get_client().get(key).decode('utf-8')) + + def get_pending_tasks(self, queue, include_task_names, exclude_task_names=None): + # queue = 'bulk_import_root' + # task_name = 'core.common.tasks.bulk_import_parallel_inline' + values = self.get_client().lrange(queue, 0, -1) + tasks = [] + exclude_task_names = exclude_task_names or [] + if values: + for value in values: + val = json.loads(value.decode('utf-8')) + headers = val.get('headers') + task_name = headers.get('task') + if headers.get('id') and task_name in include_task_names and task_name not in exclude_task_names: + tasks.append( + {'task_id': headers['id'], 'task_name': headers['task'], 'state': 'PENDING', 'queue': queue} + ) + return tasks diff --git a/core/services/storages/tests.py b/core/services/storages/tests.py new file mode 100644 index 000000000..9d7004a8d --- /dev/null +++ b/core/services/storages/tests.py @@ -0,0 +1,72 @@ +from unittest.mock import patch, Mock + +from core.services.storages.postgres import PostgresQL +from core.common.tests import OCLTestCase + + +class PostgresQLTest(OCLTestCase): + @patch('core.services.storages.postgres.connection') + def test_create_seq(self, db_connection_mock): + cursor_context_mock = Mock(execute=Mock()) + cursor_mock = Mock() + cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) + cursor_mock.__exit__ = Mock(return_value=None) + db_connection_mock.cursor = Mock(return_value=cursor_mock) + + self.assertEqual(PostgresQL.create_seq('foobar_seq', 'sources.uri', 1, 100), None) + + db_connection_mock.cursor.assert_called_once() + cursor_context_mock.execute.assert_called_once_with( + 'CREATE SEQUENCE IF NOT EXISTS foobar_seq MINVALUE 1 START 100 OWNED BY sources.uri;') + + @patch('core.services.storages.postgres.connection') + def test_update_seq(self, db_connection_mock): + cursor_context_mock = Mock(execute=Mock()) + cursor_mock = Mock() + cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) + cursor_mock.__exit__ = Mock(return_value=None) + db_connection_mock.cursor = Mock(return_value=cursor_mock) + + self.assertEqual(PostgresQL.update_seq('foobar_seq', 1567), None) + + db_connection_mock.cursor.assert_called_once() + cursor_context_mock.execute.assert_called_once_with("SELECT setval('foobar_seq', 1567, true);") + + @patch('core.services.storages.postgres.connection') + def test_drop_seq(self, db_connection_mock): + cursor_context_mock = Mock(execute=Mock()) + cursor_mock = Mock() + cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) + cursor_mock.__exit__ = Mock(return_value=None) + db_connection_mock.cursor = Mock(return_value=cursor_mock) + + self.assertEqual(PostgresQL.drop_seq('foobar_seq'), None) + + db_connection_mock.cursor.assert_called_once() + cursor_context_mock.execute.assert_called_once_with("DROP SEQUENCE IF EXISTS foobar_seq;") + + @patch('core.services.storages.postgres.connection') + def test_next_value(self, db_connection_mock): + cursor_context_mock = Mock(execute=Mock(), fetchone=Mock(return_value=[1568])) + cursor_mock = Mock() + cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) + cursor_mock.__exit__ = Mock(return_value=None) + db_connection_mock.cursor = Mock(return_value=cursor_mock) + + self.assertEqual(PostgresQL.next_value('foobar_seq'), 1568) + + db_connection_mock.cursor.assert_called_once() + cursor_context_mock.execute.assert_called_once_with("SELECT nextval('foobar_seq');") + + @patch('core.services.storages.postgres.connection') + def test_last_value(self, db_connection_mock): + cursor_context_mock = Mock(execute=Mock(), fetchone=Mock(return_value=[1567])) + cursor_mock = Mock() + cursor_mock.__enter__ = Mock(return_value=cursor_context_mock) + cursor_mock.__exit__ = Mock(return_value=None) + db_connection_mock.cursor = Mock(return_value=cursor_mock) + + self.assertEqual(PostgresQL.last_value('foobar_seq'), 1567) + + db_connection_mock.cursor.assert_called_once() + cursor_context_mock.execute.assert_called_once_with("SELECT last_value from foobar_seq;") diff --git a/core/settings.py b/core/settings.py index 359bbe12c..2dbb040b4 100644 --- a/core/settings.py +++ b/core/settings.py @@ -522,7 +522,7 @@ ERRBIT_KEY = os.environ.get('ERRBIT_KEY', 'errbit-key') # Repo Export Upload/download -EXPORT_SERVICE = os.environ.get('EXPORT_SERVICE', 'core.common.services.S3') +EXPORT_SERVICE = os.environ.get('EXPORT_SERVICE', 'core.services.storages.cloud.aws.S3') # Locales Repository URI # can either be /orgs/OCL/sources/Locales/ (old-style, ISO-639-2) diff --git a/core/sources/models.py b/core/sources/models.py index fa7e0776b..285e8a327 100644 --- a/core/sources/models.py +++ b/core/sources/models.py @@ -10,10 +10,10 @@ from core.common.constants import HEAD from core.common.models import ConceptContainerModel -from core.common.services import PostgresQL from core.common.tasks import update_mappings_source, index_source_concepts, index_source_mappings from core.common.validators import validate_non_negative from core.concepts.models import ConceptName, Concept +from core.services.storages.postgres import PostgresQL from core.sources.constants import SOURCE_TYPE, SOURCE_VERSION_TYPE, HIERARCHY_ROOT_MUST_BELONG_TO_SAME_SOURCE, \ HIERARCHY_MEANINGS, AUTO_ID_CHOICES, AUTO_ID_SEQUENTIAL, AUTO_ID_UUID, LOCALE_EXTERNAL_AUTO_ID_CHOICES diff --git a/core/sources/tests/tests.py b/core/sources/tests/tests.py index 75655d1e9..ce4c5bf74 100644 --- a/core/sources/tests/tests.py +++ b/core/sources/tests/tests.py @@ -7,7 +7,6 @@ from core.collections.tests.factories import OrganizationCollectionFactory from core.common.constants import HEAD, ACCESS_TYPE_EDIT, ACCESS_TYPE_NONE, ACCESS_TYPE_VIEW, \ OPENMRS_VALIDATION_SCHEMA -from core.common.services import PostgresQL from core.common.tasks import index_source_mappings, index_source_concepts from core.common.tasks import seed_children_to_new_version from core.common.tasks import update_source_active_concepts_count @@ -20,6 +19,7 @@ from core.mappings.documents import MappingDocument from core.mappings.tests.factories import MappingFactory from core.orgs.tests.factories import OrganizationFactory +from core.services.storages.postgres import PostgresQL from core.sources.documents import SourceDocument from core.sources.models import Source from core.sources.tests.factories import OrganizationSourceFactory, UserSourceFactory @@ -784,7 +784,7 @@ def test_autoid_start_from_validate_non_negative(self): ]: Source(**{field: 1}, mnemonic='foo', version='HEAD', name='foo').full_clean() - @patch('core.common.services.PostgresQL.create_seq') + @patch('core.services.storages.postgres.PostgresQL.create_seq') def test_autoid_field_changes(self, create_seq): org = OrganizationFactory(mnemonic='org') source = OrganizationSourceFactory(mnemonic='sequence', organization=org) diff --git a/core/users/tests/tests.py b/core/users/tests/tests.py index a30b84818..d4f23bd81 100644 --- a/core/users/tests/tests.py +++ b/core/users/tests/tests.py @@ -305,7 +305,7 @@ def test_get_405(self, is_sso_enabled_mock): self.assertEqual(response.status_code, 405) - @patch('core.users.views.OIDCAuthService.get_login_redirect_url') + @patch('core.users.views.OpenIDAuthService.get_login_redirect_url') @patch('core.users.views.AuthService.is_sso_enabled') def test_get_200(self, is_sso_enabled_mock, get_login_url_mock): is_sso_enabled_mock.return_value = True @@ -328,7 +328,7 @@ def setUp(self): self.user = UserProfileFactory(username='username1') self.token = self.user.get_token() - @patch('core.common.services.S3.upload_base64') + @patch('core.services.storages.cloud.aws.S3.upload_base64') def test_post_200(self, upload_base64_mock): upload_base64_mock.return_value = 'users/username1/logo.png' self.assertIsNone(self.user.logo_url) diff --git a/core/users/views.py b/core/users/views.py index 7af6045f1..9c4dcb895 100644 --- a/core/users/views.py +++ b/core/users/views.py @@ -27,14 +27,14 @@ from core.common.utils import parse_updated_since_param, from_string_to_date, get_truthy_values from core.common.views import BaseAPIView, BaseLogoView from core.orgs.models import Organization +from core.services.auth.core import AuthService +from core.services.auth.openid import OpenIDAuthService from core.users.constants import VERIFICATION_TOKEN_MISMATCH, VERIFY_EMAIL_MESSAGE, REACTIVATE_USER_MESSAGE from core.users.documents import UserProfileDocument from core.users.search import UserProfileFacetedSearch from core.users.serializers import UserDetailSerializer, UserCreateSerializer, UserListSerializer, UserSummarySerializer from .models import UserProfile from ..common import ERRBIT_LOGGER -from ..common.services import AuthService, OIDCAuthService - TRUTHY = get_truthy_values() @@ -59,7 +59,7 @@ def post(request): status=status.HTTP_400_BAD_REQUEST ) return Response( - OIDCAuthService.exchange_code_for_token(code, redirect_uri, client_id, client_secret)) + OpenIDAuthService.exchange_code_for_token(code, redirect_uri, client_id, client_secret)) # This API is only to migrate users from Django to OID, requires OID admin credentials in payload @@ -82,7 +82,7 @@ def post(self, request, **kwargs): # pylint: disable=unused-argument status=status.HTTP_400_BAD_REQUEST ) user = self.get_object() - result = OIDCAuthService.add_user(user=user, username=username, password=password) + result = OpenIDAuthService.add_user(user=user, username=username, password=password) return Response(result) @@ -101,7 +101,7 @@ class OIDCLogoutView(APIView): def get(request): if AuthService.is_sso_enabled(): return redirect( - OIDCAuthService.get_logout_redirect_url( + OpenIDAuthService.get_logout_redirect_url( request.query_params.get('id_token_hint'), request.query_params.get('post_logout_redirect_uri'), ) @@ -116,7 +116,7 @@ class TokenAuthenticationView(ObtainAuthToken): def get(request): if AuthService.is_sso_enabled(): return redirect( - OIDCAuthService.get_login_redirect_url( + OpenIDAuthService.get_login_redirect_url( request.query_params.get('client_id'), request.query_params.get('redirect_uri'), request.query_params.get('state'), @@ -271,7 +271,7 @@ class UserSignup(UserBaseView, mixins.CreateModelMixin): def get(request): if AuthService.is_sso_enabled(): return redirect( - OIDCAuthService.get_registration_redirect_url( + OpenIDAuthService.get_registration_redirect_url( request.query_params.get('client_id'), request.query_params.get('redirect_uri'), request.query_params.get('state'),