Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix memory leak for receive buffers #46

Merged
merged 4 commits into from
Jan 11, 2025
Merged

Conversation

bits0rcerer
Copy link
Contributor

@bits0rcerer bits0rcerer commented Jan 10, 2025

This PR fixes a memory leak that occurred because we did not deallocate the rx buffers of our connection. This PR introduces a new struct handling the allocation and deallocation.

Closer inspection revealed a potential cause for #28.
alloc::alloc() returns a null pointer in case of error. This PR now handles that. docs

@bits0rcerer bits0rcerer force-pushed the main branch 2 times, most recently from d8669fd to c32d43e Compare January 10, 2025 16:17
Copy link
Owner

@sbernauer sbernauer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for catching and debugging this!
Stupid me is not used to explicitly freeing stuff ^^

It looks good in general to me, but I would have some feedback on it:

  1. I would prefer to move the new ConnectionBuffer struct into e.g. connection_buffer.rs
  2. Ideally we only pass the buffer_size into ConnectionBuffer::new and let it do all the heavy lifting. This way server.rs does not need to know about page_size, layout and such.
  3. For consistency reasons I would prefer a snafu error enum at the top of connection_buffer.rs. Not to say that the current error handling is the best of the world, mostly for constituency reasons. I'm totally happy to change the error handling (nowadays I would probably start with something more simple), but we should do it across the code base in that case.

I took the freedom of suggesting a diff below, so that you don't need to make all the changes.
But please don't take it as granted, please feel free to review it from your side as well

diff --git a/breakwater/src/connection_buffer.rs b/breakwater/src/connection_buffer.rs
new file mode 100644
index 0000000..13c5904
--- /dev/null
+++ b/breakwater/src/connection_buffer.rs
@@ -0,0 +1,72 @@
+use std::alloc::{self, LayoutError};
+
+use log::warn;
+use memadvise::{Advice, MemAdviseError};
+use snafu::{ResultExt, Snafu};
+
+#[derive(Debug, Snafu)]
+pub enum Error {
+    #[snafu(display("Failed to create memory layout"))]
+    CreateMemoryLayout {
+        source: LayoutError,
+        buffer_size: usize,
+        page_size: usize,
+    },
+
+    #[snafu(display("Allocation failed (alloc::alloc returned null ptr) for layout {layout:?}"))]
+    AllocationFailed { layout: alloc::Layout },
+}
+
+pub struct ConnectionBuffer {
+    ptr: *mut u8,
+    layout: alloc::Layout,
+}
+unsafe impl Send for ConnectionBuffer {}
+
+/// Allocates a memory slice with the specified size, which can be used for client connections.
+///
+/// It takes care of de-allocating the memory slice on [`Drop`].
+/// It also `memadvise`s the memory slice, so that the Kernel is aware that we are going to
+/// sequentially read it.
+impl ConnectionBuffer {
+    pub fn new(buffer_size: usize) -> Result<Self, Error> {
+        let page_size = page_size::get();
+        let layout = alloc::Layout::from_size_align(buffer_size, page_size).context(
+            CreateMemoryLayoutSnafu {
+                buffer_size,
+                page_size,
+            },
+        )?;
+
+        let ptr = unsafe { alloc::alloc(layout) };
+
+        if ptr.is_null() {
+            AllocationFailedSnafu { layout }.fail()?;
+        }
+
+        if let Err(err) = memadvise::advise(ptr as _, layout.size(), Advice::Sequential) {
+            // [`MemAdviseError`] does not implement Debug...
+            let err = match err {
+                MemAdviseError::NullAddress => "NullAddress",
+                MemAdviseError::InvalidLength => "InvalidLength",
+                MemAdviseError::UnalignedAddress => "UnalignedAddress",
+                MemAdviseError::InvalidRange => "InvalidRange",
+            };
+            warn!("Failed to memadvise sequential read access for buffer to kernel. This should not effect any client connections, but might having some minor performance degration: {err}");
+        }
+
+        Ok(Self { ptr, layout })
+    }
+
+    pub fn as_slice_mut(&mut self) -> &mut [u8] {
+        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.layout.size()) }
+    }
+}
+
+impl Drop for ConnectionBuffer {
+    fn drop(&mut self) {
+        unsafe {
+            alloc::dealloc(self.ptr, self.layout);
+        }
+    }
+}
diff --git a/breakwater/src/main.rs b/breakwater/src/main.rs
index f0340bd..7cc88c0 100644
--- a/breakwater/src/main.rs
+++ b/breakwater/src/main.rs
@@ -25,6 +25,7 @@ use crate::sinks::native_display::NativeDisplaySink;
 use crate::sinks::vnc::VncSink;
 
 mod cli_args;
+mod connection_buffer;
 mod prometheus_exporter;
 mod server;
 mod sinks;
diff --git a/breakwater/src/server.rs b/breakwater/src/server.rs
index 7d15153..267081f 100644
--- a/breakwater/src/server.rs
+++ b/breakwater/src/server.rs
@@ -1,11 +1,10 @@
-use std::alloc;
 use std::collections::hash_map::Entry;
 use std::collections::HashMap;
 use std::{cmp::min, net::IpAddr, sync::Arc, time::Duration};
 
 use breakwater_parser::{FrameBuffer, OriginalParser, Parser};
-use log::{debug, info, warn};
-use memadvise::{Advice, MemAdviseError};
+use log::{debug, info};
+use memadvise::Advice;
 use snafu::{ResultExt, Snafu};
 use tokio::{
     io::{AsyncReadExt, AsyncWriteExt},
@@ -14,22 +13,16 @@ use tokio::{
     time::Instant,
 };
 
-use crate::statistics::StatisticsEvent;
+use crate::{
+    connection_buffer::{self, ConnectionBuffer},
+    statistics::StatisticsEvent,
+};
 
 const CONNECTION_DENIED_TEXT: &[u8] = b"Connection denied as connection limit is reached";
 
 // Every client connection spawns a new thread, so we need to limit the number of stat events we send
 const STATISTICS_REPORT_INTERVAL: Duration = Duration::from_millis(250);
 
-#[derive(Debug)]
-pub struct BufferAllocationError;
-impl std::fmt::Display for BufferAllocationError {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        write!(f, "{self:?}")
-    }
-}
-impl snafu::Error for BufferAllocationError {}
-
 #[derive(Debug, Snafu)]
 pub enum Error {
     #[snafu(display("Failed to bind to listen address {listen_address:?}"))]
@@ -49,8 +42,8 @@ pub enum Error {
         source: mpsc::error::SendError<StatisticsEvent>,
     },
 
-    #[snafu(display("Failed to allocate network buffer"))]
-    BufferAllocation { source: BufferAllocationError },
+    #[snafu(display("Failed to allocate network connection buffer"))]
+    BufferAllocation { source: connection_buffer::Error },
 }
 
 pub struct Server<FB: FrameBuffer> {
@@ -91,9 +84,6 @@ impl<FB: FrameBuffer + Send + Sync + 'static> Server<FB> {
             mpsc::unbounded_channel::<IpAddr>();
         let connection_dropped_tx = self.max_connections_per_ip.map(|_| connection_dropped_tx);
 
-        let page_size = page_size::get();
-        debug!("System has a page size of {page_size} bytes");
-
         loop {
             let (mut socket, socket_addr) = self
                 .listener
@@ -144,7 +134,6 @@ impl<FB: FrameBuffer + Send + Sync + 'static> Server<FB> {
                     ip,
                     fb_for_thread,
                     statistics_tx_for_thread,
-                    page_size,
                     network_buffer_size,
                     connection_dropped_tx_clone,
                 )
@@ -154,53 +143,11 @@ impl<FB: FrameBuffer + Send + Sync + 'static> Server<FB> {
     }
 }
 
-struct ConnectionBuffer {
-    ptr: *mut u8,
-    layout: alloc::Layout,
-}
-unsafe impl Send for ConnectionBuffer {}
-
-impl ConnectionBuffer {
-    fn new(layout: alloc::Layout) -> Result<Self, BufferAllocationError> {
-        let ptr = unsafe { alloc::alloc(layout) };
-
-        if ptr.is_null() {
-            return Err(BufferAllocationError);
-        }
-
-        if let Err(err) = memadvise::advise(ptr as _, layout.size(), Advice::Sequential) {
-            // [`MemAdviseError`] does not implement Debug...
-            let err = match err {
-                MemAdviseError::NullAddress => "NullAddress",
-                MemAdviseError::InvalidLength => "InvalidLength",
-                MemAdviseError::UnalignedAddress => "UnalignedAddress",
-                MemAdviseError::InvalidRange => "InvalidRange",
-            };
-            warn!("Failed to memadvise sequential read access for buffer to kernel. This should not effect any client connections, but might having some minor performance degration: {err}");
-        }
-
-        Ok(Self { ptr, layout })
-    }
-
-    fn as_slice_mut(&mut self) -> &mut [u8] {
-        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.layout.size()) }
-    }
-}
-
-impl Drop for ConnectionBuffer {
-    fn drop(&mut self) {
-        unsafe {
-            alloc::dealloc(self.ptr, self.layout);
-        }
-    }
-}
-
 pub async fn handle_connection<FB: FrameBuffer>(
     mut stream: impl AsyncReadExt + AsyncWriteExt + Send + Unpin,
     ip: IpAddr,
     fb: Arc<FB>,
     statistics_tx: mpsc::Sender<StatisticsEvent>,
-    page_size: usize,
     network_buffer_size: usize,
     connection_dropped_tx: Option<mpsc::UnboundedSender<IpAddr>>,
 ) -> Result<(), Error> {
@@ -211,9 +158,7 @@ pub async fn handle_connection<FB: FrameBuffer>(
         .await
         .context(WriteToStatisticsChannelSnafu)?;
 
-    let layout = alloc::Layout::from_size_align(network_buffer_size, page_size)
-        .expect("invalid network buffer size for page size");
-    let mut recv_buf = ConnectionBuffer::new(layout).context(BufferAllocationSnafu)?;
+    let mut recv_buf = ConnectionBuffer::new(network_buffer_size).context(BufferAllocationSnafu)?;
     let buffer = recv_buf.as_slice_mut();
     let mut response_buf = Vec::new();
 
diff --git a/breakwater/src/tests.rs b/breakwater/src/tests.rs
index cc2370b..541b260 100644
--- a/breakwater/src/tests.rs
+++ b/breakwater/src/tests.rs
@@ -116,7 +116,6 @@ async fn test_safe<FB: FrameBuffer>(
         fb.clone(),
         statistics_channel.0,
         DEFAULT_NETWORK_BUFFER_SIZE,
-        page_size::get(),
         None,
     )
     .await
@@ -191,7 +190,6 @@ async fn test_drawing_rect<FB: FrameBuffer>(
         Arc::clone(&fb),
         statistics_channel.0.clone(),
         DEFAULT_NETWORK_BUFFER_SIZE,
-        page_size::get(),
         None,
     )
     .await
@@ -206,7 +204,6 @@ async fn test_drawing_rect<FB: FrameBuffer>(
         Arc::clone(&fb),
         statistics_channel.0.clone(),
         DEFAULT_NETWORK_BUFFER_SIZE,
-        page_size::get(),
         None,
     )
     .await
@@ -221,7 +218,6 @@ async fn test_drawing_rect<FB: FrameBuffer>(
         Arc::clone(&fb),
         statistics_channel.0.clone(),
         DEFAULT_NETWORK_BUFFER_SIZE,
-        page_size::get(),
         None,
     )
     .await
@@ -236,7 +232,6 @@ async fn test_drawing_rect<FB: FrameBuffer>(
         Arc::clone(&fb),
         statistics_channel.0.clone(),
         DEFAULT_NETWORK_BUFFER_SIZE,
-        page_size::get(),
         None,
     )
     .await
@@ -278,7 +273,6 @@ async fn test_binary_set_pixel<FB: FrameBuffer>(
         fb,
         statistics_channel.0,
         DEFAULT_NETWORK_BUFFER_SIZE,
-        page_size::get(),
         None,
     )
     .await
@@ -452,7 +446,6 @@ async fn test_binary_sync_pixels_larger_than_buffer<FB: FrameBuffer>(fb: Arc<FB>
         fb,
         statistics_channel().0,
         DEFAULT_NETWORK_BUFFER_SIZE,
-        page_size::get(),
         None,
     )
     .await
@@ -469,7 +462,6 @@ async fn assert_returns(input: &[u8], expected: &str) {
         fb(),
         statistics_channel().0,
         DEFAULT_NETWORK_BUFFER_SIZE,
-        page_size::get(),
         None,
     )
     .await

@sbernauer sbernauer changed the title fix memory leak for receive buffers fix: Fix memory leak for receive buffers Jan 11, 2025
@sbernauer
Copy link
Owner

Can you please also add a changelog entry?

@bits0rcerer
Copy link
Contributor Author

I took the freedom of suggesting a diff below, so that you don't need to make all the changes.
But please don't take it as granted, please feel free to review it from your side as well

LGTM 👍

For consistency reasons I would prefer a snafu error enum at the top of connection_buffer.rs. Not to say that the current error handling is the best of the world, mostly for constituency reasons. I'm totally happy to change the error handling (nowadays I would probably start with something more simple), but we should do it across the code base in that case.

never used snafu. My goto is eyre for application code and thiserror for library code or if i feel fancy.

snafu seems a bit intense here

@bits0rcerer bits0rcerer requested a review from sbernauer January 11, 2025 16:57
@sbernauer sbernauer merged commit 9ecb53d into sbernauer:main Jan 11, 2025
8 checks passed
@sbernauer sbernauer mentioned this pull request Jan 11, 2025
@sbernauer
Copy link
Owner

Thanks for pulling in the changes, released as 0.16.3 🚀

At work we use snafu, so I started with that back in the day. Nowadays I would go with either anyhow or eyre.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants