diff --git a/src/index.ts b/src/index.ts index 801e41da..51c8e8e2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -35,25 +35,37 @@ export async function load(base = BASE_PATH, options = { size: IMAGE_SIZE }) { return nsfwnet } +interface IOHandler { + load: () => any +} + export class NSFWJS { public endpoints: string[] private options: nsfwjsOptions - private path: string + private pathOrIOHandler: string | IOHandler private model: tf.LayersModel private intermediateModels: { [layerName: string]: tf.LayersModel } = {} private normalizationOffset: tf.Scalar - constructor(base: string, options: nsfwjsOptions) { + constructor( + modelPathBaseOrIOHandler: string | IOHandler, + options: nsfwjsOptions + ) { this.options = options - this.path = `${base}model.json` this.normalizationOffset = tf.scalar(255) + + if (typeof modelPathBaseOrIOHandler === 'string') { + this.pathOrIOHandler = `${modelPathBaseOrIOHandler}model.json` + } else { + this.pathOrIOHandler = modelPathBaseOrIOHandler + } } async load() { // this is a Layers Model - this.model = await tf.loadLayersModel(this.path) + this.model = await tf.loadLayersModel(this.pathOrIOHandler) this.endpoints = this.model.layers.map(l => l.name) const { size } = this.options