Skip to content

Commit

Permalink
Check endpoint args.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds committed Feb 17, 2025
1 parent f369639 commit 4ee63e8
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 27 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Tmds.Ssh

The `Tmds.Ssh` is a modern, managed .NET SSH client implementation for .NET 6+.
The `Tmds.Ssh` is a modern, managed .NET SSH client library for .NET 6+.

## Getting Started

Expand Down Expand Up @@ -67,11 +67,16 @@ class SshClient : IDisposable

Task<SshDataStream> OpenTcpConnectionAsync(string host, int port, CancellationToken cancellationToken = default);
Task<SshDataStream> OpenUnixConnectionAsync(string path, CancellationToken cancellationToken = default);

Task<RemoteListener> ListenTcpAsync(string address, int port, CancellationToken cancellationToken = default);

// bindEP can be an IPEndPoint or a UnixDomainSocketEndPoint.
// remoteEP can be a RemoteHostEndPoint, a RemoteUnixEndPoint or a RemoteIPEndPoint.
Task<DirectForward> StartForwardAsync(EndPoint bindEP, RemoteEndPoint remoteEP, CancellationToken cancellationToken = default);
Task<SocksForward> StartForwardSocksAsync(EndPoint bindEP, CancellationToken cancellationToken = default);

Task<RemoteListener> ListenTcpAsync(string address, int port, CancellationToken cancellationToken = default);
// bindEP can be a RemoteIPListenEndPoint.
// localEP can be a DnsEndPoint or an IPEndPoint.
Task<RemoteForward> StartRemoteForwardAsync(RemoteEndPoint bindEP, EndPoint localEP, CancellationToken cancellationToken = default);

Task<SftpClient> OpenSftpClientAsync(CancellationToken cancellationToken);
Expand Down
11 changes: 10 additions & 1 deletion src/Tmds.Ssh/ArgumentValidation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,20 @@ namespace Tmds.Ssh;
static class ArgumentValidation
{
public static void ValidatePort(int port, bool allowZero, string argumentName = "port")
{
if (!IsValidPort(port, allowZero))
{
throw new ArgumentException($"Invalid port number: '{port}'.", argumentName);
}
}

private static bool IsValidPort(int port, bool allowZero)
{
if (port < 0 || port > 0xffff || (!allowZero && port == 0))
{
throw new ArgumentException(argumentName);
return false;
}
return true;
}

public static void ValidateIPListenAddress(string address, string argumentName = "address")
Expand Down
35 changes: 18 additions & 17 deletions src/Tmds.Ssh/ForwardServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,7 @@ protected async ValueTask StartAsync(SshSession session, ForwardProtocol forward
await ListenAsync(cancellationToken).ConfigureAwait(false);
_ctr = _session.ConnectionClosed.UnsafeRegister(o => ((ForwardServer<T, TTargetStream>)o!).Stop(ConnectionClosed), this);

_ = AcceptLoop();
}
catch (Exception ex)
{
Stop(ex);

throw;
}
}

protected abstract ValueTask ListenAsync(CancellationToken cancellationToken);

private async Task AcceptLoop()
{
Debug.Assert(_listenEndPoint is not null);
try
{
Debug.Assert(_listenEndPoint is not null);
if (_protocol == ForwardProtocol.Direct)
{
Debug.Assert(_targetEndPoint is not null);
Expand All @@ -171,6 +155,23 @@ private async Task AcceptLoop()
// Log stop when we've logged the start.
_logStopped = true;

_ = AcceptLoop();
}
catch (Exception ex)
{
Stop(ex);

throw;
}
}

protected abstract ValueTask ListenAsync(CancellationToken cancellationToken);

private async Task AcceptLoop()
{
try
{
Debug.Assert(_listenEndPoint is not null);
while (true)
{
(Stream? acceptedStream, string address) = await AcceptAsync().ConfigureAwait(false);
Expand Down
11 changes: 10 additions & 1 deletion src/Tmds.Ssh/LocalForwardServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ internal async ValueTask StartDirectForwardAsync(SshSession session, EndPoint bi
{
Debug.Assert(bindEP is not null);
CheckBindEndPoint(bindEP);
ArgumentNullException.ThrowIfNull(remoteEndPoint);
CheckTargetEndPoint(remoteEndPoint);

_localEndPoint = bindEP;
_remoteEndPoint = remoteEndPoint;
Expand Down Expand Up @@ -60,6 +60,15 @@ private static void CheckBindEndPoint(EndPoint bindEP)
}
}

private static void CheckTargetEndPoint(RemoteEndPoint remoteEndPoint)
{
ArgumentNullException.ThrowIfNull(remoteEndPoint);
if (remoteEndPoint is not RemoteHostEndPoint and not RemoteUnixEndPoint and not RemoteIPEndPoint)
{
throw new ArgumentException($"Unsupported RemoteEndPoint type: {remoteEndPoint.GetType().FullName}.");
}
}

protected override async Task<(Stream?, string)> AcceptAsync()
{
Debug.Assert(_serverSocket is not null);
Expand Down
2 changes: 1 addition & 1 deletion src/Tmds.Ssh/RemoteEndPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public sealed class RemoteIPListenEndPoint : RemoteEndPoint
public RemoteIPListenEndPoint(string address, int port)
{
ArgumentValidation.ValidateIPListenAddress(address);
ArgumentValidation.ValidatePort(port, allowZero: false);
ArgumentValidation.ValidatePort(port, allowZero: true);

Address = address;
Port = port;
Expand Down
25 changes: 20 additions & 5 deletions src/Tmds.Ssh/RemoteForwardServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,28 @@ internal async ValueTask StartDirectForwardAsync(SshSession session, RemoteEndPo
private static void CheckBindEndPoint(RemoteEndPoint bindEP)
{
ArgumentNullException.ThrowIfNull(bindEP);
// TODO...
if (bindEP is not RemoteIPListenEndPoint)
{
throw new ArgumentException($"Unsupported RemoteEndPoint type: {bindEP.GetType().FullName}.");
}
}

private static void CheckTargetEndPoint(EndPoint targetEP)
private static void CheckTargetEndPoint(EndPoint localEndPoint)
{
ArgumentNullException.ThrowIfNull(targetEP);
// TODO...
ArgumentNullException.ThrowIfNull(localEndPoint);
if (localEndPoint is DnsEndPoint dnsEndPoint)
{
ArgumentValidation.ValidatePort(dnsEndPoint.Port, allowZero: false, nameof(localEndPoint));
ArgumentValidation.ValidateHost(dnsEndPoint.Host, allowEmpty: false, nameof(localEndPoint));
}
else if (localEndPoint is IPEndPoint ipEndPoint)
{
ArgumentValidation.ValidatePort(ipEndPoint.Port, allowZero: false, nameof(localEndPoint));
}
else
{
throw new ArgumentException($"Unsupported EndPoint type: {localEndPoint.GetType().FullName}.");
}
}

protected override async Task<(Stream?, string)> AcceptAsync()
Expand Down Expand Up @@ -96,7 +111,7 @@ private static void CheckTargetEndPoint(EndPoint targetEP)

protected override void Stop()
{
_listener?.Dispose();
_listener?.Stop();
}

protected override async ValueTask ListenAsync(CancellationToken cancellationToken)
Expand Down

0 comments on commit 4ee63e8

Please sign in to comment.