Skip to content

Commit

Permalink
fix: refine
Browse files Browse the repository at this point in the history
  • Loading branch information
zk-steve committed Apr 1, 2024
1 parent 6ac9c62 commit 4554da0
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 78 deletions.
47 changes: 10 additions & 37 deletions src/adapter/src/repositories/grpc/gpt_answer_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use anyhow::Error;
use tonic::{
async_trait,
transport::{Channel, Endpoint},
Expand All @@ -15,27 +14,10 @@ use rust_core::{common::errors::CoreError, ports::gpt_answer::GptAnswerPort};
/// methods for connecting to the service, sending a question, and receiving an answer.
#[derive(Clone)]
pub struct GptAnswerClient {
client: Option<GptAnswerServiceClient<Channel>>,
endpoint: Endpoint,
}

impl GptAnswerClient {
/// Creates a new `GptAnswerClient` instance with the provided gRPC endpoint.
///
/// # Arguments
///
/// * `endpoint`: An `Endpoint` representing the gRPC communication endpoint.
///
/// # Returns
///
/// Returns a new instance of `GptAnswerClient`.
fn new(endpoint: Endpoint) -> Self {
Self {
client: None,
endpoint,
}
}

/// Initializes a new `GptAnswerClient` instance with the provided URI.
///
/// # Arguments
Expand All @@ -46,12 +28,10 @@ impl GptAnswerClient {
///
/// Returns a `Result` containing the initialized `GptAnswerClient` if successful,
/// or a `CoreError` if an error occurs during initialization.
pub async fn init(uri: String) -> Result<Self, CoreError> {
// Establish connection to the gRPC server
let endpoint =
Channel::from_shared(uri).map_err(|err| CoreError::InternalError(err.into()))?;

Ok(Self::new(endpoint))
pub fn new(uri: String) -> Result<Self, CoreError> {
Channel::from_shared(uri)
.map_err(|err| CoreError::InternalError(err.into()))
.map(|endpoint| Self { endpoint })
}

/// Establishes a connection to the GPT answer service at the specified URI.
Expand All @@ -60,13 +40,10 @@ impl GptAnswerClient {
///
/// Returns a `Result` containing the connected `GptAnswerServiceClient` if successful,
/// or a `CoreError` if an error occurs during connection.
pub async fn connect(&mut self) -> Result<(), CoreError> {
let client = GptAnswerServiceClient::connect(self.endpoint.clone())
pub async fn connect(&self) -> Result<GptAnswerServiceClient<Channel>, CoreError> {
GptAnswerServiceClient::connect(self.endpoint.clone())
.await
.map_err(|err| CoreError::InternalError(err.into()))?;

self.client = Some(client);
Ok(())
.map_err(|err| CoreError::InternalError(err.into()))
}
}

Expand All @@ -83,17 +60,13 @@ impl GptAnswerPort for GptAnswerClient {
/// Returns a `Result` containing the generated answer as a `String` if successful,
/// or a `CoreError` if an error occurs during communication with the service.
async fn get_answer(&self, question: &str) -> Result<String, CoreError> {
let client = self
.client
.as_ref()
.ok_or_else(|| CoreError::InternalError(Error::msg("Client not initialized")))?;

let request = tonic::Request::new(GetAnswerPayload {
question: question.to_string(),
});

let response = client
.clone()
let response = self
.connect()
.await?
.get_answer(request)
.await
.map_err(|err| CoreError::InternalError(err.into()))?;
Expand Down
14 changes: 11 additions & 3 deletions src/gpt_answer_server/src/controllers/gpt_answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@ use common::grpc::gpt_answer::gpt_answer::{
/// the `GptAnswerService` trait generated by Tonic, which defines the RPC methods for answering
/// questions.
#[derive(Debug, Default)]
pub struct GptAnswerServer;
pub struct GptAnswerServiceImpl {
pub dummy_prop: String,
}

impl GptAnswerServiceImpl {
pub fn new(dummy_prop: String) -> Self {
Self { dummy_prop }
}
}

#[tonic::async_trait]
impl GptAnswerService for GptAnswerServer {
impl GptAnswerService for GptAnswerServiceImpl {
/// Handle the gRPC `get_answer` request.
///
/// This method is called when a gRPC client sends a request to get an answer to a question.
Expand All @@ -38,7 +46,7 @@ impl GptAnswerService for GptAnswerServer {

// TODO: Implement your logic to generate an answer based on the question.
// Placeholder logic: Generate an answer string
let answer = format!("Answer to: {}", payload.question);
let answer = format!("Answer to: {}, {}", payload.question, self.dummy_prop);

// Construct a response containing the generated answer
let response = GetAnswerResponse { answer };
Expand Down
17 changes: 8 additions & 9 deletions src/gpt_answer_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@ use tonic::transport::Server;
use common::grpc::gpt_answer::gpt_answer::gpt_answer_service_server::GptAnswerServiceServer;
use common::loggers::telemetry::init_telemetry;
use common::options::parse_options;
use gpt_answer_server::controllers::gpt_answer::GptAnswerServer;
use gpt_answer_server::controllers::gpt_answer::GptAnswerServiceImpl;
use gpt_answer_server::options::Options;

pub async fn init_grpc_server(options: Options) {
let server_endpoint = &options.server_endpoint;
let address = server_endpoint.parse().unwrap();
println!("Starting GPT Answer server at {}", server_endpoint);
pub async fn serve(options: Options) {
let address = options.server_endpoint.parse().unwrap();
println!("Starting GPT Answer server at {}", options.server_endpoint);

let gpt_answer_server = GptAnswerServer::default();
let gpt_answer_service = GptAnswerServiceImpl::new("dummy_prop".to_string());
Server::builder()
.add_service(GptAnswerServiceServer::new(gpt_answer_server))
.add_service(GptAnswerServiceServer::new(gpt_answer_service))
.serve(address)
.await
.unwrap();
Expand Down Expand Up @@ -48,9 +47,9 @@ async fn main() {
options.log.level.as_str(),
);

let gpt_answer_server = tokio::spawn(init_grpc_server(options));
let server = tokio::spawn(serve(options));

tokio::try_join!(gpt_answer_server).expect("Failed to run servers");
tokio::try_join!(server).expect("Failed to run servers");

global::shutdown_tracer_provider();
}
Expand Down
4 changes: 1 addition & 3 deletions src/public/src/controllers/question.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub async fn update_question(
#[instrument(level = "info", skip(question_port, gpt_answer_client))]
pub async fn get_question_answer(
question_port: Arc<dyn QuestionPort + Send + Sync>,
mut gpt_answer_client: GptAnswerClient,
gpt_answer_client: Arc<GptAnswerClient>,
id: String,
) -> Result<impl Reply, Rejection> {
let question_id = QuestionId::from_str(&id).map_err(WarpError::from)?;
Expand All @@ -142,8 +142,6 @@ pub async fn get_question_answer(
.await
.map_err(WarpError::from)?;

gpt_answer_client.connect().await.map_err(WarpError::from)?;

let answer = gpt_answer_client
.get_answer(&question.content)
.await
Expand Down
26 changes: 15 additions & 11 deletions src/public/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr;
use std::sync::Arc;

#[cfg_attr(debug_assertions, allow(dead_code, unused_imports))]
use openssl;
#[rustfmt::skip]
#[cfg_attr(debug_assertions, allow(dead_code, unused_imports))]
use diesel;

use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr;
use std::sync::Arc;

use clap::{Parser, Subcommand};
use deadpool_diesel::postgres::Pool;
use deadpool_diesel::{Manager, Runtime};
use opentelemetry::global;
use tracing::info;

use adapter::repositories::grpc::gpt_answer_client::GptAnswerClient;
use adapter::repositories::in_memory::question::QuestionInMemoryRepository;
use adapter::repositories::postgres::question_db::QuestionDBRepository;
use cli::options::Options;
Expand Down Expand Up @@ -48,8 +50,8 @@ async fn main() {
options.log.level.as_str(),
);

let warp_server = tokio::spawn(run_warp_server(options));
tokio::try_join!(warp_server).expect("Failed to run servers");
let server = tokio::spawn(serve(options));
tokio::try_join!(server).expect("Failed to run servers");

global::shutdown_tracer_provider();
}
Expand All @@ -74,7 +76,7 @@ enum Commands {
Config,
}

pub async fn run_warp_server(options: Options) {
pub async fn serve(options: Options) {
let question_port: Arc<dyn QuestionPort + Send + Sync> = if options.db.in_memory.is_some() {
info!("Using in-memory database");
Arc::new(QuestionInMemoryRepository::new())
Expand All @@ -92,13 +94,15 @@ pub async fn run_warp_server(options: Options) {
Arc::new(QuestionInMemoryRepository::new())
};

let grpc_client = options.gpt_answer_service_url.clone();
let router = Router::new(question_port, grpc_client.into());
let routers = router.routes().await;
let gpt_answer_client =
Arc::new(GptAnswerClient::new(options.gpt_answer_service_url.to_string()).unwrap());

let router = Router::new(question_port, gpt_answer_client);
let routes = router.routes();
let address = SocketAddrV4::new(
Ipv4Addr::from_str(options.server.url.as_str()).unwrap(),
options.server.port,
);

warp::serve(routers).run(address).await
warp::serve(routes).run(address).await
}
19 changes: 6 additions & 13 deletions src/public/src/router.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::sync::Arc;

use warp::http::Method;
use warp::{Filter, Rejection, Reply};

Expand All @@ -14,33 +15,25 @@ use crate::errors::return_error;
/// Router for handling HTTP requests related to questions.
pub struct Router {
question_port: Arc<dyn QuestionPort + Send + Sync + 'static>,
gpt_answer_service_url: Arc<String>,
gpt_answer_client: Arc<GptAnswerClient>,
}

impl Router {
/// Creates a new Router instance with the specified QuestionPort.
pub fn new(
question_port: Arc<dyn QuestionPort + Send + Sync + 'static>,
gpt_answer_service_url: Arc<String>,
gpt_answer_client: Arc<GptAnswerClient>,
) -> Self {
Router {
question_port: question_port.clone(),
gpt_answer_service_url: gpt_answer_service_url.clone(),
gpt_answer_client,
}
}

/// Configures and returns the Warp filter for handling HTTP requests.
pub async fn routes(self) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
pub fn routes(self) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
let store_filter = warp::any().map(move || self.question_port.clone());

let gpt_answer_client = GptAnswerClient::init(self.gpt_answer_service_url.to_string())
.await
.map_err(|err| {
tracing::error!("Failed to init GPT answer service: {:?}", err);
err
})
.unwrap();

let cors = warp::cors()
.allow_any_origin()
.allow_header("content-type")
Expand Down Expand Up @@ -85,7 +78,7 @@ impl Router {
let get_question_answer = warp::get()
.and(warp::path("questions"))
.and(store_filter.clone())
.and(warp::any().map(move || gpt_answer_client.clone()))
.and(warp::any().map(move || self.gpt_answer_client.clone()))
.and(warp::path::param::<String>())
.and(warp::path("answer"))
.and_then(get_question_answer);
Expand Down
9 changes: 7 additions & 2 deletions src/public/tests/questions_router_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod tests {
use warp::test::request;

use adapter::repositories::{
grpc::gpt_answer_client::GptAnswerClient,
in_memory::question::QuestionInMemoryRepository,
postgres::question_db::{QuestionDBRepository, MIGRATIONS},
};
Expand All @@ -32,8 +33,12 @@ mod tests {
{
let gpt_answer_service_url = "grpc://0.0.0.0:50051".to_string();

let router = Router::new(question_port, gpt_answer_service_url.into());
let routers = router.routes().await;
let gpt_answer_client: Arc<
adapter::repositories::grpc::gpt_answer_client::GptAnswerClient,
> = Arc::new(GptAnswerClient::new(gpt_answer_service_url.to_string()).unwrap());

let router = Router::new(question_port, gpt_answer_client);
let routers = router.routes();

let raw_question_id: String = rand::thread_rng().gen_range(1..=1000).to_string();
let question_id = QuestionId::from_str(&raw_question_id.clone()).unwrap();
Expand Down

0 comments on commit 4554da0

Please sign in to comment.