From 97fcf777fd565c96ddef2862e21c24936ff7d30f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8B=8F=E5=90=91=E5=A4=9C?= Date: Thu, 25 Apr 2024 21:12:55 +0800 Subject: [PATCH] perf: improve code quality --- src/models/packet.rs | 69 +++++++++++++++++------------------------- src/models/render.rs | 66 +++++----------------------------------- src/models/server.rs | 2 +- src/models/session.rs | 8 ++--- src/utils/generator.rs | 14 +++++---- src/utils/parser.rs | 4 +-- 6 files changed, 51 insertions(+), 112 deletions(-) diff --git a/src/models/packet.rs b/src/models/packet.rs index 8eac464..d4eb53c 100644 --- a/src/models/packet.rs +++ b/src/models/packet.rs @@ -28,25 +28,15 @@ impl OSC { } pub async fn from_stream(stream: &Socket) -> Result { - let status_code = stream.recv_u32().await?; - Ok(Self { status_code }) + Ok(Self { + status_code: stream.recv_u32().await?, + }) } pub async fn to_stream(&mut self, stream: &Socket) -> Result<()> { - stream.send(&self.plain_data()).await?; + stream.send(&self.status_code.to_be_bytes()).await?; Ok(()) } - - pub fn plain_data(&mut self) -> [u8; 4] { - let status_code = self.status_code as u32; - status_code.to_be_bytes() - } -} - -impl From for OSC { - fn from(value: u32) -> Self { - Self { status_code: value } - } } #[cfg(feature = "unsafe")] @@ -136,9 +126,9 @@ impl<'a> OKE<'a> { #[cfg(not(feature = "unsafe"))] pub struct OKE { - public_key: Option>>, + public_key: UnparsedPublicKey>, private_key: Option, - salt: Option>, + salt: Vec, remote_public_key: Option>>, shared_aes_key: Option>, } @@ -147,19 +137,19 @@ pub struct OKE { impl OKE { pub fn new( private_key: Option, - public_key: Option>>, - ) -> Result { - Ok(Self { + public_key: UnparsedPublicKey>, + ) -> Self { + Self { public_key, private_key, - salt: Some(generate_random_salt()), + salt: generate_random_salt(), remote_public_key: None, shared_aes_key: None, - }) + } } pub fn from_public_key_bytes(&mut self, public_key_bytes: &[u8]) -> Result<&mut Self> { - self.public_key = Some(UnparsedPublicKey::new(&X25519, public_key_bytes.to_owned())); + self.public_key = UnparsedPublicKey::new(&X25519, public_key_bytes.to_owned()); Ok(self) } @@ -170,8 +160,8 @@ impl OKE { let mut shared_key = SharedKey::new( self.private_key.take().unwrap(), self.remote_public_key.as_ref().unwrap(), - ); - self.shared_aes_key = Some(shared_key.hkdf(&self.salt.as_mut().unwrap())?); + )?; + self.shared_aes_key = Some(shared_key.hkdf(&self.salt)?); Ok(self) } @@ -180,12 +170,12 @@ impl OKE { let remote_public_key_bytes = stream.recv(remote_public_key_length).await?; self.remote_public_key = Some(UnparsedPublicKey::new(&X25519, remote_public_key_bytes)); let salt_length = stream.recv_usize().await?; - self.salt = Some(stream.recv(salt_length).await?); + self.salt = stream.recv(salt_length).await?; let mut shared_key = SharedKey::new( self.private_key.take().unwrap(), self.remote_public_key.as_ref().unwrap(), - ); - self.shared_aes_key = Some(shared_key.hkdf(&self.salt.as_mut().unwrap())?); + )?; + self.shared_aes_key = Some(shared_key.hkdf(&self.salt)?); Ok(self) } @@ -201,16 +191,15 @@ impl OKE { } pub fn plain_data(&mut self) -> Result> { - let public_key_bytes = self.public_key.clone().unwrap().as_ref().to_vec(); - let mut plain_data_bytes = length(&public_key_bytes)?.to_vec(); + let public_key_bytes = self.public_key.as_ref(); + let mut plain_data_bytes = length(public_key_bytes)?.to_vec(); plain_data_bytes.extend(public_key_bytes); Ok(plain_data_bytes) } pub fn plain_salt(&mut self) -> Result> { - let salt_bytes = self.salt.as_ref().unwrap(); - let mut plain_salt_bytes = length(&salt_bytes)?.to_vec(); - plain_salt_bytes.extend(salt_bytes); + let mut plain_salt_bytes = length(&self.salt)?.to_vec(); + plain_salt_bytes.extend(&self.salt); Ok(plain_salt_bytes) } @@ -220,7 +209,7 @@ impl OKE { } pub struct OED { - aes_key: Option>, + aes_key: Vec, data: Option>, encrypted_data: Option>, tag: Option>, @@ -229,7 +218,7 @@ pub struct OED { } impl OED { - pub fn new(aes_key: Option>) -> Self { + pub fn new(aes_key: Vec) -> Self { Self { aes_key, data: None, @@ -241,16 +230,14 @@ impl OED { } pub fn from_json_or_string(&mut self, json_or_str: String) -> Result<&mut Self, Exception> { - let (encrypted_data, tag, nonce) = - encrypt_plaintext(json_or_str, &self.aes_key.as_ref().unwrap())?; + let (encrypted_data, tag, nonce) = encrypt_plaintext(json_or_str, &self.aes_key)?; (self.encrypted_data, self.tag, self.nonce) = (Some(encrypted_data), Some(tag), Some(nonce)); Ok(self) } pub fn from_dict(&mut self, dict: Value) -> Result<&mut Self, Exception> { - let (encrypted_data, tag, nonce) = - encrypt_plaintext(dict.to_string(), &self.aes_key.as_ref().unwrap())?; + let (encrypted_data, tag, nonce) = encrypt_plaintext(dict.to_string(), &self.aes_key)?; (self.encrypted_data, self.tag, self.nonce) = (Some(encrypted_data), Some(tag), Some(nonce)); Ok(self) @@ -262,7 +249,7 @@ impl OED { } pub fn from_bytes(&mut self, data: Vec) -> Result<&mut Self, Exception> { - let (encrypted_data, tag, nonce) = encrypt_bytes(data, &self.aes_key.as_ref().unwrap())?; + let (encrypted_data, tag, nonce) = encrypt_bytes(data, &self.aes_key)?; (self.encrypted_data, self.tag, self.nonce) = (Some(encrypted_data), Some(tag), Some(nonce)); Ok(self) @@ -297,7 +284,7 @@ impl OED { match decrypt_bytes( self.encrypted_data.clone().unwrap(), self.tag.as_ref().unwrap(), - self.aes_key.as_ref().unwrap(), + &self.aes_key, self.nonce.as_ref().unwrap(), ) { Ok(data) => { @@ -315,7 +302,7 @@ impl OED { let encrypted_data = self.encrypted_data.as_ref().unwrap(); let mut remaining_data = &encrypted_data[..]; while !remaining_data.is_empty() { - let chunk_size = remaining_data.len().min(2048); + let chunk_size = remaining_data.len().min(1024); let chunk_length = chunk_size as u32; diff --git a/src/models/render.rs b/src/models/render.rs index 0e7cb4d..41d8c74 100644 --- a/src/models/render.rs +++ b/src/models/render.rs @@ -1,6 +1,5 @@ //! # Oblivion Render use anyhow::Result; -use futures::future::BoxFuture; use serde_json::Value; use crate::exceptions::Exception; @@ -12,64 +11,17 @@ pub enum BaseResponse { JsonResponse(Value, u32), } -pub type Response = BoxFuture<'static, Result>; - -pub struct FileResponse {} - -pub struct TextResponse { - status_code: u32, - text: String, -} - -impl TextResponse { - pub fn new(text: &str, status_code: u32) -> Self { - Self { - status_code, - text: text.to_string(), - } - } - - pub fn as_bytes(&self) -> Vec { - self.text.as_bytes().to_vec() - } - - pub fn get_status_code(&self) -> u32 { - self.status_code - } -} - -pub struct JsonResponse { - data: Value, - status_code: u32, -} - -impl JsonResponse { - pub fn new(data: Value, status_code: u32) -> Self { - Self { data, status_code } - } - - pub fn as_bytes(&self) -> Vec { - self.data.to_string().as_bytes().to_vec() - } - - pub fn get_status_code(&self) -> u32 { - self.status_code - } -} - impl BaseResponse { pub fn as_bytes(&self) -> Result, Exception> { match self { Self::FileResponse(_, _) => Err(Exception::UnsupportedMethod { method: "FileResponse".to_string(), }), - Self::TextResponse(text, status_code) => { - let tres = TextResponse::new(&text, *status_code); - Ok(tres.as_bytes()) + Self::TextResponse(text, _) => { + Ok(text.as_bytes().to_vec()) } - Self::JsonResponse(data, status_code) => { - let jres = JsonResponse::new(data.clone(), *status_code); - Ok(jres.as_bytes()) + Self::JsonResponse(data, _) => { + Ok(data.to_string().as_bytes().to_vec()) } } } @@ -79,13 +31,11 @@ impl BaseResponse { Self::FileResponse(_, _) => Err(Exception::UnsupportedMethod { method: "FileResponse".to_string(), }), - Self::TextResponse(text, status_code) => { - let tres = TextResponse::new(&text, *status_code); - Ok(tres.get_status_code()) + Self::TextResponse(_, status_code) => { + Ok(*status_code) } - Self::JsonResponse(data, status_code) => { - let jres = JsonResponse::new(data.clone(), *status_code); - Ok(jres.get_status_code()) + Self::JsonResponse(_, status_code) => { + Ok(*status_code) } } } diff --git a/src/models/server.rs b/src/models/server.rs index 2ce4a75..0d4a300 100644 --- a/src/models/server.rs +++ b/src/models/server.rs @@ -50,7 +50,7 @@ async fn _handle(router: &Router, stream: TcpStream, peer: SocketAddr) -> Result let status_code = callback.get_status_code()?; OSC::from_u32(1).to_stream(&socket).await?; - OED::new(Some(aes_key)) + OED::new(aes_key) .from_bytes(callback.as_bytes()?)? .to_stream(&socket) .await?; diff --git a/src/models/session.rs b/src/models/session.rs index ff8b0dc..cbb3426 100644 --- a/src/models/session.rs +++ b/src/models/session.rs @@ -83,7 +83,7 @@ impl Session { #[cfg(not(feature = "unsafe"))] let public_key = UnparsedPublicKey::new(&X25519, self.public_key.as_ref().to_vec()); #[cfg(not(feature = "unsafe"))] - let mut oke = OKE::new(self.private_key.take(), Some(public_key))?; + let mut oke = OKE::new(self.private_key.take(), public_key); oke.from_stream_with_salt(&socket).await?; self.aes_key = Some(oke.get_aes_key()); oke.to_stream(&socket).await?; @@ -103,7 +103,7 @@ impl Session { #[cfg(not(feature = "unsafe"))] let public_key = UnparsedPublicKey::new(&X25519, self.public_key.as_ref().to_vec()); #[cfg(not(feature = "unsafe"))] - let mut oke = OKE::new(self.private_key.take(), Some(public_key))?; + let mut oke = OKE::new(self.private_key.take(), public_key); oke.to_stream_with_salt(&socket).await?; oke.from_stream(&socket).await?; @@ -132,7 +132,7 @@ impl Session { let socket = &self.socket; OSC::from_u32(0).to_stream(socket).await?; - OED::new(self.aes_key.clone()) + OED::new(self.aes_key.clone().unwrap()) .from_bytes(data)? .to_stream(socket) .await?; @@ -157,7 +157,7 @@ impl Session { let socket = &self.socket; let flag = OSC::from_stream(socket).await?.status_code; - let content = OED::new(self.aes_key.clone()) + let content = OED::new(self.aes_key.clone().unwrap()) .from_stream(socket) .await? .get_data(); diff --git a/src/utils/generator.rs b/src/utils/generator.rs index 7348ef8..9a1e43d 100644 --- a/src/utils/generator.rs +++ b/src/utils/generator.rs @@ -30,8 +30,6 @@ use crate::exceptions::Exception; /// ``` #[cfg(not(feature = "unsafe"))] pub fn generate_key_pair() -> Result<(EphemeralPrivateKey, PublicKey), Exception> { - // let private_key = EphemeralSecret::random(&mut OsRng); - // let public_key = private_key.public_key(); let rng = SystemRandom::new(); let private_key = EphemeralPrivateKey::generate(&X25519, &rng).unwrap(); let public_key = private_key.compute_public_key().unwrap(); @@ -52,7 +50,7 @@ pub fn generate_key_pair() -> Result<(EphemeralSecret, PublicKey), Exception> { /// /// let salt = generate_random_salt(); /// let (private_key, public_key) = generate_key_pair().unwrap(); -/// +/// /// #[cfg(feature = "unsafe")] /// let mut shared_key = SharedKey::new(&private_key, &public_key); /// @@ -77,9 +75,13 @@ impl SharedKey { } #[cfg(not(feature = "unsafe"))] - pub fn new(private_key: EphemeralPrivateKey, public_key: &UnparsedPublicKey>) -> Self { - Self { - shared_key: agree_ephemeral(private_key, public_key, |key| key.to_vec()).unwrap(), + pub fn new( + private_key: EphemeralPrivateKey, + public_key: &UnparsedPublicKey>, + ) -> Result { + match agree_ephemeral(private_key, public_key, |key| key.to_vec()) { + Ok(shared_key) => Ok(Self { shared_key }), + Err(error) => Err(Exception::DecryptError { error }.into()), } } diff --git a/src/utils/parser.rs b/src/utils/parser.rs index 4aace8c..236b4d0 100644 --- a/src/utils/parser.rs +++ b/src/utils/parser.rs @@ -25,10 +25,10 @@ use crate::exceptions::Exception; /// ``` /// /// The `vec` in the above example is a `Vec` of length 39, and `length(&vec)` gets `b "0039".to_vec()`. -pub fn length(bytes: &Vec) -> Result<[u8; 4], Exception> { +pub fn length(bytes: &[u8]) -> Result<[u8; 4], Exception> { let size = bytes.len() as u32; - if size > 4096 { + if size > 2048 { return Err(Exception::DataTooLarge { size: size as usize, });