From 35158a4f4674c7a68abaa49a7113bc850fcdf058 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Fri, 27 Dec 2024 15:25:56 +0300 Subject: [PATCH] Implement read modes (default, stale, consistent) (#27) --- examples/kv.cpp | 16 +++++-- src/messages.h | 15 ++++-- src/raft.cpp | 123 ++++++++++++++++++++++++++++++++++-------------- src/raft.h | 12 ++++- src/server.cpp | 6 ++- 5 files changed, 126 insertions(+), 46 deletions(-) diff --git a/examples/kv.cpp b/examples/kv.cpp index 581ff7f..a14e8a6 100644 --- a/examples/kv.cpp +++ b/examples/kv.cpp @@ -74,7 +74,7 @@ TMessageHolder TKv::Prepare(TMessageHolder command) } template -NNet::TFuture Client(TPoller& poller, TSocket socket) { +NNet::TFuture Client(TPoller& poller, TSocket socket, uint32_t flags) { using TFileHandle = typename TPoller::TFileHandle; TFileHandle input{0, poller}; // stdin co_await socket.Connect(); @@ -112,11 +112,14 @@ NNet::TFuture Client(TPoller& poller, TSocket socket) { auto key = strtok(nullptr, sep); auto size = strlen(key); auto mes = NewHoldedMessage(sizeof(TReadKv) + size); + mes->Flags = flags; mes->KeySize = size; memcpy(mes->Data, key, size); req = mes; } else if (!strcmp(prefix, "list")) { - req = NewHoldedMessage(sizeof(TReadKv)); + auto mes = NewHoldedMessage(sizeof(TReadKv)); + mes->Flags = flags; + req = mes; } else if (!strcmp(prefix, "del")) { auto key = strtok(nullptr, sep); auto size = strlen(key); @@ -145,7 +148,7 @@ NNet::TFuture Client(TPoller& poller, TSocket socket) { } void usage(const char* prog) { - std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...] [--persist]" << "\n"; + std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...] [--persist] [--stale] [--consistent]" << "\n"; exit(0); } @@ -157,6 +160,7 @@ int main(int argc, char** argv) { uint32_t id = 0; bool server = false; bool persist = false; + uint32_t flags = 0; for (int i = 1; i < argc; i++) { if (!strcmp(argv[i], "--server")) { server = true; @@ -167,6 +171,10 @@ int main(int argc, char** argv) { id = atoi(argv[++i]); } else if (!strcmp(argv[i], "--persist")) { persist = true; + } else if (!strcmp(argv[i], "--stale")) { + flags |= TCommandRequest::EStale; + } else if (!strcmp(argv[i], "--consistent")) { + flags |= TCommandRequest::EConsistent; } else if (!strcmp(argv[i], "--help")) { usage(argv[0]); } @@ -212,7 +220,7 @@ int main(int argc, char** argv) { NNet::TAddress addr{hosts[0].Address, hosts[0].Port}; NNet::TSocket socket(std::move(addr), loop.Poller()); - auto h = Client(loop.Poller(), std::move(socket)); + auto h = Client(loop.Poller(), std::move(socket), flags); while (!h.done()) { loop.Step(); } diff --git a/src/messages.h b/src/messages.h index 85e982f..830030b 100644 --- a/src/messages.h +++ b/src/messages.h @@ -15,8 +15,10 @@ enum class EMessageType : uint32_t { REQUEST_VOTE_RESPONSE = 3, APPEND_ENTRIES_REQUEST = 4, APPEND_ENTRIES_RESPONSE = 5, - COMMAND_REQUEST = 6, - COMMAND_RESPONSE = 7, + INSTALL_SNAPSHOT_REQUEST = 6, // TODO: not implemented it + INSTALL_SNAPSHOT_RESPONSE = 7, // TODO: not implemented it + COMMAND_REQUEST = 8, + COMMAND_RESPONSE = 9, }; struct TMessage { @@ -43,9 +45,10 @@ struct TMessageEx: public TMessage { uint32_t Src = 0; uint32_t Dst = 0; uint64_t Term = 0; + uint64_t Seqno = 0; }; -static_assert(sizeof(TMessageEx) == sizeof(TMessage)+16); +static_assert(sizeof(TMessageEx) == sizeof(TMessage)+24); struct TRequestVoteRequest: public TMessageEx { static constexpr EMessageType MessageType = EMessageType::REQUEST_VOTE_REQUEST; @@ -89,7 +92,11 @@ struct TCommandRequest: public TMessage { static constexpr EMessageType MessageType = EMessageType::COMMAND_REQUEST; enum EFlags { ENone = 0, - EWrite = 1, + EWrite = 1, // + + // read semantics, default: read from leader w/o ping check, possible stale reads if there are 2 leaders + EStale = 2, // stale read, can read from follower + EConsistent = 4, // strong consistent read (wait for pings, low latency, no stale read) }; uint32_t Flags = ENone; uint32_t Cookie = 0; diff --git a/src/raft.cpp b/src/raft.cpp index 53a18f8..5d3afb6 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -68,7 +68,7 @@ TVolatileState& TVolatileState::Vote(uint32_t nodeId) return *this; } -TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state) +TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state, uint64_t seqno) { auto lastIndex = state.LastLogIndex; Indices.clear(); Indices.reserve(nservers); @@ -82,9 +82,9 @@ TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state) std::sort(Indices.begin(), Indices.end()); auto commitIndex = std::max(CommitIndex, Indices[nservers / 2]); if (state.LogTerm(commitIndex) == state.CurrentTerm) { + CommitSeqno = std::max(CommitSeqno, seqno); CommitIndex = commitIndex; } - // TODO: If state.LogTerm(commitIndex) < state.CurrentTerm need to append empty message to log return *this; } @@ -196,6 +196,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolderSrc, .Term = State->CurrentTerm, + .Seqno = message->Seqno, }, TAppendEntriesResponse { .MatchIndex = 0, @@ -234,7 +235,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolderSrc, .Term = State->CurrentTerm}, + TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm, .Seqno = message->Seqno}, TAppendEntriesResponse {.MatchIndex = matchIndex, .Success = success}); (*VolatileState) @@ -260,7 +261,7 @@ void TRaft::OnAppendEntries(TMessageHolder message) { .SetRpcDue(nodeId, ITimeSource::Time{}) .SetBatchSize(nodeId, 1024) .SetBackOff(nodeId, 1) - .CommitAdvance(Nservers, *State); + .CommitAdvance(Nservers, *State, message->Seqno); } else { auto backOff = std::max(VolatileState->BackOff[nodeId], 1); auto nextIndex = VolatileState->NextIndex[nodeId] > backOff @@ -294,7 +295,7 @@ TMessageHolder TRaft::CreateAppendEntries(uint32_t nodeId } auto mes = NewHoldedMessage( - TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm}, + TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm, .Seqno = Seqno++}, TAppendEntriesRequest { .PrevLogIndex = prevIndex, .PrevLogTerm = State->LogTerm(prevIndex), @@ -338,7 +339,7 @@ void TRaft::Leader(ITimeSource::Time now, TMessageHolder message) { OnRequestVote(now, std::move(maybeVoteRequest.Cast())); } else if (auto maybeAppendEntries = message.Maybe()) { OnAppendEntries(now, std::move(maybeAppendEntries.Cast())); - } + } } void TRaft::Become(EState newStateName) { @@ -408,6 +409,18 @@ void TRaft::LeaderTimeout(ITimeSource::Time now) { } } +uint64_t TRaft::ApproveRead() { + int seqno = Seqno; + for (auto& [id, node] : Nodes) { + node->Send(CreateAppendEntries(id)); + } + return seqno; +} + +uint64_t TRaft::CommitSeqno() const { + return VolatileState->CommitSeqno; +} + void TRaft::ProcessTimeout(ITimeSource::Time now) { if (StateName == EState::CANDIDATE || StateName == EState::FOLLOWER) { if (VolatileState->ElectionDue <= now) { @@ -508,33 +521,9 @@ void TRequestProcessor::CheckStateChange() { } } -void TRequestProcessor::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { - auto stateName = Raft->CurrentStateName(); - auto leaderId = Raft->GetLeaderId(); - - // read request - if (! (command->Flags & TCommandRequest::EWrite)) { - if (replyTo) { - // TODO: possible stale read, should use max(LastIndex, LeaderLastIndex) - assert(Waiting.empty() || Waiting.back().Index <= Raft->GetLastIndex()); - Waiting.emplace(TWaiting{Raft->GetLastIndex(), std::move(command), replyTo}); - } - return; - } - - // write request - if (stateName == EState::LEADER) { - auto index = Raft->Append(std::move(Rsm->Prepare(command))); - if (replyTo) { - assert(Waiting.empty() || Waiting.back().Index <= index); - Waiting.emplace(TWaiting{index, std::move(command), replyTo}); - } - return; - } - - // forwarding write request +void TRequestProcessor::Forward(TMessageHolder command, const std::shared_ptr& replyTo) +{ if (!replyTo) { - // nothing to forward return; } @@ -543,9 +532,11 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); return; } - + + auto stateName = Raft->CurrentStateName(); + auto leaderId = Raft->GetLeaderId(); if (stateName == EState::CANDIDATE || leaderId == 0) { - WaitingStateChange.emplace(TWaiting{0, std::move(command), replyTo}); + WaitingStateChange.emplace(TWaiting{0, 0, std::move(command), replyTo}); return; } @@ -563,6 +554,56 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command assert(false && "Wrong state"); } +void TRequestProcessor::OnReadRequest(TMessageHolder command, const std::shared_ptr& replyTo) +{ + auto stateName = Raft->CurrentStateName(); + auto flags = command->Flags; + assert(!(flags & TCommandRequest::EWrite)); + + // stale read, default read (from leader) + if ((flags & TCommandRequest::EStale) || (!(flags & TCommandRequest::EConsistent) && stateName == EState::LEADER)) { + assert(Waiting.empty() || Waiting.back().Index <= Raft->GetLastIndex()); + Waiting.emplace(TWaiting{Raft->GetLastIndex(), 0, std::move(command), replyTo}); + return; + } + + if (stateName != EState::LEADER) { + return Forward(std::move(command), replyTo); + } + + // Consistent read + auto seqno = Raft->ApproveRead(); + assert(StrongWaiting.empty() || (StrongWaiting.back().Index <= Raft->GetLastIndex() && StrongWaiting.back().Seqno <= seqno)); + StrongWaiting.emplace(TWaiting{Raft->GetLastIndex(), seqno, std::move(command), replyTo}); +} + +void TRequestProcessor::OnWriteRequest(TMessageHolder command, const std::shared_ptr& replyTo) { + auto stateName = Raft->CurrentStateName(); + auto flags = command->Flags; + assert((flags & TCommandRequest::EWrite)); + + if (stateName == EState::LEADER) { + uint64_t index = Raft->Append(std::move(Rsm->Prepare(command))); + if (replyTo) { + assert(Waiting.empty() || Waiting.back().Index <= index); + // TODO: cleanup these queues on state change + Waiting.emplace(TWaiting{index, 0, std::move(command), replyTo}); + } + } else { + Forward(std::move(command), replyTo); + } +} + +void TRequestProcessor::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { + auto flags = command->Flags; + + if (!(flags & TCommandRequest::EWrite)) { + return OnReadRequest(std::move(command), replyTo); + } else { + return OnWriteRequest(std::move(command), replyTo); + } +} + void TRequestProcessor::OnCommandResponse(TMessageHolder command) { // forwarded auto it = Cookie2Client.find(command->Cookie); @@ -611,6 +652,7 @@ void TRequestProcessor::ProcessWaiting() { while (!Waiting.empty() && Waiting.back().Index <= lastApplied) { auto w = Waiting.back(); Waiting.pop(); TMessageHolder reply; + auto cookie = w.Command->Cookie;; if (w.Command->Flags & TCommandRequest::EWrite) { while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) { WriteAnswers.pop(); @@ -622,7 +664,18 @@ void TRequestProcessor::ProcessWaiting() { } else { reply = Rsm->Read(std::move(w.Command), w.Index).Cast(); } - reply->Cookie = w.Command->Cookie; + reply->Cookie = cookie; + w.ReplyTo->Send(std::move(reply)); + } + + auto seqno = Raft->CommitSeqno(); + while (!StrongWaiting.empty() && StrongWaiting.back().Index <= lastApplied && StrongWaiting.back().Seqno <= seqno) { + auto w = StrongWaiting.back(); StrongWaiting.pop(); + TMessageHolder reply; + assert (!(w.Command->Flags & TCommandRequest::EWrite)); + auto cookie = w.Command->Cookie; + reply = Rsm->Read(std::move(w.Command), w.Index).Cast(); + reply->Cookie = cookie; w.ReplyTo->Send(std::move(reply)); } } diff --git a/src/raft.h b/src/raft.h index 17e6ae9..c6bb923 100644 --- a/src/raft.h +++ b/src/raft.h @@ -41,6 +41,7 @@ using TNodeDict = std::unordered_map>; struct TVolatileState { uint64_t CommitIndex = 0; + uint64_t CommitSeqno = 0; uint32_t LeaderId = 0; std::unordered_map NextIndex; std::unordered_map MatchIndex; @@ -54,7 +55,7 @@ struct TVolatileState { std::vector Indices; TVolatileState& Vote(uint32_t id); - TVolatileState& CommitAdvance(int nservers, const IState& state); + TVolatileState& CommitAdvance(int nservers, const IState& state, uint64_t seqno = 0); TVolatileState& SetCommitIndex(int index); TVolatileState& SetElectionDue(ITimeSource::Time); TVolatileState& SetNextIndex(uint32_t id, uint64_t nextIndex); @@ -89,6 +90,8 @@ class TRaft { uint64_t Append(TMessageHolder entry); uint32_t GetLeaderId() const; uint64_t GetLastIndex() const; + uint64_t ApproveRead(); + uint64_t CommitSeqno() const; // ut const auto& GetState() const { @@ -146,6 +149,7 @@ class TRaft { int Nservers; std::shared_ptr State; std::unique_ptr VolatileState; + uint64_t Seqno = 0; // for matching responses EState StateName; uint32_t Seed = 31337; @@ -167,16 +171,22 @@ class TRequestProcessor { void CleanUp(const std::shared_ptr& replyTo); private: + void Forward(TMessageHolder message, const std::shared_ptr& replyTo); + void OnReadRequest(TMessageHolder message, const std::shared_ptr& replyTo); + void OnWriteRequest(TMessageHolder message, const std::shared_ptr& replyTo); + std::shared_ptr Raft; std::shared_ptr Rsm; TNodeDict Nodes; struct TWaiting { uint64_t Index; + uint64_t Seqno = 0; TMessageHolder Command; std::shared_ptr ReplyTo; }; std::queue Waiting; + std::queue StrongWaiting; std::queue WaitingStateChange; struct TAnswer { diff --git a/src/server.cpp b/src/server.cpp index c79307b..766b9de 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -122,7 +122,7 @@ NNet::TVoidTask TRaftServer::InboundConnection(TSocket socket) { Nodes.insert(client); while (true) { auto mes = co_await TMessageReader(client->Sock()).Read(); - // client request + // client request if (auto maybeCommandRequest = mes.template Maybe()) { RequestProcessor->OnCommandRequest(std::move(maybeCommandRequest.Cast()), client); } else if (auto maybeCommandResponse = mes.template Maybe()) { @@ -170,10 +170,12 @@ NNet::TVoidTask TRaftServer::OutboundServe(std::shared_ptrIsConnected()) { + throw std::runtime_error("Not connected"); + } auto mes = co_await TMessageReader(node->Sock()).Read(); if (auto maybeCommandResponse = mes.template Maybe()) { RequestProcessor->OnCommandResponse(std::move(maybeCommandResponse.Cast())); - RequestProcessor->ProcessWaiting(); DrainNodes(); } else { std::cerr << "Wrong message type: " << mes->Type << std::endl;