Skip to content

Commit 531fc70

Browse files
committed
fix: slow types
1 parent 8e31438 commit 531fc70

File tree

3 files changed

+35
-30
lines changed

3 files changed

+35
-30
lines changed

packages/core/src/core/mod.ts

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import type { Tensor } from "./tensor/tensor.ts";
1212
import type { NeuralNetwork } from "./api/network.ts";
1313
import { SGDOptimizer } from "./api/optimizer.ts";
1414
import { PostProcess, type PostProcessor } from "./api/postprocess.ts";
15+
import type { DenseLayerConfig } from "./api/layer.ts";
1516

1617
/**
1718
* Sequential Neural Network
@@ -47,21 +48,22 @@ export class Sequential implements NeuralNetwork {
4748
*/
4849
async predict(
4950
data: Tensor<Rank>,
50-
config?: { postProcess?: PostProcessor; layers?: [number, number] }
51+
config?: { postProcess?: PostProcessor; layers?: [number, number] },
5152
): Promise<Tensor<Rank>> {
52-
if (!config)
53+
if (!config) {
5354
config = {
5455
postProcess: PostProcess("none"),
5556
};
57+
}
5658
if (config.layers) {
5759
if (
5860
config.layers[0] < 0 ||
5961
config.layers[1] > this.config.layers.length
6062
) {
6163
throw new RangeError(
62-
`Execution range should be within (0, ${
63-
this.config.layers.length
64-
}). Received (${(config.layers[0], config.layers[1])})`
64+
`Execution range should be within (0, ${this.config.layers.length}). Received (${(config
65+
.layers[0],
66+
config.layers[1])})`,
6567
);
6668
}
6769
const lastLayer = this.config.layers[config.layers[1] - 1];
@@ -77,9 +79,12 @@ export class Sequential implements NeuralNetwork {
7779
data,
7880
{
7981
postProcess: config.postProcess || PostProcess("none"),
80-
outputShape: lastLayer.config.size,
82+
outputShape: (lastLayer as {
83+
type: LayerType.Dense;
84+
config: DenseLayerConfig;
85+
}).config.size,
8186
},
82-
layerList
87+
layerList,
8388
);
8489
} else if (lastLayer.type === LayerType.Activation) {
8590
const penultimate = this.config.layers[config.layers[1] - 2];
@@ -91,26 +96,29 @@ export class Sequential implements NeuralNetwork {
9196
data,
9297
{
9398
postProcess: config.postProcess || PostProcess("none"),
94-
outputShape: penultimate.config.size,
99+
outputShape: (penultimate as {
100+
type: LayerType.Dense;
101+
config: DenseLayerConfig;
102+
}).config.size,
95103
},
96-
layerList
104+
layerList,
97105
);
98106
} else {
99107
throw new Error(
100-
`The penultimate layer must be a dense layer, or a flatten layer if the last layer is an activation layer. Received ${penultimate.type}.`
108+
`The penultimate layer must be a dense layer, or a flatten layer if the last layer is an activation layer. Received ${penultimate.type}.`,
101109
);
102110
}
103111
} else {
104112
throw new Error(
105-
`The output layer must be a dense layer, activation layer, or a flatten layer. Received ${lastLayer.type}.`
113+
`The output layer must be a dense layer, activation layer, or a flatten layer. Received ${lastLayer.type}.`,
106114
);
107115
}
108116
}
109117
return await this.backend.predict(
110118
data,
111119
config.postProcess
112120
? (config as { postProcess: PostProcessor; layers?: [number, number] })
113-
: { ...config, postProcess: PostProcess("none") }
121+
: { ...config, postProcess: PostProcess("none") },
114122
);
115123
}
116124

packages/utilities/src/utils/misc/argmax.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export function argmax(mat: ArrayLike<number | bigint>) {
1+
export function argmax(mat: ArrayLike<number | bigint>): number {
22
let max = mat[0];
33
let index = 0;
44
for (let i = 0; i < mat.length; i++) {

packages/utilities/src/utils/misc/matrix.ts

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ export type MatrixLike<DT extends DataType> = {
2727
* This is a collection of row vectors.
2828
* A special case of Tensor for 2D data.
2929
*/
30-
export class Matrix<DT extends DataType>
31-
extends Tensor<DT, 2>
32-
implements Sliceable, MatrixLike<DT>
33-
{
30+
export class Matrix<DT extends DataType> extends Tensor<DT, 2>
31+
implements Sliceable, MatrixLike<DT> {
3432
/**
3533
* Create a matrix from a typed array
3634
* @param data Data to move into the matrix.
@@ -42,15 +40,15 @@ export class Matrix<DT extends DataType>
4240
constructor(dType: DT, shape: Shape<2>);
4341
constructor(
4442
data: NDArray<DT>[2] | DType<DT> | DT | TensorLike<DT, 2>,
45-
shape?: Shape<2> | DT
43+
shape?: Shape<2> | DT,
4644
) {
4745
// @ts-ignore This call will work
4846
super(data, shape);
4947
}
50-
get head() {
48+
get head(): Matrix<DT> {
5149
return this.slice(0, Math.min(this.nRows, 10));
5250
}
53-
get tail() {
51+
get tail(): Matrix<DT> {
5452
return this.slice(Math.max(this.nRows - 10, 0), this.nRows);
5553
}
5654
/** Convert the Matrix into a HTML table */
@@ -87,7 +85,7 @@ export class Matrix<DT extends DataType>
8785
/** Get the transpose of the matrix. This method clones the matrix. */
8886
get T(): Matrix<DT> {
8987
const resArr = new (this.data.constructor as DTypeConstructor<DT>)(
90-
this.nRows * this.nCols
88+
this.nRows * this.nCols,
9189
) as DType<DT>;
9290
let i = 0;
9391
for (const col of this.cols()) {
@@ -114,7 +112,7 @@ export class Matrix<DT extends DataType>
114112
col(n: number): DType<DT> {
115113
let i = 0;
116114
const col = new (this.data.constructor as DTypeConstructor<DT>)(
117-
this.nRows
115+
this.nRows,
118116
) as DType<DT>;
119117
let offset = 0;
120118
while (i < this.nRows) {
@@ -139,7 +137,7 @@ export class Matrix<DT extends DataType>
139137
/** Get a column array of all column sums in the matrix */
140138
colSum(): DType<DT> {
141139
const sum = new (this.data.constructor as DTypeConstructor<DT>)(
142-
this.nRows
140+
this.nRows,
143141
) as DType<DT>;
144142
let i = 0;
145143
while (i < this.nCols) {
@@ -169,8 +167,7 @@ export class Matrix<DT extends DataType>
169167
while (j < this.nCols) {
170168
let i = 0;
171169
while (i < this.nRows) {
172-
const adder =
173-
(this.item(i, j) as DTypeValue<DT>) *
170+
const adder = (this.item(i, j) as DTypeValue<DT>) *
174171
(rhs.item(i, j) as DTypeValue<DT>);
175172
// @ts-ignore I'll fix this later
176173
res += adder as DTypeValue<DT>;
@@ -182,7 +179,7 @@ export class Matrix<DT extends DataType>
182179
}
183180
/** Filter the matrix by rows */
184181
override filter(
185-
fn: (value: DType<DT>, row: number, _: DType<DT>[]) => boolean
182+
fn: (value: DType<DT>, row: number, _: DType<DT>[]) => boolean,
186183
): Matrix<DT> {
187184
const satisfying: number[] = [];
188185
let i = 0;
@@ -224,7 +221,7 @@ export class Matrix<DT extends DataType>
224221
/** Compute the sum of all rows */
225222
rowSum(): DType<DT> {
226223
const sum = new (this.data.constructor as DTypeConstructor<DT>)(
227-
this.nCols
224+
this.nCols,
228225
) as DType<DT>;
229226
let i = 0;
230227
let offset = 0;
@@ -271,9 +268,9 @@ export class Matrix<DT extends DataType>
271268
return new Matrix<DT>(
272269
this.data.slice(
273270
start ? start * this.nCols : 0,
274-
end ? end * this.nCols : undefined
271+
end ? end * this.nCols : undefined,
275272
) as DType<DT>,
276-
[end ? end - start : this.nRows - start, this.nCols]
273+
[end ? end - start : this.nRows - start, this.nCols],
277274
);
278275
}
279276
/** Iterate through rows */
@@ -290,7 +287,7 @@ export class Matrix<DT extends DataType>
290287
while (i < this.nCols) {
291288
let j = 0;
292289
const col = new (this.data.constructor as DTypeConstructor<DT>)(
293-
this.nRows
290+
this.nRows,
294291
) as DType<DT>;
295292
while (j < this.nRows) {
296293
col[j] = this.data[j * this.nCols + i];

0 commit comments

Comments
 (0)