-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_openvpnclient.py
262 lines (206 loc) · 8.77 KB
/
test_openvpnclient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Test the OpenVPNClient class.
The client server model uses trusted fingerprints to authenticate the server. See:
https://github.com/openvpn/openvpn/blob/master/doc/man-sections/example-fingerprint.rst
Test cases:
1. Connect and disconnect the OpenVPN client manually
2. Connect and disconnect the OpenVPN client automatically using the context manager
3. Disconnect OpenVPN client automatically on SIGINT (Ctrl+C)
4. Disconnect when not connected
5. Connect when already connected
6. Invalid client configuration syntax
7. Server not reachable (invalid ip)
8. Wrong path to ovpn config file
9. Connection attempt timeout
"""
# ruff: noqa: S101, test code should use asserts
from __future__ import annotations
import os
import signal
import subprocess
from pathlib import Path
from subprocess import DEVNULL, PIPE, Popen
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable, Generator
import pytest
from openvpnclient import PID_FILE, STDERR_FILE, STDOUT_FILE, OpenVPNClient, Status
@pytest.fixture(autouse=True)
def check_no_lingering_files() -> Generator[None, None, None]:
"""Check if the files are removed after each test."""
yield
assert not PID_FILE.exists()
assert not STDERR_FILE.exists()
assert not STDOUT_FILE.exists()
@pytest.fixture
def openvpn_client(paths: dict[str]) -> OpenVPNClient:
"""Return an OpenVPNClient instance."""
return OpenVPNClient(paths["clientconfig"])
@pytest.fixture(scope="module")
def server_details() -> dict[str]:
"""Information about the test server."""
return {
"public_port": "42812",
"public_ip": "127.0.0.1",
"base_ip": "127.0.0.0",
"netmask": "255.255.255.0",
}
@pytest.fixture(scope="module")
def tmp_dir(tmpdir_factory: pytest.TempdirFactory) -> str:
"""Temporary directory for various storage."""
return str(tmpdir_factory.mktemp("ovpn"))
@pytest.fixture(scope="module")
def paths(tmp_dir: str) -> dict[str]:
"""OpenVPN configuration file paths."""
return {
"servercrt": tmp_dir + "/server.crt",
"serverpkey": tmp_dir + "/server.key",
"clientcrt": tmp_dir + "/client.crt",
"clientpkey": tmp_dir + "/client.key",
"clientconfig": tmp_dir + "/client.ovpn",
"clientconfig_badserver": tmp_dir + "/badserver.ovpn",
"clientconfig_badsyntax": tmp_dir + "/badsyntax.ovpn",
"not_a_config_path": tmp_dir,
}
@pytest.fixture(scope="module")
def fingerprint(
gen_creds: Callable[[str, str, str], None], paths: dict[str]
) -> dict[str, str]:
"""Generate client/server certificates at `paths` and return their fingerprints."""
gen_creds("CLIENT", paths["clientpkey"], paths["clientcrt"])
gen_creds("SERVER", paths["serverpkey"], paths["servercrt"])
def get_fingerprint(certpath: str) -> str:
fingerprint_cmd = f"openssl x509 -fingerprint -sha256 -in {certpath} -noout"
return (
subprocess.run(
fingerprint_cmd.split(),
stdout=PIPE,
text=True,
check=True,
)
.stdout.split("=")[1]
.strip()
)
return {
"client": get_fingerprint(paths["clientcrt"]),
"server": get_fingerprint(paths["servercrt"]),
}
@pytest.fixture(scope="module")
def gen_creds() -> str:
"""Create a self-signed certificate."""
keygen_cmd = "openssl ecparam -name secp384r1 -genkey -noout -out %s"
gen_cert_cmd = (
"openssl req -x509 "
"-new -key %s "
"-out %s "
"-sha256 -days 1 -nodes "
"-subj /CN=TEST%s"
)
def gen(ident: str, keypath: str, certpath: str) -> None:
subprocess.run((keygen_cmd % keypath).split(), check=True)
subprocess.run((gen_cert_cmd % (keypath, certpath, ident)).split(), check=True)
return gen
@pytest.fixture(scope="module", autouse=True)
def gen_clientconfs(
server_details: dict[str], fingerprint: dict[str], paths: dict[str]
) -> None:
"""Create mock client configurations."""
conf = (
"client\n"
f"remote {server_details['public_ip']} {server_details['public_port']}\n"
"explicit-exit-notify 5\n"
"<key>\n"
f"{Path(paths['clientpkey']).read_text(encoding='ascii')}\n"
"</key>\n"
"<cert>\n"
f"{Path(paths['clientcrt']).read_text(encoding='ascii')}\n"
"</cert>\n"
f"peer-fingerprint {fingerprint['server']}"
)
with Path(paths["clientconfig"]).open("w", encoding="ascii") as f:
f.write(conf)
with Path(paths["clientconfig_badserver"]).open("w", encoding="ascii") as f:
f.write(conf.replace("1", "3"))
with Path(paths["clientconfig_badsyntax"]).open("w", encoding="ascii") as f:
f.write(conf.replace("client", "testing"))
@pytest.fixture(scope="module", autouse=True)
def local_server(
server_details: dict, paths: dict[str], fingerprint: dict[str]
) -> Generator[None, None, None]:
"""Start a local OpenVPN server for the duration of the test session."""
must_supply_password = OpenVPNClient._must_supply_password() # noqa: SLF001
sudo_pw_option = "-S " if must_supply_password else ""
ovpn_server_cmd = (
f"sudo {sudo_pw_option}"
"openvpn "
f"--server {server_details['base_ip']} {server_details['netmask']} "
f"--port {server_details['public_port']} "
f"--peer-fingerprint {fingerprint['client']} "
f"--cert {paths['servercrt']} "
f"--key {paths['serverpkey']} "
"--dev tun_server "
"--dh none "
"--verb 3 "
)
srv_proc = Popen(
ovpn_server_cmd.split(), text=True, stdin=PIPE, stdout=PIPE, stderr=PIPE
)
if must_supply_password:
srv_proc.stdin.write(os.environ["SUDO_PASSWORD"] + "\n")
srv_proc.stdin.flush()
yield
kill_srv_cmd = f"sudo {sudo_pw_option}kill {srv_proc.pid}"
kill_proc = Popen(
kill_srv_cmd.split(), text=True, stdin=PIPE, stdout=DEVNULL, stderr=DEVNULL
)
if must_supply_password:
kill_proc.stdin.write(os.environ["SUDO_PASSWORD"] + "\n")
kill_proc.stdin.flush()
def test_connect_then_disconnect(openvpn_client: OpenVPNClient) -> None:
"""The basic use case: connect and disconnect."""
openvpn_client.connect()
assert openvpn_client.status is Status.CONNECTED
openvpn_client.disconnect()
assert OpenVPNClient._get_pid() == -1 # noqa: SLF001
def test_context_manager(openvpn_client: OpenVPNClient) -> None:
"""Test that the context manager works as the above test."""
with openvpn_client as open_vpn:
assert open_vpn.status is Status.CONNECTED
assert OpenVPNClient._get_pid() == -1 # noqa: SLF001
def test_ctrlc_disconnects(paths: dict) -> None:
"""If user cancels with ctrl+c, the client should disconnect if instructed so."""
client = OpenVPNClient(paths["clientconfig"])
client.connect(sigint_disconnect=True)
with pytest.raises(KeyboardInterrupt):
os.kill(os.getpid(), signal.SIGINT)
assert client._get_pid() == -1 # noqa: SLF001
assert client.status is Status.USER_CANCELLED
def test_disconnect_when_not_connected(openvpn_client: OpenVPNClient) -> None:
"""Disconnecting when not connected should raise an error."""
with pytest.raises(ProcessLookupError):
openvpn_client.disconnect()
def test_already_connected(openvpn_client: OpenVPNClient) -> None:
"""Refuse to connect if already connected."""
openvpn_client.connect()
with pytest.raises(ConnectionRefusedError):
openvpn_client.connect()
openvpn_client.disconnect()
def test_invalid_client_config_syntax(paths: dict) -> None:
"""Invalid client configuration should raise an error."""
with pytest.raises(TimeoutError): # noqa: SIM117
with OpenVPNClient(paths["clientconfig_badsyntax"]):
raise AssertionError("Should not reach here") # noqa: EM101, TRY003
def test_server_not_reachable(paths: dict) -> None:
"""Make sure no connection is made when the server is unreachable."""
with pytest.raises(TimeoutError): # noqa: SIM117
with OpenVPNClient(paths["clientconfig_badserver"]):
raise AssertionError("Should not reach here") # noqa: EM101, TRY003
def test_invalid_paths(paths: dict) -> None:
"""Make sure an invalid path is found and not used to connect."""
with pytest.raises(FileNotFoundError): # noqa: SIM117
with OpenVPNClient(paths["not_a_config_path"]):
raise AssertionError("Should not reach here") # noqa: EM101, TRY003
def test_connection_attempt_timeout(paths: dict) -> None:
"""Make sure a connection time out is handled correctly."""
with pytest.raises(TimeoutError): # noqa: SIM117
with OpenVPNClient(paths["clientconfig"], connect_timeout=0.5):
raise AssertionError("Should not reach here") # noqa: EM101, TRY003