Skip to content

Commit

Permalink
clean up ssh.py submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
mpenkov committed Feb 20, 2024
1 parent efc32ca commit 3b8372c
Showing 1 changed file with 87 additions and 38 deletions.
125 changes: 87 additions & 38 deletions smart_open/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
import logging
import urllib.parse

from typing import (
Dict,
Callable,
Tuple,
)

try:
import paramiko
except ImportError:
Expand Down Expand Up @@ -70,6 +76,25 @@ def _str2bool(string):
return True
raise ValueError(f"Expected 'yes' / 'no', got {string}.")

#
# The parameter names used by Paramiko (and smart_open) slightly differ to
# those used in ~/.ssh/config, so we use a mapping to bridge the gap.
#
# The keys are option names as they appear in Paramiko (and smart_open)
# The values are a tuples containing:
#
# 1. their corresponding names in the ~/.ssh/config file
# 2. a callable to convert the parameter value from a string to the appropriate type
#
_PARAMIKO_CONFIG_MAP: Dict[str, Tuple[str, Callable]] = {
"timeout": ("connecttimeout", float),
"compress": ("compression", _str2bool),
"gss_auth": ("gssapiauthentication", _str2bool),
"gss_kex": ("gssapikeyexchange", _str2bool),
"gss_deleg_creds": ("gssapidelegatecredentials", _str2bool),
"gss_trust_dns": ("gssapitrustdns", _str2bool),
}


def parse_uri(uri_as_string):
split_uri = urllib.parse.urlsplit(uri_as_string)
Expand Down Expand Up @@ -117,57 +142,81 @@ def _maybe_fetch_config(host, username=None, password=None, port=None, transport
transport_params["connect_kwargs"] = {}

# Attempt to load an OpenSSH config.
# NOTE: connections configured in this way are not guaranteed to perform exactly as
# they do in typical usage due to mismatches between the set of OpenSSH configuration
# options and those that Paramiko supports. We provide a best attempt,
# and support:
#
# Connections configured in this way are not guaranteed to perform exactly
# as they do in typical usage due to mismatches between the set of OpenSSH
# configuration options and those that Paramiko supports. We provide a best
# attempt, and support:
#
# - hostname -> address resolution
# - username inference
# - port inference
# - identityfile inference
# - connection timeout inference
# - compression selection
# - GSS configuration
for config_filename in _SSH_CONFIG_FILES:
if os.path.exists(config_filename):
#
connect_params = transport_params["connect_kwargs"]
config_files = [f for f in _SSH_CONFIG_FILES if os.path.exists(f)]
#
# This is the actual name of the host. The input host may actually be an
# alias.
#
actual_hostname = ""

for config_filename in config_files:
try:
cfg = paramiko.SSHConfig.from_path(config_filename)
except PermissionError:
continue

if host not in cfg.get_hostnames():
continue

cfg = cfg.lookup(host)
if username is None:
username = cfg.get("user", None)

if not actual_hostname:
actual_hostname = cfg["hostname"]

if port is None:
try:
cfg = paramiko.SSHConfig.from_path(config_filename)
except PermissionError:
continue
if host in cfg.get_hostnames():
cfg = cfg.lookup(host)
host = cfg["hostname"]
if username is None:
username = cfg.get("user", None)
if port is None and cfg.get("port", None) is not None:
port = int(cfg["port"])

# Special case, as we can have multiple identity files, so we check that the
# identityfile list has len > 0. This should be redundant, but keeping it for safety.
if (transport_params["connect_kwargs"].get("key_filename", None) is None
and "identityfile" in cfg and len(cfg.get("identityfile", []))
):
transport_params["connect_kwargs"]["key_filename"] = cfg["identityfile"]

# Map parameters from config to their required values for Paramiko's `connect` fn.
_connect_kwarg_map = dict(
timeout=dict(key="connecttimeout", type=float),
compress=dict(key="compression", type=_str2bool),
gss_auth=dict(key="gssapiauthentication", type=_str2bool),
gss_kex=dict(key="gssapikeyexchange", type=_str2bool),
gss_deleg_creds=dict(key="gssapidelegatecredentials", type=_str2bool),
gss_trust_dns=dict(key="gssapitrustdns", type=_str2bool)
)
for target, field in _connect_kwarg_map.items():
if (
transport_params["connect_kwargs"].get(target, None) is None and field["key"] in cfg
):
transport_params["connect_kwargs"][target] = field["type"](cfg[field["key"]])
port = int(cfg["port"])
except (IndexError, ValueError):
#
# Nb. ignore missing/invalid port numbers
#
pass

#
# Special case, as we can have multiple identity files, so we check
# that the identityfile list has len > 0. This should be redundant, but
# keeping it for safety.
#
if connect_params.get("key_filename") is None:
identityfile = cfg.get("identityfile", [])
if len(identityfile):
connect_params["key_filename"] = identityfile

for param_name, (sshcfg_name, from_str) in _PARAMIKO_CONFIG_MAP.items():
if connect_params.get(param_name) is None and sshcfg_name in cfg:
connect_params[param_name] = from_str(cfg[sshcfg_name])

#
# Continue working through other config files, if there are any,
# as they may contain more options for our host
#

if port is None:
port = DEFAULT_PORT

if not username:
username = getpass.getuser()

if actual_hostname:
host = actual_hostname

return host, username, password, port, transport_params


Expand Down

0 comments on commit 3b8372c

Please sign in to comment.