Skip to content

Commit

Permalink
perf: improve code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
fu050409 committed Apr 25, 2024
1 parent ea0ff66 commit 97fcf77
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 112 deletions.
69 changes: 28 additions & 41 deletions src/models/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,15 @@ impl OSC {
}

pub async fn from_stream(stream: &Socket) -> Result<Self> {
let status_code = stream.recv_u32().await?;
Ok(Self { status_code })
Ok(Self {
status_code: stream.recv_u32().await?,
})
}

pub async fn to_stream(&mut self, stream: &Socket) -> Result<()> {
stream.send(&self.plain_data()).await?;
stream.send(&self.status_code.to_be_bytes()).await?;
Ok(())
}

pub fn plain_data(&mut self) -> [u8; 4] {
let status_code = self.status_code as u32;
status_code.to_be_bytes()
}
}

impl From<u32> for OSC {
fn from(value: u32) -> Self {
Self { status_code: value }
}
}

#[cfg(feature = "unsafe")]
Expand Down Expand Up @@ -136,9 +126,9 @@ impl<'a> OKE<'a> {

#[cfg(not(feature = "unsafe"))]
pub struct OKE {
public_key: Option<UnparsedPublicKey<Vec<u8>>>,
public_key: UnparsedPublicKey<Vec<u8>>,
private_key: Option<EphemeralPrivateKey>,
salt: Option<Vec<u8>>,
salt: Vec<u8>,
remote_public_key: Option<UnparsedPublicKey<Vec<u8>>>,
shared_aes_key: Option<Vec<u8>>,
}
Expand All @@ -147,19 +137,19 @@ pub struct OKE {
impl OKE {
pub fn new(
private_key: Option<EphemeralPrivateKey>,
public_key: Option<UnparsedPublicKey<Vec<u8>>>,
) -> Result<Self, Exception> {
Ok(Self {
public_key: UnparsedPublicKey<Vec<u8>>,
) -> Self {
Self {
public_key,
private_key,
salt: Some(generate_random_salt()),
salt: generate_random_salt(),
remote_public_key: None,
shared_aes_key: None,
})
}
}

pub fn from_public_key_bytes(&mut self, public_key_bytes: &[u8]) -> Result<&mut Self> {
self.public_key = Some(UnparsedPublicKey::new(&X25519, public_key_bytes.to_owned()));
self.public_key = UnparsedPublicKey::new(&X25519, public_key_bytes.to_owned());
Ok(self)
}

Expand All @@ -170,8 +160,8 @@ impl OKE {
let mut shared_key = SharedKey::new(
self.private_key.take().unwrap(),
self.remote_public_key.as_ref().unwrap(),
);
self.shared_aes_key = Some(shared_key.hkdf(&self.salt.as_mut().unwrap())?);
)?;
self.shared_aes_key = Some(shared_key.hkdf(&self.salt)?);
Ok(self)
}

Expand All @@ -180,12 +170,12 @@ impl OKE {
let remote_public_key_bytes = stream.recv(remote_public_key_length).await?;
self.remote_public_key = Some(UnparsedPublicKey::new(&X25519, remote_public_key_bytes));
let salt_length = stream.recv_usize().await?;
self.salt = Some(stream.recv(salt_length).await?);
self.salt = stream.recv(salt_length).await?;
let mut shared_key = SharedKey::new(
self.private_key.take().unwrap(),
self.remote_public_key.as_ref().unwrap(),
);
self.shared_aes_key = Some(shared_key.hkdf(&self.salt.as_mut().unwrap())?);
)?;
self.shared_aes_key = Some(shared_key.hkdf(&self.salt)?);
Ok(self)
}

Expand All @@ -201,16 +191,15 @@ impl OKE {
}

pub fn plain_data(&mut self) -> Result<Vec<u8>> {
let public_key_bytes = self.public_key.clone().unwrap().as_ref().to_vec();
let mut plain_data_bytes = length(&public_key_bytes)?.to_vec();
let public_key_bytes = self.public_key.as_ref();
let mut plain_data_bytes = length(public_key_bytes)?.to_vec();
plain_data_bytes.extend(public_key_bytes);
Ok(plain_data_bytes)
}

pub fn plain_salt(&mut self) -> Result<Vec<u8>> {
let salt_bytes = self.salt.as_ref().unwrap();
let mut plain_salt_bytes = length(&salt_bytes)?.to_vec();
plain_salt_bytes.extend(salt_bytes);
let mut plain_salt_bytes = length(&self.salt)?.to_vec();
plain_salt_bytes.extend(&self.salt);
Ok(plain_salt_bytes)
}

Expand All @@ -220,7 +209,7 @@ impl OKE {
}

pub struct OED {
aes_key: Option<Vec<u8>>,
aes_key: Vec<u8>,
data: Option<Vec<u8>>,
encrypted_data: Option<Vec<u8>>,
tag: Option<Vec<u8>>,
Expand All @@ -229,7 +218,7 @@ pub struct OED {
}

impl OED {
pub fn new(aes_key: Option<Vec<u8>>) -> Self {
pub fn new(aes_key: Vec<u8>) -> Self {
Self {
aes_key,
data: None,
Expand All @@ -241,16 +230,14 @@ impl OED {
}

pub fn from_json_or_string(&mut self, json_or_str: String) -> Result<&mut Self, Exception> {
let (encrypted_data, tag, nonce) =
encrypt_plaintext(json_or_str, &self.aes_key.as_ref().unwrap())?;
let (encrypted_data, tag, nonce) = encrypt_plaintext(json_or_str, &self.aes_key)?;
(self.encrypted_data, self.tag, self.nonce) =
(Some(encrypted_data), Some(tag), Some(nonce));
Ok(self)
}

pub fn from_dict(&mut self, dict: Value) -> Result<&mut Self, Exception> {
let (encrypted_data, tag, nonce) =
encrypt_plaintext(dict.to_string(), &self.aes_key.as_ref().unwrap())?;
let (encrypted_data, tag, nonce) = encrypt_plaintext(dict.to_string(), &self.aes_key)?;
(self.encrypted_data, self.tag, self.nonce) =
(Some(encrypted_data), Some(tag), Some(nonce));
Ok(self)
Expand All @@ -262,7 +249,7 @@ impl OED {
}

pub fn from_bytes(&mut self, data: Vec<u8>) -> Result<&mut Self, Exception> {
let (encrypted_data, tag, nonce) = encrypt_bytes(data, &self.aes_key.as_ref().unwrap())?;
let (encrypted_data, tag, nonce) = encrypt_bytes(data, &self.aes_key)?;
(self.encrypted_data, self.tag, self.nonce) =
(Some(encrypted_data), Some(tag), Some(nonce));
Ok(self)
Expand Down Expand Up @@ -297,7 +284,7 @@ impl OED {
match decrypt_bytes(
self.encrypted_data.clone().unwrap(),
self.tag.as_ref().unwrap(),
self.aes_key.as_ref().unwrap(),
&self.aes_key,
self.nonce.as_ref().unwrap(),
) {
Ok(data) => {
Expand All @@ -315,7 +302,7 @@ impl OED {
let encrypted_data = self.encrypted_data.as_ref().unwrap();
let mut remaining_data = &encrypted_data[..];
while !remaining_data.is_empty() {
let chunk_size = remaining_data.len().min(2048);
let chunk_size = remaining_data.len().min(1024);

let chunk_length = chunk_size as u32;

Expand Down
66 changes: 8 additions & 58 deletions src/models/render.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! # Oblivion Render
use anyhow::Result;
use futures::future::BoxFuture;
use serde_json::Value;

use crate::exceptions::Exception;
Expand All @@ -12,64 +11,17 @@ pub enum BaseResponse {
JsonResponse(Value, u32),
}

pub type Response = BoxFuture<'static, Result<BaseResponse>>;

pub struct FileResponse {}

pub struct TextResponse {
status_code: u32,
text: String,
}

impl TextResponse {
pub fn new(text: &str, status_code: u32) -> Self {
Self {
status_code,
text: text.to_string(),
}
}

pub fn as_bytes(&self) -> Vec<u8> {
self.text.as_bytes().to_vec()
}

pub fn get_status_code(&self) -> u32 {
self.status_code
}
}

pub struct JsonResponse {
data: Value,
status_code: u32,
}

impl JsonResponse {
pub fn new(data: Value, status_code: u32) -> Self {
Self { data, status_code }
}

pub fn as_bytes(&self) -> Vec<u8> {
self.data.to_string().as_bytes().to_vec()
}

pub fn get_status_code(&self) -> u32 {
self.status_code
}
}

impl BaseResponse {
pub fn as_bytes(&self) -> Result<Vec<u8>, Exception> {
match self {
Self::FileResponse(_, _) => Err(Exception::UnsupportedMethod {
method: "FileResponse".to_string(),
}),
Self::TextResponse(text, status_code) => {
let tres = TextResponse::new(&text, *status_code);
Ok(tres.as_bytes())
Self::TextResponse(text, _) => {
Ok(text.as_bytes().to_vec())
}
Self::JsonResponse(data, status_code) => {
let jres = JsonResponse::new(data.clone(), *status_code);
Ok(jres.as_bytes())
Self::JsonResponse(data, _) => {
Ok(data.to_string().as_bytes().to_vec())
}
}
}
Expand All @@ -79,13 +31,11 @@ impl BaseResponse {
Self::FileResponse(_, _) => Err(Exception::UnsupportedMethod {
method: "FileResponse".to_string(),
}),
Self::TextResponse(text, status_code) => {
let tres = TextResponse::new(&text, *status_code);
Ok(tres.get_status_code())
Self::TextResponse(_, status_code) => {
Ok(*status_code)
}
Self::JsonResponse(data, status_code) => {
let jres = JsonResponse::new(data.clone(), *status_code);
Ok(jres.get_status_code())
Self::JsonResponse(_, status_code) => {
Ok(*status_code)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async fn _handle(router: &Router, stream: TcpStream, peer: SocketAddr) -> Result
let status_code = callback.get_status_code()?;

OSC::from_u32(1).to_stream(&socket).await?;
OED::new(Some(aes_key))
OED::new(aes_key)
.from_bytes(callback.as_bytes()?)?
.to_stream(&socket)
.await?;
Expand Down
8 changes: 4 additions & 4 deletions src/models/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl Session {
#[cfg(not(feature = "unsafe"))]
let public_key = UnparsedPublicKey::new(&X25519, self.public_key.as_ref().to_vec());
#[cfg(not(feature = "unsafe"))]
let mut oke = OKE::new(self.private_key.take(), Some(public_key))?;
let mut oke = OKE::new(self.private_key.take(), public_key);
oke.from_stream_with_salt(&socket).await?;
self.aes_key = Some(oke.get_aes_key());
oke.to_stream(&socket).await?;
Expand All @@ -103,7 +103,7 @@ impl Session {
#[cfg(not(feature = "unsafe"))]
let public_key = UnparsedPublicKey::new(&X25519, self.public_key.as_ref().to_vec());
#[cfg(not(feature = "unsafe"))]
let mut oke = OKE::new(self.private_key.take(), Some(public_key))?;
let mut oke = OKE::new(self.private_key.take(), public_key);
oke.to_stream_with_salt(&socket).await?;
oke.from_stream(&socket).await?;

Expand Down Expand Up @@ -132,7 +132,7 @@ impl Session {
let socket = &self.socket;

OSC::from_u32(0).to_stream(socket).await?;
OED::new(self.aes_key.clone())
OED::new(self.aes_key.clone().unwrap())
.from_bytes(data)?
.to_stream(socket)
.await?;
Expand All @@ -157,7 +157,7 @@ impl Session {
let socket = &self.socket;

let flag = OSC::from_stream(socket).await?.status_code;
let content = OED::new(self.aes_key.clone())
let content = OED::new(self.aes_key.clone().unwrap())
.from_stream(socket)
.await?
.get_data();
Expand Down
14 changes: 8 additions & 6 deletions src/utils/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ use crate::exceptions::Exception;
/// ```
#[cfg(not(feature = "unsafe"))]
pub fn generate_key_pair() -> Result<(EphemeralPrivateKey, PublicKey), Exception> {
// let private_key = EphemeralSecret::random(&mut OsRng);
// let public_key = private_key.public_key();
let rng = SystemRandom::new();
let private_key = EphemeralPrivateKey::generate(&X25519, &rng).unwrap();
let public_key = private_key.compute_public_key().unwrap();
Expand All @@ -52,7 +50,7 @@ pub fn generate_key_pair() -> Result<(EphemeralSecret, PublicKey), Exception> {
///
/// let salt = generate_random_salt();
/// let (private_key, public_key) = generate_key_pair().unwrap();
///
///
/// #[cfg(feature = "unsafe")]
/// let mut shared_key = SharedKey::new(&private_key, &public_key);
///
Expand All @@ -77,9 +75,13 @@ impl SharedKey {
}

#[cfg(not(feature = "unsafe"))]
pub fn new(private_key: EphemeralPrivateKey, public_key: &UnparsedPublicKey<Vec<u8>>) -> Self {
Self {
shared_key: agree_ephemeral(private_key, public_key, |key| key.to_vec()).unwrap(),
pub fn new(
private_key: EphemeralPrivateKey,
public_key: &UnparsedPublicKey<Vec<u8>>,
) -> Result<Self> {
match agree_ephemeral(private_key, public_key, |key| key.to_vec()) {
Ok(shared_key) => Ok(Self { shared_key }),
Err(error) => Err(Exception::DecryptError { error }.into()),
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/utils/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ use crate::exceptions::Exception;
/// ```
///
/// The `vec` in the above example is a `Vec<u8>` of length 39, and `length(&vec)` gets `b "0039".to_vec()`.
pub fn length(bytes: &Vec<u8>) -> Result<[u8; 4], Exception> {
pub fn length(bytes: &[u8]) -> Result<[u8; 4], Exception> {
let size = bytes.len() as u32;

if size > 4096 {
if size > 2048 {
return Err(Exception::DataTooLarge {
size: size as usize,
});
Expand Down

0 comments on commit 97fcf77

Please sign in to comment.