Skip to content

Commit

Permalink
Implement read modes (default, stale, consistent) (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius authored Dec 27, 2024
1 parent 6bacab0 commit 35158a4
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 46 deletions.
16 changes: 12 additions & 4 deletions examples/kv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ TMessageHolder<TLogEntry> TKv::Prepare(TMessageHolder<TCommandRequest> command)
}

template<typename TPoller, typename TSocket>
NNet::TFuture<void> Client(TPoller& poller, TSocket socket) {
NNet::TFuture<void> Client(TPoller& poller, TSocket socket, uint32_t flags) {
using TFileHandle = typename TPoller::TFileHandle;
TFileHandle input{0, poller}; // stdin
co_await socket.Connect();
Expand Down Expand Up @@ -112,11 +112,14 @@ NNet::TFuture<void> Client(TPoller& poller, TSocket socket) {
auto key = strtok(nullptr, sep);
auto size = strlen(key);
auto mes = NewHoldedMessage<TReadKv>(sizeof(TReadKv) + size);
mes->Flags = flags;
mes->KeySize = size;
memcpy(mes->Data, key, size);
req = mes;
} else if (!strcmp(prefix, "list")) {
req = NewHoldedMessage<TReadKv>(sizeof(TReadKv));
auto mes = NewHoldedMessage<TReadKv>(sizeof(TReadKv));
mes->Flags = flags;
req = mes;
} else if (!strcmp(prefix, "del")) {
auto key = strtok(nullptr, sep);
auto size = strlen(key);
Expand Down Expand Up @@ -145,7 +148,7 @@ NNet::TFuture<void> Client(TPoller& poller, TSocket socket) {
}

void usage(const char* prog) {
std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...] [--persist]" << "\n";
std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...] [--persist] [--stale] [--consistent]" << "\n";
exit(0);
}

Expand All @@ -157,6 +160,7 @@ int main(int argc, char** argv) {
uint32_t id = 0;
bool server = false;
bool persist = false;
uint32_t flags = 0;
for (int i = 1; i < argc; i++) {
if (!strcmp(argv[i], "--server")) {
server = true;
Expand All @@ -167,6 +171,10 @@ int main(int argc, char** argv) {
id = atoi(argv[++i]);
} else if (!strcmp(argv[i], "--persist")) {
persist = true;
} else if (!strcmp(argv[i], "--stale")) {
flags |= TCommandRequest::EStale;
} else if (!strcmp(argv[i], "--consistent")) {
flags |= TCommandRequest::EConsistent;
} else if (!strcmp(argv[i], "--help")) {
usage(argv[0]);
}
Expand Down Expand Up @@ -212,7 +220,7 @@ int main(int argc, char** argv) {
NNet::TAddress addr{hosts[0].Address, hosts[0].Port};
NNet::TSocket socket(std::move(addr), loop.Poller());

auto h = Client(loop.Poller(), std::move(socket));
auto h = Client(loop.Poller(), std::move(socket), flags);
while (!h.done()) {
loop.Step();
}
Expand Down
15 changes: 11 additions & 4 deletions src/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ enum class EMessageType : uint32_t {
REQUEST_VOTE_RESPONSE = 3,
APPEND_ENTRIES_REQUEST = 4,
APPEND_ENTRIES_RESPONSE = 5,
COMMAND_REQUEST = 6,
COMMAND_RESPONSE = 7,
INSTALL_SNAPSHOT_REQUEST = 6, // TODO: not implemented it
INSTALL_SNAPSHOT_RESPONSE = 7, // TODO: not implemented it
COMMAND_REQUEST = 8,
COMMAND_RESPONSE = 9,
};

struct TMessage {
Expand All @@ -43,9 +45,10 @@ struct TMessageEx: public TMessage {
uint32_t Src = 0;
uint32_t Dst = 0;
uint64_t Term = 0;
uint64_t Seqno = 0;
};

static_assert(sizeof(TMessageEx) == sizeof(TMessage)+16);
static_assert(sizeof(TMessageEx) == sizeof(TMessage)+24);

struct TRequestVoteRequest: public TMessageEx {
static constexpr EMessageType MessageType = EMessageType::REQUEST_VOTE_REQUEST;
Expand Down Expand Up @@ -89,7 +92,11 @@ struct TCommandRequest: public TMessage {
static constexpr EMessageType MessageType = EMessageType::COMMAND_REQUEST;
enum EFlags {
ENone = 0,
EWrite = 1,
EWrite = 1, //

// read semantics, default: read from leader w/o ping check, possible stale reads if there are 2 leaders
EStale = 2, // stale read, can read from follower
EConsistent = 4, // strong consistent read (wait for pings, low latency, no stale read)
};
uint32_t Flags = ENone;
uint32_t Cookie = 0;
Expand Down
123 changes: 88 additions & 35 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ TVolatileState& TVolatileState::Vote(uint32_t nodeId)
return *this;
}

TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state)
TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state, uint64_t seqno)
{
auto lastIndex = state.LastLogIndex;
Indices.clear(); Indices.reserve(nservers);
Expand All @@ -82,9 +82,9 @@ TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state)
std::sort(Indices.begin(), Indices.end());
auto commitIndex = std::max(CommitIndex, Indices[nservers / 2]);
if (state.LogTerm(commitIndex) == state.CurrentTerm) {
CommitSeqno = std::max(CommitSeqno, seqno);
CommitIndex = commitIndex;
}
// TODO: If state.LogTerm(commitIndex) < state.CurrentTerm need to append empty message to log
return *this;
}

Expand Down Expand Up @@ -196,6 +196,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntries
.Src = Id,
.Dst = message->Src,
.Term = State->CurrentTerm,
.Seqno = message->Seqno,
},
TAppendEntriesResponse {
.MatchIndex = 0,
Expand Down Expand Up @@ -234,7 +235,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntries
}

auto reply = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm},
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm, .Seqno = message->Seqno},
TAppendEntriesResponse {.MatchIndex = matchIndex, .Success = success});

(*VolatileState)
Expand All @@ -260,7 +261,7 @@ void TRaft::OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message) {
.SetRpcDue(nodeId, ITimeSource::Time{})
.SetBatchSize(nodeId, 1024)
.SetBackOff(nodeId, 1)
.CommitAdvance(Nservers, *State);
.CommitAdvance(Nservers, *State, message->Seqno);
} else {
auto backOff = std::max(VolatileState->BackOff[nodeId], 1);
auto nextIndex = VolatileState->NextIndex[nodeId] > backOff
Expand Down Expand Up @@ -294,7 +295,7 @@ TMessageHolder<TAppendEntriesRequest> TRaft::CreateAppendEntries(uint32_t nodeId
}

auto mes = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm},
TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm, .Seqno = Seqno++},
TAppendEntriesRequest {
.PrevLogIndex = prevIndex,
.PrevLogTerm = State->LogTerm(prevIndex),
Expand Down Expand Up @@ -338,7 +339,7 @@ void TRaft::Leader(ITimeSource::Time now, TMessageHolder<TMessage> message) {
OnRequestVote(now, std::move(maybeVoteRequest.Cast()));
} else if (auto maybeAppendEntries = message.Maybe<TAppendEntriesRequest>()) {
OnAppendEntries(now, std::move(maybeAppendEntries.Cast()));
}
}
}

void TRaft::Become(EState newStateName) {
Expand Down Expand Up @@ -408,6 +409,18 @@ void TRaft::LeaderTimeout(ITimeSource::Time now) {
}
}

uint64_t TRaft::ApproveRead() {
int seqno = Seqno;
for (auto& [id, node] : Nodes) {
node->Send(CreateAppendEntries(id));
}
return seqno;
}

uint64_t TRaft::CommitSeqno() const {
return VolatileState->CommitSeqno;
}

void TRaft::ProcessTimeout(ITimeSource::Time now) {
if (StateName == EState::CANDIDATE || StateName == EState::FOLLOWER) {
if (VolatileState->ElectionDue <= now) {
Expand Down Expand Up @@ -508,33 +521,9 @@ void TRequestProcessor::CheckStateChange() {
}
}

void TRequestProcessor::OnCommandRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo) {
auto stateName = Raft->CurrentStateName();
auto leaderId = Raft->GetLeaderId();

// read request
if (! (command->Flags & TCommandRequest::EWrite)) {
if (replyTo) {
// TODO: possible stale read, should use max(LastIndex, LeaderLastIndex)
assert(Waiting.empty() || Waiting.back().Index <= Raft->GetLastIndex());
Waiting.emplace(TWaiting{Raft->GetLastIndex(), std::move(command), replyTo});
}
return;
}

// write request
if (stateName == EState::LEADER) {
auto index = Raft->Append(std::move(Rsm->Prepare(command)));
if (replyTo) {
assert(Waiting.empty() || Waiting.back().Index <= index);
Waiting.emplace(TWaiting{index, std::move(command), replyTo});
}
return;
}

// forwarding write request
void TRequestProcessor::Forward(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo)
{
if (!replyTo) {
// nothing to forward
return;
}

Expand All @@ -543,9 +532,11 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder<TCommandRequest> command
replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1}));
return;
}


auto stateName = Raft->CurrentStateName();
auto leaderId = Raft->GetLeaderId();
if (stateName == EState::CANDIDATE || leaderId == 0) {
WaitingStateChange.emplace(TWaiting{0, std::move(command), replyTo});
WaitingStateChange.emplace(TWaiting{0, 0, std::move(command), replyTo});
return;
}

Expand All @@ -563,6 +554,56 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder<TCommandRequest> command
assert(false && "Wrong state");
}

void TRequestProcessor::OnReadRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo)
{
auto stateName = Raft->CurrentStateName();
auto flags = command->Flags;
assert(!(flags & TCommandRequest::EWrite));

// stale read, default read (from leader)
if ((flags & TCommandRequest::EStale) || (!(flags & TCommandRequest::EConsistent) && stateName == EState::LEADER)) {
assert(Waiting.empty() || Waiting.back().Index <= Raft->GetLastIndex());
Waiting.emplace(TWaiting{Raft->GetLastIndex(), 0, std::move(command), replyTo});
return;
}

if (stateName != EState::LEADER) {
return Forward(std::move(command), replyTo);
}

// Consistent read
auto seqno = Raft->ApproveRead();
assert(StrongWaiting.empty() || (StrongWaiting.back().Index <= Raft->GetLastIndex() && StrongWaiting.back().Seqno <= seqno));
StrongWaiting.emplace(TWaiting{Raft->GetLastIndex(), seqno, std::move(command), replyTo});
}

void TRequestProcessor::OnWriteRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo) {
auto stateName = Raft->CurrentStateName();
auto flags = command->Flags;
assert((flags & TCommandRequest::EWrite));

if (stateName == EState::LEADER) {
uint64_t index = Raft->Append(std::move(Rsm->Prepare(command)));
if (replyTo) {
assert(Waiting.empty() || Waiting.back().Index <= index);
// TODO: cleanup these queues on state change
Waiting.emplace(TWaiting{index, 0, std::move(command), replyTo});
}
} else {
Forward(std::move(command), replyTo);
}
}

void TRequestProcessor::OnCommandRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo) {
auto flags = command->Flags;

if (!(flags & TCommandRequest::EWrite)) {
return OnReadRequest(std::move(command), replyTo);
} else {
return OnWriteRequest(std::move(command), replyTo);
}
}

void TRequestProcessor::OnCommandResponse(TMessageHolder<TCommandResponse> command) {
// forwarded
auto it = Cookie2Client.find(command->Cookie);
Expand Down Expand Up @@ -611,6 +652,7 @@ void TRequestProcessor::ProcessWaiting() {
while (!Waiting.empty() && Waiting.back().Index <= lastApplied) {
auto w = Waiting.back(); Waiting.pop();
TMessageHolder<TCommandResponse> reply;
auto cookie = w.Command->Cookie;;
if (w.Command->Flags & TCommandRequest::EWrite) {
while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) {
WriteAnswers.pop();
Expand All @@ -622,7 +664,18 @@ void TRequestProcessor::ProcessWaiting() {
} else {
reply = Rsm->Read(std::move(w.Command), w.Index).Cast<TCommandResponse>();
}
reply->Cookie = w.Command->Cookie;
reply->Cookie = cookie;
w.ReplyTo->Send(std::move(reply));
}

auto seqno = Raft->CommitSeqno();
while (!StrongWaiting.empty() && StrongWaiting.back().Index <= lastApplied && StrongWaiting.back().Seqno <= seqno) {
auto w = StrongWaiting.back(); StrongWaiting.pop();
TMessageHolder<TCommandResponse> reply;
assert (!(w.Command->Flags & TCommandRequest::EWrite));
auto cookie = w.Command->Cookie;
reply = Rsm->Read(std::move(w.Command), w.Index).Cast<TCommandResponse>();
reply->Cookie = cookie;
w.ReplyTo->Send(std::move(reply));
}
}
Expand Down
12 changes: 11 additions & 1 deletion src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ using TNodeDict = std::unordered_map<uint32_t, std::shared_ptr<INode>>;

struct TVolatileState {
uint64_t CommitIndex = 0;
uint64_t CommitSeqno = 0;
uint32_t LeaderId = 0;
std::unordered_map<uint32_t, uint64_t> NextIndex;
std::unordered_map<uint32_t, uint64_t> MatchIndex;
Expand All @@ -54,7 +55,7 @@ struct TVolatileState {
std::vector<uint64_t> Indices;

TVolatileState& Vote(uint32_t id);
TVolatileState& CommitAdvance(int nservers, const IState& state);
TVolatileState& CommitAdvance(int nservers, const IState& state, uint64_t seqno = 0);
TVolatileState& SetCommitIndex(int index);
TVolatileState& SetElectionDue(ITimeSource::Time);
TVolatileState& SetNextIndex(uint32_t id, uint64_t nextIndex);
Expand Down Expand Up @@ -89,6 +90,8 @@ class TRaft {
uint64_t Append(TMessageHolder<TLogEntry> entry);
uint32_t GetLeaderId() const;
uint64_t GetLastIndex() const;
uint64_t ApproveRead();
uint64_t CommitSeqno() const;

// ut
const auto& GetState() const {
Expand Down Expand Up @@ -146,6 +149,7 @@ class TRaft {
int Nservers;
std::shared_ptr<IState> State;
std::unique_ptr<TVolatileState> VolatileState;
uint64_t Seqno = 0; // for matching responses

EState StateName;
uint32_t Seed = 31337;
Expand All @@ -167,16 +171,22 @@ class TRequestProcessor {
void CleanUp(const std::shared_ptr<INode>& replyTo);

private:
void Forward(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);
void OnReadRequest(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);
void OnWriteRequest(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);

std::shared_ptr<TRaft> Raft;
std::shared_ptr<IRsm> Rsm;
TNodeDict Nodes;

struct TWaiting {
uint64_t Index;
uint64_t Seqno = 0;
TMessageHolder<TCommandRequest> Command;
std::shared_ptr<INode> ReplyTo;
};
std::queue<TWaiting> Waiting;
std::queue<TWaiting> StrongWaiting;
std::queue<TWaiting> WaitingStateChange;

struct TAnswer {
Expand Down
6 changes: 4 additions & 2 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ NNet::TVoidTask TRaftServer<TSocket>::InboundConnection(TSocket socket) {
Nodes.insert(client);
while (true) {
auto mes = co_await TMessageReader(client->Sock()).Read();
// client request
// client request
if (auto maybeCommandRequest = mes.template Maybe<TCommandRequest>()) {
RequestProcessor->OnCommandRequest(std::move(maybeCommandRequest.Cast()), client);
} else if (auto maybeCommandResponse = mes.template Maybe<TCommandResponse>()) {
Expand Down Expand Up @@ -170,10 +170,12 @@ NNet::TVoidTask TRaftServer<TSocket>::OutboundServe(std::shared_ptr<TNode<TSocke
while (true) {
bool error = false;
try {
if (!node->IsConnected()) {
throw std::runtime_error("Not connected");
}
auto mes = co_await TMessageReader(node->Sock()).Read();
if (auto maybeCommandResponse = mes.template Maybe<TCommandResponse>()) {
RequestProcessor->OnCommandResponse(std::move(maybeCommandResponse.Cast()));
RequestProcessor->ProcessWaiting();
DrainNodes();
} else {
std::cerr << "Wrong message type: " << mes->Type << std::endl;
Expand Down

0 comments on commit 35158a4

Please sign in to comment.