Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing the Variables class to implement MutableMapping. #67

Merged
merged 4 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mysql_mimic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def __init__(

@property
def server_charset(self) -> CharacterSet:
return CharacterSet[self.session.variables.get("character_set_results")]
return CharacterSet[str(self.session.variables.get("character_set_results"))]

@property
def client_charset(self) -> CharacterSet:
return CharacterSet[self.session.variables.get("character_set_client")]
return CharacterSet[str(self.session.variables.get("character_set_client"))]

async def start(self) -> None:
self._task = asyncio.create_task(self._start())
Expand Down Expand Up @@ -131,7 +131,7 @@ async def connection_phase(self) -> None:
handshake_v10 = packets.make_handshake_v10(
capabilities=self.server_capabilities,
server_charset=self.server_charset,
server_version=self.session.variables.get("version"),
server_version=str(self.session.variables.get("version")),
connection_id=self.connection_id,
auth_data=auth_data,
status_flags=self.status_flags,
Expand Down
11 changes: 5 additions & 6 deletions mysql_mimic/variable_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from contextlib import contextmanager
from typing import Dict, Mapping, Generator
from typing import Dict, Mapping, Generator, MutableMapping, Any

from sqlglot import expressions as exp

from mysql_mimic.intercept import value_to_expression, expression_to_value
from mysql_mimic.variables import Variables

variable_constants = {
"CURRENT_USER",
Expand Down Expand Up @@ -51,28 +50,28 @@ class VariableProcessor:
"""

def __init__(
self, functions: Mapping, variables: Variables, expression: exp.Expression
self, functions: Mapping, variables: MutableMapping, expression: exp.Expression
):
self._functions = functions
self._variables = variables
self._expression = expression

# Stores the original system variable values.
self._orig: Dict[str, str] = {}
self._orig: Dict[str, Any] = {}

@contextmanager
def set_variables(self) -> Generator[exp.Expression, None, None]:
assignments = _get_var_assignments(self._expression)
self._orig = {k: self._variables.get(k) for k in assignments}
for k, v in assignments.items():
self._variables.set(k, v)
self._variables[k] = v

self._replace_variables()

yield self._expression

for k, v in self._orig.items():
self._variables.set(k, v)
self._variables[k] = v

def _replace_variables(self) -> None:
"""Replaces certain functions in the query with literals provided from the mapping in _functions,
Expand Down
40 changes: 26 additions & 14 deletions mysql_mimic/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from datetime import timezone, timedelta
from functools import lru_cache
from typing import Any, Callable, Tuple
from typing import Any, Callable, Tuple, Iterator, MutableMapping

from mysql_mimic.charset import CharacterSet, Collation
from mysql_mimic.errors import MysqlError, ErrorCode
Expand Down Expand Up @@ -57,14 +57,34 @@ class Default: ...
}


class Variables(abc.ABC):
class Variables(abc.ABC, MutableMapping[str, Any]):
"""
Abstract class for MySQL system variables.
"""

def __init__(self) -> None:
# Current variable values
self.values: dict[str, Any] = {}
self._values: dict[str, Any] = {}

def __getitem__(self, key: str) -> Any | None:
cecilycarver marked this conversation as resolved.
Show resolved Hide resolved
key = key.lower()
if key in self._values:
return self._values[key]
_, default, _ = self.get_schema(key)

return default

def __setitem__(self, key: str, value: Any) -> None:
return self.set(key, value)

def __delitem__(self, key: str) -> None:
raise MysqlError(f"Cannot delete session variable {key}.")

def __iter__(self) -> Iterator[str]:
return self._values.__iter__()

def __len__(self) -> int:
return len(self._values)

def get_schema(self, name: str) -> VariableSchema:
schema = self.schema.get(name)
Expand All @@ -84,19 +104,11 @@ def set(self, name: str, value: Any, force: bool = False) -> None:
)

if value is DEFAULT or value is None:
self.values[name] = default
self._values[name] = default
else:
self.values[name] = type_(value)

def get(self, name: str) -> Any:
name = name.lower()
if name in self.values:
return self.values[name]
_, default, _ = self.get_schema(name)

return default
self._values[name] = type_(value)

def list(self) -> list[tuple[str, str]]:
def list(self) -> list[tuple[str, Any]]:
return [(name, self.get(name)) for name in sorted(self.schema)]

@property
Expand Down