Skip to content

Commit

Permalink
[coro_rpc][fix] keep protocol compatible/use context<T>::tag() instea…
Browse files Browse the repository at this point in the history
…d o… (#555)
  • Loading branch information
poor-circle authored Jan 3, 2024
1 parent 73a0016 commit 3d8252f
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 130 deletions.
15 changes: 5 additions & 10 deletions include/ylt/coro_rpc/impl/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ class context_base {

using return_type = return_msg_type;

void response_error(coro_rpc::errc error_code, std::string_view error_msg) {
void response_error(coro_rpc::err_code error_code,
std::string_view error_msg) {
if (!check_status())
AS_UNLIKELY { return; };
self_->conn_->template response_error<rpc_protocol>(
error_code, error_msg, self_->req_head_, self_->is_delay_);
}

void response_error(coro_rpc::errc error_code) {
void response_error(coro_rpc::err_code error_code) {
response_error(error_code, make_error_message(error_code));
}
/*!
Expand Down Expand Up @@ -190,13 +190,8 @@ class context_base {
self_->conn_->set_rpc_call_type(
coro_connection::rpc_call_type::callback_with_delay);
}

template <typename T>
void set_tag(T &&tag) {
self_->conn_->set_tag(std::forward<T>(tag));
}

std::any get_tag() { return self_->conn_->get_tag(); }
std::any &tag() { return self_->conn_->tag(); }
const std::any &tag() const { return self_->conn_->tag(); }
};

template <typename T>
Expand Down
12 changes: 4 additions & 8 deletions include/ylt/coro_rpc/impl/coro_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
return {};
};
std::string body_buf;
std::string header_buf =
rpc_protocol::prepare_response(body_buf, req_head, 0, ec, error_msg);
std::string header_buf = rpc_protocol::prepare_response(
body_buf, req_head, 0, ec, error_msg, true);
response(std::move(header_buf), std::move(body_buf), std::move(attach_ment),
shared_from_this(), is_delay)
.via(executor_)
Expand Down Expand Up @@ -356,12 +356,8 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
conn_id_ = conn_id;
}

template <typename T>
void set_tag(T &&tag) {
tag_ = std::forward<T>(tag);
}

std::any get_tag() { return tag_; }
std::any &tag() { return tag_; }
const std::any &tag() const { return tag_; }

auto &get_executor() { return *executor_; }

Expand Down
64 changes: 38 additions & 26 deletions include/ylt/coro_rpc/impl/coro_rpc_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class coro_rpc_client {
* @param timeout_duration RPC call timeout
* @return error code
*/
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::errc> reconnect(
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::err_code> reconnect(
std::string host, std::string port,
std::chrono::steady_clock::duration timeout_duration =
std::chrono::seconds(5)) {
Expand All @@ -191,7 +191,7 @@ class coro_rpc_client {
return connect(is_reconnect_t{true});
}

[[nodiscard]] async_simple::coro::Lazy<coro_rpc::errc> reconnect(
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::err_code> reconnect(
std::string endpoint,
std::chrono::steady_clock::duration timeout_duration =
std::chrono::seconds(5)) {
Expand All @@ -214,7 +214,7 @@ class coro_rpc_client {
* @param timeout_duration RPC call timeout
* @return error code
*/
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::errc> connect(
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::err_code> connect(
std::string host, std::string port,
std::chrono::steady_clock::duration timeout_duration =
std::chrono::seconds(5)) {
Expand All @@ -224,7 +224,7 @@ class coro_rpc_client {
std::chrono::duration_cast<std::chrono::milliseconds>(timeout_duration);
return connect();
}
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::errc> connect(
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::err_code> connect(
std::string_view endpoint,
std::chrono::steady_clock::duration timeout_duration =
std::chrono::seconds(5)) {
Expand Down Expand Up @@ -389,8 +389,8 @@ class coro_rpc_client {
is_timeout_ = false;
has_closed_ = false;
}
static bool is_ok(coro_rpc::errc ec) noexcept { return !ec; }
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::errc> connect(
static bool is_ok(coro_rpc::err_code ec) noexcept { return !ec; }
[[nodiscard]] async_simple::coro::Lazy<coro_rpc::err_code> connect(
is_reconnect_t is_reconnect = is_reconnect_t{false}) {
#ifdef YLT_ENABLE_SSL
if (!ssl_init_ret_) {
Expand Down Expand Up @@ -447,7 +447,7 @@ class coro_rpc_client {
}
#endif

co_return coro_rpc::errc{};
co_return coro_rpc::err_code{};
};
#ifdef YLT_ENABLE_SSL
[[nodiscard]] bool init_ssl_impl() {
Expand Down Expand Up @@ -664,9 +664,9 @@ class coro_rpc_client {
file << resp_attachment_buf_;
file.close();
#endif
r = handle_response_buffer<R>(read_buf_,
coro_rpc::errc{header.err_code});
if (!r) {
bool ec = false;
r = handle_response_buffer<R>(read_buf_, header.err_code, ec);
if (ec) {
close();
}
co_return r;
Expand Down Expand Up @@ -740,29 +740,41 @@ class coro_rpc_client {
}

template <typename T>
rpc_result<T, coro_rpc_protocol> handle_response_buffer(
std::string &buffer, coro_rpc::errc rpc_errc) {
rpc_result<T, coro_rpc_protocol> handle_response_buffer(std::string &buffer,
uint8_t rpc_errc,
bool &error_happen) {
rpc_return_type_t<T> ret;
struct_pack::errc ec;
coro_rpc_protocol::rpc_error err;
if (rpc_errc == coro_rpc::errc{}) {
ec = struct_pack::deserialize_to(ret, buffer);
if (ec == struct_pack::errc::ok) {
if constexpr (std::is_same_v<T, void>) {
return {};
}
else {
return std::move(ret);
if (rpc_errc == 0)
AS_LIKELY {
ec = struct_pack::deserialize_to(ret, buffer);
if (ec == struct_pack::errc::ok) {
if constexpr (std::is_same_v<T, void>) {
return {};
}
else {
return std::move(ret);
}
}
}
}
else {
err.code = rpc_errc;
ec = struct_pack::deserialize_to(err.msg, buffer);
if (ec == struct_pack::errc::ok) {
return rpc_result<T, coro_rpc_protocol>{unexpect_t{}, std::move(err)};
if (rpc_errc != UINT8_MAX) {
ec = struct_pack::deserialize_to(err.msg, buffer);
if (ec == struct_pack::errc::ok) {
error_happen = true;
return rpc_result<T, coro_rpc_protocol>{unexpect_t{}, std::move(err)};
}
}
else {
ec = struct_pack::deserialize_to(err, buffer);
if (ec == struct_pack::errc::ok) {
return rpc_result<T, coro_rpc_protocol>{unexpect_t{}, std::move(err)};
}
}
}
error_happen = true;
// deserialize failed.
err = {errc::invalid_argument, "failed to deserialize rpc return value"};
return rpc_result<T, coro_rpc_protocol>{unexpect_t{}, std::move(err)};
Expand Down Expand Up @@ -807,8 +819,8 @@ class coro_rpc_client {

#ifdef UNIT_TEST_INJECT
public:
coro_rpc::errc sync_connect(const std::string &host,
const std::string &port) {
coro_rpc::err_code sync_connect(const std::string &host,
const std::string &port) {
return async_simple::coro::syncAwait(connect(host, port));
}

Expand Down
20 changes: 10 additions & 10 deletions include/ylt/coro_rpc/impl/coro_rpc_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class coro_rpc_server_base {
*
* @return error code if start failed, otherwise block until server stop.
*/
[[nodiscard]] coro_rpc::errc start() noexcept {
[[nodiscard]] coro_rpc::err_code start() noexcept {
auto ret = async_start();
if (ret) {
ret.value().wait();
Expand All @@ -115,10 +115,10 @@ class coro_rpc_server_base {
}
}

[[nodiscard]] coro_rpc::expected<async_simple::Future<coro_rpc::errc>,
coro_rpc::errc>
[[nodiscard]] coro_rpc::expected<async_simple::Future<coro_rpc::err_code>,
coro_rpc::err_code>
async_start() noexcept {
coro_rpc::errc ec{};
coro_rpc::err_code ec{};
{
std::unique_lock lock(start_mtx_);
if (flag_ != stat::init) {
Expand All @@ -128,7 +128,7 @@ class coro_rpc_server_base {
else if (flag_ == stat::stop) {
ELOGV(INFO, "has stoped");
}
return coro_rpc::unexpected<coro_rpc::errc>{
return coro_rpc::unexpected<coro_rpc::err_code>{
coro_rpc::errc::server_has_ran};
}
ec = listen();
Expand All @@ -147,11 +147,11 @@ class coro_rpc_server_base {
}
}
if (!ec) {
async_simple::Promise<coro_rpc::errc> promise;
async_simple::Promise<coro_rpc::err_code> promise;
auto future = promise.getFuture();
accept().start([p = std::move(promise)](auto &&res) mutable {
if (res.hasError()) {
p.setValue(coro_rpc::errc::io_error);
p.setValue(coro_rpc::err_code{coro_rpc::errc::io_error});
}
else {
p.setValue(res.value());
Expand All @@ -160,7 +160,7 @@ class coro_rpc_server_base {
return std::move(future);
}
else {
return coro_rpc::unexpected<coro_rpc::errc>{ec};
return coro_rpc::unexpected<coro_rpc::err_code>{ec};
}
}

Expand Down Expand Up @@ -285,7 +285,7 @@ class coro_rpc_server_base {
auto &get_io_context_pool() noexcept { return pool_; }

private:
coro_rpc::errc listen() {
coro_rpc::err_code listen() {
ELOGV(INFO, "begin to listen");
using asio::ip::tcp;
auto endpoint = tcp::endpoint(tcp::v4(), port_);
Expand Down Expand Up @@ -319,7 +319,7 @@ class coro_rpc_server_base {
return {};
}

async_simple::coro::Lazy<coro_rpc::errc> accept() {
async_simple::coro::Lazy<coro_rpc::err_code> accept() {
for (;;) {
auto executor = pool_.get_executor();
asio::ip::tcp::socket socket(executor->get_asio_executor());
Expand Down
27 changes: 25 additions & 2 deletions include/ylt/coro_rpc/impl/errno.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/
#include <ylt/struct_pack/util.h>

#include <cstdint>
#pragma once
namespace coro_rpc {
enum class errc : uint16_t {
Expand All @@ -27,10 +29,31 @@ enum class errc : uint16_t {
interrupted,
function_not_registered,
protocol_error,
unknown_protocol_version,
message_too_large,
server_has_ran,
user_defined_err_min = 256,
user_defined_err_max = 65535
};
struct err_code {
public:
errc ec;
err_code() : ec(errc::ok) {}
explicit err_code(uint16_t ec) : ec{ec} {};
err_code(errc ec) : ec(ec){};
err_code& operator=(errc ec) {
this->ec = ec;
return *this;
}
err_code& operator=(uint16_t ec) {
this->ec = errc{ec};
return *this;
}
err_code(const err_code& err_code) = default;
err_code& operator=(const err_code& o) = default;
bool operator!() const { return ec == errc::ok; }
operator errc() const { return ec; }
operator bool() const { return static_cast<uint16_t>(ec); }
explicit operator uint16_t() const { return static_cast<uint16_t>(ec); }
uint16_t val() const { return static_cast<uint16_t>(ec); }
};
inline bool operator!(errc ec) { return ec == errc::ok; }
inline std::string_view make_error_message(errc ec) {
Expand Down
Loading

0 comments on commit 3d8252f

Please sign in to comment.