Skip to content

Commit

Permalink
tidy up votes counter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreben committed Oct 11, 2024
1 parent bd59183 commit 0b82a2d
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 31 deletions.
6 changes: 3 additions & 3 deletions lsh/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ func Fit(data []uint64, labels []int, hash Hash, opts ...bitknn.Option) *Model {
}

// Predict1 predicts the label for a single input using the LSH model.
func (me *Model) Predict1(k int, x uint64, votes bitknn.Votes) int {
func (me *Model) Predict1(k int, x uint64, votes bitknn.VoteCounter) int {
me.PreallocateHeap(k)
return me.Predict1Into(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.HeapDistances, me.HeapIndices)
}

// Predicts the label of a single input point. Each call allocates three new slices of length [k]+1 for the neighbor heaps.
func (me *Model) Predict1Alloc(k int, x uint64, votes bitknn.Votes) int {
func (me *Model) Predict1Alloc(k int, x uint64, votes bitknn.VoteCounter) int {
bucketDistances := make([]int, k+1)
bucketIDs := make([]uint64, k+1)
distances := make([]int, k+1)
Expand All @@ -89,7 +89,7 @@ func (me *Model) Predict1Alloc(k int, x uint64, votes bitknn.Votes) int {
}

// Predict1Into predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps.
func (me *Model) Predict1Into(k int, x uint64, votes bitknn.Votes, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int {
func (me *Model) Predict1Into(k int, x uint64, votes bitknn.VoteCounter, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int {
xp := me.Hash.Hash1(x)
k, n := Nearest(me.Data, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices)
me.Vote(k, distances, indices, votes)
Expand Down
3 changes: 1 addition & 2 deletions lsh/model_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
)

func Benchmark_Model_Predict1(b *testing.B) {
votes := make([]float64, 256)
type bench struct {
hashes []lsh.Hash
dataSize []int
Expand All @@ -35,7 +34,7 @@ func Benchmark_Model_Predict1(b *testing.B) {
model.PreallocateHeap(k)
b.ResetTimer()
for n := 0; n < b.N; n++ {
model.Predict1(k, query, bitknn.VoteSlice(votes))
model.Predict1(k, query, bitknn.DiscardVotes)
}
})
}
Expand Down
6 changes: 3 additions & 3 deletions lsh/model_wide.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ func FitWide(data [][]uint64, labels []int, hash HashWide, opts ...bitknn.Option
}

// Predict1 predicts the label for a single input using the LSH model.
func (me *WideModel) Predict1(k int, x []uint64, votes bitknn.Votes) int {
func (me *WideModel) Predict1(k int, x []uint64, votes bitknn.VoteCounter) int {
me.PreallocateHeap(k)
return me.Predict1Into(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.Narrow.HeapDistances, me.Narrow.HeapIndices)
}

// Predicts the label of a single input point. Each call allocates three new slices of length [k]+1 for the neighbor heaps.
func (me *WideModel) Predict1Alloc(k int, x []uint64, votes bitknn.Votes) int {
func (me *WideModel) Predict1Alloc(k int, x []uint64, votes bitknn.VoteCounter) int {
bucketDistances := make([]int, k+1)
bucketIDs := make([]uint64, k+1)
distances := make([]int, k+1)
Expand All @@ -82,7 +82,7 @@ func (me *WideModel) Predict1Alloc(k int, x []uint64, votes bitknn.Votes) int {
}

// Predict1Into predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps.
func (me *WideModel) Predict1Into(k int, x []uint64, votes bitknn.Votes, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int {
func (me *WideModel) Predict1Into(k int, x []uint64, votes bitknn.VoteCounter, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int {
xp := me.Hash.Hash1Wide(x)
k0, _ := NearestWide(me.WideData, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices)
me.WideModel.Narrow.Vote(k0, distances, indices, votes)
Expand Down
24 changes: 12 additions & 12 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,26 @@ func (me *Model) PreallocateHeap(k int) {
}

// Predicts the label of a single input point. Each call allocates two new slices of length K+1 for the neighbor heap.
func (me *Model) Predict1Alloc(k int, x uint64, votes Votes) {
func (me *Model) Predict1Alloc(k int, x uint64, votes VoteCounter) {
distances, indices := make([]int, k+1), make([]int, k+1)
me.Predict1Into(k, x, distances, indices, votes)
}

// Predicts the label of a single input point. Reuses two slices of length K+1 for the neighbor heap.
func (me *Model) Predict1(k int, x uint64, votes Votes) {
func (me *Model) Predict1(k int, x uint64, votes VoteCounter) {
me.HeapDistances = slice.OrAlloc(me.HeapDistances, k+1)
me.HeapIndices = slice.OrAlloc(me.HeapIndices, k+1)
me.Predict1Into(k, x, me.HeapDistances, me.HeapIndices, votes)
}

// Predicts the label of a single input point, using the given slices for the neighbor heap.
func (me *Model) Predict1Into(k int, x uint64, distances []int, indices []int, votes Votes) {
func (me *Model) Predict1Into(k int, x uint64, distances []int, indices []int, votes VoteCounter) {
k = Nearest(me.Data, k, x, distances, indices)
me.Vote(k, distances, indices, votes)
}

// Predicts the label of a single input point, using the given slices for the neighbor heap.
func (me *Model) Vote(k int, distances []int, indices []int, votes Votes) {
func (me *Model) Vote(k int, distances []int, indices []int, votes VoteCounter) {
votes.Clear()
switch me.DistanceWeighting {
case DistanceWeightingNone:
Expand Down Expand Up @@ -93,63 +93,63 @@ func (me *Model) Vote(k int, distances []int, indices []int, votes Votes) {
}
}

func (me *Model) votes1vc(k int, indices []int, votes Votes, f func(int) float64, distances []int) {
func (me *Model) votes1vc(k int, indices []int, votes VoteCounter, f func(int) float64, distances []int) {
for i := range k {
index := indices[i]
label := me.Labels[index]
votes.Add(label, f(distances[i])*me.Values[index])
}
}

func (me *Model) votes1c(k int, indices []int, votes Votes, f func(int) float64, distances []int) {
func (me *Model) votes1c(k int, indices []int, votes VoteCounter, f func(int) float64, distances []int) {
for i := range k {
index := indices[i]
label := me.Labels[index]
votes.Add(label, f(distances[i]))
}
}

func (me *Model) votes1vq(k int, indices []int, votes Votes, distances []int) {
func (me *Model) votes1vq(k int, indices []int, votes VoteCounter, distances []int) {
for i := range k {
index := indices[i]
label := me.Labels[index]
votes.Add(label, DistanceWeightingFuncQuadratic(distances[i])*me.Values[index])
}
}

func (me *Model) votes1q(k int, indices []int, votes Votes, distances []int) {
func (me *Model) votes1q(k int, indices []int, votes VoteCounter, distances []int) {
for i := range k {
index := indices[i]
label := me.Labels[index]
votes.Add(label, DistanceWeightingFuncQuadratic(distances[i]))
}
}

func (me *Model) votes1vl(k int, indices []int, votes Votes, distances []int) {
func (me *Model) votes1vl(k int, indices []int, votes VoteCounter, distances []int) {
for i := range k {
index := indices[i]
label := me.Labels[index]
votes.Add(label, DistanceWeightingFuncLinear(distances[i])*me.Values[index])
}
}

func (me *Model) votes1l(k int, indices []int, votes Votes, distances []int) {
func (me *Model) votes1l(k int, indices []int, votes VoteCounter, distances []int) {
for i := range k {
index := indices[i]
label := me.Labels[index]
votes.Add(label, DistanceWeightingFuncLinear(distances[i]))
}
}

func (me *Model) votes1v(k int, indices []int, votes Votes) {
func (me *Model) votes1v(k int, indices []int, votes VoteCounter) {
for i := range k {
index := indices[i]
label := me.Labels[index]
votes.Add(label, me.Values[index])
}
}

func (me *Model) votes1(k int, indices []int, votes Votes) {
func (me *Model) votes1(k int, indices []int, votes VoteCounter) {
for i := range k {
index := indices[i]
label := me.Labels[index]
Expand Down
12 changes: 7 additions & 5 deletions model_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
)

func Benchmark_Model_Predict1(b *testing.B) {
votes := make([]float64, 256)
type bench struct {
dataSize []int
k []int
Expand All @@ -31,7 +30,7 @@ func Benchmark_Model_Predict1(b *testing.B) {
model.PreallocateHeap(k)
b.ResetTimer()
for n := 0; n < b.N; n++ {
model.Predict1(k, query, bitknn.VoteSlice(votes))
model.Predict1(k, query, bitknn.DiscardVotes)
}
})
}
Expand Down Expand Up @@ -59,9 +58,10 @@ func Benchmark_Model_Predict1V(b *testing.B) {
query := rand.Uint64()

model.PreallocateHeap(k)
voteSlice := bitknn.VoteSlice(votes)
b.ResetTimer()
for n := 0; n < b.N; n++ {
model.Predict1(k, query, bitknn.VoteSlice(votes))
model.Predict1(k, query, &voteSlice)
}
})
}
Expand Down Expand Up @@ -89,10 +89,11 @@ func Benchmark_Model_Predict1D(b *testing.B) {
model.DistanceWeighting = d
model.DistanceWeightingFunc = func(d int) float64 { return 1 / float64(1+d) }
query := rand.Uint64()
voteSlice := bitknn.VoteSlice(votes)

b.ResetTimer()
for n := 0; n < b.N; n++ {
model.Predict1(k, query, bitknn.VoteSlice(votes))
model.Predict1(k, query, &voteSlice)
}
})
}
Expand Down Expand Up @@ -122,10 +123,11 @@ func Benchmark_Model_Predict1DV(b *testing.B) {
model.DistanceWeighting = d
model.DistanceWeightingFunc = func(d int) float64 { return 1 / float64(1+d) }
query := rand.Uint64()
voteSlice := bitknn.VoteSlice(votes)

b.ResetTimer()
for n := 0; n < b.N; n++ {
model.Predict1(k, query, bitknn.VoteSlice(votes))
model.Predict1(k, query, &voteSlice)
}
})
}
Expand Down
4 changes: 2 additions & 2 deletions model_wide.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ func (me *WideModel) PreallocateHeap(k int) {

// Predicts the label of a single input point. Reuses two slices of length K+1 for the neighbor heap.
// Returns the number of neighbors found.
func (me *WideModel) Predict1(k int, x []uint64, votes Votes) int {
func (me *WideModel) Predict1(k int, x []uint64, votes VoteCounter) int {
me.Narrow.PreallocateHeap(k)
return me.Predict1Into(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes)
}

// Predicts the label of a single input point, using the given slices for the neighbor heap.
// Returns the number of neighbors found.
func (me *WideModel) Predict1Into(k int, x []uint64, distances []int, indices []int, votes Votes) int {
func (me *WideModel) Predict1Into(k int, x []uint64, distances []int, indices []int, votes VoteCounter) int {
k = NearestWide(me.WideData, k, x, distances, indices)
me.Narrow.Vote(k, distances, indices, votes)
return k
Expand Down
3 changes: 1 addition & 2 deletions model_wide_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
)

func Benchmark_WideModel_Predict1(b *testing.B) {
votes := make([]float64, 256)
type bench struct {
dim []int
dataSize []int
Expand All @@ -32,7 +31,7 @@ func Benchmark_WideModel_Predict1(b *testing.B) {
model.PreallocateHeap(k)
b.ResetTimer()
for n := 0; n < b.N; n++ {
model.Predict1(k, query, bitknn.VoteSlice(votes))
model.Predict1(k, query, bitknn.DiscardVotes)
}
})
}
Expand Down
15 changes: 13 additions & 2 deletions votes.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package bitknn

import "slices"

// Votes is a k-NN vote counter interface.
type Votes interface {
// VoteCounter is a k-NN vote counter interface.
type VoteCounter interface {
// Clear removes all votes.
Clear()

Expand All @@ -21,6 +21,17 @@ type Votes interface {
Add(label int, delta float64)
}

type discardVotes int

// DiscardVotes is a no-op vote counter.
const DiscardVotes = discardVotes(0)

func (me discardVotes) Clear() {}
func (me discardVotes) ArgMax() int { return 0 }
func (me discardVotes) Max() float64 { return 0 }
func (me discardVotes) Get(label int) float64 { return 0 }
func (me discardVotes) Add(label int, delta float64) {}

// VoteSlice is a dense vote counter that stores votes in a slice.
// It is efficient for small sets of class labels.
type VoteSlice []float64
Expand Down
29 changes: 29 additions & 0 deletions votes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,35 @@ import (

const eps = 1e-9

func TestVoteDiscard(t *testing.T) {
discard := bitknn.DiscardVotes
{
pre := discard
discard.Add(0, 1)
post := discard
if pre != post {
t.Fail()
}
}
if discard.Get(0) != 0 {
t.Fatal()
}
if discard.ArgMax() != 0 {
t.Fatal()
}
if discard.Max() != 0 {
t.Fatal()
}
{
pre := discard
discard.Clear()
post := discard
if pre != post {
t.Fail()
}
}
}

func TestVoteSlice_Clear(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
length := rapid.IntRange(0, 100).Draw(t, "length")
Expand Down

0 comments on commit 0b82a2d

Please sign in to comment.