From 50c3a6894ec04868c904ed21de1c55b49320cbe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8B=8F=E5=90=91=E5=A4=9C?= Date: Sun, 21 Apr 2024 17:20:28 +0800 Subject: [PATCH] feat(packet): allow hkdf and scrpyt algorithm --- src/api.rs | 53 ----------------------------------------- src/bin/main.rs | 41 ++++++++++++++++++++++++++++---- src/lib.rs | 23 +++++++++--------- src/models/packet.rs | 24 ++++++++++++------- src/models/server.rs | 6 ++--- src/models/session.rs | 15 ++++++++---- src/sessions.rs | 28 ---------------------- src/utils/generator.rs | 54 +++++++++++++++++++++++++++++------------- src/utils/parser.rs | 2 +- 9 files changed, 113 insertions(+), 133 deletions(-) delete mode 100644 src/api.rs delete mode 100644 src/sessions.rs diff --git a/src/api.rs b/src/api.rs deleted file mode 100644 index 5017591..0000000 --- a/src/api.rs +++ /dev/null @@ -1,53 +0,0 @@ -//! # Oblivion API Interface -//! -//! Oblivion provides methods for making direct GET, POST, PUT, etc. requests. -use anyhow::Result; -use serde_json::Value; - -use crate::models::client::Response; - -use super::sessions::Session; - -/// Naked Oblivion Request Mode -/// -/// ```rust -/// use oblivion::api::request; -/// use oblivion::models::client::Response; -/// use oblivion::exceptions::OblivionException; -/// -/// #[tokio::test] -/// async fn run() -> Result { -/// request("get", "127.0.0.1:813/get", None, None, true).await -/// } -/// ``` -pub async fn request( - method: &str, - olps: &str, - data: Option, - file: Option>, -) -> Result { - let session = Session::new(); - session - .request(method.to_string(), olps.to_string(), data, file) - .await -} - -/// GET method -pub async fn get(olps: &str) -> Result { - request("get", olps, None, None).await -} - -/// POST method -pub async fn post(olps: &str, data: Value) -> Result { - request("post", olps, Some(data), None).await -} - -/// PUT method -pub async fn put(olps: &str, data: Option, file: Vec) -> Result { - request("put", olps, data, Some(file)).await -} - -#[deprecated(since = "1.0.0", note = "FORWARD method is no longer supported.")] -pub async fn forward(olps: &str, data: Option, file: Vec) -> Result { - request("forward", olps, data, Some(file)).await -} diff --git a/src/bin/main.rs b/src/bin/main.rs index 55ee060..98d5988 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -1,10 +1,11 @@ use anyhow::Result; -use oblivion::api::get; +use oblivion::models::client::Client; use oblivion::models::render::{BaseResponse, Response}; use oblivion::models::router::{RoutePath, RouteType, Router}; use oblivion::models::server::Server; use oblivion::models::session::Session; use oblivion::path_route; +use oblivion::utils::generator::{generate_key_pair, generate_random_salt, SharedKey}; use oblivion_codegen::async_route; use serde_json::json; use std::env::args; @@ -39,7 +40,8 @@ fn json(_sess: Session) -> Response { } #[async_route] -async fn alive(mut _sess: Session) -> Response { +async fn alive(mut sess: Session) -> Response { + sess.send("test".into(), 200).await?; Ok(BaseResponse::JsonResponse( json!({"status": true, "msg": "结束"}), 200, @@ -52,14 +54,42 @@ async fn main() -> Result<()> { if args.len() <= 1 { args.push("serve".to_string()); } + if args.len() <= 2 { + args.push("/welcome".to_string()); + } match args[1].as_str() { + "keygen" => { + let now = Instant::now(); + generate_key_pair()?; + println!("执行时间: {}", now.elapsed().as_millis()); + } + "dh" => { + let now = Instant::now(); + let (pr, pu) = generate_key_pair()?; + let (alice_pr, alice_pu) = generate_key_pair()?; + let salt = generate_random_salt(); + let mut shared_bob = SharedKey::new(&pr, &alice_pu); + let mut shared_alice = SharedKey::new(&alice_pr, &pu); + let bob_key = shared_bob.hkdf(&salt)?; + let alice_key = shared_alice.hkdf(&salt)?; + assert_eq!(bob_key, alice_key); + println!("执行时间: {}", now.elapsed().as_millis()); + } "bench" => loop { let now = Instant::now(); - let mut res = get("127.0.0.1:7076/welcome").await?; - println!("{}", res.text()?); + let mut client = Client::new("CONNECT", format!("127.0.0.1:7076{}", args[2]))?; + client.connect().await?; + client.recv().await?.text()?; + client.close().await?; println!("执行时间: {}", now.elapsed().as_millis()); }, - "socket" => todo!(), + "socket" => { + let mut client = Client::new("CONNECT", format!("127.0.0.1:7076{}", args[2]))?; + client.connect().await?; + client.recv().await?.text()?; + client.recv().await?.json()?; + client.close().await?; + } "serve" => { let mut router = Router::new(); @@ -67,6 +97,7 @@ async fn main() -> Result<()> { path_route!(&mut router, "/welcome" => welcome); path_route!(&mut router, "/json" => json); + path_route!(&mut router, "/alive" => alive); let mut server = Server::new("0.0.0.0", 7076, router); server.run().await?; diff --git a/src/lib.rs b/src/lib.rs index 7848029..4debce3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,9 +8,9 @@ //! which makes it possible to apply it to message dispatching and just-in-time communication. pub extern crate oblivion_codegen; pub extern crate proc_macro; -pub mod api; + +/// # Oblivion Exceptions pub mod exceptions; -pub mod sessions; /// # Oblivion Utilities /// @@ -34,15 +34,15 @@ pub mod models; /// /// ```rust /// use oblivion::path_route; -/// use oblivion::utils::parser::OblivionRequest; /// use oblivion::models::render::{BaseResponse, Response}; /// use oblivion_codegen::async_route; /// use oblivion::models::router::Router; +/// use oblivion::models::session::Session; /// /// #[async_route] -/// fn welcome(mut req: OblivionRequest) -> Response { +/// fn welcome(mut sess: Session) -> Response { /// Ok(BaseResponse::TextResponse( -/// format!("欢迎进入信息绝对安全区, 来自[{}]的朋友", req.get_ip()), +/// format!("欢迎进入信息绝对安全区, 来自[{}]的朋友", sess.get_ip()), /// 200, /// )) /// } @@ -69,15 +69,15 @@ macro_rules! path_route { /// /// ```rust /// use oblivion::startswith_route; -/// use oblivion::utils::parser::OblivionRequest; /// use oblivion::models::render::{BaseResponse, Response}; /// use oblivion_codegen::async_route; /// use oblivion::models::router::Router; +/// use oblivion::models::session::Session; /// /// #[async_route] -/// fn welcome(mut req: OblivionRequest) -> Response { +/// fn welcome(mut sess: Session) -> Response { /// Ok(BaseResponse::TextResponse( -/// format!("欢迎进入信息绝对安全区, 来自[{}]的朋友", req.get_ip()), +/// format!("欢迎进入信息绝对安全区, 来自[{}]的朋友", sess.get_ip()), /// 200, /// )) /// } @@ -106,17 +106,16 @@ macro_rules! startswith_route { /// Regular routing can be simply implemented using regular routing macros: /// /// ```rust -/// use futures::future::{BoxFuture, FutureExt}; /// use oblivion::regex_route; -/// use oblivion::utils::parser::OblivionRequest; /// use oblivion::models::render::{BaseResponse, Response}; /// use oblivion_codegen::async_route; /// use oblivion::models::router::Router; +/// use oblivion::models::session::Session; /// /// #[async_route] -/// fn welcome(mut req: OblivionRequest) -> Response { +/// fn welcome(mut sess: Session) -> Response { /// Ok(BaseResponse::TextResponse( -/// format!("欢迎进入信息绝对安全区, 来自[{}]的朋友", req.get_ip()), +/// format!("欢迎进入信息绝对安全区, 来自[{}]的朋友", sess.get_ip()), /// 200, /// )) /// } diff --git a/src/models/packet.rs b/src/models/packet.rs index 72e3bcd..786cfae 100644 --- a/src/models/packet.rs +++ b/src/models/packet.rs @@ -3,7 +3,7 @@ use crate::exceptions::OblivionException; use crate::utils::decryptor::decrypt_bytes; use crate::utils::encryptor::{encrypt_bytes, encrypt_plaintext}; use crate::utils::gear::Socket; -use crate::utils::generator::{generate_random_salt, generate_shared_key}; +use crate::utils::generator::{generate_random_salt, SharedKey}; use crate::utils::parser::length; use anyhow::{Error, Result}; @@ -66,6 +66,12 @@ impl OSC { } } +impl From for OSC { + fn from(value: u32) -> Self { + Self { status_code: value } + } +} + pub struct OKE<'a> { public_key: Option, private_key: Option<&'a EphemeralSecret>, @@ -97,11 +103,11 @@ impl<'a> OKE<'a> { let remote_public_key_length = stream.recv_usize().await?; let remote_public_key_bytes = stream.recv(remote_public_key_length).await?; self.remote_public_key = Some(PublicKey::from_sec1_bytes(&remote_public_key_bytes)?); - self.shared_aes_key = Some(generate_shared_key( + let mut shared_key = SharedKey::new( self.private_key.as_ref().unwrap(), self.remote_public_key.as_ref().unwrap(), - &self.salt.as_mut().unwrap(), - )?); + ); + self.shared_aes_key = Some(shared_key.hkdf(&self.salt.as_mut().unwrap())?); Ok(self) } @@ -111,11 +117,11 @@ impl<'a> OKE<'a> { self.remote_public_key = Some(PublicKey::from_sec1_bytes(&remote_public_key_bytes)?); let salt_length = stream.recv_usize().await?; self.salt = Some(stream.recv(salt_length).await?); - self.shared_aes_key = Some(generate_shared_key( - self.private_key.unwrap(), - &self.remote_public_key.unwrap(), - self.salt.as_mut().unwrap(), - )?); + let mut shared_key = SharedKey::new( + self.private_key.as_ref().unwrap(), + self.remote_public_key.as_ref().unwrap(), + ); + self.shared_aes_key = Some(shared_key.hkdf(&self.salt.as_mut().unwrap())?); Ok(self) } diff --git a/src/models/server.rs b/src/models/server.rs index f3c399a..665c7b2 100644 --- a/src/models/server.rs +++ b/src/models/server.rs @@ -29,18 +29,18 @@ async fn _handle(router: &mut Router, mut socket: Socket, peer: SocketAddr) -> R return Err(Error::from(error)); } - let header = session.header.as_ref().unwrap().clone(); - let ip_addr = session.request.as_mut().unwrap().get_ip(); + let header = session.header(); + let ip_addr = session.get_ip(); let aes_key = session.aes_key.clone().unwrap(); let arc_socket = Arc::clone(&session.socket); - let mut socket = arc_socket.lock().await; let mut route = router.get_handler(&session.request.as_ref().unwrap().olps)?; let mut callback = route.get_handler()(session).await?; let status_code = callback.get_status_code()?; + let mut socket = arc_socket.lock().await; OSC::from_u32(1).to_stream(&mut socket).await?; OED::new(Some(aes_key)) .from_bytes(callback.as_bytes()?)? diff --git a/src/models/session.rs b/src/models/session.rs index 7741bfd..b829ffb 100644 --- a/src/models/session.rs +++ b/src/models/session.rs @@ -91,12 +91,9 @@ impl Session { Ok(()) } - pub async fn send( - &mut self, - data: Vec, - status_code: u32, - ) -> Result<()> { + pub async fn send(&mut self, data: Vec, status_code: u32) -> Result<()> { let socket = &mut self.socket.lock().await; + OSC::from_u32(0).to_stream(socket).await?; OED::new(Some(self.aes_key.clone().unwrap())) .from_bytes(data)? @@ -105,4 +102,12 @@ impl Session { OSC::from_u32(status_code).to_stream(socket).await?; Ok(()) } + + pub fn header(&mut self) -> String { + self.header.clone().unwrap() + } + + pub fn get_ip(&mut self) -> String { + self.request.as_mut().unwrap().get_ip() + } } diff --git a/src/sessions.rs b/src/sessions.rs deleted file mode 100644 index a1289f2..0000000 --- a/src/sessions.rs +++ /dev/null @@ -1,28 +0,0 @@ -//! # Oblivion Sessions -use anyhow::Result; -use serde_json::Value; - -use crate::models::client::{Client, Response}; - -/// ## Oblivion Abstract Session -/// -/// Used to connect to the model and create a request session. -pub struct Session; - -impl Session { - pub fn new() -> Self { - Self - } - - pub async fn request( - &self, - method: String, - olps: String, - _data: Option, - _file: Option>, - ) -> Result { - let mut req = Client::new(method, olps)?; - req.prepare().await?; - Ok(req.recv().await?) - } -} diff --git a/src/utils/generator.rs b/src/utils/generator.rs index 84c24e0..2a69455 100644 --- a/src/utils/generator.rs +++ b/src/utils/generator.rs @@ -2,11 +2,14 @@ extern crate rand; extern crate ring; +use anyhow::Result; use elliptic_curve::rand_core::OsRng; +use p256::ecdh::SharedSecret; use p256::{ecdh::EphemeralSecret, EncodedPoint, PublicKey}; use ring::aead::AES_128_GCM; use ring::rand::{SecureRandom, SystemRandom}; use scrypt::{scrypt, Params}; +use sha2::Sha256; use crate::exceptions::OblivionException; @@ -31,28 +34,45 @@ pub fn generate_key_pair() -> Result<(EphemeralSecret, PublicKey), OblivionExcep /// Create an ECDH Shared Key /// /// ```rust -/// use oblivion::utils::generator::{generate_key_pair, generate_shared_key, generate_random_salt}; +/// use oblivion::utils::generator::{generate_key_pair, generate_random_salt, SharedKey}; /// /// let salt = generate_random_salt(); /// let (private_key, public_key) = generate_key_pair().unwrap(); /// -/// let shared_key = generate_shared_key(&private_key, &public_key, &salt).unwrap(); +/// let mut shared_key = SharedKey::new(&private_key, &public_key); +/// +/// shared_key.hkdf(&salt); +/// shared_key.scrypt(&salt); /// ``` -pub fn generate_shared_key( - private_key: &EphemeralSecret, - public_key: &PublicKey, - salt: &[u8], -) -> Result, OblivionException> { - let shared_key = private_key.diffie_hellman(&public_key); - let mut aes_key = [0u8; 16]; - match scrypt( - &shared_key.raw_secret_bytes().to_vec(), - &salt, - &Params::new(12, 8, 1, 16).unwrap(), - &mut aes_key, - ) { - Ok(_) => Ok(aes_key.to_vec()), - Err(error) => Err(OblivionException::InvalidOutputLen { error: error }), +pub struct SharedKey { + shared_key: SharedSecret, +} + +impl SharedKey { + pub fn new(private_key: &EphemeralSecret, public_key: &PublicKey) -> Self { + Self { + shared_key: private_key.diffie_hellman(&public_key), + } + } + + pub fn scrypt(&mut self, salt: &[u8]) -> Result> { + let mut aes_key = [0u8; 16]; + match scrypt( + &self.shared_key.raw_secret_bytes().to_vec(), + &salt, + &Params::new(12, 8, 1, 16).unwrap(), + &mut aes_key, + ) { + Ok(()) => Ok(aes_key.to_vec()), + Err(error) => Err(OblivionException::InvalidOutputLen { error }.into()), + } + } + + pub fn hkdf(&mut self, salt: &[u8]) -> Result> { + let key = self.shared_key.extract::(Some(salt)); + let mut aes_key = [0u8; 16]; + key.expand(&[], &mut aes_key).unwrap(); + Ok(aes_key.to_vec()) } } diff --git a/src/utils/parser.rs b/src/utils/parser.rs index 7025f14..d8d3811 100644 --- a/src/utils/parser.rs +++ b/src/utils/parser.rs @@ -21,7 +21,7 @@ use crate::exceptions::OblivionException; /// /// let vec = b"fw4rg45245ygergeqwrgqwerg342rg342gjisdu".to_vec(); /// -/// assert_eq!(b"0039".to_vec(), length(&vec).unwrap()); +/// assert_eq!((39 as u32).to_be_bytes(), length(&vec).unwrap()); /// ``` /// /// The `vec` in the above example is a `Vec` of length 39, and `length(&vec)` gets `b "0039".to_vec()`.