-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmpsc_queue.cpp
129 lines (117 loc) · 4.25 KB
/
mpsc_queue.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <atomic>
#include <stdint.h>
using namespace std;
using enum std::memory_order;
#define CAPACITY 16384 // Must be a power of 2.
struct Queue {
alignas(64) atomic<uint32_t> write_ticket = 0;
alignas(64) uint32_t read_ticket = 0;
struct {
alignas(64)
atomic<uint8_t> turn = 0;
atomic<bool> full = false;
int item = 0;
} slots[CAPACITY];
};
// Blocking API
void enqueue(Queue *queue, int item) {
uint32_t ticket = queue->write_ticket.fetch_add(1, relaxed); // Serialization with writers.
uint32_t slot = ticket % CAPACITY;
uint8_t turn = (uint8_t)(ticket / CAPACITY);
uint8_t current_turn;
while ((current_turn = queue->slots[slot].turn.load(acquire)) != turn) // Serialization with reader.
queue->slots[slot].turn.wait(current_turn, acquire); // Block while queue is full.
queue->slots[slot].item = item;
queue->slots[slot].full.store(true, release); // Serialization with reader.
queue->slots[slot].full.notify_one();
}
int dequeue(Queue *queue) {
uint32_t ticket = queue->read_ticket++;
uint32_t slot = ticket % CAPACITY;
uint8_t turn = (uint8_t)(ticket / CAPACITY);
queue->slots[slot].full.wait(false, acquire); // Block while queue is empty.
int item = queue->slots[slot].item;
queue->slots[slot].full.store(false, relaxed);
queue->slots[slot].turn.store(turn + 1, release); // Serialization with 1 writer.
queue->slots[slot].turn.notify_all();
return item;
}
// Polling API
bool try_enqueue(Queue *queue, int item) {
uint32_t try_ticket = queue->write_ticket.load(relaxed); // Serialization with writers.
for (;;) {
uint32_t slot = try_ticket % CAPACITY;
uint8_t turn = (uint8_t)(try_ticket / CAPACITY);
uint8_t current_turn = queue->slots[slot].turn.load(acquire); // Serialization with reader.
int turns_remaining = (int)turn - (int)current_turn;
if (turns_remaining > 0)
return false; // Queue is full.
else if (turns_remaining < 0)
try_ticket = queue->write_ticket; // Another writer lapped us, try again.
else if (queue->write_ticket.compare_exchange_weak(try_ticket, try_ticket + 1, relaxed)) {
queue->slots[slot].item = item;
queue->slots[slot].full.store(true, release); // Serialization with reader.
queue->slots[slot].full.notify_one(); // Hash table lookup. Remove this if you only use Polling and not Blocking.
return true;
}
}
}
bool try_dequeue(Queue *queue, int *out_item) {
uint32_t ticket = queue->read_ticket;
uint32_t slot = ticket % CAPACITY;
if (!queue->slots[slot].full.load(acquire)) // Serialization with 1 writer.
return false; // Queue is empty.
uint8_t turn = (uint8_t)(ticket / CAPACITY);
(*out_item) = queue->slots[slot].item;
queue->slots[slot].full.store(false, relaxed);
queue->slots[slot].turn.store(turn + 1, release); // Serialization with 1 writer.
queue->slots[slot].turn.notify_all(); // Hash table crawl. Remove this if you only use Polling and not Blocking.
++(queue->read_ticket);
return true;
}
// Test
#include <thread>
#include <assert.h>
void reader_thread(Queue *queue) {
static int counters[5][1000000];
int last_writer_data[5] = { -1, -1, -1, -1, -1 };
for (int i = 0; i < 5000000; ++i) {
int item;
if (i < 2500000)
item = dequeue(queue);
else
while (!try_dequeue(queue, &item));
int writer = item / 1000000;
int data = item % 1000000;
assert(writer < 5); // Ensure no data corruption corruption.
++(counters[writer][data]);
assert(last_writer_data[writer] < data); // Ensure data is correctly sequenced FIFO.
last_writer_data[writer] = data;
}
for (int writer = 0; writer < 5; ++writer)
for (int i = 0; i < 1000000; ++i)
assert(counters[writer][i] == 1); // Ensure all items have been properly received.
}
void writer_thread(Queue *queue) {
static atomic<int> id_dispenser;
int id = id_dispenser.fetch_add(1);
for (int i = 0; i < 500000; ++i)
enqueue(queue, id * 1000000 + i);
for (int i = 500000; i < 1000000; ++i)
while (!try_enqueue(queue, id * 1000000 + i));
}
int main(void) {
static struct Queue queue;
thread reader(reader_thread, &queue);
thread writer0(writer_thread, &queue);
thread writer1(writer_thread, &queue);
thread writer2(writer_thread, &queue);
thread writer3(writer_thread, &queue);
thread writer4(writer_thread, &queue);
reader.join();
writer0.join();
writer1.join();
writer2.join();
writer3.join();
writer4.join();
}