Skip to content

Commit

Permalink
feat(socket): warp socket into a arc and a mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
fu050409 committed Apr 21, 2024
1 parent 4591a86 commit e85afcf
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 66 deletions.
46 changes: 29 additions & 17 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::env::args;
use std::time::Instant;

#[async_route]
fn handler(_sess: &mut Session) -> Response {
fn handler(_sess: Session) -> Response {
Ok(BaseResponse::TextResponse(
"每一个人都应该拥有守护信息与获得真实信息的神圣权利, 任何与之对抗的都是我们的敌人"
.to_string(),
Expand All @@ -20,15 +20,18 @@ fn handler(_sess: &mut Session) -> Response {
}

#[async_route]
fn welcome(_sess: &mut Session) -> Response {
fn welcome(mut sess: Session) -> Response {
Ok(BaseResponse::TextResponse(
format!("欢迎进入信息绝对安全区, 来自[{}]的朋友", 1),
format!(
"欢迎进入信息绝对安全区, 来自[{}]的朋友",
sess.request.as_mut().unwrap().get_ip()
),
200,
))
}

#[async_route]
fn json(_sess: &mut Session) -> Response {
fn json(_sess: Session) -> Response {
Ok(BaseResponse::JsonResponse(
json!({"status": true, "msg": "只身堕入极暗之永夜, 以期再世涅槃之阳光"}),
200,
Expand All @@ -45,24 +48,33 @@ async fn alive(mut _sess: Session) -> Response {

#[tokio::main]
async fn main() -> Result<()> {
let args: Vec<String> = args().collect();
let is_server = if args.len() == 1 { true } else { false };
if !is_server {
loop {
let mut args: Vec<String> = args().collect();
if args.len() <= 1 {
args.push("serve".to_string());
}
match args[1].as_str() {
"bench" => loop {
let now = Instant::now();
get("127.0.0.1:7076/path").await?;
let mut res = get("127.0.0.1:7076/welcome").await?;
println!("{}", res.text()?);
println!("执行时间: {}", now.elapsed().as_millis());
}
} else {
let mut router = Router::new();
},
"socket" => todo!(),
"serve" => {
let mut router = Router::new();

router.route(RoutePath::new("/handler", RouteType::Path), handler);
router.route(RoutePath::new("/handler", RouteType::Path), handler);

path_route!(&mut router, "/welcome" => welcome);
path_route!(&mut router, "/json" => json);
path_route!(&mut router, "/welcome" => welcome);
path_route!(&mut router, "/json" => json);

let mut server = Server::new("0.0.0.0", 7076, router);
server.run().await?;
let mut server = Server::new("0.0.0.0", 7076, router);
server.run().await?;
}
_ => {
print!("未知的指令: {}", args[1]);
}
}

Ok(())
}
20 changes: 10 additions & 10 deletions src/models/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl Client {

pub async fn send(&mut self, bytes: Vec<u8>) -> Result<()> {
let session = self.session.as_mut().unwrap();
let tcp = &mut session.socket;
let tcp = &mut session.socket.lock().await;
let mut oed = OED::new(session.aes_key.clone());
oed.from_bytes(bytes)?;
oed.to_stream(tcp, 5).await?;
Expand All @@ -158,23 +158,23 @@ impl Client {

pub async fn recv(&mut self) -> Result<Response> {
let session = self.session.as_mut().unwrap();
let tcp = &mut session.socket;
let tcp = &mut session.socket.lock().await;

let flag = OSC::from_stream(tcp).await?;
let flag = OSC::from_stream(tcp).await?.status_code;

let mut oed = OED::new(session.aes_key.clone());
oed.from_stream(tcp, 5).await?;

let osc = OSC::from_stream(tcp).await?;
let content = OED::new(session.aes_key.clone())
.from_stream(tcp, 5)
.await?
.get_data();

let response = Response::new(
self.plain_text.clone(),
oed.get_data(),
content,
self.olps.clone(),
osc.status_code,
OSC::from_stream(tcp).await?.status_code,
);

if flag.status_code == 1 {
if flag == 1 {
tcp.close().await?;
}
Ok(response)
Expand Down
25 changes: 13 additions & 12 deletions src/models/handler.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
//! # Oblivion Default Handler
use super::render::{BaseResponse, Response};
use super::session::Session;
use futures::FutureExt;
use super::{
render::{BaseResponse, Response},
session::Session,
};
use oblivion_codegen::async_route;

/// Not Found Handler
///
/// Handling a non-existent route request.
pub fn not_found(session: &mut Session) -> Response {
let olps = session.request.as_mut().unwrap().get_ip();
async move {
Ok(BaseResponse::TextResponse(
format!("Path {} is not found, error with code 404.", olps),
404,
))
}
.boxed()
#[async_route]
pub fn not_found(mut sess: Session) -> Response {
let olps = sess.request.as_mut().unwrap().get_ip();

Ok(BaseResponse::TextResponse(
format!("Path {} is not found, error with code 404.", olps),
404,
))
}
8 changes: 1 addition & 7 deletions src/models/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use anyhow::Result;
use regex::Regex;
use std::collections::HashMap;

pub type Handler = fn(&mut Session) -> Response;
pub type Handler = fn(Session) -> Response;

#[derive(Clone)]
pub struct Route {
Expand All @@ -18,12 +18,6 @@ impl Route {
Self { handler }
}

pub fn clone(&mut self) -> Self {
Self {
handler: self.handler.clone(),
}
}

pub fn get_handler(&mut self) -> Handler {
self.handler.clone()
}
Expand Down
27 changes: 13 additions & 14 deletions src/models/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use anyhow::{Error, Result};
use chrono::Local;
use colored::Colorize;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;

use super::packet::{OED, OSC};
use super::router::Router;
Expand All @@ -17,11 +16,9 @@ use super::session::Session;
async fn _handle(router: &mut Router, mut socket: Socket, peer: SocketAddr) -> Result<String> {
socket.set_ttl(20)?;

let session = Arc::new(Mutex::new(Session::new(socket)?));
let mut session = Session::new(socket)?;

let mut brd_sess = session.lock().await;

if let Err(error) = brd_sess.handshake(1).await {
if let Err(error) = session.handshake(1).await {
eprintln!(
"{} -> [{}] \"{}\" {}",
peer.ip().to_string().cyan(),
Expand All @@ -32,23 +29,25 @@ async fn _handle(router: &mut Router, mut socket: Socket, peer: SocketAddr) -> R
return Err(Error::from(error));
}

let header = brd_sess.header.as_ref().unwrap().clone();
let ip_addr = brd_sess.request.as_mut().unwrap().get_ip();
let aes_key = brd_sess.aes_key.clone().unwrap();
let header = session.header.as_ref().unwrap().clone();
let ip_addr = session.request.as_mut().unwrap().get_ip();
let aes_key = session.aes_key.clone().unwrap();

let arc_socket = Arc::clone(&session.socket);
let mut socket = arc_socket.lock().await;

let mut route = router.get_handler(&brd_sess.request.as_ref().unwrap().olps)?;
let handler = route.get_handler();
let mut callback = handler(&mut brd_sess).await?;
let mut route = router.get_handler(&session.request.as_ref().unwrap().olps)?;
let mut callback = route.get_handler()(session).await?;

let status_code = callback.get_status_code()?;

OSC::from_u32(1).to_stream(&mut brd_sess.socket).await?;
OSC::from_u32(1).to_stream(&mut socket).await?;
OED::new(Some(aes_key))
.from_bytes(callback.as_bytes()?)?
.to_stream(&mut brd_sess.socket, 5)
.to_stream(&mut socket, 5)
.await?;
OSC::from_u32(callback.get_status_code()?)
.to_stream(&mut brd_sess.socket)
.to_stream(&mut socket)
.await?;

let display = format!(
Expand Down
15 changes: 9 additions & 6 deletions src/models/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::sync::Arc;

use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use p256::{ecdh::EphemeralSecret, PublicKey};
use tokio::sync::Mutex;

use crate::utils::gear::Socket;
use crate::utils::generator::generate_key_pair;
Expand All @@ -15,7 +18,7 @@ pub struct Session {
pub(crate) aes_key: Option<Vec<u8>>,
pub request_time: DateTime<Local>,
pub request: Option<OblivionRequest>,
pub socket: Socket,
pub socket: Arc<Mutex<Socket>>,
}

impl Session {
Expand All @@ -28,7 +31,7 @@ impl Session {
aes_key: None,
request_time: Local::now(),
request: None,
socket,
socket: Arc::new(Mutex::new(socket)),
})
}

Expand All @@ -41,12 +44,12 @@ impl Session {
aes_key: None,
request_time: Local::now(),
request: None,
socket,
socket: Arc::new(Mutex::new(socket)),
})
}

pub async fn first_hand(&mut self) -> Result<()> {
let socket = &mut self.socket;
let socket = &mut self.socket.lock().await;
let header = self.header.as_ref().unwrap().as_bytes();
socket
.send(&[&length(&header.to_vec())?, header].concat())
Expand All @@ -60,7 +63,7 @@ impl Session {
}

pub async fn second_hand(&mut self) -> Result<()> {
let socket = &mut self.socket;
let socket = &mut self.socket.lock().await;
let peer = socket.peer_addr()?;
let len_header = socket.recv_usize().await?;
let header = socket.recv_str(len_header).await?;
Expand Down Expand Up @@ -90,10 +93,10 @@ impl Session {

pub async fn send(
&mut self,
socket: &mut Socket,
data: Vec<u8>,
status_code: u32,
) -> Result<()> {
let socket = &mut self.socket.lock().await;
OSC::from_u32(0).to_stream(socket).await?;
OED::new(Some(self.aes_key.clone().unwrap()))
.from_bytes(data)?
Expand Down

0 comments on commit e85afcf

Please sign in to comment.