Skip to content

Commit

Permalink
change predictview setup style
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Feb 12, 2024
1 parent 7e56057 commit 4bc1fd2
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 93 deletions.
31 changes: 31 additions & 0 deletions src/stores/run.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { defineStore } from "pinia";
import { computed } from "vue";

export async function waitState(stateFn: () => any, value: any) {
const stateValue = computed(stateFn); // replace with your state

while (stateValue.value !== value) {
await new Promise((resolve) => setTimeout(resolve, 100)); // polling interval
}

console.log("State changed to the specific value");
}

export async function waitRunable() {
const store = useRunStore();
await waitState(() => store.runable, true);
}

export const useRunStore = defineStore("run", {
state: () => ({
queryCount: 0,
runable: false,
}),
actions: {
async run() {
await waitRunable();
this.queryCount += 1;
await waitRunable();
}
}
})
195 changes: 102 additions & 93 deletions src/views/PredictView.vue
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,116 @@
<script lang="ts">
import * as ort from 'onnxruntime-web';
import { getImjoyApi, isPluginMode, downloadBlob } from '@/utils';
import { ref, onMounted } from 'vue';
export default {
setup() {
const modelUrl = window.location.origin + "/model/v1.0-alldata-ufish_c32.onnx"
const running = ref(false)
const plugin = ref(null as any)
const ortSession = ref(null as ort.InferenceSession | null)
const runInfoText = ref("loading...")
const hasError = ref(false)
async function loadPlugin() {
const imjoy_api = await getImjoyApi()
imjoy_api.log("Hello from Imjoy!")
const url = window.location.origin + "/plugins/ufish.imjoy.html"
plugin.value = await imjoy_api.loadPlugin({ src: url, })
}
async function infer_2d(input: ort.Tensor) {
if (input.type !== 'float32') {
throw new Error('Input tensor must be of type float32')
}
if (input.dims.length !== 2) {
throw new Error('Input tensor must be 2D')
}
const input4d = input.reshape([1, 1, input.dims[0], input.dims[1]])
const session = ortSession
if (session.value === null) {
throw new Error('ONNX session not loaded')
}
const feeds = {'input': input4d}
const { output } = await session.value.run(feeds)
const out = output.reshape([output.dims[2], output.dims[3]])
return out
}
async function loadOrtSession() {
ort.env.wasm.numThreads = 4
ort.env.wasm.simd = true
const session = await ort.InferenceSession.create(
modelUrl,
{ executionProviders: ['wasm'] })
ortSession.value = session
}
onMounted(() => {
const plugP = loadPlugin()
plugP.then(() => {
runInfoText.value = "Plugin loaded"
})
plugP.catch((e) => {
runInfoText.value = "Failed to load plugin, see console for more detail."
console.log(e)
hasError.value = true
})
loadOrtSession().catch((e) => {
runInfoText.value = "Failed to load ONNX session, see console for more detail."
console.log(e)
hasError.value = true
})
})
async function run() {
running.value = true
const sImg = await plugin.value.scale_image()
const f32data = new Float32Array(sImg._rvalue)
const input = new ort.Tensor(
sImg._rdtype, f32data, sImg._rshape);
let output;
try {
output = await infer_2d(input)
const outImg = {
_rtype: "ndarray",
_rdtype: output.type,
_rshape: output.dims,
_rvalue: (output.data as Float32Array).buffer
}
const [enhBytes, coords, numSpots] = await plugin.value.process_enhanced(outImg)
output = {
enhanced: enhBytes,
coords: coords
}
runInfoText.value = `Done, ${numSpots} spots detected.`
running.value = false
} catch (error) {
console.log(error)
runInfoText.value = "Failed to run inference, see console for more detail."
running.value = false
hasError.value = true
return
}
}
return {
plugin,
run,
running,
hasError,
runInfoText,
ortSession,
}
},
data: () => ({
loadingData: false,
hasData: false,
hasError: false,
running: false,
plugin: null as any,
modelLoaded: false,
modelUrl: window.location.origin + "/model/v1.0-alldata-ufish_c32.onnx",
ortSession: null as ort.InferenceSession | null,
test_data_url: "https://huggingface.co/datasets/NaNg/TestData/resolve/main/FISH_spots/MERFISH_1.tif",
output: null as any,
runInfoText: "loading...",
showViewer: !isPluginMode(),
}),
computed: {
Expand All @@ -62,40 +158,7 @@ export default {
return this.output !== null
}
},
created() {
const loadPluginPromise = this.loadPlugin()
loadPluginPromise.then(() => {
this.runInfoText = "Plugin loaded"
})
loadPluginPromise.catch((e) => {
this.runInfoText = "Failed to load plugin, see console for more detail."
console.log(e)
this.hasError = true
})
const loadSessionPromise = this.loadOrtSession()
loadSessionPromise.catch((e) => {
this.runInfoText = "Failed to load ONNX session, see console for more detail."
console.log(e)
this.hasError = true
})
},
methods: {
async loadPlugin() {
const imjoy_api = await getImjoyApi()
imjoy_api.log("Hello from Imjoy!")
const url = window.location.origin + "/plugins/ufish.imjoy.html"
this.plugin = await imjoy_api.loadPlugin({ src: url, })
},
async loadOrtSession() {
ort.env.wasm.numThreads = 4
ort.env.wasm.simd = true
const session = await ort.InferenceSession.create(
this.modelUrl,
{ executionProviders: ['wasm'] })
this.ortSession = session
},
async loadExample() {
this.loadingData = true
if (this.plugin !== null) {
Expand Down Expand Up @@ -152,60 +215,6 @@ export default {
fileInput.click()
},
async scaleImage() {
const image = await this.plugin.scale_image()
return image
},
async infer_2d(input: ort.Tensor) {
if (input.type !== 'float32') {
throw new Error('Input tensor must be of type float32')
}
if (input.dims.length !== 2) {
throw new Error('Input tensor must be 2D')
}
const input4d = input.reshape([1, 1, input.dims[0], input.dims[1]])
const session = this.ortSession
if (session === null) {
throw new Error('ONNX session not loaded')
}
const feeds = {'input': input4d}
const { output } = await session.run(feeds)
const out = output.reshape([output.dims[2], output.dims[3]])
return out
},
async run() {
this.running = true
const sImg = await this.scaleImage()
const f32data = new Float32Array(sImg._rvalue)
const input = new ort.Tensor(
sImg._rdtype, f32data, sImg._rshape);
let output;
try {
output = await this.infer_2d(input)
const outImg = {
_rtype: "ndarray",
_rdtype: output.type,
_rshape: output.dims,
_rvalue: (output.data as Float32Array).buffer
}
const [enhBytes, coords, numSpots] = await this.plugin.process_enhanced(outImg)
this.output = {
enhanced: enhBytes,
coords: coords
}
this.runInfoText = `Done, ${numSpots} spots detected.`
this.running = false
} catch (error) {
console.log(error)
this.runInfoText = "Failed to run inference, see console for more detail."
this.running = false
this.hasError = true
return
}
},
async download() {
downloadBlob(this.output.enhanced, "enhanced.tif", "image/tiff")
downloadBlob(this.output.coords, "coords.csv", "text/csv;charset=utf-8;")
Expand Down

0 comments on commit 4bc1fd2

Please sign in to comment.