diff --git a/overhave/storage/emulation_storage.py b/overhave/storage/emulation_storage.py index a7727a7c..6a6b92bf 100644 --- a/overhave/storage/emulation_storage.py +++ b/overhave/storage/emulation_storage.py @@ -1,14 +1,15 @@ import abc import logging +import pickle import socket from typing import cast import sqlalchemy as sa -import sqlalchemy.orm as so from overhave import db from overhave.entities.settings import OverhaveEmulationSettings from overhave.storage import EmulationRunModel +from overhave.transport.redis.deps import make_redis, get_redis_settings from overhave.utils import get_current_time logger = logging.getLogger(__name__) @@ -52,6 +53,8 @@ class EmulationStorage(IEmulationStorage): """Class for emulation runs storage.""" def __init__(self, settings: OverhaveEmulationSettings): + self._redis = redis = make_redis(get_redis_settings()) + self._redis.set('allocated_ports', pickle.dumps([])) self._settings = settings self._emulation_ports_len = len(self._settings.emulation_ports) @@ -65,20 +68,9 @@ def create_emulation_run(emulation_id: int, initiated_by: str) -> int: session.flush() return emulation_run.id - def _get_next_port(self, session: so.Session) -> int: - runs_with_allocated_ports = ( # noqa: ECE001 - session.query(db.EmulationRun) - .filter(db.EmulationRun.port.isnot(None)) - .order_by(db.EmulationRun.id.desc()) - .limit(self._emulation_ports_len) - .all() - ) - allocated_sorted_runs = sorted( - runs_with_allocated_ports, - key=lambda t: t.changed_at, - ) - - allocated_ports = {run.port for run in allocated_sorted_runs} + def _get_next_port(self) -> int: + allocated_ports = self.get_allocated_ports() + logger.debug("Allocated ports: %s", allocated_ports) not_allocated_ports = set(self._settings.emulation_ports).difference(allocated_ports) logger.debug("Not allocated ports: %s", not_allocated_ports) @@ -88,12 +80,20 @@ def _get_next_port(self, session: so.Session) -> int: continue return port logger.debug("All not allocated ports are busy!") - for run in allocated_sorted_runs: - if self._is_port_in_use(cast(int, run.port)): + for port in allocated_ports: + if self._is_port_in_use(cast(int, port)): continue - return cast(int, run.port) + return cast(int, port) raise AllPortsAreBusyError("All ports are busy - could not find free port!") + def get_allocated_ports(self): + return pickle.loads(self._redis.get('allocated_ports')) + + def allocate_port(self, port): + new_allocated_ports = self.get_allocated_ports() + new_allocated_ports.append(port) + self._redis.set('allocated_ports', pickle.dumps(sorted(new_allocated_ports))) + def _is_port_in_use(self, port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex((self._settings.emulation_bind_ip, port)) == 0 @@ -102,7 +102,8 @@ def get_requested_emulation_run(self, emulation_run_id: int) -> EmulationRunMode with db.create_session() as session: emulation_run = session.query(db.EmulationRun).filter(db.EmulationRun.id == emulation_run_id).one() emulation_run.status = db.EmulationStatus.REQUESTED - emulation_run.port = self._get_next_port(session) + emulation_run.port = self._get_next_port() + self.allocate_port(emulation_run.port) emulation_run.changed_at = get_current_time() return EmulationRunModel.model_validate(emulation_run) diff --git a/tests/integration/emulation/test_emulator.py b/tests/integration/emulation/test_emulator.py index a214737e..cf1b57ce 100644 --- a/tests/integration/emulation/test_emulator.py +++ b/tests/integration/emulation/test_emulator.py @@ -18,7 +18,7 @@ class TestEmulator: def test_start_emulation( self, emulator: Emulator, emulation_task: EmulationTask, mock_subprocess_popen: MagicMock ) -> None: - with count_queries(8): + with count_queries(6): emulator.start_emulation(task=emulation_task) with create_test_session() as session: emulation_run_db = session.get(db.EmulationRun, emulation_task.data.emulation_run_id) @@ -33,7 +33,7 @@ def test_start_emulation_with_error( emulation_task: EmulationTask, mock_subprocess_popen: MagicMock, ) -> None: - with count_queries(8): + with count_queries(6): emulator.start_emulation(task=emulation_task) with create_test_session() as session: emulation_run_db = session.get(db.EmulationRun, emulation_task.data.emulation_run_id) diff --git a/tests/unit/publication/gitlab/conftest.py b/tests/unit/publication/gitlab/conftest.py index 8c4221e7..8ec2851e 100644 --- a/tests/unit/publication/gitlab/conftest.py +++ b/tests/unit/publication/gitlab/conftest.py @@ -1,6 +1,7 @@ -from typing import Callable, Mapping, Sequence, cast +from typing import Callable, Mapping, Sequence, cast, Any import pytest +from _pytest.fixtures import FixtureRequest from faker import Faker from pytest_mock import MockFixture @@ -66,6 +67,15 @@ def get_tokenizer_settings(): return get_tokenizer_settings +@pytest.fixture() +def test_tokenizer_client( + request: FixtureRequest +) -> TokenizerClient: + return TokenizerClient(settings=TokenizerClientSettings( + enabled=True, + **request.param, + )) + @pytest.fixture() def test_gitlab_publisher_with_default_reviewers( diff --git a/tests/unit/publication/gitlab/tokenizer/test_client.py b/tests/unit/publication/gitlab/tokenizer/test_client.py new file mode 100644 index 00000000..fce9784b --- /dev/null +++ b/tests/unit/publication/gitlab/tokenizer/test_client.py @@ -0,0 +1,40 @@ +from unittest.mock import patch, MagicMock + +import pytest + +from overhave.publication.gitlab.tokenizer import TokenizerClient +from overhave.publication.gitlab.tokenizer.client import InvalidUrlException + + +class TestTokenizerClient: + """Tests for :class:`TokenizerClient`.""" + + @pytest.mark.parametrize('test_tokenizer_client', + [{"initiator": "peka", "remote_key": "pepe", "remote_key_name": "sad-pepe", + "url": "https://ya.ru"}], indirect=True) + def test_tokenizer_client_get_token_works( + self, test_tokenizer_client + ) -> None: + token_mock = "magic_token" + draft_id_mock = 4 + + client = test_tokenizer_client + + request_mock = MagicMock() + request_mock.json = MagicMock(return_value={"token": token_mock}) + + with patch.object(TokenizerClient, '_make_request', return_value=request_mock) as make_request: + tokenizerClient = client.get_token(draft_id_mock) + assert tokenizerClient.token == token_mock + make_request.assert_called_once() + + @pytest.mark.parametrize('test_tokenizer_client', + [{"initiator": "peka", "remote_key": "pepe", "remote_key_name": "sad-pepe"}], + indirect=True) + def test_tokenizer_client_get_token_url_validation_raises_error( + self, test_tokenizer_client + ) -> None: + draft_id_mock = 4 + + with pytest.raises(InvalidUrlException): + test_tokenizer_client.get_token(draft_id_mock)