Skip to content

Commit

Permalink
fix: garbage collect Stream when dropped (libp2p#167)
Browse files Browse the repository at this point in the history
This fixes issue libp2p#166 where a `Stream` would never be removed
until its connection is closed, given that the `Connection` holds a
copy of the stream `Sender`.

Co-authored-by: Max Inden <mail@max-inden.de>
Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
  • Loading branch information
3 people authored Jul 11, 2023
1 parent e7c17ff commit dcff3d5
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 50 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Users have to move to the `poll_` functions of `Connection`.
See [PR #164](https://github.com/libp2p/rust-yamux/pull/164).

- Fix a bug where `Stream`s would not be dropped until their corresponding `Connection` was dropped.
See [PR #167](https://github.com/libp2p/rust-yamux/pull/167).

# 0.11.1

- Avoid race condition between pending frames and closing stream.
Expand Down
52 changes: 52 additions & 0 deletions test-harness/tests/poll_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,55 @@ fn write_deadlock() {
.unwrap(),
);
}

#[test]
fn close_through_drop_of_stream_propagates_to_remote() {
let _ = env_logger::try_init();
let mut pool = LocalPool::new();

let (server_endpoint, client_endpoint) = futures_ringbuf::Endpoint::pair(1024, 1024);
let mut server = Connection::new(server_endpoint, Config::default(), Mode::Server);
let mut client = Connection::new(client_endpoint, Config::default(), Mode::Client);

// Spawn client, opening a stream, writing to the stream, dropping the stream, driving the
// client connection state machine.
pool.spawner()
.spawn_obj(
async {
let mut stream = future::poll_fn(|cx| client.poll_new_outbound(cx))
.await
.unwrap();
stream.write_all(&[42]).await.unwrap();
drop(stream);

noop_server(stream::poll_fn(move |cx| client.poll_next_inbound(cx))).await;
}
.boxed()
.into(),
)
.unwrap();

// Accept inbound stream.
let mut stream_server_side = pool
.run_until(future::poll_fn(|cx| server.poll_next_inbound(cx)))
.unwrap()
.unwrap();

// Spawn server connection state machine.
pool.spawner()
.spawn_obj(
noop_server(stream::poll_fn(move |cx| server.poll_next_inbound(cx)))
.boxed()
.into(),
)
.unwrap();

// Expect to eventually receive close on stream.
pool.run_until(async {
let mut buf = Vec::new();
stream_server_side.read_to_end(&mut buf).await?;
assert_eq!(buf, vec![42]);
Ok::<(), std::io::Error>(())
})
.unwrap();
}
50 changes: 30 additions & 20 deletions yamux/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ use closing::Closing;
use futures::stream::SelectAll;
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
use nohash_hasher::IntMap;
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::task::{Context, Waker};
use std::{fmt, sync::Arc, task::Poll};
Expand Down Expand Up @@ -354,7 +355,7 @@ struct Active<T> {
socket: Fuse<frame::Io<T>>,
next_id: u32,

streams: IntMap<StreamId, Stream>,
streams: IntMap<StreamId, Arc<Mutex<stream::Shared>>>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
no_streams_waker: Option<Waker>,

Expand Down Expand Up @@ -530,7 +531,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream.clone());
self.streams.insert(id, stream.clone_shared());

Poll::Ready(Ok(stream))
}
Expand All @@ -551,13 +552,12 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
.push_back(Frame::close_stream(id, ack).into());
}

fn on_drop_stream(&mut self, id: StreamId) {
let stream = self.streams.remove(&id).expect("stream not found");
fn on_drop_stream(&mut self, stream_id: StreamId) {
let s = self.streams.remove(&stream_id).expect("stream not found");

log::trace!("{}: removing dropped {}", self.id, stream);
let stream_id = stream.id();
log::trace!("{}: removing dropped stream {}", self.id, stream_id);
let frame = {
let mut shared = stream.shared();
let mut shared = s.lock();
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
Expand Down Expand Up @@ -627,7 +627,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
let id = frame.header().stream_id();
if let Some(stream) = self.streams.get(&id) {
stream
.shared()
.lock()
.update_state(self.id, id, State::Open { acknowledged: true });
}
if let Some(waker) = self.new_outbound_stream_waker.take() {
Expand Down Expand Up @@ -678,7 +678,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
if frame.header().flags().contains(header::RST) {
// stream reset
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.shared();
let mut shared = s.lock();
shared.update_state(self.id, stream_id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
Expand Down Expand Up @@ -736,12 +736,12 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
if window_update.is_none() {
stream.set_flag(stream::Flag::Ack)
}
self.streams.insert(stream_id, stream.clone());
self.streams.insert(stream_id, stream.clone_shared());
return Action::New(stream, window_update);
}

if let Some(stream) = self.streams.get_mut(&stream_id) {
let mut shared = stream.shared();
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
if frame.body().len() > shared.window as usize {
log::error!(
"{}/{}: frame body larger than window of stream",
Expand Down Expand Up @@ -801,7 +801,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
if frame.header().flags().contains(header::RST) {
// stream reset
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.shared();
let mut shared = s.lock();
shared.update_state(self.id, stream_id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
Expand Down Expand Up @@ -839,12 +839,12 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
.shared()
.update_state(self.id, stream_id, State::RecvClosed);
}
self.streams.insert(stream_id, stream.clone());
self.streams.insert(stream_id, stream.clone_shared());
return Action::New(stream, None);
}

if let Some(stream) = self.streams.get_mut(&stream_id) {
let mut shared = stream.shared();
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
shared.credit += frame.header().credit();
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
Expand Down Expand Up @@ -938,9 +938,19 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
/// The ACK backlog is defined as the number of outbound streams that have not yet been acknowledged.
fn ack_backlog(&mut self) -> usize {
self.streams
.values()
.filter(|s| s.is_outbound(self.mode))
.filter(|s| s.is_pending_ack())
.iter()
// Whether this is an outbound stream.
//
// Clients use odd IDs and servers use even IDs.
// A stream is outbound if:
//
// - Its ID is odd and we are the client.
// - Its ID is even and we are the server.
.filter(|(id, _)| match self.mode {
Mode::Client => id.is_client(),
Mode::Server => id.is_server(),
})
.filter(|(_, s)| s.lock().is_pending_ack())
.count()
}

Expand All @@ -960,7 +970,7 @@ impl<T> Active<T> {
/// Close and drop all `Stream`s and wake any pending `Waker`s.
fn drop_all_streams(&mut self) {
for (id, s) in self.streams.drain() {
let mut shared = s.shared();
let mut shared = s.lock();
shared.update_state(self.id, id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
Expand Down
44 changes: 14 additions & 30 deletions yamux/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
header::{Data, Header, StreamId, WindowUpdate},
Frame,
},
Config, Mode, WindowUpdateMode, DEFAULT_CREDIT,
Config, WindowUpdateMode, DEFAULT_CREDIT,
};
use futures::{
channel::mpsc,
Expand Down Expand Up @@ -161,26 +161,7 @@ impl Stream {

/// Whether we are still waiting for the remote to acknowledge this stream.
pub fn is_pending_ack(&self) -> bool {
matches!(
self.shared().state(),
State::Open {
acknowledged: false
}
)
}

/// Whether this is an outbound stream.
///
/// Clients use odd IDs and servers use even IDs.
/// A stream is outbound if:
///
/// - Its ID is odd and we are the client.
/// - Its ID is even and we are the server.
pub(crate) fn is_outbound(&self, our_mode: Mode) -> bool {
match our_mode {
Mode::Client => self.id.is_client(),
Mode::Server => self.id.is_server(),
}
self.shared().is_pending_ack()
}

/// Set the flag that should be set on the next outbound frame header.
Expand All @@ -192,15 +173,8 @@ impl Stream {
self.shared.lock()
}

pub(crate) fn clone(&self) -> Self {
Stream {
id: self.id,
conn: self.conn,
config: self.config.clone(),
sender: self.sender.clone(),
flag: self.flag,
shared: self.shared.clone(),
}
pub(crate) fn clone_shared(&self) -> Arc<Mutex<Shared>> {
self.shared.clone()
}

fn write_zero_err(&self) -> io::Error {
Expand Down Expand Up @@ -551,4 +525,14 @@ impl Shared {
None
}
}

/// Whether we are still waiting for the remote to acknowledge this stream.
pub fn is_pending_ack(&self) -> bool {
matches!(
self.state(),
State::Open {
acknowledged: false
}
)
}
}

0 comments on commit dcff3d5

Please sign in to comment.