From d1f33b34c26012d48f7e1513d9a2f638342620a1 Mon Sep 17 00:00:00 2001 From: Sergey Grebenshchikov Date: Thu, 10 Oct 2024 02:12:00 +0200 Subject: [PATCH] return num of neighbors --- lsh/model_wide.go | 6 +++--- model_wide.go | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/lsh/model_wide.go b/lsh/model_wide.go index 8b3a010..855a554 100644 --- a/lsh/model_wide.go +++ b/lsh/model_wide.go @@ -84,7 +84,7 @@ func (me *WideModel) Predict1Alloc(k int, x []uint64, votes bitknn.Votes) int { // Predict1Into predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps. func (me *WideModel) Predict1Into(k int, x []uint64, votes bitknn.Votes, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int { xp := me.Hash.Hash1Wide(x) - k, n := NearestWide(me.WideData, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) - me.WideModel.Narrow.Vote(k, distances, indices, votes) - return n + k0, _ := NearestWide(me.WideData, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) + me.WideModel.Narrow.Vote(k0, distances, indices, votes) + return k0 } diff --git a/model_wide.go b/model_wide.go index 3a9251d..65ee23a 100644 --- a/model_wide.go +++ b/model_wide.go @@ -22,13 +22,16 @@ func (me *WideModel) PreallocateHeap(k int) { } // Predicts the label of a single input point. Reuses two slices of length K+1 for the neighbor heap. -func (me *WideModel) Predict1(k int, x []uint64, votes Votes) { +// Returns the number of neighbors found. +func (me *WideModel) Predict1(k int, x []uint64, votes Votes) int { me.Narrow.PreallocateHeap(k) - me.Predict1Into(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes) + return me.Predict1Into(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes) } // Predicts the label of a single input point, using the given slices for the neighbor heap. -func (me *WideModel) Predict1Into(k int, x []uint64, distances []int, indices []int, votes Votes) { +// Returns the number of neighbors found. +func (me *WideModel) Predict1Into(k int, x []uint64, distances []int, indices []int, votes Votes) int { k = NearestWide(me.WideData, k, x, distances, indices) me.Narrow.Vote(k, distances, indices, votes) + return k }