diff --git a/caikit_tgis_backend/tgis_backend.py b/caikit_tgis_backend/tgis_backend.py index 6362ea5..aa6238b 100644 --- a/caikit_tgis_backend/tgis_backend.py +++ b/caikit_tgis_backend/tgis_backend.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This module implements a TGIS backend configuration -""" +"""This module implements a TGIS backend configuration""" # Standard +from copy import deepcopy from threading import Lock -from typing import Dict, Optional +from typing import Any, Dict, Optional # Third Party import grpc @@ -35,6 +35,7 @@ log = alog.use_channel("TGISBKND") error = error_handler.get(log) + # pylint: disable=too-many-instance-attributes class TGISBackend(BackendBase): """Caikit backend with a connection to the TGIS server. If no connection @@ -67,7 +68,7 @@ def __init__(self, config: Optional[dict] = None): self._mutex = Lock() self._local_tgis = None self._managed_tgis = None - self._model_connections = {} + self._model_connections: Dict[str, TGISConnection] = {} self._test_connections = self.config.get("test_connections", False) self._connect_timeout = self.config.get("connect_timeout", None) @@ -75,7 +76,9 @@ def __init__(self, config: Optional[dict] = None): # TGIS instance or running a local copy connection_cfg = self.config.get("connection") or {} error.type_check("", dict, connection=connection_cfg) - self._remote_models_cfg = self.config.get("remote_models") or {} + self._remote_models_cfg: Dict[str, dict] = ( + self.config.get("remote_models") or {} + ) error.type_check("", dict, connection=self._remote_models_cfg) local_cfg = self.config.get("local") or {} error.type_check("", dict, local=local_cfg) @@ -114,19 +117,9 @@ def __init__(self, config: Optional[dict] = None): model_id, ) if self._test_connections: - try: - model_conn.test_connection(timeout=self._connect_timeout) - except grpc.RpcError as err: - log.warning( - "", - "Unable to connect to model %s: %s", - model_id, - err, - exc_info=True, - ) - model_conn = None + model_conn = self._test_connection(model_conn, self._connect_timeout) if model_conn is not None: - self._model_connections[model_id] = model_conn + self._safely_update_state(model_id, model_conn) # We manage a local TGIS instance if there are no remote connections # specified as either a valid base connection or remote_connections @@ -182,25 +175,52 @@ def get_connection( if not model_conn and create and not self.local_tgis and conn_cfg: model_conn = TGISConnection.from_config(model_id, conn_cfg) if self._test_connections: - try: - model_conn.test_connection() - except grpc.RpcError as err: - log.warning( - "", - "Unable to connect to model %s: %s", - model_id, - err, - exc_info=True, - ) - model_conn = None + model_conn = self._test_connection(model_conn) if model_conn is not None: - # NOTE: setdefault used here to avoid the need to hold the mutex - # when running the connection test. It's possible that two - # threads would stimulate the creation of the connection - # concurrently, so just keep whichever dict update lands first - self._model_connections.setdefault(model_id, model_conn) + self._safely_update_state(model_id, model_conn) + return model_conn + def register_model_connection( + self, + model_id: str, + conn_cfg: Optional[Dict[str, Any]] = None, + fill_with_defaults: bool = True, + ) -> None: + """ + Register a remote model connection. + + If the model connection is already registered, do nothing. + + Otherwise create and register the model connection using the TGISBackend's + config connection, or the `conn_cfg` if provided. + + If `fill_with_defaults == True`, missing keys in `conn_cfg` will be populated + with defaults from the TGISBackend's config connection. + """ + if model_id in self._model_connections: + return # Model connection exists --> do nothing + + # Craft new connection config + new_conn_cfg = {} + if conn_cfg is None: + new_conn_cfg = deepcopy(self._base_connection_cfg) + else: + if fill_with_defaults: + new_conn_cfg = deepcopy(self._base_connection_cfg) + new_conn_cfg.update(conn_cfg) + + # Create model connection + model_conn = TGISConnection.from_config(model_id, new_conn_cfg) + + error.value_check("", model_conn is not None) + + # Register model connection + if self._test_connections: + model_conn = self._test_connection(model_conn) + if model_conn is not None: + self._safely_update_state(model_id, model_conn, new_conn_cfg) + def get_client(self, model_id: str) -> generation_pb2_grpc.GenerationServiceStub: model_conn = self.get_connection(model_id) if model_conn is None and self.local_tgis: @@ -269,6 +289,48 @@ def model_loaded(self) -> bool: self._managed_tgis is not None and self._managed_tgis.is_ready() ) + def _test_connection( + self, model_conn: Optional[TGISConnection], timeout: Optional[float] = None + ) -> Optional[TGISConnection]: + """ + Returns the TGISConnection if successful, else returns None. + """ + if model_conn is None: + return + + try: + model_conn.test_connection(timeout) + except grpc.RpcError as err: + log.warning( + "", + "Unable to connect to model %s: %s", + model_conn.model_id, + err, + exc_info=True, + ) + model_conn = None + + return model_conn + + def _safely_update_state( + self, + model_id: str, + model_connections: Optional[TGISConnection] = None, + remote_models_cfg: Optional[Dict[str, Any]] = None, + ): + """ + Update the `_model_connections` and `_remote_models_cfg` state dictionaries in a + thread safe manner. + """ + # NOTE: setdefault used here to avoid the need to hold the mutex + # when running the connection test. It's possible that two + # threads would stimulate the creation of the connection + # concurrently, so just keep whichever dict update lands first + if model_connections: + self._model_connections.setdefault(model_id, model_connections) + if remote_models_cfg: + self._remote_models_cfg.setdefault(model_id, remote_models_cfg) + # Register local backend register_backend_type(TGISBackend) diff --git a/caikit_tgis_backend/tgis_connection.py b/caikit_tgis_backend/tgis_connection.py index 1458cd8..5045708 100644 --- a/caikit_tgis_backend/tgis_connection.py +++ b/caikit_tgis_backend/tgis_connection.py @@ -20,7 +20,7 @@ from collections.abc import Container from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional import os import shutil @@ -90,19 +90,22 @@ class TGISConnection: TLS_HN_OVERRIDE_KEY = "hostname_override" @classmethod - def from_config(cls, model_id: str, config: dict) -> Optional["TGISConnection"]: + def from_config( + cls, model_id: str, config: Dict[str, Any] + ) -> Optional["TGISConnection"]: """Create an instance from a connection template and a model_id""" hostname = config.get(cls.HOSTNAME_KEY) if hostname: - hostname = hostname.format( - **{ + error.type_check("", str, hostname=hostname) + + hostname = hostname.format_map( + { cls.HOSTNAME_TEMPLATE_MODEL_ID: model_id, } ) log.debug("Resolved hostname [%s] for model %s", hostname, model_id) tls_hostname_override = config.get(cls.TLS_HN_OVERRIDE_KEY) - lb_policy = config.get(cls.LB_POLICY_KEY) or None error.type_check( "", diff --git a/tests/test_tgis_backend.py b/tests/test_tgis_backend.py index 888dd20..b255f36 100644 --- a/tests/test_tgis_backend.py +++ b/tests/test_tgis_backend.py @@ -16,6 +16,9 @@ """ # Standard +from copy import deepcopy +from dataclasses import asdict +from typing import Any, Dict, Optional from unittest import mock import os import tempfile @@ -32,16 +35,16 @@ # Local from caikit_tgis_backend import TGISBackend from caikit_tgis_backend.protobufs import generation_pb2 -from tests.tgis_mock import ( - TGISMock, - tgis_mock_insecure, - tgis_mock_insecure_health_delay, - tgis_mock_mtls, - tgis_mock_tls, -) +from caikit_tgis_backend.tgis_connection import TGISConnection +from tests.tgis_mock import tgis_mock_insecure # noqa +from tests.tgis_mock import tgis_mock_insecure_health_delay # noqa +from tests.tgis_mock import tgis_mock_mtls # noqa +from tests.tgis_mock import tgis_mock_tls # noqa +from tests.tgis_mock import TGISMock ## Helpers ##################################################################### + # for convenience in managing the multiple parts of the fixture class MockTGISFixture: def __init__( @@ -575,7 +578,6 @@ def test_tgis_backend_config_load_prompt_artifacts(): """Make sure that loading prompt artifacts behaves as expected""" with tempfile.TemporaryDirectory() as source_dir: with tempfile.TemporaryDirectory() as prompt_dir: - # Make some source files source_fnames = ["prompt1.pt", "prompt2.pt"] source_files = [os.path.join(source_dir, fname) for fname in source_fnames] @@ -681,6 +683,94 @@ def test_tgis_backend_config_load_prompt_artifacts(): tgis_be.load_prompt_artifacts("buz", prompt_id1, source_files[0]) +@pytest.mark.parametrize( + argnames=["model_id", "conn_cfg", "fill", "expected_conn_cfg"], + argvalues=[ + ( + "model1", + None, + False, + { + "hostname": "localhost:1234", + "model_id": "model1", + "lb_policy": "abc", + }, + ), + ( + "model1", + None, + True, + { + "hostname": "localhost:1234", + "model_id": "model1", + "lb_policy": "abc", + }, + ), + ( + "model1", + {"hostname": "myhost"}, + False, + {"hostname": "myhost", "model_id": "model1"}, + ), + ( + "model1", + {"hostname": "myhost"}, + True, + {"hostname": "myhost", "model_id": "model1", "lb_policy": "abc"}, + ), + ], +) +def test_tgis_backend_register_model_connection( + model_id: str, + conn_cfg: Optional[dict], + fill: bool, + expected_conn_cfg: Dict[str, Any], +): + """Test that register_model_connection correctly adds a TGISConnection to the _model_connections dictionary""" + tgis_be = TGISBackend( + { + "connection": {"hostname": "localhost:1234", "grpc_lb_policy_name": "abc"}, + "remote_models": {}, + } + ) + + # Assert new model is not in backend + assert model_id not in tgis_be._remote_models_cfg + assert model_id not in tgis_be._model_connections + backup_base_cfg = deepcopy(tgis_be._base_connection_cfg) + + # Register model + tgis_be.register_model_connection(model_id, conn_cfg, fill_with_defaults=fill) + assert model_id in tgis_be._remote_models_cfg + assert model_id in tgis_be._model_connections + assert isinstance(tgis_be._model_connections[model_id], TGISConnection) + assert { + k: v + for k, v in asdict(tgis_be._model_connections[model_id]).items() + if v is not None + } == expected_conn_cfg + + # Re-register -> no change to existing model + tgis_be.register_model_connection(model_id, {"hostname": "{model_id}.mycluster"}) + assert { + k: v + for k, v in asdict(tgis_be._model_connections[model_id]).items() + if v is not None + } == expected_conn_cfg + + # Confirm get_connection works + conn = tgis_be.get_connection(model_id, create=False) + assert isinstance(conn, TGISConnection) + assert { + k: v + for k, v in asdict(tgis_be._model_connections[model_id]).items() + if v is not None + } == expected_conn_cfg + + # Confirm that the source _base_connection_cfg wasn't mutated + assert tgis_be._base_connection_cfg == backup_base_cfg + + ## Failure Tests ############################################################### diff --git a/tests/test_tgis_connection.py b/tests/test_tgis_connection.py index e35d3b7..a425d84 100644 --- a/tests/test_tgis_connection.py +++ b/tests/test_tgis_connection.py @@ -14,6 +14,7 @@ """ Unit tests for the TGISConnection class """ + # Standard from contextlib import contextmanager from pathlib import Path @@ -27,7 +28,7 @@ # Local from caikit_tgis_backend.tgis_connection import TGISConnection -from tests.tgis_mock import tgis_mock_insecure +from tests.tgis_mock import tgis_mock_insecure # noqa @contextmanager