Skip to content

Commit

Permalink
Support TCP/IP remote forward. (#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Feb 14, 2025
1 parent e0cc202 commit 5b43d19
Show file tree
Hide file tree
Showing 15 changed files with 659 additions and 17 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class SshClient : IDisposable
Task<DirectForward> StartForwardAsync(EndPoint bindEP, RemoteEndPoint remoteEP, CancellationToken cancellationToken = default);
Task<SocksForward> StartForwardSocksAsync(EndPoint bindEP, CancellationToken cancellationToken = default);

Task<RemoteListener> ListenTcpAsync(string address, int port, CancellationToken cancellationToken = default);

Task<SftpClient> OpenSftpClientAsync(CancellationToken cancellationToken);
Task<SftpClient> OpenSftpClientAsync(SftpClientOptions? options = null, CancellationToken cancellationToken = default)
}
Expand Down Expand Up @@ -125,6 +127,8 @@ class RemoteIPEndPoint(IPAddress address, int port) : RemoteEndPoint
{ }
class RemoteUnixEndPoint(string path) : RemoteEndPoint
{ }
class RemoteIPListenEndPoint(string address, int port) : RemoteEndPoint
{ }
class DirectForward : IDisposable
{
EndPoint LocalEndPoint { get; }
Expand All @@ -137,6 +141,26 @@ class SocksForward : IDisposable
CancellationToken Stopped { get; }
void ThrowIfStopped();
}
class RemoteListener : IDisposable
{
// For ListenTcpAsync, type is RemoteIPListenEndPoint.
RemoteEndPoint ListenEndPoint { get; }

// This method throws when the SshClient disconnects (SshConnectionClosedException), or the RemoteListener is disposed (ObjectDisposedException).
// Calling Stop makes the method return a default(RemoteConnection) instead.
ValueTask<RemoteConnection> AcceptAsync(CancellationToken cancellationToken = default);

void Stop();
}
struct RemoteConnection : IDisposable
{
// For ListenTcpAsync, type is RemoteIPEndPoint.
RemoteEndPoint? RemoteEndPoint { get; }

SshDataStream? Stream { get; }
bool HasStream { get; }
Stream MoveStream(); // Transfers ownership of the Stream to the caller.
}
class SftpClient : IDisposable
{
// Note: umask is applied on the server.
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/AlgorithmNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ static class AlgorithmNames
public static Name ServerStrictKex => Name.FromKnownNameString(KnownNameStrings.ServerStrictKex);
// Extension Negotiation
public static Name ClientExtensionNegotiation => Name.FromKnownNameString(KnownNameStrings.ClientExtensionNegotiation);
// Channel types
public static Name ForwardTcpIp => Name.FromKnownNameString(KnownNameStrings.ForwardTcpIp);

// For GetSignatureAlgorithmsForKeyType
internal static readonly Name[] SshRsaAlgorithms = [ RsaSshSha2_512, RsaSshSha2_256 ];
Expand Down
35 changes: 31 additions & 4 deletions src/Tmds.Ssh/ArgumentValidation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,29 @@ public static void ValidatePort(int port, bool allowZero, string argumentName =
}
}

public static void ValidateIPListenAddress(string address, string argumentName = "address")
{
if (address is null)
{
throw new ArgumentNullException(argumentName);
}

if (address.Length == 0)
{
throw new ArgumentException("The address can not be empty.", argumentName);
}

if (address == Constants.AnyAddress)
{
return;
}

if (!IsValidHostName(address))
{
throw new ArgumentException("The address is not valid", argumentName);
}
}

public static void ValidateHost(string host, bool allowEmpty = false, string argumentName = "host")
{
if (host is null)
Expand All @@ -29,12 +52,16 @@ public static void ValidateHost(string host, bool allowEmpty = false, string arg
throw new ArgumentException("The host can not be empty.", argumentName);
}

UriHostNameType hostNameType = Uri.CheckHostName(host);
bool isValid = hostNameType is UriHostNameType.IPv4 or UriHostNameType.IPv6 or UriHostNameType.Dns;

if (!isValid)
if (!IsValidHostName(host))
{
throw new ArgumentException("The host name is not valid", argumentName);
}
}

private static bool IsValidHostName(string address)
{
// Check whether the name is an IPv4/IPv6/DNS name using 'Uri.CheckHostName'.
// Disallow IPv6 addresses to be enclosed with '[]'.
return !address.StartsWith('[') && Uri.CheckHostName(address) is UriHostNameType.IPv4 or UriHostNameType.IPv6 or UriHostNameType.Dns;
}
}
1 change: 1 addition & 0 deletions src/Tmds.Ssh/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ static class Constants
public const int MaxMPIntLength = 1024; // Arbitrary limit, may be increased.
public const int MaxBannerPackets = 1024; // Abitrary limit
public const int MaxPartialAuths = 16; // Abitrary limit
public const string AnyAddress = "*"; // Public API listen address for "all IPv4 and all IPv6".
}
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/KnownNameStrings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ static class KnownNameStrings
internal const string ClientStrictKex = "kex-strict-c-v00@openssh.com";
internal const string ServerStrictKex = "kex-strict-s-v00@openssh.com";
internal const string ClientExtensionNegotiation = "ext-info-c";
internal const string ForwardTcpIp = "forwarded-tcpip";

public static string? FindKnownName(ReadOnlySpan<char> name)
{
Expand Down Expand Up @@ -103,6 +104,7 @@ static class KnownNameStrings
case ClientStrictKex: return ClientStrictKex;
case ServerStrictKex: return ServerStrictKex;
case ClientExtensionNegotiation: return ClientExtensionNegotiation;
case ForwardTcpIp: return ForwardTcpIp;
default: return null;
}
}
Expand Down
28 changes: 28 additions & 0 deletions src/Tmds.Ssh/RemoteConnection.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

namespace Tmds.Ssh;

public struct RemoteConnection : IDisposable
{
internal RemoteConnection(SshDataStream stream, RemoteEndPoint? remoteEndPoint)
{
Stream = stream;
RemoteEndPoint = remoteEndPoint;
}

public bool HasStream => Stream is not null;

public SshDataStream MoveStream()
{
var stream = Stream;
Stream = null;
return stream ?? throw new InvalidOperationException("There is no stream to obtain.");
}

public RemoteEndPoint? RemoteEndPoint { get; }
public SshDataStream? Stream { get; private set; }

public void Dispose()
=> Stream?.Dispose();
}
21 changes: 20 additions & 1 deletion src/Tmds.Ssh/RemoteEndPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public RemoteIPEndPoint(IPAddress address, int port)
}

public override string ToString()
=> _toString ??= Address.AddressFamily == AddressFamily.InterNetworkV6 ? $"[{Address}]:{Port}" : $"{Address}:{Port}";
=> _toString ??= Address.AddressFamily == AddressFamily.InterNetworkV6 ? $"[{Address}]:{Port}" : $"{Address}:{Port}";
}

public sealed class RemoteUnixEndPoint : RemoteEndPoint
Expand All @@ -63,4 +63,23 @@ public RemoteUnixEndPoint(string path)

public override string ToString()
=> Path;
}

public sealed class RemoteIPListenEndPoint : RemoteEndPoint
{
public string Address { get; }
public int Port { get; }
private string? _toString;

public RemoteIPListenEndPoint(string address, int port)
{
ArgumentValidation.ValidateIPListenAddress(address);
ArgumentValidation.ValidatePort(port, allowZero: false);

Address = address;
Port = port;
}

public override string ToString()
=> _toString ??= Address.Contains(':') ? $"[{Address}]:{Port}" : $"{Address}:{Port}";
}
136 changes: 136 additions & 0 deletions src/Tmds.Ssh/RemoteListener.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System.Diagnostics;
using System.Threading.Channels;

namespace Tmds.Ssh;

public sealed class RemoteListener : IDisposable
{
// Sentinel stop reasons.
private static readonly Exception ConnectionClosed = new();
private static readonly Exception Disposed = new();
private static readonly Exception Stopped = new();

private readonly Channel<RemoteConnection> _connectionChannel;

public RemoteEndPoint ListenEndPoint => _listenEndPoint ?? throw new InvalidOperationException("Not started");

private SshSession? _session;
private RemoteEndPoint? _listenEndPoint;
private Name _forwardType;
private CancellationTokenRegistration _ctr;
private Exception? _stopReason;

public void Stop()
=> Stop(Stopped);

public void Dispose()
=> Stop(Disposed);

public async ValueTask<RemoteConnection> AcceptAsync(CancellationToken cancellationToken = default)
{
while (true)
{
if (!await _connectionChannel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
Exception? stopReason = _stopReason;
if (ReferenceEquals(stopReason, Stopped))
{
// return 'null' when the user called 'Stop' to indicate no more connections should be accepted.
return default;
}
else if (ReferenceEquals(stopReason, Disposed))
{
throw new ObjectDisposedException(GetType().FullName);
}
else if (ReferenceEquals(stopReason, ConnectionClosed))
{
throw _session!.CreateCloseException();
}
else
{
throw new SshException($"{GetType().FullName} stopped due to an unexpected error.", stopReason);
}
}

// TryRead may return false if we're competing with Stop.
if (_connectionChannel.Reader.TryRead(out RemoteConnection remoteConnection))
{
Debug.Assert(remoteConnection.HasStream);
if (remoteConnection.HasStream)
{
return new RemoteConnection(remoteConnection.MoveStream(), remoteConnection.RemoteEndPoint);
}
}
}
}

private void Stop(Exception stopReason)
{
if (Interlocked.CompareExchange(ref _stopReason, stopReason, null) != null)
{
return;
}

if (_listenEndPoint is not null)
{
_ctr.Dispose();

string? address = null;
ushort port = 0;
if (_listenEndPoint is RemoteIPListenEndPoint ipListenEndPoint)
{
address = ipListenEndPoint.Address;
port = (ushort)ipListenEndPoint.Port;
}
else
{
Debug.Assert(false);
}
Debug.Assert(address is not null);
_session?.StopRemoteForward(_forwardType, address, port);
}

_connectionChannel.Writer.Complete();

while (_connectionChannel.Reader.TryRead(out RemoteConnection connection))
{
Debug.Assert(connection.HasStream);
connection.Dispose();
}
}

internal RemoteListener()
{
_connectionChannel = Channel.CreateUnbounded<RemoteConnection>();
}

private async Task OpenAsync(SshSession session, Name forwardType, string address, ushort port, CancellationToken cancellationToken)
{
_session = session;
_forwardType = forwardType;

try
{
port = await _session.StartRemoteForwardAsync(forwardType, address, port, _connectionChannel.Writer, cancellationToken).ConfigureAwait(false);
_listenEndPoint = new RemoteIPListenEndPoint(address, port);
_ctr = _session.ConnectionClosed.UnsafeRegister(o => ((RemoteListener)o!).Stop(ConnectionClosed), this);
}
catch (Exception ex)
{
Stop(ex);

throw;
}
}

internal Task OpenTcpAsync(SshSession session, string address, int port, CancellationToken cancellationToken)
{
ArgumentValidation.ValidateIPListenAddress(address);
ArgumentValidation.ValidatePort(port, allowZero: true);

return OpenAsync(session, AlgorithmNames.ForwardTcpIp, address, (ushort)port, cancellationToken);
}
}
6 changes: 6 additions & 0 deletions src/Tmds.Ssh/SshChannel.OpenMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ sealed partial class SshChannel : ISshChannel
public void TrySendChannelOpenDirectStreamLocalMessage(string socketPath)
=> TrySendPacket(_sequencePool.CreateChannelOpenDirectStreamLocalMessage(LocalChannel, (uint)_receiveWindow, (uint)ReceiveMaxPacket, socketPath));

public void TrySendChannelOpenConfirmationMessage(uint remoteChannel)
{
TrySendPacket(_sequencePool.CreateChannelOpenConfirmationMessage(remoteChannel, LocalChannel, (uint)_receiveWindow, (uint)ReceiveMaxPacket));
_sendCloseOnDispose = 1;
}

public void TrySendChannelOpenSessionMessage()
=> TrySendPacket(_sequencePool.CreateChannelOpenSessionMessage(LocalChannel, (uint)_receiveWindow, (uint)ReceiveMaxPacket));

Expand Down
7 changes: 6 additions & 1 deletion src/Tmds.Ssh/SshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@ enum AbortState
Disposed // By user
}

public SshChannel(SshSession client, SequencePool sequencePool, uint channelNumber, Type channelType, Action<SshChannel>? onAbort = null)
public SshChannel(SshSession client, SequencePool sequencePool, uint channelNumber, Type channelType, Action<SshChannel>? onAbort = null,
uint remoteChannel = 0, int sendMaxPacket = 0, int sendWindow = 0)
{
LocalChannel = channelNumber;
_client = client;
_sequencePool = sequencePool;
_receiveWindow = MaxWindowSize;
_channelType = channelType;
_onAbort = onAbort;

RemoteChannel = remoteChannel;
SendMaxPacket = sendMaxPacket;
_sendWindow = sendWindow;
}

public CancellationToken ChannelAborted
Expand Down
9 changes: 9 additions & 0 deletions src/Tmds.Ssh/SshClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,15 @@ public async Task<SocksForward> StartSocksForward(EndPoint bindEP, CancellationT
return forward;
}

public async Task<RemoteListener> ListenTcpAsync(string address, int port, CancellationToken cancellationToken = default)
{
SshSession session = await GetSessionAsync(cancellationToken).ConfigureAwait(false);

var listener = new RemoteListener();
await listener.OpenTcpAsync(session, address, port, cancellationToken).ConfigureAwait(false);
return listener;
}

public Task<SftpClient> OpenSftpClientAsync(CancellationToken cancellationToken)
=> OpenSftpClientAsync(null, cancellationToken);

Expand Down
Loading

0 comments on commit 5b43d19

Please sign in to comment.