-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsharded.go
62 lines (49 loc) · 1.15 KB
/
sharded.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package batchify
import (
"sync"
"github.com/samber/go-batchify/internal"
"github.com/samber/go-batchify/pkg/hasher"
)
func newShardedBatch[I comparable, O any](
batches []Batch[I, O],
shardingFn hasher.Hasher[I],
) *shardedBatchImpl[I, O] {
return &shardedBatchImpl[I, O]{
shards: uint64(len(batches)),
batches: batches,
shardingFn: shardingFn,
}
}
var _ Batch[string, int] = (*shardedBatchImpl[string, int])(nil)
type shardedBatchImpl[I comparable, O any] struct {
_ internal.NoCopy
shards uint64
batches []Batch[I, O]
shardingFn hasher.Hasher[I]
}
func (b *shardedBatchImpl[I, O]) Do(input I) (output O, err error) {
shardIdx := b.shardingFn.ComputeHash(input, b.shards)
return b.batches[shardIdx].Do(input)
}
func (b *shardedBatchImpl[I, O]) Flush() {
var wg sync.WaitGroup
wg.Add(len(b.batches))
for _, batch := range b.batches {
go func(b Batch[I, O]) {
defer wg.Done()
b.Flush()
}(batch)
}
wg.Wait()
}
func (b *shardedBatchImpl[I, O]) Stop() {
var wg sync.WaitGroup
wg.Add(len(b.batches))
for _, batch := range b.batches {
go func(b Batch[I, O]) {
defer wg.Done()
b.Stop()
}(batch)
}
wg.Wait()
}