diff --git a/src/qslib/machine.py b/src/qslib/machine.py index 8d1c3b2..d718654 100644 --- a/src/qslib/machine.py +++ b/src/qslib/machine.py @@ -144,6 +144,8 @@ def __init__( max_access_level: AccessLevel | str = AccessLevel.Controller, port: int | None = None, ssl: bool | None = None, + client_certificate_path: str | None = None, + server_ca_file: str | None = None, _initial_access_level: AccessLevel | str = AccessLevel.Observer, ): self.host = host @@ -154,6 +156,8 @@ def __init__( self.max_access_level = AccessLevel(max_access_level) self._initial_access_level = AccessLevel(_initial_access_level) self._connection = None + self.client_certificate_path = client_certificate_path + self.server_ca_file = server_ca_file def connect(self) -> None: """Open the connection manually.""" @@ -165,6 +169,8 @@ def connect(self) -> None: ssl=self.ssl, password=self.password, initial_access_level=self._initial_access_level, + client_certificate_path=self.client_certificate_path, + server_ca_file=self.server_ca_file, ) loop.run_until_complete(self.connection.connect()) self._current_access_level = self.get_access_level()[0] diff --git a/src/qslib/qsconnection_async.py b/src/qslib/qsconnection_async.py index 5cbb07f..72812a0 100644 --- a/src/qslib/qsconnection_async.py +++ b/src/qslib/qsconnection_async.py @@ -25,14 +25,6 @@ log = logging.getLogger(__name__) -CTX = ssl.create_default_context() -CTX.check_hostname = False -CTX.verify_mode = ssl.CERT_NONE -CTX.minimum_version = ( - ssl.TLSVersion.SSLv3 -) # Yes, we actually need this for QS5 connections - - def _gen_auth_response(password: str, challenge_string: str) -> str: return hmac.digest(password.encode(), challenge_string.encode(), "md5").hex() @@ -240,6 +232,8 @@ def __init__( authenticate_on_connect: bool = True, initial_access_level: AccessLevel = AccessLevel.Observer, password: Optional[str] = None, + client_certificate_path: Optional[str] = None, + server_ca_file: Optional[str] = None, ): """Create a connection to a QuantStudio Instrument Server.""" self.host = host @@ -248,6 +242,8 @@ def __init__( self.password = password self._initial_access_level = initial_access_level self._authenticate_on_connect = authenticate_on_connect + self.client_certificate_path = client_certificate_path + self.server_ca_file = server_ca_file def _parse_access_line(self, aline: str) -> None: # pylint: disable=attribute-defined-outside-init @@ -274,6 +270,19 @@ async def connect( if initial_access_level is not None: self._initial_access_level = initial_access_level + + CTX = ssl.create_default_context() + CTX.check_hostname = False + CTX.verify_mode = ssl.CERT_NONE + CTX.minimum_version = ( + ssl.TLSVersion.SSLv3 + ) # Yes, we actually need this for QS5 connections + if self.client_certificate_path is not None: + CTX.load_cert_chain(self.client_certificate_path) + if self.server_ca_file is not None: + CTX.load_verify_locations(self.server_ca_file) + CTX.verify_mode = ssl.CERT_REQUIRED + self.loop = asyncio.get_running_loop() if (self.ssl is None) and (self.port is None):