Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow async session factory #55

Merged
merged 4 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v2
- name: Set up Kerberos
Expand All @@ -30,12 +30,15 @@ jobs:
run: |
make test
- name: Lint
if: matrix.python-version != '3.7'
run: |
make lint
- name: Format
if: matrix.python-version != '3.7'
run: |
make format-check
- name: Type annotations
if: matrix.python-version != '3.7'
run: |
make types
mysql-connector-j:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ celerybeat.pid
.venv
env/
venv/
venv*/
ENV/
env.bak/
venv.bak/
Expand Down
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ disable=
unused-argument,
redefined-outer-name,
too-many-statements,
multiple-statements,

[FORMAT]

Expand Down
8 changes: 5 additions & 3 deletions integration/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ async def get_user(self, username):
password = user.get("password")
return User(
name=username,
auth_string=NativePasswordAuthPlugin.create_auth_string(password)
if password
else None,
auth_string=(
NativePasswordAuthPlugin.create_auth_string(password)
if password
else None
),
auth_plugin=NativePasswordAuthPlugin.name,
)
elif auth_plugin == "authentication_kerberos":
Expand Down
1 change: 1 addition & 0 deletions mysql_mimic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementation of the mysql server wire protocol"""

from mysql_mimic.auth import (
User,
IdentityProvider,
Expand Down
8 changes: 5 additions & 3 deletions mysql_mimic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,11 @@ async def handle_stmt_fetch(self, data: bytes) -> None:

await self.stream.write(
self.ok_or_eof(
flags=types.ServerStatus.SERVER_STATUS_LAST_ROW_SENT
if done
else types.ServerStatus.SERVER_STATUS_CURSOR_EXISTS
flags=(
types.ServerStatus.SERVER_STATUS_LAST_ROW_SENT
if done
else types.ServerStatus.SERVER_STATUS_CURSOR_EXISTS
)
)
)

Expand Down
3 changes: 1 addition & 2 deletions mysql_mimic/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ class BaseInfoSchema:
Base InfoSchema interface used by the `Session` class.
"""

async def query(self, expression: exp.Expression) -> AllowedResult:
...
async def query(self, expression: exp.Expression) -> AllowedResult: ...


class InfoSchema(BaseInfoSchema):
Expand Down
12 changes: 9 additions & 3 deletions mysql_mimic/server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import asyncio
import inspect
from ssl import SSLContext
from socket import socket
from typing import Callable, Any, Optional, Sequence
from typing import Callable, Any, Optional, Sequence, Awaitable

from mysql_mimic.auth import IdentityProvider, SimpleIdentityProvider
from mysql_mimic.connection import Connection
Expand Down Expand Up @@ -36,7 +37,7 @@ class MysqlServer:

def __init__(
self,
session_factory: Callable[[], BaseSession] = Session,
session_factory: Callable[[], BaseSession | Awaitable[BaseSession]] = Session,
capabilities: Capabilities = DEFAULT_SERVER_CAPABILITIES,
control: Control | None = None,
identity_provider: IdentityProvider | None = None,
Expand All @@ -57,9 +58,14 @@ async def _client_connected_cb(
) -> None:
stream = MysqlStream(reader, writer)

if inspect.iscoroutinefunction(self.session_factory):
session = await self.session_factory()
else:
session = self.session_factory()

connection = Connection(
stream=stream,
session=self.session_factory(),
session=session,
control=self.control,
server_capabilities=self.capabilities,
identity_provider=self.identity_provider,
Expand Down
4 changes: 3 additions & 1 deletion mysql_mimic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def nonce(nbytes: int) -> bytes:

def find_tables(expression: exp.Expression) -> List[exp.Table]:
"""Find all tables in an expression"""
if isinstance(expression, (exp.Subqueryable, exp.Subquery)):
if isinstance(
expression, (exp.Select, exp.Subquery, exp.Union, exp.Except, exp.Intersect)
):
return [
source
for scope in traverse_scope(expression)
Expand Down
6 changes: 2 additions & 4 deletions mysql_mimic/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from mysql_mimic.errors import MysqlError, ErrorCode


class Default:
...
class Default: ...


VariableType = Callable[[Any], Any]
Expand Down Expand Up @@ -102,8 +101,7 @@ def list(self) -> list[tuple[str, str]]:

@property
@abc.abstractmethod
def schema(self) -> dict[str, VariableSchema]:
...
def schema(self) -> dict[str, VariableSchema]: ...


class GlobalVariables(Variables):
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
return await loop.run_in_executor(None, func_call)


@pytest.fixture
def session() -> MockSession:
@pytest_asyncio.fixture
async def session() -> MockSession:
return MockSession()


Expand Down
Loading