Skip to content

Commit

Permalink
SFTP: correct receive(id, timeout)
Browse files Browse the repository at this point in the history
Simplify the implementation, and make sure it works when called
directly with a non-zero timeout.
  • Loading branch information
tomaswolf committed Jul 25, 2024
1 parent 145eef1 commit d9d1277
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,6 @@ protected void checkResponseStatus(SftpResponse response) throws IOException {
protected void checkResponseStatus(int cmd, int id, SftpStatus status) throws IOException {
if (!status.isOk()) {
throwStatusException(cmd, id, status);
} else if (log.isTraceEnabled()) {
log.trace("throwStatusException({})[id={}] cmd={} status={}",
getClientChannel(), id, SftpConstants.getCommandMessageName(cmd),
status);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
import java.time.Instant;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -72,7 +72,7 @@
public class DefaultSftpClient extends AbstractSftpClient {
private final ClientSession clientSession;
private final ChannelSubsystem channel;
private final Map<Integer, Buffer> messages = new HashMap<>();
private final Map<Integer, Buffer> messages = new ConcurrentHashMap<>();
private final AtomicInteger cmdId = new AtomicInteger(100);
private final Buffer receiveBuffer = new ByteArrayBuffer();
private final AtomicInteger versionHolder = new AtomicInteger(0);
Expand Down Expand Up @@ -305,33 +305,15 @@ public int send(int cmd, Buffer buffer) throws IOException {

@Override
public Buffer receive(int id) throws IOException {
Session session = getClientSession();
Duration idleTimeout = CoreModuleProperties.IDLE_TIMEOUT.getRequired(session);
Duration idleTimeout = CoreModuleProperties.IDLE_TIMEOUT.getRequired(getClientSession());
if (GenericUtils.isNegativeOrNull(idleTimeout)) {
idleTimeout = CoreModuleProperties.IDLE_TIMEOUT.getRequiredDefault();
}

Instant now = Instant.now();
Instant waitEnd = now.plus(idleTimeout);
boolean traceEnabled = log.isTraceEnabled();
for (int count = 1;; count++) {
if (isClosing() || (!isOpen())) {
throw new SshException("Channel is being closed");
}
if (now.compareTo(waitEnd) > 0) {
throw new SshException("Timeout expired while waiting for id=" + id);
}

Buffer buffer = receive(id, Duration.between(now, waitEnd));
if (buffer != null) {
return buffer;
}

now = Instant.now();
if (traceEnabled) {
log.trace("receive({}) check iteration #{} for id={} remain time={}", this, count, id, idleTimeout);
}
Buffer result = receive(id, idleTimeout);
if (result == null) {
throw new SshException("Timeout expired while waiting for id=" + id);
}
return result;
}

@Override
Expand All @@ -342,19 +324,30 @@ public Buffer receive(int id, long idleTimeout) throws IOException {
@Override
public Buffer receive(int id, Duration idleTimeout) throws IOException {
synchronized (messages) {
Buffer buffer = messages.remove(id);
if (buffer != null) {
return buffer;
}
if (GenericUtils.isPositive(idleTimeout)) {
try {
messages.wait(idleTimeout.toMillis(), idleTimeout.getNano() % 1_000_000);
} catch (InterruptedException e) {
throw (IOException) new InterruptedIOException("Interrupted while waiting for messages").initCause(e);
Instant waitUntil = Instant.now().plus(idleTimeout);
for (;;) {
if (isClosing() || !isOpen()) {
throw new SshException("Channel is being closed");
}
Buffer buffer = messages.remove(id);
if (buffer != null) {
return buffer;
}
Duration waitFor = Duration.between(Instant.now(), waitUntil);
if (!GenericUtils.isPositive(waitFor)) {
break; // Timeout expired
}
try {
messages.wait(waitFor.toMillis(), waitFor.getNano() % 1_000_000);
} catch (InterruptedException e) {
throw (IOException) new InterruptedIOException("Interrupted while waiting for messages").initCause(e);
}
}
}
// Try one last time.
return messages.remove(id);
}
return null;
}

protected void init(ClientSession session, SftpVersionSelector initialVersionSelector, Duration initializationTimeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.sshd.sftp.client.impl;

import java.io.IOException;
import java.time.Duration;
import java.util.Collection;
import java.util.Deque;
import java.util.LinkedList;
Expand Down Expand Up @@ -163,7 +164,7 @@ public void flush() throws IOException {
log.debug("flush({}) waiting for ack #{}: {}", this, ackIndex, ack);
}

Buffer buf = client.receive(ack.id, 0L);
Buffer buf = client.receive(ack.id, Duration.ZERO);
if (buf == null) {
if (debugEnabled) {
log.debug("flush({}) no response for ack #{}: {}", this, ackIndex, ack);
Expand Down

0 comments on commit d9d1277

Please sign in to comment.