Skip to content

Commit

Permalink
Merge pull request #71 from orlandos-nl/jo/streaming-input
Browse files Browse the repository at this point in the history
Allow `with`-style streaming APIs for communicating with a running program
  • Loading branch information
Joannis authored Nov 8, 2024
2 parents e0413b2 + 5aa51c5 commit afe3204
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Sources/Citadel/Errors.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ public enum CitadelError: Error {
case channelFailure
}

public struct AuthenticationFailed: Error, Equatable {}
public struct AuthenticationFailed: Error, Equatable {}
154 changes: 141 additions & 13 deletions Sources/Citadel/TTY/Client/TTY.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,55 @@ public enum ExecCommandOutput {
case stderr(ByteBuffer)
}

struct EmptySequence<Element>: Sendable, AsyncSequence {
struct AsyncIterator: AsyncIteratorProtocol {
func next() async throws -> Element? {
nil
}
}

func makeAsyncIterator() -> AsyncIterator {
AsyncIterator()
}
}

@available(macOS 15.0, *)
public struct TTYOutput: AsyncSequence {
internal let sequence: AsyncThrowingStream<ExecCommandOutput, Error>
public typealias Element = ExecCommandOutput

public struct AsyncIterator: AsyncIteratorProtocol {
public typealias Element = ExecCommandOutput
var iterator: AsyncThrowingStream<ExecCommandOutput, Error>.AsyncIterator

public mutating func next() async throws -> ExecCommandOutput? {
try await iterator.next()
}
}

public func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(iterator: sequence.makeAsyncIterator())
}
}

public struct TTYStdinWriter {
internal let channel: Channel

public func write(_ buffer: ByteBuffer) async throws {
try await channel.writeAndFlush(SSHChannelData(type: .channel, data: .byteBuffer(buffer)))
}

public func changeSize(cols: Int, rows: Int) async throws {
try await channel.triggerUserOutboundEvent(
SSHChannelRequestEvent.WindowChangeRequest(
terminalCharacterWidth: 0,
terminalRowHeight: 0,
terminalPixelWidth: 0,
terminalPixelHeight: 0
)
)
}
}

final class ExecCommandHandler: ChannelDuplexHandler {
enum Output {
Expand Down Expand Up @@ -126,7 +175,12 @@ extension SSHClient {
/// - maxResponseSize: The maximum size of the response. If the response is larger, the command will fail.
/// - mergeStreams: If the answer should also include stderr.
/// - inShell: Whether to request the remote server to start a shell before executing the command.
public func executeCommand(_ command: String, maxResponseSize: Int = .max, mergeStreams: Bool = false, inShell: Bool = false) async throws -> ByteBuffer {
public func executeCommand(
_ command: String,
maxResponseSize: Int = .max,
mergeStreams: Bool = false,
inShell: Bool = false
) async throws -> ByteBuffer {
var result = ByteBuffer()
let stream = try await executeCommandStream(command, inShell: inShell)

Expand Down Expand Up @@ -156,12 +210,27 @@ extension SSHClient {
/// - Parameters:
/// - command: The command to execute.
/// - inShell: Whether to request the remote server to start a shell before executing the command.
public func executeCommandStream(_ command: String, inShell: Bool = false) async throws -> AsyncThrowingStream<ExecCommandOutput, Error> {
var streamContinuation: AsyncThrowingStream<ExecCommandOutput, Error>.Continuation!
let stream = AsyncThrowingStream<ExecCommandOutput, Error>(bufferingPolicy: .unbounded) { continuation in
streamContinuation = continuation
}

public func executeCommandStream(
_ command: String,
environment: [SSHChannelRequestEvent.EnvironmentRequest] = [],
inShell: Bool = false
) async throws -> AsyncThrowingStream<ExecCommandOutput, Error> {
try await _executeCommandStream(
environment: environment,
mode: inShell ? .tty(command: command) : .command(command)
).output
}

enum CommandMode {
case pty(SSHChannelRequestEvent.PseudoTerminalRequest), tty(command: String?), command(String)
}

internal func _executeCommandStream(
environment: [SSHChannelRequestEvent.EnvironmentRequest] = [],
mode: CommandMode
) async throws -> (channel: Channel, output: AsyncThrowingStream<ExecCommandOutput, Error>) {
let (stream, streamContinuation) = AsyncThrowingStream<ExecCommandOutput, Error>.makeStream()

var hasReceivedChannelSuccess = false
var exitCode: Int?

Expand All @@ -180,9 +249,11 @@ extension SSHClient {
streamContinuation.finish()
}
case .channelSuccess:
if inShell, !hasReceivedChannelSuccess {
let commandData = SSHChannelData(type: .channel,
data: .byteBuffer(ByteBuffer(string: command + ";exit\n")))
if case .tty(.some(let command)) = mode, !hasReceivedChannelSuccess {
let commandData = SSHChannelData(
type: .channel,
data: .byteBuffer(ByteBuffer(string: command + ";exit\n"))
)
channel.writeAndFlush(commandData, promise: nil)
hasReceivedChannelSuccess = true
}
Expand All @@ -204,18 +275,75 @@ extension SSHClient {
return createChannel.futureResult
}.get()

if inShell {
for env in environment {
try await channel.triggerUserOutboundEvent(env)
}

switch mode {
case .pty(let request):
try await channel.triggerUserOutboundEvent(request)
fallthrough
case .tty:
try await channel.triggerUserOutboundEvent(SSHChannelRequestEvent.ShellRequest(
wantReply: true
))
} else {
case .command(let command):
try await channel.triggerUserOutboundEvent(SSHChannelRequestEvent.ExecRequest(
command: command,
wantReply: true
))
}

return stream
return (channel, stream)
}

@available(macOS 15.0, *)
public func withPTY(
_ request: SSHChannelRequestEvent.PseudoTerminalRequest,
environment: [SSHChannelRequestEvent.EnvironmentRequest] = [],
perform: (_ inbound: TTYOutput, _ outbound: TTYStdinWriter) async throws -> Void
) async throws {
let (channel, output) = try await _executeCommandStream(
environment: environment,
mode: .pty(request)
)

func close() async throws {
try await channel.close()
}

do {
let inbound = TTYOutput(sequence: output)
try await perform(inbound, TTYStdinWriter(channel: channel))
try await close()
} catch {
try await close()
throw error
}
}

@available(macOS 15.0, *)
public func withTTY(
environment: [SSHChannelRequestEvent.EnvironmentRequest] = [],
perform: (_ inbound: TTYOutput, _ outbound: TTYStdinWriter) async throws -> Void
) async throws {
let (channel, output) = try await _executeCommandStream(
environment: environment,
mode: .tty(command: nil)
)

func close() async throws {
try await channel.close()
}

do {
let inbound = TTYOutput(sequence: output)
try await perform(inbound, TTYStdinWriter(channel: channel))
try await close()
} catch {
try await close()
throw error
}
}

/// Executes a command on the remote server. This will return the pair of streams stdout and stderr of the command. If the command fails, the error will be thrown.
Expand Down
53 changes: 53 additions & 0 deletions Tests/CitadelTests/Citadel2Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,57 @@ final class Citadel2Tests: XCTestCase {

try await client.close()
}

@available(macOS 15.0, *)
func testStdinStream() async throws {
guard
let host = ProcessInfo.processInfo.environment["SSH_HOST"],
let _port = ProcessInfo.processInfo.environment["SSH_PORT"],
let port = Int(_port),
let username = ProcessInfo.processInfo.environment["SSH_USERNAME"],
let password = ProcessInfo.processInfo.environment["SSH_PASSWORD"]
else {
throw XCTSkip()
}

let client = try await SSHClient.connect(
host: host,
port: port,
authenticationMethod: .passwordBased(username: username, password: password),
hostKeyValidator: .acceptAnything(),
reconnect: .never
)

try await client.withTTY { inbound, outbound in
try await outbound.write(ByteBuffer(string: "cat"))
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
var i = UInt8.min
for try await value in inbound {
switch value {
case .stdout(let value):
for byte in value.readableBytesView {
XCTAssertEqual(byte, i)
i = i &+ 1
}
case .stderr:
XCTFail("Unexpected stderr")
}
}
}

group.addTask {
for i: UInt8 in .min ... .max {
let value = ByteBufferAllocator().buffer(integer: i)
try await outbound.write(value)
}
}

try await group.next()
group.cancelAll()
}
}

try await client.close()
}
}

0 comments on commit afe3204

Please sign in to comment.