From a07af23dad052c71a6a32580eaae0ed0af3aff43 Mon Sep 17 00:00:00 2001 From: zy Date: Thu, 9 Dec 2021 22:24:56 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fuso-api/src/core.rs | 35 ++++++++++++++++++++++----- fuso-core/src/ciphe.rs | 55 +++++++++++++++++------------------------- src/client.rs | 1 + 3 files changed, 52 insertions(+), 39 deletions(-) diff --git a/fuso-api/src/core.rs b/fuso-api/src/core.rs index 1d404ce..1756f80 100644 --- a/fuso-api/src/core.rs +++ b/fuso-api/src/core.rs @@ -30,6 +30,29 @@ pub fn now_mills() -> u64 { .as_secs() } +#[inline] +pub async fn copy(mut reader: R, mut writer: W) -> std::io::Result<()> +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + loop { + let mut buf = Vec::new(); + buf.resize(0x2000, 0); + + let n = reader.read(&mut buf).await?; + + if n == 0 { + let _ = writer.close().await; + break Ok(()); + } + + buf.truncate(n); + + writer.write_all(&mut buf).await?; + } +} + #[derive(Debug, Clone)] pub struct Packet { magic: u32, @@ -317,12 +340,12 @@ where let (reader_s, writer_s) = self.split(); let (reader_t, writer_t) = to.split(); - smol::future::race( - smol::io::copy(reader_t, writer_s), - smol::io::copy(reader_s, writer_t), - ) - .await - .map_err(|e| error::Error::with_io(e))?; + smol::future::race(copy(reader_t, writer_s), copy(reader_s, writer_t)) + .await + .map_err(|e| { + log::warn!("{}", e); + error::Error::with_io(e) + })?; Ok(()) } diff --git a/fuso-core/src/ciphe.rs b/fuso-core/src/ciphe.rs index 3e8b8d8..49f5ee3 100644 --- a/fuso-core/src/ciphe.rs +++ b/fuso-core/src/ciphe.rs @@ -35,7 +35,7 @@ pub trait Cipher { pub struct Crypt { buf: Arc>>, target: T, - cipher: Arc>, + cipher: C, } #[derive(Clone)] @@ -74,7 +74,7 @@ where Crypt { target: self, buf: Arc::new(Mutex::new(Buffer::new())), - cipher: Arc::new(Mutex::new(c)), + cipher: c, } } } @@ -95,36 +95,31 @@ where let mut io_buf = io_buf.lock().unwrap(); if !io_buf.is_empty() { - log::info!("read buffer"); Pin::new(&mut *io_buf).poll_read(cx, buf) } else { match Pin::new(&mut self.target).poll_read(cx, buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Ready(Ok(0)) => Poll::Ready(Ok(0)), - Poll::Ready(Ok(n)) => { - let mut cipher = self.cipher.lock().unwrap(); - - match Pin::new(&mut *cipher).poll_decode(cx, &buf[..n]) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Ready(Ok(data)) => { - let total = buf.len(); - let mut cur = Cursor::new(buf); - - let write_len = if total >= data.len() { - cur.write_all(&data).unwrap(); - data.len() - } else { - cur.write_all(&data[..total]).unwrap(); - io_buf.push_back(&data[total..]); - total - }; - - Poll::Ready(Ok(write_len)) - } + Poll::Ready(Ok(n)) => match Pin::new(&mut self.cipher).poll_decode(cx, &buf[..n]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok(data)) => { + let total = buf.len(); + let mut cur = Cursor::new(buf); + + let write_len = if total >= data.len() { + cur.write_all(&data).unwrap(); + data.len() + } else { + cur.write_all(&data[..total]).unwrap(); + io_buf.push_back(&data[total..]); + total + }; + + Poll::Ready(Ok(write_len)) } - } + }, } } } @@ -142,16 +137,10 @@ where cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - let cipher = self.cipher.clone(); - let mut cipher = cipher.lock().unwrap(); - - match Pin::new(&mut *cipher).poll_encode(cx, buf) { + match Pin::new(&mut self.cipher).poll_encode(cx, buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Ready(Ok(data)) => { - let _ = Pin::new(&mut self.target).poll_write(cx, &data)?; - Poll::Ready(Ok(buf.len())) - } + Poll::Ready(Ok(data)) => Pin::new(&mut self.target).poll_write(cx, &data), } } diff --git a/src/client.rs b/src/client.rs index 9d84641..f18b837 100644 --- a/src/client.rs +++ b/src/client.rs @@ -199,6 +199,7 @@ fn main() { if let Err(e) = from.forward(to).await { log::debug!("[fuc] Forwarding failed {}", e); } + } .detach() })