Skip to content

Commit

Permalink
rework
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreben committed Oct 8, 2024
1 parent 6cb2e59 commit 9b8c8d9
Show file tree
Hide file tree
Showing 12 changed files with 697 additions and 437 deletions.
55 changes: 38 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ If you need to classify **binary feature vectors that fit into `uint64`s**, this

You can optionally weigh class votes by distance, or specify different vote values per data point.


**Contents**
- [Usage](#usage)
- [Options](#options)
Expand All @@ -38,21 +37,23 @@ func main() {
// class labels
labels := []int{0, 1, 1}

model := bitknn.Fit(data, labels, 2, bitknn.WithLinearDecay())
model := bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting())

// one vote counter per class
votes := make([]float64, 2)
model.Predict1(0b101011, votes)

k := 2
model.Predict1(k, 0b101011, votes)

fmt.Println("Votes:", votes)
}
```

## Options

- `WithLinearDecay()`: Apply linear distance weighting (`1 / (1 + dist)`).
- `WithQuadraticDecay()`: Apply quadratic distance weighting (`1 / (1 + dist^2)`).
- `WithDistanceWeightFunc(f func(dist int) float64)`: Use a custom distance weighting function.
- `WithLinearDistanceWeighting()`: Apply linear distance weighting (`1 / (1 + dist)`).
- `WithQuadraticDistanceWeighting()`: Apply quadratic distance weighting (`1 / (1 + dist^2)`).
- `WithDistanceWeightingFunc(f func(dist int) float64)`: Use a custom distance weighting function.
- `WithValues(values []float64)`: Assign vote values for each data point.

## Benchmarks
Expand All @@ -64,17 +65,37 @@ pkg: github.com/keilerkonzept/bitknn
cpu: Apple M1 Pro
```

| op | N | k | iters | ns/op | B/op | allocs/op |
|------------|---------|-----|---------|--------------|------|-----------|
| `Predict1` | 100 | 3 | 8308794 | 121.4 ns/op | 0 | 0 |
| `Predict1` | 100 | 10 | 4707778 | 269.7 ns/op | 0 | 0 |
| `Predict1` | 100 | 100 | 2255380 | 549.2 ns/op | 0 | 0 |
| `Predict1` | 1000 | 3 | 1693364 | 659.3 ns/op | 0 | 0 |
| `Predict1` | 1000 | 10 | 1220426 | 1005 ns/op | 0 | 0 |
| `Predict1` | 1000 | 100 | 345151 | 3560 ns/op | 0 | 0 |
| `Predict1` | 1000000 | 3 | 2076 | 566647 ns/op | 0 | 0 |
| `Predict1` | 1000000 | 10 | 2112 | 568787 ns/op | 0 | 0 |
| `Predict1` | 1000000 | 100 | 2066 | 587827 ns/op | 0 | 0 |
| Op | N | k | Distance weighting | Vote values | sec / op | B/op | allocs/op |
|------------|---------|-----|--------------------|-------------|--------------|------|-----------|
| `Predict1` | 100 | 3 | | | 138.7n ± 22% | 0 | 0 |
| `Predict1` | 100 | 3 | | ☑️ | 127.8n ± 11% | 0 | 0 |
| `Predict1` | 100 | 3 | linear | | 137.0n ± 11% | 0 | 0 |
| `Predict1` | 100 | 3 | linear | ☑️ | 136.7n ± 10% | 0 | 0 |
| `Predict1` | 100 | 3 | quadratic | | 137.2n ± 7% | 0 | 0 |
| `Predict1` | 100 | 3 | quadratic | ☑️ | 130.4n ± 4% | 0 | 0 |
| `Predict1` | 100 | 3 | custom | | 140.6n ± 7% | 0 | 0 |
| `Predict1` | 100 | 3 | custom | ☑️ | 134.9n ± 13% | 0 | 0 |
| `Predict1` | 100 | 10 | | | 307.4n ± 11% | 0 | 0 |
| `Predict1` | 100 | 10 | | ☑️ | 297.8n ± 15% | 0 | 0 |
| `Predict1` | 100 | 10 | linear | | 288.2n ± 18% | 0 | 0 |
| `Predict1` | 100 | 10 | linear | ☑️ | 302.9n ± 14% | 0 | 0 |
| `Predict1` | 100 | 10 | quadratic | | 283.7n ± 15% | 0 | 0 |
| `Predict1` | 100 | 10 | quadratic | ☑️ | 290.0n ± 13% | 0 | 0 |
| `Predict1` | 100 | 10 | custom | | 313.1n ± 17% | 0 | 0 |
| `Predict1` | 100 | 10 | custom | ☑️ | 316.2n ± 11% | 0 | 0 |
| `Predict1` | 100 | 100 | | ☑️ | 545.4n ± 4% | 0 | 0 |
| `Predict1` | 100 | 100 | linear | | 542.4n ± 4% | 0 | 0 |
| `Predict1` | 100 | 100 | linear | ☑️ | 577.5n ± 4% | 0 | 0 |
| `Predict1` | 100 | 100 | quadratic | | 553.1n ± 3% | 0 | 0 |
| `Predict1` | 100 | 100 | quadratic | ☑️ | 582.4n ± 6% | 0 | 0 |
| `Predict1` | 100 | 100 | custom | | 683.8n ± 4% | 0 | 0 |
| `Predict1` | 100 | 100 | custom | ☑️ | 748.5n ± 2% | 0 | 0 |
| `Predict1` | 1000 | 3 | | | 669.5n ± 6% | 0 | 0 |
| `Predict1` | 1000 | 10 | | | 930.3n ± 7% | 0 | 0 |
| `Predict1` | 1000 | 100 | | | 3.762µ ± 5% | 0 | 0 |
| `Predict1` | 1000000 | 3 | | | 532.1µ ± 1% | 0 | 0 |
| `Predict1` | 1000000 | 10 | | | 534.5µ ± 1% | 0 | 0 |
| `Predict1` | 1000000 | 100 | | | 551.7µ ± 1% | 0 | 0 |

## License

Expand Down
39 changes: 19 additions & 20 deletions heap.go → internal/heap/heap.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,40 @@
package bitknn
package heap

import "unsafe"

// neighborHeap is a max-heap that stores data point's distances together with their indices in the training set.
// The heap is used to keep track of nearest neighbors.
type neighborHeap struct {
// Max is a max-heap used to keep track of nearest neighbors.
type Max[T int | uint64] struct {
distances []int
lastDistance *int
indices []int
lastIndex *int
values []T
lastValue *T
len int
}

const unsafeSizeofInt = unsafe.Sizeof(int(0))

func makeNeighborHeap(distances, indices []int) neighborHeap {
return neighborHeap{
func MakeMax[T int | uint64](distances []int, value []T) Max[T] {
return Max[T]{
distances: distances,
lastDistance: (*int)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(distances)), unsafeSizeofInt*uintptr(len(distances)-1))),
indices: indices,
lastIndex: (*int)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(indices)), unsafeSizeofInt*uintptr(len(indices)-1))),
values: value,
lastValue: (*T)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(value)), unsafe.Sizeof(T(0))*uintptr(len(value)-1))),
}
}

func (me *neighborHeap) swap(i, j int) {
func (me *Max[T]) swap(i, j int) {
me.distances[i], me.distances[j] = me.distances[j], me.distances[i]
me.indices[i], me.indices[j] = me.indices[j], me.indices[i]
me.values[i], me.values[j] = me.values[j], me.values[i]
}

func (me *neighborHeap) less(i, j int) bool {
func (me *Max[T]) less(i, j int) bool {
return me.distances[i] > me.distances[j]
}

func (me *neighborHeap) pushpop(value int, index int) {
func (me *Max[T]) PushPop(dist int, value T) {
n := me.len
*me.lastDistance = value
*me.lastIndex = index
*me.lastDistance = dist
*me.lastValue = value
me.up(n)
me.swap(0, n)

Expand All @@ -58,15 +57,15 @@ func (me *neighborHeap) pushpop(value int, index int) {
}
}

func (me *neighborHeap) push(value int, index int) {
func (me *Max[T]) Push(dist int, value T) {
n := me.len
me.distances[n] = value
me.indices[n] = index
me.distances[n] = dist
me.values[n] = value
me.len = n + 1
me.up(n)
}

func (me *neighborHeap) up(i int) {
func (me *Max[T]) up(i int) {
for {
p := (i - 1) / 2 // Parent index
if p == i || !me.less(i, p) { // If parent is larger or i is root, stop
Expand Down
54 changes: 27 additions & 27 deletions heap_test.go → internal/heap/heap_test.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,43 @@
package bitknn
package heap

import (
"testing"
)

func TestMakeNeighborHeap(t *testing.T) {
distances := []int{10, 20, 30}
indices := []int{1, 2, 3}
heap := makeNeighborHeap(distances, indices)
values := []int{1, 2, 3}
heap := MakeMax(distances, values)

// Check if lastDistance and lastIndex are pointing to the correct elements
// Check if lastDistance and lastValue are pointing to the correct elements
if *heap.lastDistance != 30 {
t.Errorf("Expected lastDistance to be 30, got %d", *heap.lastDistance)
}
if *heap.lastIndex != 3 {
t.Errorf("Expected lastIndex to be 3, got %d", *heap.lastIndex)
if *heap.lastValue != 3 {
t.Errorf("Expected lastValue to be 3, got %d", *heap.lastValue)
}
}

func TestNeighborHeapSwap(t *testing.T) {
heap := neighborHeap{
heap := Max[int]{
distances: []int{10, 20, 30},
indices: []int{1, 2, 3},
values: []int{1, 2, 3},
}

heap.swap(0, 2)

if heap.distances[0] != 30 || heap.distances[2] != 10 {
t.Errorf("Swap failed on distances, got %v", heap.distances)
}
if heap.indices[0] != 3 || heap.indices[2] != 1 {
t.Errorf("Swap failed on indices, got %v", heap.indices)
if heap.values[0] != 3 || heap.values[2] != 1 {
t.Errorf("Swap failed on values, got %v", heap.values)
}
}

func TestNeighborHeapLess(t *testing.T) {
heap := neighborHeap{
heap := Max[int]{
distances: []int{10, 20, 30},
indices: []int{1, 2, 3},
values: []int{1, 2, 3},
}

if !heap.less(2, 0) {
Expand All @@ -51,11 +51,11 @@ func TestNeighborHeapLess(t *testing.T) {

func TestNeighborHeapPushPop(t *testing.T) {
distances := []int{30, 20, 10, 0}
indices := []int{1, 2, 3, 0}
heap := makeNeighborHeap(distances, indices)
values := []int{1, 2, 3, 0}
heap := MakeMax(distances, values)
heap.len = 3

heap.pushpop(25, 4)
heap.PushPop(25, 4)

// Check if heap is reordered correctly
expectedDistances := []int{25, 20, 10,
Expand All @@ -68,30 +68,30 @@ func TestNeighborHeapPushPop(t *testing.T) {
if heap.distances[i] != expectedDistances[i] {
t.Errorf("Expected distance at %d to be %d, got %d", i, expectedDistances[i], heap.distances[i])
}
if heap.indices[i] != expectedIndices[i] {
t.Errorf("Expected index at %d to be %d, got %d", i, expectedIndices[i], heap.indices[i])
if heap.values[i] != expectedIndices[i] {
t.Errorf("Expected value at %d to be %d, got %d", i, expectedIndices[i], heap.values[i])
}
}
}

func TestNeighborHeapPush(t *testing.T) {
heap := makeNeighborHeap(
heap := MakeMax(
make([]int, 4),
make([]int, 4),
)

heap.push(10, 3)
heap.push(15, 5)
heap.push(25, 6)
heap.pushpop(9, 3)
heap.pushpop(7, 2)
heap.pushpop(8, 1)
heap.pushpop(6, 0)
heap.Push(10, 3)
heap.Push(15, 5)
heap.Push(25, 6)
heap.PushPop(9, 3)
heap.PushPop(7, 2)
heap.PushPop(8, 1)
heap.PushPop(6, 0)

if heap.distances[0] != 8 {
t.Errorf("Expected root distance to be 25, got %d", heap.distances[0])
}
if heap.indices[0] != 1 {
t.Errorf("Expected root index to be 6, got %d", heap.indices[0])
if heap.values[0] != 1 {
t.Errorf("Expected root value to be 6, got %d", heap.values[0])
}
}
Loading

0 comments on commit 9b8c8d9

Please sign in to comment.