Skip to content

Commit

Permalink
Add SSL support (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Aug 25, 2024
1 parent 51d410e commit 3bce385
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ services:
ARTEMIS_HOST: artemis

artemis:
image: apache/activemq-artemis:2.34.0-alpine
image: apache/activemq-artemis:2.37.0-alpine
environment:
ARTEMIS_USER: admin
ARTEMIS_PASSWORD: ":=123"
Expand Down
5 changes: 4 additions & 1 deletion stompman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field
from functools import partial
from ssl import SSLContext
from types import TracebackType
from typing import ClassVar, Self
from typing import ClassVar, Literal, Self

from stompman.config import ConnectionParameters, Heartbeat
from stompman.connection import AbstractConnection, Connection
Expand Down Expand Up @@ -32,6 +33,7 @@ class Client:
on_heartbeat: Callable[[], None] | None = None

heartbeat: Heartbeat = field(default=Heartbeat(1000, 1000))
ssl: Literal[True] | SSLContext | None = None
connect_retry_attempts: int = 3
connect_retry_interval: int = 1
connect_timeout: int = 2
Expand Down Expand Up @@ -71,6 +73,7 @@ def __post_init__(self) -> None:
read_timeout=self.read_timeout,
read_max_chunk_size=self.read_max_chunk_size,
write_retry_attempts=self.write_retry_attempts,
ssl=self.ssl,
)

async def __aenter__(self) -> Self:
Expand Down
30 changes: 26 additions & 4 deletions stompman/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from collections.abc import AsyncGenerator, Generator, Iterator
from contextlib import contextmanager, suppress
from dataclasses import dataclass
from typing import Protocol, Self, cast
from ssl import SSLContext
from typing import Literal, Protocol, Self, cast

from stompman.errors import ConnectionLostError
from stompman.frames import AnyClientFrame, AnyServerFrame
Expand All @@ -14,7 +15,14 @@
class AbstractConnection(Protocol):
@classmethod
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
cls,
*,
host: str,
port: int,
timeout: int,
read_max_chunk_size: int,
read_timeout: int,
ssl: Literal[True] | SSLContext | None,
) -> Self | None: ...
async def close(self) -> None: ...
def write_heartbeat(self) -> None: ...
Expand All @@ -36,17 +44,31 @@ class Connection(AbstractConnection):
writer: asyncio.StreamWriter
read_max_chunk_size: int
read_timeout: int
ssl: Literal[True] | SSLContext | None

@classmethod
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
cls,
*,
host: str,
port: int,
timeout: int,
read_max_chunk_size: int,
read_timeout: int,
ssl: Literal[True] | SSLContext | None,
) -> Self | None:
try:
reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout)
except (TimeoutError, ConnectionError, socket.gaierror):
return None
else:
return cls(reader=reader, writer=writer, read_max_chunk_size=read_max_chunk_size, read_timeout=read_timeout)
return cls(
reader=reader,
writer=writer,
read_max_chunk_size=read_max_chunk_size,
read_timeout=read_timeout,
ssl=ssl,
)

async def close(self) -> None:
self.writer.close()
Expand Down
5 changes: 4 additions & 1 deletion stompman/connection_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from ssl import SSLContext
from types import TracebackType
from typing import TYPE_CHECKING, Self
from typing import TYPE_CHECKING, Literal, Self

from stompman.config import ConnectionParameters
from stompman.connection import AbstractConnection
Expand Down Expand Up @@ -34,6 +35,7 @@ class ConnectionManager:
connect_retry_attempts: int
connect_retry_interval: int
connect_timeout: int
ssl: Literal[True] | SSLContext | None
read_timeout: int
read_max_chunk_size: int
write_retry_attempts: int
Expand Down Expand Up @@ -63,6 +65,7 @@ async def _create_connection_to_one_server(self, server: ConnectionParameters) -
timeout=self.connect_timeout,
read_max_chunk_size=self.read_max_chunk_size,
read_timeout=self.read_timeout,
ssl=self.ssl,
):
return ActiveConnectionState(
connection=connection,
Expand Down
13 changes: 11 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any, Self, TypeVar
from ssl import SSLContext
from typing import Any, Literal, Self, TypeVar

import pytest
from polyfactory.factories.dataclass_factory import DataclassFactory
Expand Down Expand Up @@ -37,7 +38,14 @@ def noop_error_handler(exception: Exception, frame: stompman.MessageFrame) -> No
class BaseMockConnection(AbstractConnection):
@classmethod
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
cls,
*,
host: str,
port: int,
timeout: int,
read_max_chunk_size: int,
read_timeout: int,
ssl: Literal[True] | SSLContext | None,
) -> Self | None:
return cls()

Expand Down Expand Up @@ -78,6 +86,7 @@ class EnrichedConnectionManager(ConnectionManager):
read_timeout: int = 4
read_max_chunk_size: int = 5
write_retry_attempts: int = 3
ssl: Literal[True] | SSLContext | None = None


DataclassType = TypeVar("DataclassType")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

async def make_connection() -> Connection | None:
return await Connection.connect(
host="localhost", port=12345, timeout=2, read_max_chunk_size=1024 * 1024, read_timeout=2
host="localhost", port=12345, timeout=2, read_max_chunk_size=1024 * 1024, read_timeout=2, ssl=None
)


Expand Down
23 changes: 20 additions & 3 deletions tests/test_connection_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from collections.abc import AsyncGenerator, AsyncIterable
from typing import Self
from ssl import SSLContext
from typing import Literal, Self
from unittest import mock

import pytest
Expand Down Expand Up @@ -29,7 +30,14 @@ async def test_connect_attempts_ok(ok_on_attempt: int, monkeypatch: pytest.Monke
class MockConnection(BaseMockConnection):
@classmethod
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
cls,
*,
host: str,
port: int,
timeout: int,
read_max_chunk_size: int,
read_timeout: int,
ssl: Literal[True] | SSLContext | None,
) -> Self | None:
assert (host, port) == (manager.servers[0].host, manager.servers[0].port)
nonlocal attempts
Expand All @@ -42,6 +50,7 @@ async def connect(
timeout=timeout,
read_max_chunk_size=read_max_chunk_size,
read_timeout=read_timeout,
ssl=ssl,
)
if attempts == ok_on_attempt
else None
Expand All @@ -67,7 +76,14 @@ async def test_connect_to_any_server_ok() -> None:
class MockConnection(BaseMockConnection):
@classmethod
async def connect(
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
cls,
*,
host: str,
port: int,
timeout: int,
read_max_chunk_size: int,
read_timeout: int,
ssl: Literal[True] | SSLContext | None,
) -> Self | None:
return (
await super().connect(
Expand All @@ -76,6 +92,7 @@ async def connect(
timeout=timeout,
read_max_chunk_size=read_max_chunk_size,
read_timeout=read_timeout,
ssl=ssl,
)
if port == successful_server.port
else None
Expand Down

0 comments on commit 3bce385

Please sign in to comment.