diff --git a/mysql_mimic/connection.py b/mysql_mimic/connection.py index 5160192..b16c6ad 100644 --- a/mysql_mimic/connection.py +++ b/mysql_mimic/connection.py @@ -77,11 +77,11 @@ def __init__( @property def server_charset(self) -> CharacterSet: - return CharacterSet[self.session.variables.get("character_set_results")] + return CharacterSet[str(self.session.variables.get("character_set_results"))] @property def client_charset(self) -> CharacterSet: - return CharacterSet[self.session.variables.get("character_set_client")] + return CharacterSet[str(self.session.variables.get("character_set_client"))] async def start(self) -> None: self._task = asyncio.create_task(self._start()) @@ -131,7 +131,7 @@ async def connection_phase(self) -> None: handshake_v10 = packets.make_handshake_v10( capabilities=self.server_capabilities, server_charset=self.server_charset, - server_version=self.session.variables.get("version"), + server_version=str(self.session.variables.get("version")), connection_id=self.connection_id, auth_data=auth_data, status_flags=self.status_flags, diff --git a/mysql_mimic/variable_processor.py b/mysql_mimic/variable_processor.py index 7a84bd3..10ff7ca 100644 --- a/mysql_mimic/variable_processor.py +++ b/mysql_mimic/variable_processor.py @@ -1,10 +1,9 @@ from contextlib import contextmanager -from typing import Dict, Mapping, Generator +from typing import Dict, Mapping, Generator, MutableMapping, Any from sqlglot import expressions as exp from mysql_mimic.intercept import value_to_expression, expression_to_value -from mysql_mimic.variables import Variables variable_constants = { "CURRENT_USER", @@ -51,28 +50,28 @@ class VariableProcessor: """ def __init__( - self, functions: Mapping, variables: Variables, expression: exp.Expression + self, functions: Mapping, variables: MutableMapping, expression: exp.Expression ): self._functions = functions self._variables = variables self._expression = expression # Stores the original system variable values. - self._orig: Dict[str, str] = {} + self._orig: Dict[str, Any] = {} @contextmanager def set_variables(self) -> Generator[exp.Expression, None, None]: assignments = _get_var_assignments(self._expression) self._orig = {k: self._variables.get(k) for k in assignments} for k, v in assignments.items(): - self._variables.set(k, v) + self._variables[k] = v self._replace_variables() yield self._expression for k, v in self._orig.items(): - self._variables.set(k, v) + self._variables[k] = v def _replace_variables(self) -> None: """Replaces certain functions in the query with literals provided from the mapping in _functions, diff --git a/mysql_mimic/variables.py b/mysql_mimic/variables.py index e5a0d8a..d11a7a7 100644 --- a/mysql_mimic/variables.py +++ b/mysql_mimic/variables.py @@ -4,7 +4,7 @@ import re from datetime import timezone, timedelta from functools import lru_cache -from typing import Any, Callable, Tuple +from typing import Any, Callable, Tuple, Iterator, MutableMapping from mysql_mimic.charset import CharacterSet, Collation from mysql_mimic.errors import MysqlError, ErrorCode @@ -57,14 +57,32 @@ class Default: ... } -class Variables(abc.ABC): +class Variables(abc.ABC, MutableMapping[str, Any]): """ Abstract class for MySQL system variables. """ def __init__(self) -> None: # Current variable values - self.values: dict[str, Any] = {} + self._values: dict[str, Any] = {} + + def __getitem__(self, key: str) -> Any | None: + try: + return self.get_variable(key) + except MysqlError as e: + raise KeyError from e + + def __setitem__(self, key: str, value: Any) -> None: + return self.set(key, value) + + def __delitem__(self, key: str) -> None: + raise MysqlError(f"Cannot delete session variable {key}.") + + def __iter__(self) -> Iterator[str]: + return self._values.__iter__() + + def __len__(self) -> int: + return len(self._values) def get_schema(self, name: str) -> VariableSchema: schema = self.schema.get(name) @@ -84,19 +102,19 @@ def set(self, name: str, value: Any, force: bool = False) -> None: ) if value is DEFAULT or value is None: - self.values[name] = default + self._values[name] = default else: - self.values[name] = type_(value) + self._values[name] = type_(value) - def get(self, name: str) -> Any: + def get_variable(self, name: str) -> Any | None: name = name.lower() - if name in self.values: - return self.values[name] + if name in self._values: + return self._values[name] _, default, _ = self.get_schema(name) return default - def list(self) -> list[tuple[str, str]]: + def list(self) -> list[tuple[str, Any]]: return [(name, self.get(name)) for name in sorted(self.schema)] @property diff --git a/tests/test_variables.py b/tests/test_variables.py index 329c9d8..30aa0b3 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -1,9 +1,21 @@ from datetime import timezone, timedelta +from typing import Dict import pytest from mysql_mimic.errors import MysqlError -from mysql_mimic.variables import parse_timezone +from mysql_mimic.variables import ( + parse_timezone, + Variables, + VariableSchema, +) + + +class TestVars(Variables): + + @property + def schema(self) -> Dict[str, VariableSchema]: + return {"foo": (str, "bar", True)} def test_parse_timezone() -> None: @@ -17,3 +29,20 @@ def test_parse_timezone() -> None: with pytest.raises(MysqlError): parse_timezone("whoops") + + +def test_variable_mapping() -> None: + test_vars = TestVars() + + assert test_vars.get_variable("foo") == "bar" + assert test_vars["foo"] == "bar" + + test_vars["foo"] = "hello" + assert test_vars.get_variable("foo") == "hello" + assert test_vars["foo"] == "hello" + + with pytest.raises(KeyError): + assert test_vars["world"] + + with pytest.raises(MysqlError): + test_vars["world"] = "hello"