From e385b14078702694912e7b900709704b250e81b6 Mon Sep 17 00:00:00 2001 From: Sergey Grebenshchikov Date: Sat, 12 Oct 2024 03:33:21 +0200 Subject: [PATCH] -lsh --- README.md | 86 +------ example_test.go | 2 - lsh/example_test.go | 37 --- lsh/hashes.go | 511 --------------------------------------- lsh/hashes_test.go | 460 ----------------------------------- lsh/hashes_wide_test.go | 271 --------------------- lsh/model.go | 115 --------- lsh/model_bench_test.go | 52 ---- lsh/model_test.go | 129 ---------- lsh/model_wide.go | 108 --------- lsh/model_wide_test.go | 90 ------- lsh/nearest.go | 136 ----------- lsh/nearest_test.go | 99 -------- lsh/nearest_wide.go | 115 --------- lsh/nearest_wide_test.go | 60 ----- model.go | 3 +- model_wide_test.go | 24 ++ 17 files changed, 34 insertions(+), 2264 deletions(-) delete mode 100644 lsh/example_test.go delete mode 100644 lsh/hashes.go delete mode 100644 lsh/hashes_test.go delete mode 100644 lsh/hashes_wide_test.go delete mode 100644 lsh/model.go delete mode 100644 lsh/model_bench_test.go delete mode 100644 lsh/model_test.go delete mode 100644 lsh/model_wide.go delete mode 100644 lsh/model_wide_test.go delete mode 100644 lsh/nearest.go delete mode 100644 lsh/nearest_test.go delete mode 100644 lsh/nearest_wide.go delete mode 100644 lsh/nearest_wide_test.go diff --git a/README.md b/README.md index 6543d0d..7bebe0e 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,9 @@ If your vectors are **longer than 64 bits**, you can [pack](#packing-wide-data) You can optionally weigh class votes by distance, or specify different vote values per data point. -The sub-package [`lsh`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh) implements several [Locality-Sensitive Hashing (LSH)](https://en.m.wikipedia.org/wiki/Locality-sensitive_hashing) schemes for `uint64` feature vectors. - **Contents** - [Usage](#usage) - [Basic usage](#basic-usage) - - [LSH](#lsh) - [Packing wide data](#packing-wide-data) - [ARM64 NEON Support](#arm64-neon-support) - [Options](#options) @@ -35,19 +32,20 @@ There are just three methods you'll typically need: - **Fit** *(data, labels, [\[options\]](#options))*: create a model from a dataset - Variants: [`bitknn.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Fit), [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide), [`lsh.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Fit), [`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#FitWide) + Variants: [`bitknn.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Fit), [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide) + - **Find** *(k, point)*: Given a point, return the *k* nearest neighbor's indices and distances. - Variants: [`bitknn.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Find), [`bitknn.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Find), [`lsh.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Find), [`lsh.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Find), [`bitknn.WideModel.FindV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.FindV) (vectorized on ARM64 with NEON instructions) + Variants: [`bitknn.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Find), [`bitknn.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Find), [`bitknn.WideModel.FindV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.FindV) (vectorized on ARM64 with NEON instructions) - **Predict** *(k, point, votes)*: Predict the label for a given point based on its nearest neighbors, write the label votes into the provided vote counter. - Variants: [`bitknn.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Predict), [`bitknn.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Predict), [`lsh.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Predict), [`lsh.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Predict), [`bitknn.WideModel.PredictV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.PredictV) (vectorized on ARM64 with NEON instructions). + Variants: [`bitknn.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Predict), [`bitknn.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Predict), [`bitknn.WideModel.PredictV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.PredictV) (vectorized on ARM64 with NEON instructions). -Each of the above methods is available on each model type. There are four model types in total: +Each of the above methods is available on either model type: -- **Exact k-NN** models: [`bitknn.Model`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model) (64 bits), [`bitknn.WideModel`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel) (*N* * 64 bits) -- **Approximate (ANN)** models: [`lsh.Model`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model) (64 bits), [`lsh.WideModel`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel) (*N* * 64 bits) +- [`bitknn.Model`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model) (64 bits) +- [`bitknn.WideModel`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel) (*N* * 64 bits) ### Basic usage @@ -86,76 +84,14 @@ func main() { } ``` -### LSH - -Locality-Sensitive Hashing (LSH) is a type of approximate k-NN search. It's faster at the expense of accuracy. - -LSH works by hashing data points such that points that are close in Hamming space tend to land in the same bucket. In particular, for *k*=1 only one bucket needs to be examined. - -```go -package main - -import ( - "fmt" - "github.com/keilerkonzept/bitknn/lsh" - "github.com/keilerkonzept/bitknn" -) - -func main() { - // feature vectors packed into uint64s - data := []uint64{0b101010, 0b111000, 0b000111} - // class labels - labels := []int{0, 1, 1} - - // Define a hash function (e.g., MinHash) - hash := lsh.RandomMinHash() - - // Fit an LSH model - model := lsh.Fit(data, labels, hash, bitknn.WithLinearDistanceWeighting()) - - // one vote counter per class - votes := make([]float64, 2) - - k := 2 - model.Predict(k, 0b101011, bitknn.VoteSlice(votes)) - // or, just return the nearest neighbor's distances and indices: - // distances,indices := model.Find(k, 0b101011) - - fmt.Println("Votes:", votes) - - // you can also use a map for the votes - votesMap := make(map[int]float64) - model.Predict(k, 0b101011, bitknn.VoteMap(votesMap)) - fmt.Println("Votes for 0:", votesMap[0]) -} -``` - -The model accepts anything that implements the [`lsh.Hash` interface](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Hash) as a hash function. Several functions are pre-defined: - -- [MinHash](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHash): An implementation of the [MinHash scheme](https://en.m.wikipedia.org/wiki/MinHash) for bit vectors. - - Constructors: [RandomMinHash](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHash), [RandomMinHashR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashR). -- [MinHashes](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHash): Concatenation of several *MinHash*es. - - Constructors: [RandomMinHashes](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashes), [RandomMinHashesR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashesR). -- [Blur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Blur): A threshold-based variation on bit sampling. - - Constructors: [RandomBlur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBlur), [RandomBlurR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBlurR), [BoxBlur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BoxBlur), . -- [BitSample](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BitSample): A random sampling of bits from the feature vector. - - Constructors: [RandomBitSample](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBitSample), [RandomBitSampleR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBitSampleR). - -For datasets of vectors longer than 64 bits, the `lsh` package also provides a [`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#FitWide) function, and "wide" versions of the hash functions ([MinHashWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHashWide), [BlurWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BlurWide), [BitSampleWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BitSampleWide)) - -The [`lsh.Fit`/`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Fit) functions accept the same [Options](#options) as the others. - ### Packing wide data If your vectors are longer than 64 bits, you can still use `bitknn` if you [pack](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack) them into `[]uint64`. The [`pack` package](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack) defines helper functions to pack `string`s and `[]byte`s into `[]uint64`s. > It's faster to use a `[][]uint64` allocated using a flat backing slice, laid out in one contiguous memory block. If you already have a non-contiguous `[][]uint64`, you can use [`pack.ReallocateFlat`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack#ReallocateFlat) to re-allocate the dataset using a flat 1d backing slice. -The exact k-NN model in `bitknn` and the approximate-NN model in `lsh` each have a `Wide` variant that accepts slice-valued data points: +The wide model fitting function is [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide) and accepts the same [Options](#options) as the "narrow" one: + ```go package main @@ -178,8 +114,6 @@ func main() { labels := []int{0, 1, 1} model := bitknn.FitWide(data, labels, bitknn.WithLinearDistanceWeighting()) - // also using LSH: - // model := lsh.FitWide(data, labels, lsh.RandomMinHash(), bitknn.WithLinearDistanceWeighting()) // one vote counter per class votes := make([]float64, 2) @@ -192,8 +126,6 @@ func main() { } ``` -The wide model fitting function [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide) accepts the same [Options](#options) as the "narrow" one. - ### ARM64 NEON Support For ARM64 CPUs with NEON instructions, `bitknn` has a [vectorized distance function for `[]uint64s`s](internal/neon/distance_arm64.s) that is about twice as fast as what the compiler generates. diff --git a/example_test.go b/example_test.go index 44809fe..ee86859 100644 --- a/example_test.go +++ b/example_test.go @@ -46,8 +46,6 @@ func ExampleFitWide() { labels := []int{0, 1, 1} model := bitknn.FitWide(data, labels, bitknn.WithLinearDistanceWeighting()) - // also using LSH: - // model := lsh.FitWide(data, labels, lsh.RandomMinHash(), bitknn.WithLinearDistanceWeighting()) // one vote counter per class votes := make([]float64, 2) diff --git a/lsh/example_test.go b/lsh/example_test.go deleted file mode 100644 index 69edb4a..0000000 --- a/lsh/example_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package lsh_test - -import ( - "fmt" - - "github.com/keilerkonzept/bitknn" - "github.com/keilerkonzept/bitknn/lsh" -) - -func Example() { - // feature vectors packed into uint64s - data := []uint64{0b101010, 0b111000, 0b000111} - // class labels - labels := []int{0, 1, 1} - - // Define a hash function - hash := lsh.BitSample(0xF0F0F0) - - // Fit an LSH model - model := lsh.Fit(data, labels, hash, bitknn.WithLinearDistanceWeighting()) - - // one vote counter per class - votes := make([]float64, 2) - - k := 2 - model.Predict(k, 0b101011, bitknn.VoteSlice(votes)) - - fmt.Println("Votes:", bitknn.VoteSlice(votes)) - - // you can also use a map for the votes - votesMap := make(map[int]float64) - model.Predict(k, 0b101011, bitknn.VoteMap(votesMap)) - fmt.Println("Votes for 0:", votesMap[0]) - // Output: - // Votes: [0.5 0.25] - // Votes for 0: 0.5 -} diff --git a/lsh/hashes.go b/lsh/hashes.go deleted file mode 100644 index b8da553..0000000 --- a/lsh/hashes.go +++ /dev/null @@ -1,511 +0,0 @@ -package lsh - -import ( - "math/bits" - "math/rand/v2" -) - -// Hash is a uint64 hash function for the lsh package. -type Hash interface { - Hash1(uint64) uint64 - Hash(in []uint64, out []uint64) -} - -// HashWide is a []uint64 hash function for the lsh package. -type HashWide interface { - Hash1Wide([]uint64) uint64 - HashWide(in [][]uint64, out []uint64) -} - -// HashFunc is a function type that implements the Hash interface. -type HashFunc func(uint64) uint64 - -// Hash1 applies the function to a single uint64 value. -func (me HashFunc) Hash1(x uint64) uint64 { return me(x) } - -// Hash applies the function to a slice of uint64 values. -func (me HashFunc) Hash(data []uint64, out []uint64) { - for i, d := range data { - out[i] = me(d) - } -} - -// HashWide1 is a HashWide that applies a Hash only to the first dimension. -type HashWide1 struct { - Single Hash -} - -// Hash1 applies the function to a single uint64 value. -func (me *HashWide1) Hash1Wide(x []uint64) uint64 { - return me.Single.Hash1(x[0]) -} - -// Hash applies the function to a slice of uint64 values. -func (me *HashWide1) HashWide(data [][]uint64, out []uint64) { - for i, d := range data { - out[i] = me.Hash1Wide(d) - } -} - -// HashCompose is the composition of several hash functions. -type HashCompose []Hash - -// Hash1 applies the function to a single uint64 value. -func (me HashCompose) Hash1(x uint64) uint64 { - for _, h := range me { - x = h.Hash1(x) - } - return x -} - -// Hash applies the function to a slice of uint64 values. -func (me HashCompose) Hash(data []uint64, out []uint64) { - for _, h := range me { - h.Hash(data, out) - data = out - } -} - -// NoHash is the identity function. Used as a dummy [Hash] for testing. -type NoHash struct{} - -// Hash1 returns the given value. -func (me NoHash) Hash1(x uint64) uint64 { return x } - -// Hash copies the input slice to the output slice. -func (me NoHash) Hash(data []uint64, out []uint64) { - copy(out, data) -} - -// ConstantHash is a constant 0 function. Used as a dummy [Hash] for testing. -type ConstantHash struct{} - -// Hash1 returns the given value. -func (me ConstantHash) Hash1(x uint64) uint64 { return 0 } - -// Hash1 returns the given value. -func (me ConstantHash) Hash1Wide(x []uint64) uint64 { return 0 } - -// Hash clears the output slice. -func (me ConstantHash) Hash(data []uint64, out []uint64) { - clear(out) -} - -// Hash clears the output slice. -func (me ConstantHash) HashWide(data [][]uint64, out []uint64) { - clear(out) -} - -// MinHashes is a concatenation of [MinHash]es -type MinHashes []MinHash - -// RandomMinHashes creates n random MinHash functions. -func RandomMinHashesR(n int, rand *rand.Rand) MinHashes { - out := make([]MinHash, n) - for i := range out { - out[i] = RandomMinHashR(rand) - } - return out -} - -// RandomMinHashes creates [n] random MinHashes. -func RandomMinHashes(n int) MinHashes { - out := make([]MinHash, n) - for i := range out { - out[i] = RandomMinHash() - } - return out -} - -// Hash1 applies each MinHash to the given uint64 value and concatenates the MinHash bits. -func (me MinHashes) Hash1(x uint64) uint64 { - var out uint64 - for _, h := range me { - out <<= 6 // log2(64) - out |= h.Hash1(x) - } - return out -} - -// Has1 applies the MinHashes to each uint64 value in the slice. -func (me MinHashes) Hash(data []uint64, out []uint64) { - for i, d := range data { - var m uint64 - for _, h := range me { - m <<= 6 // log2(64) - m |= h.Hash1(d) - } - out[i] = m - } -} - -// MinHash is a MinHash function for Hamming space. -type MinHash []uint64 - -// RandomMinHashR returns a random [MinHash]. -func RandomMinHashR(rand *rand.Rand) MinHash { - ones := rand.Perm(64) - out := make([]uint64, 64) - for i := range out { - out[i] = 1 << ones[i] - } - return out -} - -// RandomMinHash returns a random [MinHash]. -func RandomMinHash() MinHash { - ones := rand.Perm(64) - out := make([]uint64, 64) - for i := range out { - out[i] = 1 << ones[i] - } - return out -} - -// Hash1 hashes a single uint64 value. -func (me MinHash) Hash1(x uint64) uint64 { - for j, m := range me { - if (x & m) != 0 { - return uint64(j) - } - } - return 0 // never reached -} - -// Hash hashes a slice of uint64 values. -func (me MinHash) Hash(data []uint64, out []uint64) { - for i, d := range data { - for j, m := range me { - if (d & m) != 0 { - out[i] = uint64(j) - break - } - } - } -} - -// Blur hashes values based on thresholding the number of bits in common with the given bitmasks. -// For bitmasks of consecutive set bits, this is in effect a "blur" of the bit vector. -type Blur struct { - Masks []uint64 // Bitmasks - Threshold int // Minimum number of common bits required to set the output bit -} - -// Hash1 hashes a single uint64 value. -func (me Blur) Hash1(x uint64) uint64 { - var bx uint64 - for _, b := range me.Masks { - bx <<= 1 - if bits.OnesCount64(x&b) >= me.Threshold { - bx |= 1 - } - } - return bx -} - -// Hash hashes a slice of uint64 values. -func (me Blur) Hash(data []uint64, out []uint64) { - for i, d := range data { - var bx uint64 - for _, b := range me.Masks { - bx <<= 1 - if bits.OnesCount64(d&b) >= me.Threshold { - bx |= 1 - } - } - out[i] = bx - } -} - -// RandomBlurR generates a Blur of [n] bitmasks with the given number [numBits] of set bits. -func RandomBlurR(numBits int, n int, rand *rand.Rand) Blur { - bits := make([]uint64, n) - threshold := numBits/2 + 1 - for i := range n { - b := uint64(RandomBitSampleR(numBits, rand)) - bits[i] = b - } - return Blur{ - Masks: bits, - Threshold: threshold, - } -} - -// RandomBlur generates a Blur of [n] bitmasks with the given number [numBits] of set bits. -func RandomBlur(numBits int, n int) Blur { - bits := make([]uint64, n) - threshold := numBits/2 + 1 - for i := range n { - b := uint64(RandomBitSample(numBits)) - bits[i] = b - } - return Blur{ - Masks: bits, - Threshold: threshold, - } -} - -// BitSample is a random sample of bits in a uint64 value. -// Only the bits set in the BitSample are kept. -type BitSample uint64 - -// Hash1 hashes a single uint64 value. -func (me BitSample) Hash1(x uint64) uint64 { - return x & uint64(me) -} - -// Hash hashes a slice of uint64 values. -func (me BitSample) Hash(data []uint64, out []uint64) { - for i, d := range data { - out[i] = d & uint64(me) - } -} - -// RandomBitSample generates a BitSample with a specified number of bits set to 1. -func RandomBitSample(numBitsSet int) BitSample { - ones := rand.Perm(64) - var out uint64 - for i := 0; i < numBitsSet; i++ { - out |= uint64(1) << ones[i] - } - return BitSample(out) -} - -// RandomBitSample generates a BitSample with a specified number of bits set to 1. -func RandomBitSampleR(numBitsSet int, rand *rand.Rand) BitSample { - ones := rand.Perm(64) - var out uint64 - for i := 0; i < numBitsSet; i++ { - out |= uint64(1) << ones[i] - } - return BitSample(out) -} - -// BoxBlur generates a Blur that averages groups of neighboring bits for each bit in the output. -func BoxBlur(radius int, step int) Blur { - mask := uint64(1<= me.Threshold { - bx |= 1 - } - } - return bx -} - -// Hash hashes a slice of uint64 values. -func (me BlurWide) HashWide(data [][]uint64, out []uint64) { - for i, d := range data { - var bx uint64 - for _, bs := range me.Masks { - bx <<= 1 - count := 0 - for _, b := range bs { - count += bits.OnesCount64(d[b.D] & b.M) - } - if count >= me.Threshold { - bx |= 1 - } - } - out[i] = bx - } -} - -// RandomBlurWideR generates a BlurWide of [n] bitmasks with the given number [numBits] of set bits. -func RandomBlurWideR(dims int, numBits int, n int, rand *rand.Rand) BlurWide { - masks := make([][]BlurWideMask, n) - threshold := numBits/2 + 1 - for i := range n { - mask := make([]BlurWideMask, 0, numBits) - ones := rand.Perm(64 * dims)[:numBits] - for _, bitIndex := range ones { - dim := bitIndex / 64 - bit := bitIndex % 64 - mask = append(mask, BlurWideMask{ - D: dim, - M: 1 << bit, - }) - } - masks[i] = mask - } - return BlurWide{ - Masks: masks, - Threshold: threshold, - } -} - -// RandomBlurWide generates a BlurWide of [n] bitmasks with the given number [numBits] of set bits. -func RandomBlurWide(dims int, numBits int, n int) BlurWide { - masks := make([][]BlurWideMask, n) - threshold := numBits/2 + 1 - for i := range n { - mask := make([]BlurWideMask, 0, numBits) - ones := rand.Perm(64 * dims)[:numBits] - for _, bitIndex := range ones { - dim := bitIndex / 64 - bit := bitIndex % 64 - mask = append(mask, BlurWideMask{ - D: dim, - M: 1 << bit, - }) - } - masks[i] = mask - } - return BlurWide{ - Masks: masks, - Threshold: threshold, - } -} diff --git a/lsh/hashes_test.go b/lsh/hashes_test.go deleted file mode 100644 index 34821b6..0000000 --- a/lsh/hashes_test.go +++ /dev/null @@ -1,460 +0,0 @@ -package lsh_test - -import ( - "math/bits" - "reflect" - "testing" - - "github.com/keilerkonzept/bitknn/internal/testrandom" - "github.com/keilerkonzept/bitknn/lsh" - "pgregory.net/rapid" -) - -func TestHashCompose(t *testing.T) { - h1 := lsh.RandomBlurR(3, 20, testrandom.Source) - h2 := lsh.RandomMinHash() - h := lsh.HashCompose{h1, h2} - rapid.Check(t, func(t *rapid.T) { - q := rapid.Uint64().Draw(t, "q") - qs := rapid.SliceOf(rapid.Uint64()).Draw(t, "qs") - if h.Hash1(q) != h2.Hash1(h1.Hash1(q)) { - t.Fatal() - } - out12 := make([]uint64, len(qs)) - out := make([]uint64, len(qs)) - h.Hash(qs, out) - h1.Hash(qs, out12) - h2.Hash(out12, out12) - if !reflect.DeepEqual(out, out12) { - t.Fatal() - } - }) -} - -func TestMinHash(t *testing.T) { - t.Run("RandomMinHash", func(t *testing.T) { - h := lsh.RandomMinHash() - if len(h) != 64 { - t.Errorf("RandomMinHash() returned slice of length %d; want 64", len(h)) - } - - // Check that all bit positions are represented - var allBits uint64 - for _, m := range h { - allBits |= m - } - if allBits != ^uint64(0) { - t.Errorf("RandomMinHash() doesn't cover all bit positions") - } - }) - - t.Run("MinHash_Hash1", func(t *testing.T) { - h := lsh.RandomMinHash() - testCases := []struct { - input uint64 - }{ - {0b1000}, - {0b1100}, - {0b1111}, - } - - for _, tc := range testCases { - got := h.Hash1(tc.input) - if got >= 64 { - t.Errorf("MinHash.Hash() returned %d for input %b; want value < 64", got, tc.input) - } - } - }) - - t.Run("MinHash_Hash1_Equiv_Hash", func(t *testing.T) { - rapid.Check(t, func(t *rapid.T) { - data := rapid.SliceOf(rapid.Uint64()).Draw(t, "data") - out := make([]uint64, len(data)) - h := lsh.RandomMinHash() - h.Hash(data, out) - for i, d := range data { - if out[i] != h.Hash1(d) { - t.Fatal() - } - } - }) - }) - - t.Run("MinHash_LS_Property", func(t *testing.T) { - x := uint64(0b1110) - y := uint64(0b1100) - z := uint64(0b0001) - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomMinHash() - if h.Hash1(x) == h.Hash1(y) { - xyEqual++ - } - if h.Hash1(x) == h.Hash1(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) -} - -func TestBlur(t *testing.T) { - t.Run("Blur_Hash1", func(t *testing.T) { - h := &lsh.Blur{ - Masks: []uint64{0xF0F0F0F0, 0x0F0F0F0F}, - Threshold: 4, - } - - testCases := []struct { - input uint64 - want uint64 - }{ - {0xFFFFFFFF, 3}, - {0xF0F0F0F0, 2}, - {0x0F0F0F0F, 1}, - {0x00000000, 0}, - } - - for _, tc := range testCases { - got := h.Hash1(tc.input) - if got != tc.want { - t.Errorf("Blur.Hash1(%x) = %d; want %d", tc.input, got, tc.want) - } - } - }) - - t.Run("Blur_Hash1_Equiv_Hash", func(t *testing.T) { - rapid.Check(t, func(t *rapid.T) { - bits := rapid.IntRange(0, 64).Draw(t, "bits") - masks := rapid.IntRange(0, 10).Draw(t, "masks") - data := rapid.SliceOf(rapid.Uint64()).Draw(t, "data") - out := make([]uint64, len(data)) - h := lsh.RandomBlur(bits, masks) - h.Hash(data, out) - for i, d := range data { - if out[i] != h.Hash1(d) { - t.Fatal() - } - } - }) - }) - - t.Run("BoxBlur", func(t *testing.T) { - trials := 1000 - yCloser := 0 - zCloser := 0 - for range trials { - n := testrandom.Source.IntN(32) - dist := func(x, y uint64) int { - return bits.OnesCount64(x ^ y) - } - flipNBits := uint64(lsh.RandomBitSampleR(n, testrandom.Source)) - flip2NBits := uint64(lsh.RandomBitSampleR(2*n, testrandom.Source)) - x := testrandom.Query() - y := x ^ flipNBits - z := x ^ flip2NBits - h := lsh.BoxBlur(3, 3) - dy := dist(h.Hash1(x), h.Hash1(y)) - dz := dist(h.Hash1(x), h.Hash1(z)) - if dy < dz { - yCloser++ - } - if dy > dz { - zCloser++ - } - } - - if zCloser > yCloser { - t.Errorf("Expected Hash1(x) to be closer to Hash1(y) more often than Hash1(x) to be closer to Hash1(z), got %d and %d", yCloser, zCloser) - } - }) - - t.Run("Blur_LS_Property", func(t *testing.T) { - x := uint64(0b1110) - y := uint64(0b1100) - z := uint64(0b0001) - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomBlur(3, 50) - if h.Hash1(x) == h.Hash1(y) { - xyEqual++ - } - if h.Hash1(x) == h.Hash1(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) - - t.Run("BlurR_LS_Property", func(t *testing.T) { - x := uint64(0b1110) - y := uint64(0b1100) - z := uint64(0b0001) - - xyEqual := 0 - xzEqual := 0 - trials := 10_000 - - for range trials { - h := lsh.RandomBlurR(3, 50, testrandom.Source) - if h.Hash1(x) == h.Hash1(y) { - xyEqual++ - } - if h.Hash1(x) == h.Hash1(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) -} - -func TestBitSample(t *testing.T) { - t.Run("RandomBitSample", func(t *testing.T) { - for _, numBits := range []int{1, 32, 63} { - h := lsh.RandomBitSample(numBits) - count := bits.OnesCount64(uint64(h)) - if count != numBits { - t.Errorf("RandomBitSample(%d) set %d bits; want %d", numBits, count, numBits) - } - } - }) - - t.Run("BitSample_Hash1", func(t *testing.T) { - h := lsh.BitSample(0xF0F0F0F0) - - testCases := []struct { - input uint64 - want uint64 - }{ - {0xFFFFFFFF, 0xF0F0F0F0}, - {0x0F0F0F0F, 0x00000000}, - {0xAAAAAAAA, 0xA0A0A0A0}, - } - - for _, tc := range testCases { - got := h.Hash1(tc.input) - if got != tc.want { - t.Errorf("BitSample.Hash1(%x) = %x; want %x", tc.input, got, tc.want) - } - } - }) - - t.Run("BitSample_Hash1_Equiv_Hash", func(t *testing.T) { - rapid.Check(t, func(t *rapid.T) { - bits := rapid.IntRange(0, 64).Draw(t, "bits") - data := rapid.SliceOf(rapid.Uint64()).Draw(t, "data") - out := make([]uint64, len(data)) - h := lsh.RandomBitSample(bits) - h.Hash(data, out) - for i, d := range data { - if out[i] != h.Hash1(d) { - t.Fatal() - } - } - }) - }) - - t.Run("BitSample_LS_Property", func(t *testing.T) { - x := uint64(0b1110) - y := uint64(0b1100) - z := uint64(0b0001) - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomBitSample(48) - if h.Hash1(x) == h.Hash1(y) { - xyEqual++ - } - if h.Hash1(x) == h.Hash1(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) - - t.Run("BitSampleR_LS_Property", func(t *testing.T) { - x := uint64(0b1110) - y := uint64(0b1100) - z := uint64(0b0001) - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomBitSampleR(48, testrandom.Source) - if h.Hash1(x) == h.Hash1(y) { - xyEqual++ - } - if h.Hash1(x) == h.Hash1(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) -} - -func TestMinHashes(t *testing.T) { - t.Run("RandomMinHashes", func(t *testing.T) { - h := lsh.RandomMinHashes(3) - if len(h) != 3 { - t.Errorf("RandomMinHashes(3) returned slice of length %d; want 3", len(h)) - } - for _, h := range h { - if len(h) != 64 { - t.Errorf("RandomMinHashes(3) contains MinHash of length %d; want 64", len(h)) - } - } - }) - - t.Run("MinHashes_Hash1", func(t *testing.T) { - h := lsh.RandomMinHashes(3) - input := uint64(0xFFFFFFFF) - output := h.Hash1(input) - if output >= 1<<18 { // 3 * 6 bits - t.Errorf("MinHashes.Hash1(%x) = %d; want value < %d", input, output, 1<<18) - } - }) - - t.Run("MinHashes_Hash1_Equiv_Hash", func(t *testing.T) { - rapid.Check(t, func(t *rapid.T) { - masks := rapid.IntRange(0, 10).Draw(t, "masks") - data := rapid.SliceOf(rapid.Uint64()).Draw(t, "data") - out := make([]uint64, len(data)) - h := lsh.RandomMinHashes(masks) - h.Hash(data, out) - for i, d := range data { - if out[i] != h.Hash1(d) { - t.Fatal() - } - } - }) - }) - - t.Run("MinHashes_LS_Property", func(t *testing.T) { - x := uint64(0b1110) - y := uint64(0b1100) - z := uint64(0b0001) - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomMinHashes(3) - if h.Hash1(x) == h.Hash1(y) { - xyEqual++ - } - if h.Hash1(x) == h.Hash1(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) - - t.Run("MinHashesR_LS_Property", func(t *testing.T) { - x := uint64(0b1110) - y := uint64(0b1100) - z := uint64(0b0001) - - xyEqual := 0 - xzEqual := 0 - trials := 10_000 - - for range trials { - h := lsh.RandomMinHashesR(3, testrandom.Source) - if h.Hash1(x) == h.Hash1(y) { - xyEqual++ - } - if h.Hash1(x) == h.Hash1(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) -} - -func TestHashFunc(t *testing.T) { - trials := 1000 - - for range trials { - data := testrandom.Data(16) - h := lsh.RandomBitSample(10) - hf := lsh.HashFunc(h.Hash1) - outh := make([]uint64, len(data)) - outhf := make([]uint64, len(data)) - x := testrandom.Query() - if h.Hash1(x) != hf.Hash1(x) { - t.Fatal() - } - h.Hash(data, outh) - hf.Hash(data, outhf) - if !reflect.DeepEqual(outh, outhf) { - t.Fatal() - } - - } -} - -func TestDummyHashes(t *testing.T) { - t.Run("NoHash", func(t *testing.T) { - var h lsh.NoHash - query := uint64(0x12345) - data := []uint64{0x12345, 0x54321} - out := make([]uint64, len(data)) - if h.Hash1(query) != query { - t.Fatal() - } - h.Hash(data, out) - if !reflect.DeepEqual(data, out) { - t.Fatal() - } - }) - t.Run("ConstantHash", func(t *testing.T) { - var h lsh.ConstantHash - q := uint64(0x12345) - data := []uint64{0x12345, 0x54321} - out := make([]uint64, len(data)) - if h.Hash1(q) != 0 { - t.Fatal() - } - h.Hash(data, out) - for i := range out { - if out[i] != 0 { - t.Fatal() - } - } - }) -} diff --git a/lsh/hashes_wide_test.go b/lsh/hashes_wide_test.go deleted file mode 100644 index e469d50..0000000 --- a/lsh/hashes_wide_test.go +++ /dev/null @@ -1,271 +0,0 @@ -package lsh_test - -import ( - "math/bits" - "reflect" - "testing" - - "github.com/keilerkonzept/bitknn/internal/testrandom" - "github.com/keilerkonzept/bitknn/lsh" - "pgregory.net/rapid" -) - -func TestBlurWide(t *testing.T) { - t.Run("BlurWide_Hash1_Equiv_Hash", func(t *testing.T) { - rapid.Check(t, func(t *rapid.T) { - dims := rapid.IntRange(1, 4).Draw(t, "dims") - data := rapid.SliceOf(rapid.SliceOfN(rapid.Uint64(), dims, dims)).Draw(t, "data") - out := make([]uint64, len(data)) - h := lsh.RandomBlurWide(dims, 4, 3) - h.HashWide(data, out) - for i, d := range data { - if out[i] != h.Hash1Wide(d) { - t.Fatal() - } - } - }) - }) - t.Run("RandomBlurWide", func(t *testing.T) { - blur := lsh.RandomBlurWide(2, 4, 3) - - if len(blur.Masks) != 3 { - t.Errorf("RandomBlurWide returned BlurWide with %d masks, expected 3", len(blur.Masks)) - } - - for i, mask := range blur.Masks { - totalBits := 0 - for _, b := range mask { - totalBits += bits.OnesCount64(b.M) - } - if totalBits != 4 { - t.Errorf("Mask %d has %d bits set, expected 4", i, totalBits) - } - } - - if blur.Threshold != 3 { - t.Errorf("RandomBlurWide returned BlurWide with threshold %d, expected 3", blur.Threshold) - } - }) - - t.Run("BlurWide_Hamming_LS_Property", func(t *testing.T) { - x := []uint64{0b11, 0b10} - y := []uint64{0b11, 0b00} - z := []uint64{0b00, 0b01} - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomBlurWide(2, 3, 100) - if h.Hash1Wide(x) == h.Hash1Wide(y) { - xyEqual++ - } - if h.Hash1Wide(x) == h.Hash1Wide(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) - - t.Run("BlurWideR_Hamming_LS_Property", func(t *testing.T) { - x := []uint64{0b11, 0b10} - y := []uint64{0b11, 0b00} - z := []uint64{0b00, 0b01} - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomBlurWideR(2, 3, 100, testrandom.Source) - if h.Hash1Wide(x) == h.Hash1Wide(y) { - xyEqual++ - } - if h.Hash1Wide(x) == h.Hash1Wide(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) -} - -func TestMinHashWide(t *testing.T) { - t.Run("MinHashWide_Hash1", func(t *testing.T) { - h := lsh.RandomMinHashWide(3) - input := []uint64{0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF} - output := h.Hash1Wide(input) - if output >= 1<<6 { // 6 bits - t.Errorf("MinHashWide.Hash1Wide(%x) = %d; want value < %d", input, output, 1<<18) - } - }) - - t.Run("MinHashWide_Hash1_Equiv_Hash", func(t *testing.T) { - rapid.Check(t, func(t *rapid.T) { - dims := rapid.IntRange(1, 4).Draw(t, "dims") - data := rapid.SliceOf(rapid.SliceOfN(rapid.Uint64(), dims, dims)).Draw(t, "data") - out := make([]uint64, len(data)) - h := lsh.RandomMinHashWide(dims) - h.HashWide(data, out) - for i, d := range data { - if out[i] != h.Hash1Wide(d) { - t.Fatal() - } - } - }) - }) - - t.Run("MinHashWide_Hamming_LS_Property", func(t *testing.T) { - x := []uint64{0b1110, 0b1110, 0b1100} - y := []uint64{0b1100, 0b1100, 0b1110} - z := []uint64{0b0001, 0b0001, 0b1100} - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomMinHashWide(3) - if h.Hash1Wide(x) == h.Hash1Wide(y) { - xyEqual++ - } - if h.Hash1Wide(x) == h.Hash1Wide(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) - - t.Run("MinHashWideR_Hamming_LS_Property", func(t *testing.T) { - x := []uint64{0b1110, 0b1110, 0b1100} - y := []uint64{0b1100, 0b1100, 0b1110} - z := []uint64{0b0001, 0b0001, 0b1100} - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomMinHashWideR(3, testrandom.Source) - if h.Hash1Wide(x) == h.Hash1Wide(y) { - xyEqual++ - } - if h.Hash1Wide(x) == h.Hash1Wide(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) -} - -func TestBitSampleWide(t *testing.T) { - t.Run("RandomBitSampleWide", func(t *testing.T) { - for _, numBits := range []int{1, 32, 63} { - h := lsh.RandomBitSampleWide(3, numBits) - count := 0 - for _, ms := range h { - for _, m := range ms { - count += bits.OnesCount64(m) - } - } - if count != numBits { - t.Errorf("RandomBitSampleWide(%d) set %d bits; want %d", numBits, count, numBits) - } - } - }) - - t.Run("BitSampleWide_Hash1", func(t *testing.T) { - h := lsh.BitSampleWide{{1 << 15, 1 << 0}} - - testCases := []struct { - input []uint64 - want uint64 - }{ - {[]uint64{0xFFFFFFFF}, 3}, - {[]uint64{0x0F0F0F0F}, 1}, - {[]uint64{0xAAAAAAAA}, 2}, - } - - for _, tc := range testCases { - got := h.Hash1Wide(tc.input) - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("BitSampleWide.Hash1(%x) = %x; want %x", tc.input, got, tc.want) - } - } - }) - - t.Run("BitSampleWide_Hash1_Equiv_Hash", func(t *testing.T) { - rapid.Check(t, func(t *rapid.T) { - dims := rapid.IntRange(1, 4).Draw(t, "dims") - data := rapid.SliceOf(rapid.SliceOfN(rapid.Uint64(), dims, dims)).Draw(t, "data") - out := make([]uint64, len(data)) - h := lsh.RandomBitSampleWide(dims, 10) - h.HashWide(data, out) - for i, d := range data { - if out[i] != h.Hash1Wide(d) { - t.Fatal() - } - } - }) - }) - - t.Run("BitSampleWide_Hamming_LS_Property", func(t *testing.T) { - x := []uint64{0b1110, 0b1110, 0b1100} - y := []uint64{0b1100, 0b1100, 0b1110} - z := []uint64{0b0001, 0b0001, 0b1100} - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomBitSampleWide(3, 48) - if h.Hash1Wide(x) == h.Hash1Wide(y) { - xyEqual++ - } - if h.Hash1Wide(x) == h.Hash1Wide(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) - - t.Run("BitSampleWideR_Hamming_LS_Property", func(t *testing.T) { - x := []uint64{0b1110, 0b1110, 0b1100} - y := []uint64{0b1100, 0b1100, 0b1110} - z := []uint64{0b0001, 0b0001, 0b1100} - - xyEqual := 0 - xzEqual := 0 - trials := 1000 - - for range trials { - h := lsh.RandomBitSampleWideR(3, 48, testrandom.Source) - if h.Hash1Wide(x) == h.Hash1Wide(y) { - xyEqual++ - } - if h.Hash1Wide(x) == h.Hash1Wide(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) - } - }) -} diff --git a/lsh/model.go b/lsh/model.go deleted file mode 100644 index e35f1fd..0000000 --- a/lsh/model.go +++ /dev/null @@ -1,115 +0,0 @@ -// Package lsh implements Locality-Sensitive Hashing (LSH) for efficient approximate nearest neighbor search in Hamming space. -// -// This package also provides several hash functions for use with binary feature vectors (`uint64`), as well as `Wide` variants of the hash functions that work with the `[]uint64`s and the [WideModel] model: -// -// - [MinHash]: A hashing scheme for similarity search based on common bits. -// - [Blur]: Hashes values based on thresholding the number of bits in common with predefined bitmasks. -// - [BitSample]: A random sampling of bits in a feature vector. -package lsh - -import ( - "cmp" - "slices" - - "github.com/keilerkonzept/bitknn" - "github.com/keilerkonzept/bitknn/internal/slice" -) - -// Model is an LSH k-NN model, mapping points to buckets based on a locality-sensitive hash function. -type Model struct { - *bitknn.Model - Hash Hash // LSH function mapping points to bucket IDs. - - BucketIDs []uint64 // Bucket IDs. - Buckets map[uint64]slice.IndexRange // Bucket contents for each hash (offset+length in Data). - - HeapBucketDistances []int - HeapBucketIDs []uint64 -} - -// PreallocateHeap allocates memory for the nearest neighbor heap. -func (me *Model) PreallocateHeap(k int) { - me.HeapBucketDistances = slice.OrAlloc(me.HeapBucketDistances, k+1) - me.HeapBucketIDs = slice.OrAlloc(me.HeapBucketIDs, k+1) - me.Model.PreallocateHeap(k) -} - -// Fit creates and fits an LSH k-NN model using the provided data, labels, and hash function. -// It groups points into buckets using the LSH hash function. -func Fit(data []uint64, labels []int, hash Hash, opts ...bitknn.Option) *Model { - knnModel := bitknn.Fit(data, labels, opts...) - values := knnModel.Values - buckets := make([]uint64, len(data)) - hash.Hash(data, buckets) - - indices := make([]int, len(data)) - for i := range indices { - indices[i] = i - } - - // Sort data by bucket id so that each bucket's data slice is contiguous. slices.SortStableFunc(indices, func(a, b int) int { - slices.SortStableFunc(indices, func(a, b int) int { - return cmp.Compare(buckets[a], buckets[b]) - }) - - // Reorder all data-indexed slices to match the bucket sort order. - slice.ReorderInPlace(func(i, j int) { - buckets[i], buckets[j] = buckets[j], buckets[i] - data[i], data[j] = data[j], data[i] - labels[i], labels[j] = labels[j], labels[i] - if values != nil { - values[i], values[j] = values[j], values[i] - } - }, indices) - - bucketData, bucketIDs := slice.GroupSorted(data, buckets) - - return &Model{ - Model: knnModel, - Hash: hash, - BucketIDs: bucketIDs, - Buckets: bucketData, - } -} - -// Finds the nearest neighbors of the given point. -// Writes their distances and indices in the dataset into the pre-allocated slices. -// Returns the distance and index slices, truncated to the actual number of neighbors found. -func (me *Model) Find(k int, x uint64) ([]int, []int) { - me.PreallocateHeap(k) - return me.FindInto(k, x, me.HeapBucketDistances, me.HeapBucketIDs, me.HeapDistances, me.HeapIndices) -} - -// Finds the nearest neighbors of the given point. -// Writes their distances and indices in the dataset into the provided slices. -// The slices should be pre-allocated to length k+1. -// Returns the distance and index slices, truncated to the actual number of neighbors found. -func (me *Model) FindInto(k int, x uint64, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) ([]int, []int) { - xp := me.Hash.Hash1(x) - k, _ = Nearest(me.Data, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) - return distances[:k], indices[:k] -} - -// Predict predicts the label for a single input using the LSH model. -func (me *Model) Predict(k int, x uint64, votes bitknn.VoteCounter) int { - me.PreallocateHeap(k) - return me.PredictInto(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.HeapDistances, me.HeapIndices) -} - -// Predicts the label of a single input point. Each call allocates three new slices of length [k]+1 for the neighbor heaps. -func (me *Model) PredictAlloc(k int, x uint64, votes bitknn.VoteCounter) int { - bucketDistances := make([]int, k+1) - bucketIDs := make([]uint64, k+1) - distances := make([]int, k+1) - indices := make([]int, k+1) - - return me.PredictInto(k, x, votes, bucketDistances, bucketIDs, distances, indices) -} - -// PredictInto predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps. -func (me *Model) PredictInto(k int, x uint64, votes bitknn.VoteCounter, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int { - xp := me.Hash.Hash1(x) - k, n := Nearest(me.Data, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) - me.Vote(k, distances, indices, votes) - return n -} diff --git a/lsh/model_bench_test.go b/lsh/model_bench_test.go deleted file mode 100644 index 0b9cef8..0000000 --- a/lsh/model_bench_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package lsh_test - -import ( - "fmt" - "testing" - - "github.com/keilerkonzept/bitknn" - "github.com/keilerkonzept/bitknn/internal/testrandom" - "github.com/keilerkonzept/bitknn/lsh" -) - -func BenchmarkModel(b *testing.B) { - type bench struct { - hashes []lsh.Hash - dataSize []int - k []int - } - hashes := []lsh.Hash{ - lsh.ConstantHash{}, // should be only a bit slower than exact KNN - } - benches := []bench{ - {hashes: hashes, dataSize: []int{100}, k: []int{1, 3, 10}}, - {hashes: hashes, dataSize: []int{1_000_000}, k: []int{3, 10, 100}}, - } - for _, bench := range benches { - for _, dataSize := range bench.dataSize { - data := testrandom.Data(dataSize) - labels := testrandom.Labels(dataSize) - query := testrandom.Query() - for _, k := range bench.k { - for _, hash := range bench.hashes { - b.Run(fmt.Sprintf("Op=Predict_hash=%T_N=%d_k=%d", hash, dataSize, k), func(b *testing.B) { - model := lsh.Fit(data, labels, hash) - model.PreallocateHeap(k) - b.ResetTimer() - for n := 0; n < b.N; n++ { - model.Predict(k, query, bitknn.DiscardVotes) - } - }) - b.Run(fmt.Sprintf("Op=Find_hash=%T_N=%d_k=%d", hash, dataSize, k), func(b *testing.B) { - model := lsh.Fit(data, labels, hash) - model.PreallocateHeap(k) - b.ResetTimer() - for n := 0; n < b.N; n++ { - model.Find(k, query) - } - }) - } - } - } - } -} diff --git a/lsh/model_test.go b/lsh/model_test.go deleted file mode 100644 index 15e6f3d..0000000 --- a/lsh/model_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package lsh_test - -import ( - "math" - "reflect" - "slices" - "testing" - - "github.com/keilerkonzept/bitknn" - "github.com/keilerkonzept/bitknn/lsh" - "pgregory.net/rapid" -) - -func Test_Model_NoHash_IsExact(t *testing.T) { - var h lsh.NoHash - var h0 lsh.ConstantHash - id := func(a uint64) uint64 { return a } - rapid.Check(t, func(t *rapid.T) { - k := rapid.IntRange(1, 1001).Draw(t, "k") - data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, id).Draw(t, "data") - labels := rapid.SliceOfN(rapid.IntRange(0, 3), len(data), len(data)).Draw(t, "labels") - values := rapid.SliceOfN(rapid.Float64(), len(data), len(data)).Draw(t, "values") - queries := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 64, id).Draw(t, "queries") - knnVotes := make([]float64, 4) - annVotes := make([]float64, 4) - type pair struct { - name string - KNN *bitknn.Model - ANN *lsh.Model - ANN0 *lsh.Model - } - pairs := []pair{ - { - "V", - bitknn.Fit(data, labels, bitknn.WithValues(values)), - lsh.Fit(data, labels, h, bitknn.WithValues(values)), - lsh.Fit(data, labels, h0, bitknn.WithValues(values)), - }, - { - "LV", - bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), - lsh.Fit(data, labels, h, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), - lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), - }, - { - "QV", - bitknn.Fit(data, labels, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), - lsh.Fit(data, labels, h, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), - lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), - }, - { - "CV", - bitknn.Fit(data, labels, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), - lsh.Fit(data, labels, h, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), - lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), - }, - { - "0", - bitknn.Fit(data, labels), - lsh.Fit(data, labels, h), - lsh.Fit(data, labels, h0), - }, - { - "L", - bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting()), - lsh.Fit(data, labels, h, bitknn.WithLinearDistanceWeighting()), - lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting()), - }, - { - "Q", - bitknn.Fit(data, labels, bitknn.WithQuadraticDistanceWeighting()), - lsh.Fit(data, labels, h, bitknn.WithQuadraticDistanceWeighting()), - lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting()), - }, - { - "C", - bitknn.Fit(data, labels, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), - lsh.Fit(data, labels, h, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), - lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), - }, - } - const eps = 1e-8 - for _, pair := range pairs { - knn := pair.KNN - ann := pair.ANN - ann0 := pair.ANN0 - knn.PreallocateHeap(k) - ann.PreallocateHeap(k) - for _, q := range queries { - knn.Predict(k, q, bitknn.VoteSlice(knnVotes)) - ann.Predict(k, q, bitknn.VoteSlice(annVotes)) - slices.Sort(knn.HeapDistances[:k]) - slices.Sort(ann.HeapDistances[:k]) - if !reflect.DeepEqual(knn.HeapDistances[:k], ann.HeapDistances[:k]) { - t.Fatal("NoHash ANN should result in the same distances for the nearest neighbors: ", knn.HeapDistances[:k], ann.HeapDistances[:k], knn.HeapIndices[:k], ann.HeapIndices[:k]) - } - - kd, ki := knn.Find(k, q) - ad, ai := ann.Find(k, q) - slices.Sort(kd) - slices.Sort(ad) - if !reflect.DeepEqual(kd, ad) { - t.Fatal("NoHash ANN should result in the same distances for the nearest neighbors: ", kd, ad) - } - slices.Sort(ki) - slices.Sort(ai) - if !reflect.DeepEqual(ki, ai) { - t.Fatal("NoHash ANN should result in the same indices for the nearest neighbors: ", ki, ai) - } - - ann0.PredictAlloc(k, q, bitknn.VoteSlice(annVotes)) - for i, vk := range knnVotes { - va := annVotes[i] - if math.Abs(vk-va) > eps { - t.Fatalf("ANN: %s: %v: %v %v", pair.name, q, knnVotes, annVotes) - } - } - ann0.Predict(k, q, bitknn.VoteSlice(annVotes)) - for i, vk := range knnVotes { - va := annVotes[i] - if math.Abs(vk-va) > eps { - t.Fatalf("ANN0: %s: %v: %v %v", pair.name, q, knnVotes, annVotes) - } - } - } - } - - }) -} diff --git a/lsh/model_wide.go b/lsh/model_wide.go deleted file mode 100644 index 4092fb2..0000000 --- a/lsh/model_wide.go +++ /dev/null @@ -1,108 +0,0 @@ -package lsh - -import ( - "cmp" - "slices" - - "github.com/keilerkonzept/bitknn" - "github.com/keilerkonzept/bitknn/internal/slice" -) - -// WideModel is an LSH k-NN model, mapping points to buckets based on a locality-sensitive hash function. -type WideModel struct { - *bitknn.WideModel - Hash HashWide // LSH function mapping points to bucket IDs. - - BucketIDs []uint64 // Bucket IDs. - Buckets map[uint64]slice.IndexRange // Bucket contents for each hash (offset+length in Data). - - HeapBucketDistances []int - HeapBucketIDs []uint64 -} - -// PreallocateHeap allocates memory for the nearest neighbor heap. -func (me *WideModel) PreallocateHeap(k int) { - me.HeapBucketDistances = slice.OrAlloc(me.HeapBucketDistances, k+1) - me.HeapBucketIDs = slice.OrAlloc(me.HeapBucketIDs, k+1) - me.WideModel.PreallocateHeap(k) -} - -// Fit creates and fits an LSH k-NN model using the provided data, labels, and hash function. -// It groups points into buckets using the LSH hash function. -func FitWide(data [][]uint64, labels []int, hash HashWide, opts ...bitknn.Option) *WideModel { - knnModel := bitknn.FitWide(data, labels, opts...) - values := knnModel.Narrow.Values - buckets := make([]uint64, len(data)) - hash.HashWide(data, buckets) - - indices := make([]int, len(data)) - for i := range indices { - indices[i] = i - } - - // Sort data by bucket id so that each bucket's data slice is contiguous. slices.SortStableFunc(indices, func(a, b int) int { - slices.SortStableFunc(indices, func(a, b int) int { - return cmp.Compare(buckets[a], buckets[b]) - }) - - // Reorder all data-indexed slices to match the bucket sort order. - slice.ReorderInPlace(func(i, j int) { - buckets[i], buckets[j] = buckets[j], buckets[i] - data[i], data[j] = data[j], data[i] - labels[i], labels[j] = labels[j], labels[i] - if values != nil { - values[i], values[j] = values[j], values[i] - } - }, indices) - - bucketData, bucketIDs := slice.GroupSorted(data, buckets) - - return &WideModel{ - WideModel: knnModel, - Hash: hash, - BucketIDs: bucketIDs, - Buckets: bucketData, - } -} - -// Finds the nearest neighbors of the given point. -// Writes their distances and indices in the dataset into the pre-allocated slices. -// Returns the distance and index slices, truncated to the actual number of neighbors found. -func (me *WideModel) Find(k int, x []uint64) ([]int, []int) { - me.PreallocateHeap(k) - return me.FindInto(k, x, me.HeapBucketDistances, me.HeapBucketIDs, me.Narrow.HeapDistances, me.Narrow.HeapIndices) -} - -// Finds the nearest neighbors of the given point. -// Writes their distances and indices in the dataset into the provided slices. -// The slices should be pre-allocated to length k+1. -// Returns the distance and index slices, truncated to the actual number of neighbors found. -func (me *WideModel) FindInto(k int, x []uint64, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) ([]int, []int) { - xp := me.Hash.Hash1Wide(x) - k, _ = NearestWide(me.WideData, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) - return distances[:k], indices[:k] -} - -// Predict predicts the label for a single input using the LSH model. -func (me *WideModel) Predict(k int, x []uint64, votes bitknn.VoteCounter) int { - me.PreallocateHeap(k) - return me.PredictInto(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.Narrow.HeapDistances, me.Narrow.HeapIndices) -} - -// Predicts the label of a single input point. Each call allocates three new slices of length [k]+1 for the neighbor heaps. -func (me *WideModel) PredictAlloc(k int, x []uint64, votes bitknn.VoteCounter) int { - bucketDistances := make([]int, k+1) - bucketIDs := make([]uint64, k+1) - distances := make([]int, k+1) - indices := make([]int, k+1) - - return me.PredictInto(k, x, votes, bucketDistances, bucketIDs, distances, indices) -} - -// PredictInto predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps. -func (me *WideModel) PredictInto(k int, x []uint64, votes bitknn.VoteCounter, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int { - xp := me.Hash.Hash1Wide(x) - k0, _ := NearestWide(me.WideData, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) - me.WideModel.Narrow.Vote(k0, distances, indices, votes) - return k0 -} diff --git a/lsh/model_wide_test.go b/lsh/model_wide_test.go deleted file mode 100644 index c98ae4e..0000000 --- a/lsh/model_wide_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package lsh_test - -import ( - "math" - "reflect" - "slices" - "testing" - - "github.com/keilerkonzept/bitknn" - "github.com/keilerkonzept/bitknn/lsh" - "pgregory.net/rapid" -) - -func Test_WideModel_64bit_Equal_To_Narrow(t *testing.T) { - id := func(a uint64) uint64 { return a } - rapid.Check(t, func(t *rapid.T) { - k := rapid.IntRange(1, 1001).Draw(t, "k") - data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, id).Draw(t, "data") - dataWide := make([][]uint64, len(data)) - for i := range data { - dataWide[i] = []uint64{data[i]} - } - labels := rapid.SliceOfN(rapid.IntRange(0, 3), len(data), len(data)).Draw(t, "labels") - values := rapid.SliceOfN(rapid.Float64(), len(data), len(data)).Draw(t, "values") - queries := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 64, id).Draw(t, "queries") - wideVotes := make([]float64, 4) - narrowVotes := make([]float64, 4) - type pair struct { - name string - Narrow *lsh.Model - Wide *lsh.WideModel - } - pairs := []pair{ - { - "", - lsh.Fit(data, labels, lsh.ConstantHash{}, bitknn.WithValues(values)), - lsh.FitWide(dataWide, labels, lsh.ConstantHash{}, bitknn.WithValues(values)), - }, - } - const eps = 1e-9 - for _, pair := range pairs { - narrow := pair.Narrow - wide := pair.Wide - narrow.PreallocateHeap(k) - wide.PreallocateHeap(k) - for _, q := range queries { - nd, ni := narrow.Find(k, q) - wd, wi := wide.Find(k, []uint64{q}) - if !reflect.DeepEqual(nd, wd) { - t.Fatal("Wide KNN should result in the same distances for the nearest neighbors: ", nd, wd) - } - if !reflect.DeepEqual(ni, wi) { - t.Fatal("Wide ANN should result in the same indices for the nearest neighbors: ", ni, wi) - } - narrow.Predict(k, q, bitknn.VoteSlice(narrowVotes)) - wide.Predict(k, []uint64{q}, bitknn.VoteSlice(wideVotes)) - if !reflect.DeepEqual(narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) { - t.Fatal("Wide KNN should result in the same distances for the nearest neighbors: ", narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) - } - if !reflect.DeepEqual(narrow.HeapIndices[:k], wide.Narrow.HeapIndices[:k]) { - t.Fatal("Wide ANN should result in the same indices for the nearest neighbors: ", narrow.HeapIndices[:k], wide.Narrow.HeapIndices[:k]) - } - for i, vk := range narrowVotes { - va := wideVotes[i] - if math.Abs(vk-va) > eps { - t.Fatalf("%s: %v: %v %v", pair.name, q, narrowVotes, wideVotes) - } - } - wide.PredictAlloc(k, []uint64{q}, bitknn.VoteSlice(wideVotes)) - slices.Sort(narrow.HeapDistances[:k]) - slices.Sort(wide.Narrow.HeapDistances[:k]) - if !reflect.DeepEqual(narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) { - t.Fatal("Wide KNN should result in the same distances for the nearest neighbors: ", narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) - } - slices.Sort(narrow.HeapIndices[:k]) - slices.Sort(wide.Narrow.HeapIndices[:k]) - if !reflect.DeepEqual(narrow.HeapIndices[:k], wide.Narrow.HeapIndices[:k]) { - t.Fatal("Wide ANN should result in the same indices for the nearest neighbors: ", narrow.HeapIndices[:k], wide.Narrow.HeapIndices[:k]) - } - for i, vk := range narrowVotes { - va := wideVotes[i] - if math.Abs(vk-va) > eps { - t.Fatalf("%s: %v: %v %v", pair.name, q, narrowVotes, wideVotes) - } - } - } - } - - }) -} diff --git a/lsh/nearest.go b/lsh/nearest.go deleted file mode 100644 index c60cc6c..0000000 --- a/lsh/nearest.go +++ /dev/null @@ -1,136 +0,0 @@ -package lsh - -import ( - "math/bits" - - "github.com/keilerkonzept/bitknn/internal/heap" - "github.com/keilerkonzept/bitknn/internal/slice" -) - -// Nearest finds the nearest neighbors for a given data point within the nearest buckets by hash Hamming distance. -// -// Parameters: -// - data: The dataset. -// - bucketIDs: All bucket IDs (hashes of dataset points) -// - buckets: A map from bucket IDs to their index ranges in the dataset. -// - k: The number of neighbors to find. -// - xh: The hashed query point. -// - x: The original query point. -// - distances, heapBucketIDs, indices: Pre-allocated slices of length (k+1) for the neighbor heaps. -// -// Returns: -// - The number of nearest neighbors found. -// - The total number of data points examined. -func Nearest(data []uint64, bucketIDs []uint64, buckets map[uint64]slice.IndexRange, k int, xh uint64, x uint64, bucketDistances []int, heapBucketIDs []uint64, distances []int, indices []int) (int, int) { - dataHeap := heap.MakeMax[int](distances, indices) - exactBucket := buckets[xh] - numExamined := exactBucket.Length - nearestInBucket(data, exactBucket, k, x, &distances[0], &dataHeap) - - // stop early for 1-NN - if k == 1 && dataHeap.Len() == k { - return k, exactBucket.Length - } - - // otherwise, determine the k nearest buckets and find the k nearest neighbors in these buckets. - bucketHeap := heap.MakeMax[uint64](bucketDistances, heapBucketIDs) - nearestBuckets(bucketIDs, k, xh, &bucketDistances[0], &bucketHeap) - n := nearestInBuckets(data, heapBucketIDs[:bucketHeap.Len()], buckets, k, x, xh, &distances[0], &dataHeap) - - return dataHeap.Len(), numExamined + n -} - -func nearestInBucket(data []uint64, b slice.IndexRange, k int, x uint64, distance0 *int, heap *heap.Max[int]) { - if b.Length == 0 { - return - } - - end := b.Offset + b.Length - end0 := b.Offset + min(b.Length, k) - - for i := b.Offset; i < end0; i++ { - dist := bits.OnesCount64(x ^ data[i]) - heap.Push(dist, i) - } - - if b.Length < k { - return - } - - maxDist := *distance0 - for i := b.Offset + k; i < end; i++ { - dist := bits.OnesCount64(x ^ data[i]) - if dist >= maxDist { - continue - } - heap.PushPop(dist, i) - maxDist = *distance0 - } -} - -// nearestInBuckets finds the nearest neighbors within specific buckets. -// Returns the number of points examined. -func nearestInBuckets(data []uint64, inBuckets []uint64, buckets map[uint64]slice.IndexRange, k int, x, xh uint64, distance0 *int, heap *heap.Max[int]) int { - var maxDist int - j := heap.Len() - if j > 0 { - maxDist = *distance0 - } - t := 0 - for _, bid := range inBuckets { - if bid == xh { // skip exact bucket - continue - } - b := buckets[bid] - end := b.Offset + b.Length - t += b.Length - if j >= k { - for i := b.Offset; i < end; i++ { - dist := bits.OnesCount64(x ^ data[i]) - if dist >= maxDist { - continue - } - heap.PushPop(dist, i) - maxDist = *distance0 - } - continue - } - for i := b.Offset; i < end; i++ { - dist := bits.OnesCount64(x ^ data[i]) - if j < k { - heap.Push(dist, i) - maxDist = *distance0 - j++ - continue - } - if dist >= maxDist { - continue - } - heap.PushPop(dist, i) - maxDist = *distance0 - } - } - return t -} - -// nearestBuckets finds the buckets with IDs that are (Hamming-)nearest to a query point hash. -func nearestBuckets(bucketIDs []uint64, k int, x uint64, distance0 *int, heap *heap.Max[uint64]) { - k0 := min(k, len(bucketIDs)) - var maxDist int - for _, b := range bucketIDs[:k0] { - dist := bits.OnesCount64(x ^ b) - heap.Push(dist, b) - } - if k0 < k { - return - } - maxDist = *distance0 - for _, b := range bucketIDs[k0:] { - dist := bits.OnesCount64(x ^ b) - if dist >= maxDist { - continue - } - heap.PushPop(dist, b) - maxDist = *distance0 - } -} diff --git a/lsh/nearest_test.go b/lsh/nearest_test.go deleted file mode 100644 index 95fa33a..0000000 --- a/lsh/nearest_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package lsh_test - -import ( - "testing" - - "github.com/keilerkonzept/bitknn/internal/slice" - "github.com/keilerkonzept/bitknn/lsh" -) - -func TestNearest(t *testing.T) { - t.Run("Nearest_=k_buckets", func(t *testing.T) { - data := []uint64{1, 2, 3, 4, 5, 6, 7, 8} - bucketIDs := []uint64{0, 1, 2, 3} - buckets := map[uint64]slice.IndexRange{ - 0: {Offset: 0, Length: 2}, - 1: {Offset: 2, Length: 2}, - 2: {Offset: 4, Length: 2}, - 3: {Offset: 6, Length: 2}, - } - k := 3 - distances := make([]int, k+1) - bucketDistances := make([]int, k+1) - heapBucketIDs := make([]uint64, k+1) - indices := make([]int, k+1) - - { - x := uint64(5) - xh := uint64(1) - k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, bucketDistances, heapBucketIDs, distances, indices) - - if 3 != k { - t.Fatal(k) - } - if 6 != n { - t.Fatal(n) - } - } - { - x := uint64(4) - xh := uint64(2) - k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, bucketDistances, heapBucketIDs, distances, indices) - - if 3 != k { - t.Fatal(k) - } - if 6 != n { - t.Fatal(n) - } - } - }) -} diff --git a/lsh/nearest_wide.go b/lsh/nearest_wide.go deleted file mode 100644 index bf2ff20..0000000 --- a/lsh/nearest_wide.go +++ /dev/null @@ -1,115 +0,0 @@ -package lsh - -import ( - "math/bits" - - "github.com/keilerkonzept/bitknn/internal/heap" - "github.com/keilerkonzept/bitknn/internal/slice" -) - -// [Nearest], but for wide data. -func NearestWide(data [][]uint64, bucketIDs []uint64, buckets map[uint64]slice.IndexRange, k int, xh uint64, x []uint64, bucketDistances []int, heapBucketIDs []uint64, distances []int, indices []int) (int, int) { - dataHeap := heap.MakeMax[int](distances, indices) - exactBucket := buckets[xh] - numExamined := exactBucket.Length - nearestWideInBucket(data, exactBucket, k, x, &distances[0], &dataHeap) - - // stop early for 1-NN - if k == 1 && dataHeap.Len() == k { - return k, exactBucket.Length - } - - // otherwise, determine the k nearest buckets and find the k nearest neighbors in these buckets. - bucketHeap := heap.MakeMax[uint64](bucketDistances, heapBucketIDs) - nearestBuckets(bucketIDs, k, xh, &bucketDistances[0], &bucketHeap) - n := nearestWideInBuckets(data, heapBucketIDs[:bucketHeap.Len()], buckets, k, x, xh, &distances[0], &dataHeap) - - return dataHeap.Len(), numExamined + n -} - -func nearestWideInBucket(data [][]uint64, b slice.IndexRange, k int, x []uint64, distance0 *int, heap *heap.Max[int]) { - if b.Length == 0 { - return - } - - end := b.Offset + b.Length - end0 := b.Offset + min(b.Length, k) - - for i := b.Offset; i < end0; i++ { - d := data[i] - dist := 0 - for j, d := range d { - dist += bits.OnesCount64(x[j] ^ d) - } - heap.Push(dist, i) - } - - if b.Length < k { - return - } - - maxDist := *distance0 - for i := b.Offset + k; i < end; i++ { - d := data[i] - dist := 0 - for j, d := range d { - dist += bits.OnesCount64(x[j] ^ d) - } - if dist >= maxDist { - continue - } - heap.PushPop(dist, i) - maxDist = *distance0 - } -} - -func nearestWideInBuckets(data [][]uint64, inBuckets []uint64, buckets map[uint64]slice.IndexRange, k int, x []uint64, xh uint64, distance0 *int, heap *heap.Max[int]) int { - var maxDist int - j := heap.Len() - if j > 0 { - maxDist = *distance0 - } - t := 0 - for _, bid := range inBuckets { - if bid == xh { // skip exact bucket - continue - } - b := buckets[bid] - end := b.Offset + b.Length - t += b.Length - if j >= k { - for i := b.Offset; i < end; i++ { - d := data[i] - dist := 0 - for j1, d := range d { - dist += bits.OnesCount64(x[j1] ^ d) - } - if dist >= maxDist { - continue - } - heap.PushPop(dist, i) - maxDist = *distance0 - } - continue - } - for i := b.Offset; i < end; i++ { - d := data[i] - dist := 0 - for j1, d := range d { - dist += bits.OnesCount64(x[j1] ^ d) - } - if j < k { - heap.Push(dist, i) - maxDist = *distance0 - j++ - continue - } - if dist >= maxDist { - continue - } - heap.PushPop(dist, i) - maxDist = *distance0 - } - } - return t -} diff --git a/lsh/nearest_wide_test.go b/lsh/nearest_wide_test.go deleted file mode 100644 index 8d7b6d6..0000000 --- a/lsh/nearest_wide_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package lsh_test - -import ( - "reflect" - "testing" - - "github.com/keilerkonzept/bitknn/lsh" - "pgregory.net/rapid" -) - -func Test_Nearest_64bit_Equal_To_Narrow(t *testing.T) { - id := func(a uint64) uint64 { return a } - rapid.Check(t, func(t *rapid.T) { - k := rapid.IntRange(3, 2001).Draw(t, "k") - data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, id).Draw(t, "data") - labels := rapid.SliceOfN(rapid.IntRange(0, 3), len(data), len(data)).Draw(t, "labels") - dataWide := make([][]uint64, len(data)) - for i := range data { - dataWide[i] = []uint64{data[i]} - } - - type hash struct { - narrow lsh.Hash - wide lsh.HashWide - } - hashes := []hash{ - { - narrow: lsh.ConstantHash{}, - wide: lsh.ConstantHash{}, - }, - { - narrow: lsh.BitSample(0xF0F0F0F0F0F0F0F0), - wide: &lsh.HashWide1{ - Single: lsh.BitSample(0xF0F0F0F0F0F0F0F0), - }, - }, - } - for _, h := range hashes { - m := lsh.Fit(data, labels, h.narrow) - m.PreallocateHeap(k) - mw := lsh.FitWide(dataWide, labels, h.wide) - mw.PreallocateHeap(k) - - x := rapid.Uint64().Draw(t, "query") - xw := []uint64{x} - xh := m.Hash.Hash1(x) - xwh := mw.Hash.Hash1Wide(xw) - - lsh.Nearest(data, m.BucketIDs, m.Buckets, k, xh, x, m.HeapBucketDistances, m.HeapBucketIDs, m.HeapDistances, m.HeapIndices) - lsh.NearestWide(dataWide, mw.BucketIDs, mw.Buckets, k, xwh, xw, mw.HeapBucketDistances, mw.HeapBucketIDs, mw.Narrow.HeapDistances, mw.Narrow.HeapIndices) - - if !reflect.DeepEqual(m.HeapIndices, mw.Narrow.HeapIndices) { - t.Fatal(m.HeapIndices, mw.Narrow.HeapIndices) - } - if !reflect.DeepEqual(m.HeapDistances, mw.Narrow.HeapDistances) { - t.Fatal(m.HeapDistances, mw.Narrow.HeapDistances) - } - } - }) -} diff --git a/model.go b/model.go index d6dbec2..4d7df24 100644 --- a/model.go +++ b/model.go @@ -1,5 +1,4 @@ -// Package bitknn provides a fast k-nearest neighbors (k-NN) implementation for binary feature vectors. -// The sub-package [github.com/keilerkonzept/bitknn/lsh] implements an approximate k-nearest neighbors (ANN) model using locality-sensitive hashing. +// Package bitknn provides a fast exact k-nearest neighbors (k-NN) implementation for binary feature vectors. package bitknn import ( diff --git a/model_wide_test.go b/model_wide_test.go index 7530da8..948ef3d 100644 --- a/model_wide_test.go +++ b/model_wide_test.go @@ -125,3 +125,27 @@ func TestModel_FindV_Equiv_Find(t *testing.T) { } }) } + +func TestModel_PredictV_Equiv_Predict(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + k := rapid.IntRange(0, 1000).Draw(t, "k") + dims := rapid.IntRange(1, 10_000).Draw(t, "dims") + data := rapid.SliceOf(rapid.SliceOfN(rapid.Uint64(), dims, dims)).Draw(t, "data") + batchSizes := []int{0, len(data), len(data) - 1, len(data) - 2, 2048, 100_000} + q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q") + labels := rapid.SliceOfN(rapid.Int(), len(data), len(data)).Draw(t, "labels") + for _, batchSize := range batchSizes { + batchSize = max(k, batchSize) + m1 := bitknn.FitWide(data, labels) + m2 := bitknn.FitWide(data, labels) + batch := make([]uint32, batchSize) + vv := make(bitknn.VoteMap) + v := make(bitknn.VoteMap) + m1.PredictV(k, q, batch, vv) + m2.Predict(k, q, v) + if !reflect.DeepEqual(vv, v) { + t.Fatal(vv, v) + } + } + }) +}