Skip to content

Commit b7fec90

Browse files
committed
WIDE
1 parent a1f8dec commit b7fec90

26 files changed

Lines changed: 1776 additions & 151 deletions

README.md

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,23 @@ The sub-package [`lsh`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh)
1919

2020
**Contents**
2121
- [Usage](#usage)
22+
- [Basic usage](#basic-usage)
23+
- [LSH](#lsh)
24+
- [Packing wide data](#packing-wide-data)
2225
- [Options](#options)
2326
- [Benchmarks](#benchmarks)
2427
- [License](#license)
2528

2629
## Usage
2730

31+
### Basic usage
32+
2833
```go
2934
package main
3035

3136
import (
3237
"fmt"
38+
3339
"github.com/keilerkonzept/bitknn"
3440
)
3541

@@ -45,12 +51,121 @@ func main() {
4551
votes := make([]float64, 2)
4652

4753
k := 2
48-
model.Predict1(k, 0b101011, votes)
54+
model.Predict1(k, 0b101011, bitknn.VoteSlice(votes))
55+
56+
fmt.Println("Votes:", bitknn.VoteSlice(votes))
57+
58+
// you can also use a map for the votes.
59+
// this is good if you have a very large number of different labels:
60+
votesMap := make(map[int]float64)
61+
model.Predict1(k, 0b101011, bitknn.VoteMap(votesMap))
62+
fmt.Println("Votes for 0:", votesMap[0])
63+
}
64+
```
65+
66+
### LSH
67+
68+
Locality-Sensitive Hashing (LSH) is a type of approximate k-NN search. It's faster at the expense of accuracy.
69+
70+
LSH works by hashing data points such that points that are close in Hamming space tend to land in the same bucket, and computing k-nearest neighbors only on the buckets with the k nearest hashes. In particular, for *k*=1 only one bucket needs to be examined.
71+
72+
```go
73+
package main
74+
75+
import (
76+
"fmt"
77+
"github.com/keilerkonzept/bitknn/lsh"
78+
"github.com/keilerkonzept/bitknn"
79+
)
80+
81+
func main() {
82+
// feature vectors packed into uint64s
83+
data := []uint64{0b101010, 0b111000, 0b000111}
84+
// class labels
85+
labels := []int{0, 1, 1}
86+
87+
// Define a hash function (e.g., MinHash)
88+
hash := lsh.RandomMinHash()
89+
90+
// Fit an LSH model
91+
model := lsh.Fit(data, labels, hash, bitknn.WithLinearDistanceWeighting())
92+
93+
// one vote counter per class
94+
votes := make([]float64, 2)
95+
96+
k := 2
97+
model.Predict1(k, 0b101011, bitknn.VoteSlice(votes))
98+
99+
fmt.Println("Votes:", bitknn.VoteSlice(votes))
100+
101+
// you can also use a map for the votes
102+
votesMap := make(map[int]float64)
103+
model.Predict1(k, 0b101011, bitknn.VoteMap(votesMap))
104+
fmt.Println("Votes for 0:", votesMap[0])
105+
}
106+
```
107+
108+
The model accepts anything that implements the [`lsh.Hash` interface](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Hash) as a hash function. Several functions are pre-defined:
109+
110+
- [MinHash](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHash): An implementation of the [MinHash scheme](https://en.m.wikipedia.org/wiki/MinHash) for bit vectors.
111+
112+
Constructors: [RandomMinHash](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHash), [RandomMinHashR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashR).
113+
- [MinHashes](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHash): Concatenation of several *MinHash*es.
114+
115+
Constructors: [RandomMinHashes](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashes), [RandomMinHashesR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashesR).
116+
- [Blur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Blur): A threshold-based variation on bit sampling.
117+
118+
Constructors: [RandomBlur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBlur), [RandomBlurR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBlurR), [BoxBlur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BoxBlur), .
119+
- [BitSample](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BitSample): A random sampling of bits from the feature vector.
120+
121+
Constructors: [RandomBitSample](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBitSample), [RandomBitSampleR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBitSampleR).
122+
123+
For datasets of vectors longer than 64 bits, the `lsh` package also provides a [`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#FitWide) function, and "wide" versions of the hash functions ([MinHashWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHashWide), [BlurWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BlurWide), [BitSampleWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BitSampleWide))
124+
125+
The [`lsh.Fit`/`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Fit) functions accept the same [Options](#options) as the others.
126+
127+
### Packing wide data
128+
129+
If your vectors are longer than 64 bits, you can still use `bitknn` if you [pack](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack) them into `[]uint64`. The [`pack` package](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack) defines helper functions to pack `string`s and `[]byte`s into `[]uint64`s.
130+
131+
The exact k-NN model in `bitknn` and the approximate-NN model in `lsh` each have a `Wide` variant that accepts slice-valued data points:
132+
133+
```go
134+
package main
135+
136+
import (
137+
"fmt"
138+
139+
"github.com/keilerkonzept/bitknn"
140+
"github.com/keilerkonzept/bitknn/pack"
141+
)
142+
143+
func main() {
144+
// feature vectors packed into uint64s
145+
data := [][]uint64{
146+
pack.String("foo"),
147+
pack.String("bar"),
148+
pack.String("baz"),
149+
}
150+
// class labels
151+
labels := []int{0, 1, 1}
152+
153+
// model := lsh.FitWide(data, labels, lsh.RandomMinHash(), bitknn.WithLinearDistanceWeighting())
154+
model := bitknn.FitWide(data, labels, bitknn.WithLinearDistanceWeighting())
155+
156+
// one vote counter per class
157+
votes := make([]float64, 2)
158+
159+
k := 2
160+
query := pack.String("fob")
161+
model.Predict1(k, query, bitknn.VoteSlice(votes))
49162

50-
fmt.Println("Votes:", votes)
163+
fmt.Println("Votes:", bitknn.VoteSlice(votes))
51164
}
52165
```
53166

167+
The wide model fitting function [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide) accepts the same [Options](#options) as the "narrow" one.
168+
54169
## Options
55170

56171
- `WithLinearDistanceWeighting()`: Apply linear distance weighting (`1 / (1 + dist)`).

internal/testrandom/random.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ func Query() uint64 {
88
return Source.Uint64()
99
}
1010

11+
func WideQuery(dim int) []uint64 {
12+
return Data(dim)
13+
}
14+
1115
func Data(size int) []uint64 {
1216
data := make([]uint64, size)
1317
for i := range data {
@@ -16,6 +20,17 @@ func Data(size int) []uint64 {
1620
return data
1721
}
1822

23+
func WideData(dim int, size int) [][]uint64 {
24+
data := make([][]uint64, size)
25+
for i := range data {
26+
data[i] = make([]uint64, dim)
27+
for j := range dim {
28+
data[i][j] = Source.Uint64()
29+
}
30+
}
31+
return data
32+
}
33+
1934
func Labels(size int) []int {
2035
labels := make([]int, size)
2136
for i := range labels {

internal/testrandom/random_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,31 @@ import (
99
func TestQuery(t *testing.T) {
1010
_ = testrandom.Query()
1111
}
12+
13+
func TestWideQuery(t *testing.T) {
14+
q := testrandom.WideQuery(5)
15+
if len(q) != 5 {
16+
t.Fatal()
17+
}
18+
}
19+
1220
func TestData(t *testing.T) {
1321
data := testrandom.Data(123)
1422
if len(data) != 123 {
1523
t.Fatal()
1624
}
1725
}
26+
27+
func TestWideData(t *testing.T) {
28+
data := testrandom.WideData(3, 123)
29+
if len(data) != 123 {
30+
t.Fatal()
31+
}
32+
if len(data[0]) != 3 {
33+
t.Fatal()
34+
}
35+
}
36+
1837
func TestLabels(t *testing.T) {
1938
data := testrandom.Labels(123)
2039
if len(data) != 123 {

0 commit comments

Comments
 (0)