Skip to content

Commit

Permalink
Support remote forwarding from/to unix sockets. (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Feb 19, 2025
1 parent ca0dcf8 commit a6a99d4
Show file tree
Hide file tree
Showing 13 changed files with 288 additions and 32 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ class SshClient : IDisposable
Task<SshDataStream> OpenUnixConnectionAsync(string path, CancellationToken cancellationToken = default);

Task<RemoteListener> ListenTcpAsync(string address, int port, CancellationToken cancellationToken = default);
Task<RemoteListener> ListenUnixAsync(string path, CancellationToken cancellationToken = default);

// bindEP can be an IPEndPoint or a UnixDomainSocketEndPoint.
// remoteEP can be a RemoteHostEndPoint, a RemoteUnixEndPoint or a RemoteIPEndPoint.
Task<DirectForward> StartForwardAsync(EndPoint bindEP, RemoteEndPoint remoteEP, CancellationToken cancellationToken = default);
Task<SocksForward> StartForwardSocksAsync(EndPoint bindEP, CancellationToken cancellationToken = default);

// bindEP can be a RemoteIPListenEndPoint.
// localEP can be a DnsEndPoint or an IPEndPoint.
// bindEP can be a RemoteIPListenEndPoint or a RemoteUnixEndPoint.
// localEP can be a DnsEndPoint, an IPEndPoint or a UnixDomainSocketEndPoint.
Task<RemoteForward> StartRemoteForwardAsync(RemoteEndPoint bindEP, EndPoint localEP, CancellationToken cancellationToken = default);

Task<SftpClient> OpenSftpClientAsync(CancellationToken cancellationToken);
Expand Down Expand Up @@ -156,6 +157,7 @@ class RemoteForward : IDisposable
class RemoteListener : IDisposable
{
// For ListenTcpAsync, type is RemoteIPListenEndPoint.
// For ListenUnixAsync, type is UnixDomainSocketEndPoint.
RemoteEndPoint ListenEndPoint { get; }

// This method throws when the SshClient disconnects (SshConnectionClosedException), or the RemoteListener is disposed (ObjectDisposedException).
Expand All @@ -167,6 +169,7 @@ class RemoteListener : IDisposable
struct RemoteConnection : IDisposable
{
// For ListenTcpAsync, type is RemoteIPEndPoint.
// For ListenUnixAsync, value is 'null'.
RemoteEndPoint? RemoteEndPoint { get; }

SshDataStream? Stream { get; }
Expand Down
1 change: 1 addition & 0 deletions src/Tmds.Ssh/AlgorithmNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ static class AlgorithmNames
public static Name ClientExtensionNegotiation => Name.FromKnownNameString(KnownNameStrings.ClientExtensionNegotiation);
// Channel types
public static Name ForwardTcpIp => Name.FromKnownNameString(KnownNameStrings.ForwardTcpIp);
public static Name ForwardStreamLocal => Name.FromKnownNameString(KnownNameStrings.ForwardStreamLocal);

// For GetSignatureAlgorithmsForKeyType
internal static readonly Name[] SshRsaAlgorithms = [ RsaSshSha2_512, RsaSshSha2_256 ];
Expand Down
17 changes: 17 additions & 0 deletions src/Tmds.Ssh/Connect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ public static async Task<Stream> ConnectTcpAsync(string host, int port, Cancella
return await _defaultConnect(context, cancellationToken).ConfigureAwait(false);
}

public static async Task<Stream> ConnectUnixAsync(UnixDomainSocketEndPoint endPoint, CancellationToken cancellationToken)
{
var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
try
{
await socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false);

return new NetworkStream(socket, ownsSocket: true);
}
catch
{
socket.Dispose();

throw;
}
}

public static async ValueTask<Stream> ConnectAsync(ConnectCallback? connect, Proxy? proxy, ConnectContext context, CancellationToken cancellationToken)
{
connect ??= _defaultConnect;
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/KnownNameStrings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ static class KnownNameStrings
internal const string ServerStrictKex = "kex-strict-s-v00@openssh.com";
internal const string ClientExtensionNegotiation = "ext-info-c";
internal const string ForwardTcpIp = "forwarded-tcpip";
internal const string ForwardStreamLocal = "forwarded-streamlocal@openssh.com";

public static string? FindKnownName(ReadOnlySpan<char> name)
{
Expand Down Expand Up @@ -105,6 +106,7 @@ static class KnownNameStrings
case ServerStrictKex: return ServerStrictKex;
case ClientExtensionNegotiation: return ClientExtensionNegotiation;
case ForwardTcpIp: return ForwardTcpIp;
case ForwardStreamLocal: return ForwardStreamLocal;
default: return null;
}
}
Expand Down
14 changes: 13 additions & 1 deletion src/Tmds.Ssh/RemoteForwardServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Net;
using System.Diagnostics;
using Microsoft.Extensions.Logging;
using System.Net.Sockets;

namespace Tmds.Ssh;

Expand Down Expand Up @@ -42,7 +43,7 @@ internal async ValueTask StartDirectForwardAsync(SshSession session, RemoteEndPo
private static void CheckBindEndPoint(RemoteEndPoint bindEP)
{
ArgumentNullException.ThrowIfNull(bindEP);
if (bindEP is not RemoteIPListenEndPoint)
if (bindEP is not RemoteIPListenEndPoint and not RemoteUnixEndPoint)
{
throw new ArgumentException($"Unsupported RemoteEndPoint type: {bindEP.GetType().FullName}.");
}
Expand All @@ -60,6 +61,8 @@ private static void CheckTargetEndPoint(EndPoint localEndPoint)
{
ArgumentValidation.ValidatePort(ipEndPoint.Port, allowZero: false, nameof(localEndPoint));
}
else if (localEndPoint is UnixDomainSocketEndPoint unixEndPoint)
{ }
else
{
throw new ArgumentException($"Unsupported EndPoint type: {localEndPoint.GetType().FullName}.");
Expand Down Expand Up @@ -101,6 +104,10 @@ private static void CheckTargetEndPoint(EndPoint localEndPoint)
{
connect = Connect.ConnectTcpAsync(dnsEndPoint.Host, dnsEndPoint.Port, ct);
}
else if (endPoint is UnixDomainSocketEndPoint unixEndPoint)
{
connect = Connect.ConnectUnixAsync(unixEndPoint, ct);
}
else
{
throw new InvalidOperationException("Invalid endpoint");
Expand Down Expand Up @@ -128,6 +135,11 @@ protected override async ValueTask ListenAsync(CancellationToken cancellationTok

updateEndPoint = ipEndPoint.Port == 0;
}
else if (bindEP is RemoteUnixEndPoint unixEndPoint)
{
_listener = new RemoteListener();
await _listener.OpenUnixAsync(_session, unixEndPoint.Path, cancellationToken).ConfigureAwait(false);
}
else
{
// Type must be validated before calling this method.
Expand Down
29 changes: 25 additions & 4 deletions src/Tmds.Ssh/RemoteListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,21 @@ private void Stop(Exception stopReason)
{
_ctr.Dispose();

string? address = null;
string address;
ushort port = 0;
if (_listenEndPoint is RemoteIPListenEndPoint ipListenEndPoint)
{
address = ipListenEndPoint.Address;
port = (ushort)ipListenEndPoint.Port;
}
else if (_listenEndPoint is RemoteUnixEndPoint unixEndPoint)
{
address = unixEndPoint.Path;
}
else
{
Debug.Assert(false);
throw new IndexOutOfRangeException(_forwardType);
}
Debug.Assert(address is not null);
_session?.StopRemoteForward(_forwardType, address, port);
}

Expand All @@ -115,7 +118,18 @@ private async Task OpenAsync(SshSession session, Name forwardType, string addres
try
{
port = await _session.StartRemoteForwardAsync(forwardType, address, port, _connectionChannel.Writer, cancellationToken).ConfigureAwait(false);
_listenEndPoint = new RemoteIPListenEndPoint(address, port);
if (forwardType == AlgorithmNames.ForwardTcpIp)
{
_listenEndPoint = new RemoteIPListenEndPoint(address, port);
}
else if (forwardType == AlgorithmNames.ForwardStreamLocal)
{
_listenEndPoint = new RemoteUnixEndPoint(address);
}
else
{
throw new IndexOutOfRangeException(forwardType);
}
_ctr = _session.ConnectionClosed.UnsafeRegister(o => ((RemoteListener)o!).Stop(ConnectionClosed), this);
}
catch (Exception ex)
Expand All @@ -133,4 +147,11 @@ internal Task OpenTcpAsync(SshSession session, string address, int port, Cancell

return OpenAsync(session, AlgorithmNames.ForwardTcpIp, address, (ushort)port, cancellationToken);
}

internal Task OpenUnixAsync(SshSession session, string path, CancellationToken cancellationToken)
{
ArgumentException.ThrowIfNullOrEmpty(path);

return OpenAsync(session, AlgorithmNames.ForwardStreamLocal, path, 0, cancellationToken);
}
}
9 changes: 9 additions & 0 deletions src/Tmds.Ssh/SshClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ public async Task<RemoteListener> ListenTcpAsync(string address, int port, Cance
return listener;
}

public async Task<RemoteListener> ListenUnixAsync(string path, CancellationToken cancellationToken = default)
{
SshSession session = await GetSessionAsync(cancellationToken).ConfigureAwait(false);

var listener = new RemoteListener();
await listener.OpenUnixAsync(session, path, cancellationToken).ConfigureAwait(false);
return listener;
}

public Task<SftpClient> OpenSftpClientAsync(CancellationToken cancellationToken)
=> OpenSftpClientAsync(null, cancellationToken);

Expand Down
34 changes: 34 additions & 0 deletions src/Tmds.Ssh/SshSequencePoolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ uint32 port number to bind
return packet.Move();
}

public static Packet CreateCancelLocalStreamForwardMessage(this SequencePool sequencePool, string path)
{
/*
byte SSH2_MSG_GLOBAL_REQUEST
string "cancel-streamlocal-forward@openssh.com"
boolean FALSE
string socket path
*/
using var packet = sequencePool.RentPacket();
var writer = packet.GetWriter();
writer.WriteMessageId(MessageId.SSH_MSG_GLOBAL_REQUEST);
writer.WriteString("cancel-streamlocal-forward@openssh.com");
writer.WriteBoolean(false); // want reply
writer.WriteString(path);
return packet.Move();
}

public static Packet CreateTcpIpForwardMessage(this SequencePool sequencePool, string address, ushort port)
{
/*
Expand All @@ -304,6 +321,23 @@ uint32 port number to bind
return packet.Move();
}

public static Packet CreateStreamLocalForwardMessage(this SequencePool sequencePool, string path)
{
/*
byte SSH2_MSG_GLOBAL_REQUEST
string "streamlocal-forward@openssh.com"
boolean TRUE
string socket path
*/
using var packet = sequencePool.RentPacket();
var writer = packet.GetWriter();
writer.WriteMessageId(MessageId.SSH_MSG_GLOBAL_REQUEST);
writer.WriteString("streamlocal-forward@openssh.com");
writer.WriteBoolean(true); // want reply
writer.WriteString(path);
return packet.Move();
}

public static Packet CreateKeepAliveMessage(this SequencePool sequencePool)
{
using var packet = sequencePool.RentPacket();
Expand Down
102 changes: 81 additions & 21 deletions src/Tmds.Ssh/SshSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,15 @@ uint32 originator port
listenAddress = new ListenAddress(channelType, address, (ushort)port);
}
}
else if (channelType == AlgorithmNames.ForwardStreamLocal)
{
/*
string socket path
string reserved for future use
*/
string path = reader.ReadUtf8String();
listenAddress = new ListenAddress(channelType, path, 0);
}

if (listenAddress != default)
{
Expand Down Expand Up @@ -936,42 +945,76 @@ public async Task<ushort> StartRemoteForwardAsync(Name forwardType, string addre
{
pendingReply = SendGlobalRequestAsync(_sequencePool.CreateTcpIpForwardMessage(address, port), wantReply: true, wantPayload: port == 0, runSync: true);
}
else if (forwardType == AlgorithmNames.ForwardStreamLocal)
{
pendingReply = SendGlobalRequestAsync(_sequencePool.CreateStreamLocalForwardMessage(address), wantReply: true, runSync: true);
}
else
{
throw new ArgumentException(nameof(forwardType));
throw new IndexOutOfRangeException(forwardType);
}

try
{
GlobalRequestReply reply = await pendingReply.ConfigureAwait(false);

if (reply.Id != MessageId.SSH_MSG_REQUEST_SUCCESS)
{
throw new SshException($"Request to start forward failed on the server with {reply.Id}.");
}

if (forwardType == AlgorithmNames.ForwardTcpIp && port == 0)
{
SequenceReader reader = new(reply.Payload);
reader.ReadMessageId();
port = (ushort)reader.ReadUInt32();
}
await pendingReply.WaitAsync(cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
_ = HandleReplyAsync(pendingReply, onCancel: true);

ListenAddress listenAddress = new ListenAddress(forwardType, address, port);
lock (_gate)
{
_remoteListeners ??= new();
_remoteListeners[listenAddress] = listener;
}
throw;
}

return port;
try
{
return await HandleReplyAsync(pendingReply, onCancel: false).ConfigureAwait(false);
}
finally
{
// We complete the request from the receive loop to update the remote listeners Dictionary.
// To prevent our caller from blocking that loop, this method MUST Yield before returning.
await Task.Yield();
}

async Task<ushort> HandleReplyAsync(Task<GlobalRequestReply> pendingReply, bool onCancel)
{
try
{
GlobalRequestReply reply = await pendingReply.ConfigureAwait(false);

if (reply.Id != MessageId.SSH_MSG_REQUEST_SUCCESS)
{
throw new SshException($"Request to start forward failed on the server with {reply.Id}.");
}

if (forwardType == AlgorithmNames.ForwardTcpIp && port == 0)
{
SequenceReader reader = new(reply.Payload);
reader.ReadMessageId();
port = (ushort)reader.ReadUInt32();
}

ListenAddress listenAddress = new ListenAddress(forwardType, address, port);
if (onCancel)
{
SendRemoteForwardCancel(listenAddress);
}
else
{
lock (_gate)
{
_remoteListeners ??= new();
_remoteListeners[listenAddress] = listener;
}
}

return port;
}
catch when (onCancel)
{
return 0;
}
}
}

public void StopRemoteForward(Name forwardType, string address, ushort port)
Expand All @@ -986,9 +1029,26 @@ public void StopRemoteForward(Name forwardType, string address, ushort port)
}

if (sendCancel)
{
SendRemoteForwardCancel(listenAddress);
}
}

private void SendRemoteForwardCancel(ListenAddress listenAddress)
{
Name forwardType = listenAddress.ForwardType;
if (forwardType == AlgorithmNames.ForwardTcpIp)
{
TrySendPacket(_sequencePool.CreateCancelTcpIpForwardMessage(listenAddress.Address, listenAddress.Port));
}
else if (forwardType == AlgorithmNames.ForwardStreamLocal)
{
TrySendPacket(_sequencePool.CreateCancelLocalStreamForwardMessage(listenAddress.Address));
}
else
{
throw new IndexOutOfRangeException(forwardType);
}
}

private static string ReplaceAnyAddress(Name forwardType, string address)
Expand Down
Loading

0 comments on commit a6a99d4

Please sign in to comment.