Skip to content

Commit

Permalink
allow specify channel
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Apr 23, 2024
1 parent beff518 commit c9847a0
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 23 deletions.
50 changes: 42 additions & 8 deletions public/plugins/chatbot-extension.imjoy.html
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
<script lang="javascript">
const chatbot_url = "https://bioimage.io/chat/"
const ufish_url = "https://ufish-team.github.io/"
//const ufish_url = "http://localhost:5173/"
const utils_url = ufish_url + "/plugins/chatbot-utils.imjoy.html"
const py_console_url = "https://nanguage.github.io/web-python-console/"

Expand All @@ -38,7 +39,10 @@
if (chatbot) {
await this.registerExtensions(chatbot.registerExtension)
} else {
chatbot = await api.createWindow({src: chatbot_url, name: "BioImage.IO Chatbot"})
chatbot = await api.createWindow({
src: chatbot_url, name: "BioImage.IO Chatbot",
w: 20, h: 40
})
await this.registerExtensions(chatbot.registerExtension)
}
}
Expand All @@ -54,6 +58,7 @@
client = await api.createWindow({
src: ufish_url,
name: "ufish",
w: 25, h: 8
});
}
return client
Expand All @@ -75,6 +80,7 @@
plugin = await api.createWindow({
src: py_console_url,
name: "web-python-console",
w: 25, h: 30
});
}
return plugin
Expand All @@ -92,13 +98,28 @@
title: "RunUfish",
description: "Run U-FISH web client to detect spots in the image",
properties: {
},
};
"channel": {
type: "number",
description: "Channel to use",
default: 0,
},
"pThreshold": {
type: "number",
description: "Threshold for spot detection",
default: 0.5,
},
"viewEnhanced": {
type: "boolean",
description: "View enhanced image",
default: false,
},
}
}
},
execute: async (config) => {
const client = await this.getUfishClient()
await client.waitReady();
await client.runPredict();
await client.runPredict(config.channel, config.pThreshold, config.viewEnhanced);
const output = await client.getOutput();
this.ufish_spots = output.spots
return "Done"
Expand All @@ -119,26 +140,39 @@
"diameter": {
type: "number",
description: "Diameter of the cells",
default: 60,
default: 40,
},
"model_type": {
type: "string",
description: "Model type, cyto or nuclei",
default: "cyto",
description: "Model type, nuclei or cyto",
default: "nuclei",
},
"flow_threshold": {
type: "number",
description: "Flow threshold",
default: 0.4,
},
"channel": {
type: "number",
description: "Channel to use",
default: 2,
}
},
};
},
execute: async (config) => {
const client = await this.getUfishClient()
const utils = await this.getChatbotUtils()
const img = await client.getInputImage()
await utils.run_cellpose(img, {diameter: config.diameter})
await utils.run_cellpose(
img,
config.channel,
{
diameter: config.diameter,
model_type: config.model_type,
flow_threshold: config.flow_threshold,
}
)
await utils.compute_mask_coutours()
const coutours = await utils.get_coutours()
const kaibu = await api.getWindow("Kaibu")
Expand Down
10 changes: 6 additions & 4 deletions public/plugins/chatbot-utils.imjoy.html
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,17 @@
mask = await run_model(model, img, kwargs)
self.mask = mask

async def run_cellpose(self, img, kwargs):
async def run_cellpose(self, img, channel, kwargs):
img = img.astype("float32")
if img.ndim == 2:
img = img[None, :, :]
elif img.ndim == 3:
if img.ndim == 3:
# RGB image
# swap min channel to first
min_channel = np.argmin(img.shape)
img = np.moveaxis(img, min_channel, 0)
img = img[channel]
elif img.ndim > 3:
raise ValueError("Image should be 2D or 3D")
img = img[None, :, :]
mask = await run_model("cellpose", img, kwargs)
self.mask = mask

Expand Down
12 changes: 8 additions & 4 deletions public/plugins/ufish.imjoy.html
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
src="https://kaibu.org/#/app",
window_id="kaibu-container",
name="Kaibu",
w=38, h=20
)
self.image = None # input image
self.enhanced_image = None # enhanced image
Expand Down Expand Up @@ -159,10 +160,10 @@
await self.viewer.view_image(image, name=name)
return image.shape

async def process_enhanced(self, enh_img, view_enhanced=True):
async def process_enhanced(self, enh_img, view_enhanced=True, pThreshold=0.5):
if view_enhanced:
await self.viewer.view_image(enh_img, name="enhanced")
df = call_spots_local_maxima(enh_img)
df = call_spots_local_maxima(enh_img, intensity_threshold=pThreshold)
coords = df.values
self.enhanced_image = enh_img
self.spots = coords
Expand All @@ -178,11 +179,14 @@
async def get_spots(self):
return self.spots

async def scale_image(self):
async def scale_image(self, channel=None):
image = self.image
if len(image.shape) == 3:
channel_axis = np.argmin(image.shape)
image = image.mean(axis=channel_axis)
if channel is not None:
image = image.take(channel, axis=channel_axis)
else:
image = image.mean(axis=channel_axis)
elif len(image.shape) != 2:
raise ValueError(
f"Image has {len(image.shape)} dimensions. "
Expand Down
7 changes: 4 additions & 3 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ async function createApi() {
await runStore.waitRunable()
}

async function run() {
async function run(channel=null, pThreshold=0.5, viewEnhanced=true) {
const runStore = useRunStore()
runStore.setParams(channel, pThreshold, viewEnhanced)
await runStore.run()
}

Expand All @@ -60,10 +61,10 @@ async function createApi() {
return res
}

async function predict(inputImage: object) {
async function predict(inputImage: object, channel=null, pThreshold=0.5, viewEnhanced=true) {
await waitRunable()
await setInputImage(inputImage, "input")
await run()
await run(channel, pThreshold, viewEnhanced)
return await getOutput()
}

Expand Down
8 changes: 8 additions & 0 deletions src/stores/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ export const useRunStore = defineStore("run", {
imageUrl: null as string | null,
fetchedImage: null as object | null,
fetchGetable: false,
channel: null as (number | null),
pThreshold: 0.5,
viewEnhanced: true,
}),
actions: {
async waitRunable() {
Expand Down Expand Up @@ -87,6 +90,11 @@ export const useRunStore = defineStore("run", {
async getFetchedImage() {
await waitFetchGetable();
return this.fetchedImage;
},
setParams(channel: number | null, pThreshold: number, viewEnhanced: boolean) {
this.channel = channel;
this.pThreshold = pThreshold;
this.viewEnhanced = viewEnhanced;
}
}
})
11 changes: 7 additions & 4 deletions src/views/PredictView.vue
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ export default {
})
function checkInputShape(shape: number[]) {
if (shape.length === 3 && shape[2] === 3) {
if (shape.length === 3 && ((shape[2] === 3) || (shape[2] === 4))) {
runInfoText.value = `Image loaded, shape: ${shape}. Will treat it as an RGB image.`
hasError.value = false
} else if (shape.length !== 2) {
Expand All @@ -120,8 +120,11 @@ export default {
async function run() {
running.value = true
const channel = runStore.channel
const pThreshold = runStore.pThreshold
const viewEnhanced = runStore.viewEnhanced
try {
const sImg = await plugin.value.scale_image()
const sImg = await plugin.value.scale_image(channel)
const f32data = new Float32Array(sImg._rvalue)
const input = new ort.Tensor(
sImg._rdtype, f32data, sImg._rshape);
Expand All @@ -132,8 +135,8 @@ export default {
_rshape: modelOut.dims,
_rvalue: (modelOut.data as Float32Array).buffer
}
const viewEnhanced = !isPluginMode()
const [enhBytes, coords, numSpots] = await plugin.value.process_enhanced(outImg, viewEnhanced)
const [enhBytes, coords, numSpots] = await plugin.value.process_enhanced(
outImg, viewEnhanced, pThreshold)
output.value = {
enhanced: enhBytes,
coords: coords
Expand Down

0 comments on commit c9847a0

Please sign in to comment.