Skip to content

Commit 0497856

Browse files
authored
Merge pull request #143 from mlaco/accept-iohandler
Accept path to model or a TF IOHandler
2 parents ab7c1e9 + 6841bb7 commit 0497856

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

src/index.ts

+16-4
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,37 @@ export async function load(base = BASE_PATH, options = { size: IMAGE_SIZE }) {
3535
return nsfwnet
3636
}
3737

38+
interface IOHandler {
39+
load: () => any
40+
}
41+
3842
export class NSFWJS {
3943
public endpoints: string[]
4044

4145
private options: nsfwjsOptions
42-
private path: string
46+
private pathOrIOHandler: string | IOHandler
4347
private model: tf.LayersModel
4448
private intermediateModels: { [layerName: string]: tf.LayersModel } = {}
4549

4650
private normalizationOffset: tf.Scalar
4751

48-
constructor(base: string, options: nsfwjsOptions) {
52+
constructor(
53+
modelPathBaseOrIOHandler: string | IOHandler,
54+
options: nsfwjsOptions
55+
) {
4956
this.options = options
50-
this.path = `${base}model.json`
5157
this.normalizationOffset = tf.scalar(255)
58+
59+
if (typeof modelPathBaseOrIOHandler === 'string') {
60+
this.pathOrIOHandler = `${modelPathBaseOrIOHandler}model.json`
61+
} else {
62+
this.pathOrIOHandler = modelPathBaseOrIOHandler
63+
}
5264
}
5365

5466
async load() {
5567
// this is a Layers Model
56-
this.model = await tf.loadLayersModel(this.path)
68+
this.model = await tf.loadLayersModel(this.pathOrIOHandler)
5769
this.endpoints = this.model.layers.map(l => l.name)
5870
const { size } = this.options
5971

0 commit comments

Comments
 (0)