diff --git a/aioftp/client.py b/aioftp/client.py index 0ca5301..4a2ec8a 100644 --- a/aioftp/client.py +++ b/aioftp/client.py @@ -32,10 +32,11 @@ logger = logging.getLogger(__name__) -async def open_connection(host, port, loop, create_connection): +async def open_connection(host, port, loop, create_connection, ssl=None): reader = asyncio.StreamReader(loop=loop) protocol = asyncio.StreamReaderProtocol(reader, loop=loop) - transport, _ = await create_connection(lambda: protocol, host, port) + transport, _ = await create_connection(lambda: protocol, + host, port, ssl=ssl) writer = asyncio.StreamWriter(transport, protocol, reader, loop) return reader, writer @@ -108,7 +109,8 @@ class BaseClient: def __init__(self, *, loop=None, create_connection=None, socket_timeout=None, read_speed_limit=None, write_speed_limit=None, path_timeout=None, - path_io_factory=pathio.PathIO, encoding="utf-8"): + path_io_factory=pathio.PathIO, encoding="utf-8", + ssl=None): self.loop = loop or asyncio.get_event_loop() self.create_connection = create_connection or \ self.loop.create_connection @@ -122,6 +124,7 @@ def __init__(self, *, loop=None, create_connection=None, self.path_io = path_io_factory(timeout=path_timeout, loop=loop) self.encoding = encoding self.stream = None + self.ssl = ssl async def connect(self, host, port=DEFAULT_PORT): self.server_host = host @@ -131,6 +134,7 @@ async def connect(self, host, port=DEFAULT_PORT): port, self.loop, self.create_connection, + self.ssl, ) self.stream = ThrottleStreamIO( reader, @@ -491,6 +495,14 @@ class Client(BaseClient): :param encoding: encoding to use for convertion strings to bytes :type encoding: :py:class:`str` + + :param ssl: if given and not false, a SSL/TLS transport is created + (by default a plain TCP transport is created). + If ssl is a ssl.SSLContext object, this context is used to create + the transport; if ssl is True, a default context returned from + ssl.create_default_context() is used. + Please look :py:meth:`asyncio.loop.create_connection` docs. + :type ssl: :py:class:`bool` or :py:class:`ssl.SSLContext` """ async def connect(self, host, port=DEFAULT_PORT): """ diff --git a/aioftp/server.py b/aioftp/server.py index bf30fc6..6a08b53 100644 --- a/aioftp/server.py +++ b/aioftp/server.py @@ -427,6 +427,7 @@ async def start(self, host=None, port=0, **kwargs): host, port, loop=self.loop, + ssl=self.ssl, **self._start_server_extra_arguments, ) for sock in self.server.sockets: @@ -776,6 +777,11 @@ class Server(AbstractServer): :param encoding: encoding to use for convertion strings to bytes :type encoding: :py:class:`str` + + :param ssl: can be set to an :py:class:`ssl.SSLContext` instance + to enable TLS over the accepted connections. + Please look :py:meth:`asyncio.loop.create_server` docs. + :type ssl: :py:class:`ssl.SSLContext` """ path_facts = ( ("st_size", "Size"), @@ -799,7 +805,8 @@ def __init__(self, read_speed_limit_per_connection=None, write_speed_limit_per_connection=None, data_ports=None, - encoding="utf-8"): + encoding="utf-8", + ssl=None): self.loop = loop or asyncio.get_event_loop() self.block_size = block_size self.socket_timeout = socket_timeout @@ -832,6 +839,7 @@ def __init__(self, ) self.throttle_per_user = {} self.encoding = encoding + self.ssl = ssl async def dispatcher(self, reader, writer): host, port, *_ = writer.transport.get_extra_info("peername", ("", "")) diff --git a/setup.py b/setup.py index be57d33..a23b233 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def run_tests(self): packages=find_packages(), python_requires=" >= 3.5.3", install_requires=[], - tests_require=["nose", "coverage"], + tests_require=["nose", "coverage", "trustme"], cmdclass={"test": NoseTestCommand}, include_package_data=True ) diff --git a/tests/common.py b/tests/common.py index cca0ccc..09e66a1 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,15 +1,26 @@ import asyncio import functools -import pathlib import logging +import pathlib import shutil import socket +import ssl import nose +import trustme import aioftp +ca = trustme.CA() +server_cert = ca.issue_server_cert("127.0.0.1", "::1") + +ssl_server = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +server_cert.configure_cert(ssl_server) + +ssl_client = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) +ca.configure_trust(ssl_client) + PORT = 8888 @@ -23,7 +34,7 @@ def wrapper(): s_args, s_kwargs = server_args c_args, c_kwargs = client_args - def run_in_loop(s_args, s_kwargs, c_args, c_kwargs): + def run_in_loop(s_args, s_kwargs, c_args, c_kwargs, s_ssl=None, c_ssl=None): logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", @@ -31,8 +42,8 @@ def run_in_loop(s_args, s_kwargs, c_args, c_kwargs): ) loop = asyncio.new_event_loop() asyncio.set_event_loop(None) - server = aioftp.Server(*s_args, loop=loop, **s_kwargs) - client = aioftp.Client(*c_args, loop=loop, **c_kwargs) + server = aioftp.Server(*s_args, loop=loop, ssl=s_ssl, **s_kwargs) + client = aioftp.Client(*c_args, loop=loop, ssl=c_ssl, **c_kwargs) try: loop.run_until_complete(f(loop, client, server)) finally: @@ -46,8 +57,10 @@ def run_in_loop(s_args, s_kwargs, c_args, c_kwargs): for factory in (aioftp.PathIO, aioftp.AsyncPathIO): s_kwargs["path_io_factory"] = factory run_in_loop(s_args, s_kwargs, c_args, c_kwargs) + run_in_loop(s_args, s_kwargs, c_args, c_kwargs, ssl_server, ssl_client) else: run_in_loop(s_args, s_kwargs, c_args, c_kwargs) + run_in_loop(s_args, s_kwargs, c_args, c_kwargs, ssl_server, ssl_client) return wrapper diff --git a/tests/test-connection.py b/tests/test-connection.py index cc484b3..18f6a8d 100644 --- a/tests/test-connection.py +++ b/tests/test-connection.py @@ -98,7 +98,7 @@ async def test_pasv_connection_ports_not_added(loop, client, server): @with_connection async def test_pasv_connection_ports(loop, client, server): - clients = [aioftp.Client(loop=loop) for _ in range(2)] + clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(2)] expected_data_ports = [30000, 30001] for i, client in enumerate(clients): @@ -128,7 +128,7 @@ async def test_data_ports_remains_empty(loop, client, server): @with_connection async def test_pasv_connection_port_reused(loop, client, server): - clients = [aioftp.Client(loop=loop) for _ in range(2)] + clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(2)] for client in clients: diff --git a/tests/test-maximum-connections.py b/tests/test-maximum-connections.py index 20a9200..188450b 100644 --- a/tests/test-maximum-connections.py +++ b/tests/test-maximum-connections.py @@ -6,7 +6,7 @@ @with_connection async def test_multiply_connections_no_limits(loop, client, server): - clients = [aioftp.Client(loop=loop) for _ in range(4)] + clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(4)] for client in clients: await client.connect("127.0.0.1", PORT) @@ -23,7 +23,7 @@ async def test_multiply_connections_no_limits(loop, client, server): @with_connection async def test_multiply_connections_limited_error(loop, client, server): - clients = [aioftp.Client(loop=loop) for _ in range(5)] + clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)] for client in clients: await client.connect("127.0.0.1", PORT) @@ -53,7 +53,7 @@ async def test_multiply_user_commands(loop, client, server): async def test_multiply_connections_with_user_limited_error(loop, client, server): - clients = [aioftp.Client(loop=loop) for _ in range(5)] + clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)] for client in clients: await client.connect("127.0.0.1", PORT) @@ -69,7 +69,7 @@ async def test_multiply_connections_with_user_limited_error(loop, client, @with_connection async def test_multiply_connections_relogin_balanced(loop, client, server): - clients = [aioftp.Client(loop=loop) for _ in range(5)] + clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)] for client in clients[:-1]: await client.connect("127.0.0.1", PORT) @@ -90,7 +90,7 @@ async def test_multiply_connections_relogin_balanced(loop, client, server): @expect_codes_in_exception("421") async def test_multiply_connections_server_limit_error(loop, client, server): - clients = [aioftp.Client(loop=loop) for _ in range(5)] + clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)] for client in clients: await client.connect("127.0.0.1", PORT) @@ -107,7 +107,7 @@ async def test_multiply_connections_server_limit_error(loop, client, server): async def test_multiply_connections_server_relogin_balanced(loop, client, server): - clients = [aioftp.Client(loop=loop) for _ in range(5)] + clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)] for client in clients[:-1]: await client.connect("127.0.0.1", PORT) diff --git a/tests/test-throttle.py b/tests/test-throttle.py index f0d9689..7549abf 100644 --- a/tests/test-throttle.py +++ b/tests/test-throttle.py @@ -293,15 +293,15 @@ async def test_server_global_write_throttle_multi_users(loop, client, server, async def worker(fname): - client = aioftp.Client(loop=loop) - await client.connect("127.0.0.1", PORT) - await client.login() - await client.download( + _client = aioftp.Client(loop=loop, ssl=client.ssl) + await _client.connect("127.0.0.1", PORT) + await _client.login() + await _client.download( "tests/foo/foo.txt", str.format("tests/foo/{}", fname), write_into=True ) - await client.quit() + await _client.quit() fnames = ("bar.txt", "baz.txt", "hurr.txt") big_file = tmp_dir / "foo.txt" @@ -335,15 +335,15 @@ async def test_server_global_read_throttle_multi_users(loop, client, server, *, async def worker(fname): - client = aioftp.Client(loop=loop) - await client.connect("127.0.0.1", PORT) - await client.login() - await client.upload( + _client = aioftp.Client(loop=loop, ssl=client.ssl) + await _client.connect("127.0.0.1", PORT) + await _client.login() + await _client.upload( "tests/foo/foo.txt", str.format("tests/foo/{}", fname), write_into=True ) - await client.quit() + await _client.quit() fnames = ("bar.txt", "baz.txt", "hurr.txt") big_file = tmp_dir / "foo.txt" @@ -378,15 +378,15 @@ async def test_server_per_connection_write_throttle_multi_users(loop, client, async def worker(fname): - client = aioftp.Client(loop=loop) - await client.connect("127.0.0.1", PORT) - await client.login() - await client.download( + _client = aioftp.Client(loop=loop, ssl=client.ssl) + await _client.connect("127.0.0.1", PORT) + await _client.login() + await _client.download( "tests/foo/foo.txt", str.format("tests/foo/{}", fname), write_into=True ) - await client.quit() + await _client.quit() fnames = ("bar.txt", "baz.txt", "hurr.txt") big_file = tmp_dir / "foo.txt" @@ -421,15 +421,15 @@ async def test_server_per_connection_read_throttle_multi_users(loop, client, async def worker(fname): - client = aioftp.Client(loop=loop) - await client.connect("127.0.0.1", PORT) - await client.login() - await client.upload( + _client = aioftp.Client(loop=loop, ssl=client.ssl) + await _client.connect("127.0.0.1", PORT) + await _client.login() + await _client.upload( "tests/foo/foo.txt", str.format("tests/foo/{}", fname), write_into=True ) - await client.quit() + await _client.quit() fnames = ("bar.txt", "baz.txt", "hurr.txt") big_file = tmp_dir / "foo.txt" @@ -463,15 +463,15 @@ async def test_server_user_per_connection_write_throttle_multi_users(loop, async def worker(fname): - client = aioftp.Client(loop=loop) - await client.connect("127.0.0.1", PORT) - await client.login() - await client.download( + _client = aioftp.Client(loop=loop, ssl=client.ssl) + await _client.connect("127.0.0.1", PORT) + await _client.login() + await _client.download( "tests/foo/foo.txt", str.format("tests/foo/{}", fname), write_into=True ) - await client.quit() + await _client.quit() fnames = ("bar.txt", "baz.txt", "hurr.txt") big_file = tmp_dir / "foo.txt" @@ -505,15 +505,15 @@ async def test_server_user_per_connection_read_throttle_multi_users(loop, async def worker(fname): - client = aioftp.Client(loop=loop) - await client.connect("127.0.0.1", PORT) - await client.login() - await client.upload( + _client = aioftp.Client(loop=loop, ssl=client.ssl) + await _client.connect("127.0.0.1", PORT) + await _client.login() + await _client.upload( "tests/foo/foo.txt", str.format("tests/foo/{}", fname), write_into=True ) - await client.quit() + await _client.quit() fnames = ("bar.txt", "baz.txt", "hurr.txt") big_file = tmp_dir / "foo.txt" @@ -546,15 +546,15 @@ async def test_server_user_global_write_throttle_multi_users(loop, client, async def worker(fname): - client = aioftp.Client(loop=loop) - await client.connect("127.0.0.1", PORT) - await client.login() - await client.download( + _client = aioftp.Client(loop=loop, ssl=client.ssl) + await _client.connect("127.0.0.1", PORT) + await _client.login() + await _client.download( "tests/foo/foo.txt", str.format("tests/foo/{}", fname), write_into=True ) - await client.quit() + await _client.quit() fnames = ("bar.txt", "baz.txt", "hurr.txt") big_file = tmp_dir / "foo.txt" @@ -587,15 +587,15 @@ async def test_server_user_global_read_throttle_multi_users(loop, client, async def worker(fname): - client = aioftp.Client(loop=loop) - await client.connect("127.0.0.1", PORT) - await client.login() - await client.upload( + _client = aioftp.Client(loop=loop, ssl=client.ssl) + await _client.connect("127.0.0.1", PORT) + await _client.login() + await _client.upload( "tests/foo/foo.txt", str.format("tests/foo/{}", fname), write_into=True ) - await client.quit() + await _client.quit() fnames = ("bar.txt", "baz.txt", "hurr.txt") big_file = tmp_dir / "foo.txt"