Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add client max connection #136

Open
wants to merge 17 commits into
base: unstable
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions src/kiwi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ bool KiwiDB::Init() {
auto num = g_config.worker_threads_num + g_config.slave_threads_num;
options_.SetThreadNum(num);

options_.SetMaxClients(g_config.max_clients);

// now we only use fast cmd thread pool
auto status = cmd_threads_.Init(g_config.fast_cmd_threads_num, 1, "kiwi-cmd");
if (!status.ok()) {
Expand Down
1 change: 1 addition & 0 deletions src/net/event_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class EventServer final {
std::condition_variable cv_;

std::shared_ptr<Timer> timer_;

};

template <typename T>
Expand Down
6 changes: 6 additions & 0 deletions src/net/net_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ class NetOptions {

bool GetRwSeparation() const { return rwSeparation_; }

void SetMaxClients(uint32_t maxClients) { maxClients_ = maxClients; }

uint32_t GetMaxClients() const { return maxClients_; }

private:
bool rwSeparation_ = true; // Whether to separate read and write

int8_t threadNum_ = 1; // The number of threads

uint32_t maxClients_ = 1; // The maximum number of connections(default 40000)
};

} // namespace net
35 changes: 29 additions & 6 deletions src/net/thread_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,25 @@ class ThreadManager {

uint64_t DoTCPConnect(T &t, int fd, const std::shared_ptr<Connection> &conn);

uint32_t getClientCount() const { return clientCount_.load(); }

void clientCountIncrement() { clientCount_.fetch_add(1, std::memory_order_relaxed); }

void clientCountDecrement() { clientCount_.fetch_sub(1, std::memory_order_relaxed); }

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Maintain consistent naming convention and review memory ordering.

  1. Method names should follow the class's snake_case convention:

    • getClientCountget_client_count
    • clientCountIncrementclient_count_increment
    • clientCountDecrementclient_count_decrement
  2. Consider using std::memory_order_seq_cst instead of memory_order_relaxed for the atomic operations since the client count is used for critical decision-making in OnNetEventCreate.

-  uint32_t getClientCount() const { return clientCount_.load(); }
+  uint32_t get_client_count() const { return client_count_.load(); }

-  void clientCountIncrement() { clientCount_.fetch_add(1, std::memory_order_relaxed); }
+  void client_count_increment() { client_count_.fetch_add(1, std::memory_order_seq_cst); }

-  void clientCountDecrement() { clientCount_.fetch_sub(1, std::memory_order_relaxed); }
+  void client_count_decrement() { client_count_.fetch_sub(1, std::memory_order_seq_cst); }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
uint32_t getClientCount() const { return clientCount_.load(); }
void clientCountIncrement() { clientCount_.fetch_add(1, std::memory_order_relaxed); }
void clientCountDecrement() { clientCount_.fetch_sub(1, std::memory_order_relaxed); }
uint32_t get_client_count() const { return client_count_.load(); }
void client_count_increment() { client_count_.fetch_add(1, std::memory_order_seq_cst); }
void client_count_decrement() { client_count_.fetch_sub(1, std::memory_order_seq_cst); }
🧰 Tools
🪛 GitHub Actions: kiwi

[error] File has clang-format style issues

private:
const int8_t index_ = 0; // The index of the thread
std::atomic<bool> running_ = true; // Whether the thread is running

NetOptions netOptions_;

inline static std::atomic<uint32_t> clientCount_{0};

std::unique_ptr<IOThread> readThread_; // Read thread
std::unique_ptr<IOThread> writeThread_; // Write thread

// All connections for the current thread
std::unordered_map<uint64_t, std::pair<T, std::shared_ptr<Connection>>> connections_;
std::unordered_map<uint64_t, std::pair<T, std::shared_ptr<Connection>>>
connections_; // All connections for the current thread

std::shared_mutex mutex_;

Expand All @@ -116,7 +124,10 @@ class ThreadManager {
};

template <typename T>
requires HasSetFdFunction<T> ThreadManager<T>::~ThreadManager() { Stop(); }
requires HasSetFdFunction<T>
ThreadManager<T>::~ThreadManager() {
Stop();
}

template <typename T>
requires HasSetFdFunction<T>
Expand Down Expand Up @@ -145,6 +156,14 @@ void ThreadManager<T>::Stop() {
template <typename T>
requires HasSetFdFunction<T>
void ThreadManager<T>::OnNetEventCreate(int fd, const std::shared_ptr<Connection> &conn) {
if (getClientCount() >= netOptions_.GetMaxClients()) {
INFO("Max client connentions, refuse new connection fd:{%d}", fd);
std::string response = "-ERR max clients reached\r\n";
::send(fd, response.c_str(), response.size(), 0);
::close(fd);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加日志

return;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回 redis 格式的一个 response

}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and logging consistency.

  1. The log message format is inconsistent. Remove the curly braces around %d.
  2. Add error logging for connection closure.
  3. Handle potential errors from send/close system calls.
-    INFO("Max client connentions, refuse new connection fd:{%d}", fd);
+    INFO("Max client connections, refuse new connection fd: %d", fd);
     std::string response = "-ERR max clients reached\r\n";
-    ::send(fd, response.c_str(), response.size(), 0);
-    ::close(fd);
+    ssize_t sent = ::send(fd, response.c_str(), response.size(), 0);
+    if (sent < 0) {
+        ERROR("Failed to send error response to fd: %d, errno: %d", fd, errno);
+    }
+    if (::close(fd) < 0) {
+        ERROR("Failed to close fd: %d, errno: %d", fd, errno);
+    }

Also, fix the typo in "connentions" → "connections".

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (getClientCount() >= netOptions_.GetMaxClients()) {
INFO("Max client connentions, refuse new connection fd:{%d}", fd);
std::string response = "-ERR max clients reached\r\n";
::send(fd, response.c_str(), response.size(), 0);
::close(fd);
return;
}
if (getClientCount() >= netOptions_.GetMaxClients()) {
INFO("Max client connections, refuse new connection fd: %d", fd);
std::string response = "-ERR max clients reached\r\n";
ssize_t sent = ::send(fd, response.c_str(), response.size(), 0);
if (sent < 0) {
ERROR("Failed to send error response to fd: %d, errno: %d", fd, errno);
}
if (::close(fd) < 0) {
ERROR("Failed to close fd: %d, errno: %d", fd, errno);
}
return;
}
🧰 Tools
🪛 GitHub Actions: kiwi

[error] File has clang-format style issues


T t;
onInit_(&t);
auto connId = getConnId();
Expand All @@ -166,6 +185,7 @@ void ThreadManager<T>::OnNetEventCreate(int fd, const std::shared_ptr<Connection
}

onCreate_(connId, t, conn->addr_);
clientCountIncrement();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix potential race condition in client counting.

There's a race condition between checking the client count and incrementing it. Another thread could add a client between the check and increment, potentially exceeding the max limit.

 template <typename T>
 requires HasSetFdFunction<T>
 void ThreadManager<T>::OnNetEventCreate(int fd, const std::shared_ptr<Connection> &conn) {
-  if (getClientCount() >= netOptions_.GetMaxClients()) {
+  if (!client_count_.compare_exchange_strong(
+          expected,
+          expected + 1,
+          std::memory_order_seq_cst,
+          std::memory_order_seq_cst) ||
+      expected >= netOptions_.GetMaxClients()) {
     INFO("Max client connections, refuse new connection fd: %d", fd);
     std::string response = "-ERR max clients reached\r\n";
     ::send(fd, response.c_str(), response.size(), 0);
     ::close(fd);
     return;
   }
   
   // ... rest of the connection setup ...
-  clientCountIncrement();  // Remove this as we've already incremented atomically
 }

Also applies to: 225-225

🧰 Tools
🪛 GitHub Actions: kiwi

[error] File has clang-format style issues

}

template <typename T>
Expand Down Expand Up @@ -202,11 +222,14 @@ void ThreadManager<T>::OnNetEventClose(uint64_t connId, std::string &&err) {
iter->second.second->netEvent_->Close(); // close socket
onClose_(iter->second.first, std::move(err));
connections_.erase(iter);
clientCountDecrement();
}

template <typename T>
requires HasSetFdFunction<T>
void ThreadManager<T>::CloseConnection(uint64_t connId) { OnNetEventClose(connId, ""); }
void ThreadManager<T>::CloseConnection(uint64_t connId) {
OnNetEventClose(connId, "");
}

template <typename T>
requires HasSetFdFunction<T>
Expand Down Expand Up @@ -330,8 +353,8 @@ bool ThreadManager<T>::CreateWriteThread() {
}

template <typename T>
requires HasSetFdFunction<T> uint64_t ThreadManager<T>::DoTCPConnect(T &t, int fd,
const std::shared_ptr<Connection> &conn) {
requires HasSetFdFunction<T>
uint64_t ThreadManager<T>::DoTCPConnect(T &t, int fd, const std::shared_ptr<Connection> &conn) {
auto connId = getConnId();
if constexpr (IsPointer_v<T>) {
t->SetConnId(connId);
Expand Down
Loading