Skip to content

Commit 1923c00

Browse files
feat(bayes): add CategoricalNB
1 parent 7a9c67c commit 1923c00

8 files changed

Lines changed: 212 additions & 0 deletions

File tree

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
---
2+
title: CategoricalNB
3+
description: API reference for CategoricalNB
4+
---
5+
6+
# Bayes.CategoricalNB
7+
8+
```ts
9+
interface CategoricalNBProps {
10+
alpha?: number;
11+
forceAlpha?: boolean;
12+
fitPrior?: boolean;
13+
classPrior?: number[] | null;
14+
minCategories?: number | number[] | null;
15+
}
16+
constructor(props: CategoricalNBProps = {})
17+
```
18+
19+
### Example
20+
```ts
21+
const clf = new CategoricalNB();
22+
clf.fit(trainX, trainY);
23+
const result = clf.predict(testX);
24+
```

docs/content/docs/apis/bayes/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ description: API reference for Naive Bayes classifiers
44
---
55

66
- [BernoulliNB](bernoulliNB.md)
7+
- [CategoricalNB](categoricalNB.md)

scripts/gen_all.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def run(script):
1818
'gen_optics.py',
1919
'gen_logistic_regression.py',
2020
'gen_bernoulli_nb.py',
21+
'gen_categorical_nb.py',
2122
'gen_svc.py',
2223
'gen_pca.py',
2324
'gen_spectral_embedding.py',

scripts/gen_categorical_nb.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from sklearn.naive_bayes import CategoricalNB
2+
import numpy as np
3+
import json, os
4+
5+
rng = np.random.RandomState(0)
6+
X = rng.randint(3, size=(50, 5))
7+
y = rng.randint(2, size=50)
8+
X_test = rng.randint(3, size=(10, 5))
9+
clf = CategoricalNB()
10+
clf.fit(X, y)
11+
pred = clf.predict(X_test)
12+
13+
os.makedirs('test_data', exist_ok=True)
14+
with open('test_data/categorical_nb.json', 'w') as f:
15+
json.dump({
16+
'trainX': X.tolist(),
17+
'trainY': y.tolist(),
18+
'testX': X_test.tolist(),
19+
'expected': pred.tolist()
20+
}, f)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import { CategoricalNB } from '../categoricalNB';
2+
import fs from 'fs';
3+
import path from 'path';
4+
5+
test('compare with sklearn', () => {
6+
const p = path.join(__dirname, '../../../test_data/categorical_nb.json');
7+
const data = JSON.parse(fs.readFileSync(p, 'utf8'));
8+
const clf = new CategoricalNB();
9+
clf.fit(data.trainX, data.trainY);
10+
const pred = clf.predict(data.testX);
11+
expect(pred).toEqual(data.expected);
12+
});
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { CategoricalNB } from '../categoricalNB';
2+
3+
test('init', () => {
4+
const nb = new CategoricalNB();
5+
expect(nb).toBeDefined();
6+
});
7+
8+
test('simple classification', () => {
9+
const X = [
10+
[0, 0],
11+
[1, 0],
12+
[0, 1],
13+
[1, 1]
14+
];
15+
const Y = [0, 0, 1, 1];
16+
const nb = new CategoricalNB();
17+
nb.fit(X, Y);
18+
const pred = nb.predict(X);
19+
expect(pred).toEqual(Y);
20+
});

src/bayes/categoricalNB.ts

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import { ClassifierBase } from '../base';
2+
3+
export interface CategoricalNBProps {
4+
alpha?: number;
5+
forceAlpha?: boolean;
6+
fitPrior?: boolean;
7+
classPrior?: number[] | null;
8+
minCategories?: number | number[] | null;
9+
}
10+
11+
export class CategoricalNB extends ClassifierBase {
12+
private alpha: number;
13+
private forceAlpha: boolean;
14+
private fitPrior: boolean;
15+
private classPrior: number[] | null;
16+
private minCategories: number | number[] | null;
17+
18+
private classes: number[] = [];
19+
private classCount: number[] = [];
20+
private categoryCount: number[][][] = [];
21+
private classLogPrior: number[] = [];
22+
private featureLogProb: number[][][] = [];
23+
private nCategories: number[] = [];
24+
25+
constructor(props: CategoricalNBProps = {}) {
26+
super();
27+
const {
28+
alpha = 1.0,
29+
forceAlpha = true,
30+
fitPrior = true,
31+
classPrior = null,
32+
minCategories = null
33+
} = props;
34+
this.alpha = forceAlpha ? alpha : Math.max(alpha, 1e-10);
35+
this.forceAlpha = forceAlpha;
36+
this.fitPrior = fitPrior;
37+
this.classPrior = classPrior;
38+
this.minCategories = minCategories;
39+
}
40+
41+
private initCounters(X: number[][]): void {
42+
const nFeatures = X[0].length;
43+
this.nCategories = new Array(nFeatures).fill(0);
44+
for (let j = 0; j < nFeatures; j++) {
45+
let maxVal = 0;
46+
for (let i = 0; i < X.length; i++) {
47+
if (X[i][j] > maxVal) maxVal = X[i][j];
48+
}
49+
let minCat = 0;
50+
if (this.minCategories === null) {
51+
minCat = 0;
52+
} else if (typeof this.minCategories === 'number') {
53+
minCat = this.minCategories;
54+
} else {
55+
minCat = this.minCategories[j];
56+
}
57+
this.nCategories[j] = Math.max(maxVal + 1, minCat);
58+
}
59+
const nClasses = this.classes.length;
60+
this.categoryCount = [];
61+
this.featureLogProb = [];
62+
for (let j = 0; j < nFeatures; j++) {
63+
const cats = this.nCategories[j];
64+
const mat = Array.from({ length: nClasses }, () => new Array(cats).fill(0));
65+
this.categoryCount.push(mat.map(row => row.slice()));
66+
this.featureLogProb.push(mat.map(row => row.slice()));
67+
}
68+
this.classCount = new Array(nClasses).fill(0);
69+
}
70+
71+
public fit(trainX: number[][], trainY: number[]): void {
72+
this.classes = Array.from(new Set(trainY)).sort((a, b) => a - b);
73+
const classIndex = new Map<number, number>();
74+
this.classes.forEach((c, i) => classIndex.set(c, i));
75+
this.initCounters(trainX);
76+
const nFeatures = trainX[0].length;
77+
78+
for (let i = 0; i < trainX.length; i++) {
79+
const ci = classIndex.get(trainY[i])!;
80+
this.classCount[ci] += 1;
81+
for (let j = 0; j < nFeatures; j++) {
82+
const v = trainX[i][j];
83+
if (v >= this.nCategories[j]) continue;
84+
this.categoryCount[j][ci][v] += 1;
85+
}
86+
}
87+
88+
const nClasses = this.classes.length;
89+
if (this.classPrior) {
90+
this.classLogPrior = this.classPrior.map(p => Math.log(p));
91+
} else if (this.fitPrior) {
92+
const totalCount = this.classCount.reduce((a, b) => a + b, 0);
93+
this.classLogPrior = this.classCount.map(c => Math.log((c + this.alpha) / (totalCount + nClasses * this.alpha)));
94+
} else {
95+
this.classLogPrior = new Array(nClasses).fill(Math.log(1 / nClasses));
96+
}
97+
98+
for (let j = 0; j < nFeatures; j++) {
99+
for (let c = 0; c < nClasses; c++) {
100+
for (let k = 0; k < this.nCategories[j]; k++) {
101+
const count = this.categoryCount[j][c][k];
102+
const denom = this.classCount[c] + this.nCategories[j] * this.alpha;
103+
this.featureLogProb[j][c][k] = Math.log((count + this.alpha) / denom);
104+
}
105+
}
106+
}
107+
}
108+
109+
public predict(testX: number[][]): number[] {
110+
const nFeatures = testX[0].length;
111+
const nClasses = this.classes.length;
112+
const preds: number[] = [];
113+
for (const row of testX) {
114+
let bestIdx = 0;
115+
let bestScore = -Infinity;
116+
for (let c = 0; c < nClasses; c++) {
117+
let score = this.classLogPrior[c];
118+
for (let j = 0; j < nFeatures; j++) {
119+
const v = row[j];
120+
if (v < this.nCategories[j]) {
121+
score += this.featureLogProb[j][c][v];
122+
}
123+
}
124+
if (score > bestScore) {
125+
bestScore = score;
126+
bestIdx = c;
127+
}
128+
}
129+
preds.push(this.classes[bestIdx]);
130+
}
131+
return preds;
132+
}
133+
}

src/bayes/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
export { BernoulliNB } from './bernoulliNB';
2+
export { CategoricalNB } from './categoricalNB';

0 commit comments

Comments
 (0)