diff --git a/Cargo.lock b/Cargo.lock index 162040f..13154ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "anyhow" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" + [[package]] name = "autocfg" version = "1.2.0" @@ -460,6 +466,7 @@ dependencies = [ name = "oblivion" version = "1.1.0" dependencies = [ + "anyhow", "elliptic-curve", "futures", "oblivion-codegen", @@ -478,7 +485,7 @@ dependencies = [ [[package]] name = "oblivion-codegen" -version = "0.1.0" +version = "0.2.0" dependencies = [ "futures", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 163b129..a52d2fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ oblivion-codegen = { path = "oblivion-codegen" } proc-macro2 = "1" futures = "0.3" thiserror = "1" +anyhow = "1.0" pyo3 = { version = "0.20", optional = true } diff --git a/oblivion-codegen/Cargo.toml b/oblivion-codegen/Cargo.toml index 271a323..a46ace0 100644 --- a/oblivion-codegen/Cargo.toml +++ b/oblivion-codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "oblivion-codegen" -version = "0.1.0" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/oblivion-codegen/src/lib.rs b/oblivion-codegen/src/lib.rs index 02c41a7..16765b3 100644 --- a/oblivion-codegen/src/lib.rs +++ b/oblivion-codegen/src/lib.rs @@ -37,14 +37,15 @@ pub fn async_route(_: TokenStream, item: TokenStream) -> TokenStream { let func_name = &input.sig.ident; let func_args = &input.sig.inputs; + let func_return = &input.sig.output; let func_block = input.block; let expanded = quote! { - pub fn #func_name(#func_args) -> BoxFuture<'static, BaseResponse> + pub fn #func_name(#func_args) #func_return { - async move { + Box::pin(async move { #func_block - }.boxed() + }) } }; diff --git a/src/bin/main.rs b/src/bin/main.rs index 0afcc55..f2fa7bf 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -1,6 +1,5 @@ -use futures::future::{BoxFuture, FutureExt}; use oblivion::api::get; -use oblivion::models::render::BaseResponse; +use oblivion::models::render::{BaseResponse, Response}; use oblivion::models::router::{RoutePath, RouteType, Router}; use oblivion::models::server::Server; use oblivion::path_route; @@ -11,28 +10,28 @@ use std::env::args; use std::time::Instant; #[async_route] -fn handler(mut _req: OblivionRequest) -> BaseResponse { - BaseResponse::TextResponse( +fn handler(mut _req: OblivionRequest) -> Response { + Ok(BaseResponse::TextResponse( "每一个人都应该拥有守护信息与获得真实信息的神圣权利, 任何与之对抗的都是我们的敌人" .to_string(), 200, - ) + )) } #[async_route] -fn welcome(mut req: OblivionRequest) -> BaseResponse { - BaseResponse::TextResponse( +fn welcome(mut req: OblivionRequest) -> Response { + Ok(BaseResponse::TextResponse( format!("欢迎进入信息绝对安全区, 来自[{}]的朋友", req.get_ip()), 200, - ) + )) } #[async_route] -fn json(_req: OblivionRequest) -> BaseResponse { - BaseResponse::JsonResponse( +fn json(_req: OblivionRequest) -> Response { + Ok(BaseResponse::JsonResponse( json!({"status": true, "msg": "只身堕入极暗之永夜, 以期再世涅槃之阳光"}), 200, - ) + )) } #[tokio::main] diff --git a/src/models/handler.rs b/src/models/handler.rs index 5aa1199..c396899 100644 --- a/src/models/handler.rs +++ b/src/models/handler.rs @@ -1,19 +1,18 @@ //! # Oblivion Default Handler -use super::render::BaseResponse; +use super::{render::BaseResponse, render::Response}; use crate::utils::parser::OblivionRequest; -use futures::future::{BoxFuture, FutureExt}; use oblivion_codegen::async_route; /// Not Found Handler /// /// Handling a non-existent route request. #[async_route] -pub fn not_found(mut request: OblivionRequest) -> BaseResponse { - BaseResponse::TextResponse( +pub fn not_found(mut request: OblivionRequest) -> Response { + Ok(BaseResponse::TextResponse( format!( "Path {} is not found, error with code 404.", request.get_olps() ), 404, - ) + )) } diff --git a/src/models/render.rs b/src/models/render.rs index ae83f83..52e2e37 100644 --- a/src/models/render.rs +++ b/src/models/render.rs @@ -1,5 +1,7 @@ //! # Oblivion Render +use futures::future::BoxFuture; use serde_json::Value; +use anyhow::Result; use crate::exceptions::OblivionException; @@ -10,6 +12,8 @@ pub enum BaseResponse { JsonResponse(Value, i32), } +pub type Response = BoxFuture<'static, Result>; + pub struct FileResponse {} pub struct TextResponse { @@ -20,7 +24,7 @@ pub struct TextResponse { impl TextResponse { pub fn new(text: &str, status_code: i32) -> Result { Ok(Self { - status_code: status_code, + status_code, text: text.to_string(), }) } diff --git a/src/models/router.rs b/src/models/router.rs index 5593fcf..c527646 100644 --- a/src/models/router.rs +++ b/src/models/router.rs @@ -1,18 +1,17 @@ //! # Oblivion Router use super::handler::not_found; -use super::render::BaseResponse; +use super::render::Response; use crate::utils::parser::OblivionRequest; -use futures::future::BoxFuture; use regex::Regex; use std::collections::HashMap; #[derive(Clone)] pub struct Route { - handler: fn(OblivionRequest) -> BoxFuture<'static, BaseResponse>, + handler: fn(OblivionRequest) -> Response, } impl Route { - pub fn new(handler: fn(OblivionRequest) -> BoxFuture<'static, BaseResponse>) -> Self { + pub fn new(handler: fn(OblivionRequest) -> Response) -> Self { Self { handler: handler } } @@ -22,7 +21,7 @@ impl Route { } } - pub fn get_handler(&mut self) -> fn(OblivionRequest) -> BoxFuture<'static, BaseResponse> { + pub fn get_handler(&mut self) -> fn(OblivionRequest) -> Response { self.handler.clone() } } @@ -75,7 +74,7 @@ impl Router { pub fn route( &mut self, path: RoutePath, - handler: fn(OblivionRequest) -> BoxFuture<'static, BaseResponse>, + handler: fn(OblivionRequest) -> Response, ) -> &mut Self { self.routes.insert(path.clone(), Route { handler: handler }); self diff --git a/src/models/server.rs b/src/models/server.rs index e8113d4..c368d16 100644 --- a/src/models/server.rs +++ b/src/models/server.rs @@ -1,10 +1,8 @@ //! # Oblivion Server use std::net::SocketAddr; -use crate::models::packet::{OED, OKE, OSC}; - use crate::exceptions::OblivionException; - +use crate::models::packet::{OED, OKE, OSC}; use crate::utils::gear::Socket; use crate::utils::generator::generate_key_pair; use crate::utils::parser::OblivionRequest; @@ -12,6 +10,7 @@ use crate::utils::parser::OblivionRequest; use p256::ecdh::EphemeralSecret; use p256::PublicKey; +use anyhow::{anyhow, Error, Result}; use serde_json::from_slice; use tokio::net::{TcpListener, TcpStream}; @@ -27,7 +26,7 @@ pub struct ServerConnection { } impl ServerConnection { - pub fn new() -> Result { + pub fn new() -> Result { let (private_key, public_key) = generate_key_pair()?; Ok(Self { @@ -41,7 +40,7 @@ impl ServerConnection { &mut self, stream: &mut Socket, peer: SocketAddr, - ) -> Result { + ) -> Result { let len_header = stream.recv_len().await?; let header = stream.recv_str(len_header).await?; let mut request = OblivionRequest::new(&header)?; @@ -57,20 +56,20 @@ impl ServerConnection { if request.method == "POST" { let mut oed = OED::new(self.aes_key.clone()); oed.from_stream(stream, 5).await?; - request.set_post(from_slice(&oed.get_data()).unwrap()); + request.set_post(from_slice(&oed.get_data())?); } else if request.method == "GET" { } else if request.method == "PUT" { let mut oed = OED::new(self.aes_key.clone()); oed.from_stream(stream, 5).await?; - request.set_post(from_slice(&oed.get_data()).unwrap()); + request.set_post(from_slice(&oed.get_data())?); let mut oed = OED::new(self.aes_key.clone()); oed.from_stream(stream, 5).await?; request.set_put(oed.get_data()); } else { - return Err(OblivionException::UnsupportedMethod { + return Err(Error::from(OblivionException::UnsupportedMethod { method: request.method, - }); + })); }; Ok(request) } @@ -79,7 +78,7 @@ impl ServerConnection { &mut self, stream: &mut Socket, peer: SocketAddr, - ) -> Result { + ) -> Result { self.handshake(stream, peer).await } } @@ -92,9 +91,9 @@ pub async fn response( stream: &mut Socket, request: OblivionRequest, aes_key: Vec, -) -> Result { +) -> Result { let handler = route.get_handler(); - let mut callback = handler(request).await; + let mut callback = handler(request).await?; let mut oed = OED::new(Some(aes_key)); oed.from_bytes(callback.as_bytes()?)?; @@ -109,18 +108,18 @@ async fn _handle( router: &mut Router, stream: &mut Socket, peer: SocketAddr, -) -> Result<(OblivionRequest, i32), OblivionException> { +) -> Result<(OblivionRequest, i32)> { stream.set_ttl(20); let mut connection = ServerConnection::new()?; let mut request = match connection.solve(stream, peer).await { Ok(request) => request, Err(_) => { - return Err(OblivionException::ServerError { + return Err(anyhow!(OblivionException::ServerError { method: "CONNECT".to_string(), ipaddr: peer.ip().to_string(), olps: "-".to_string(), status_code: 500, - }) + })) } }; @@ -135,12 +134,12 @@ async fn _handle( { Ok(status_code) => status_code, Err(_) => { - return Err(OblivionException::ServerError { + return Err(anyhow!(OblivionException::ServerError { method: request.get_method(), ipaddr: request.get_ip(), olps: request.get_olps(), status_code: 501, - }); + })); } };