Skip to content

Commit

Permalink
add model file
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenbin24 committed Nov 2, 2023
1 parent 67c853c commit fa3ff70
Show file tree
Hide file tree
Showing 14 changed files with 2,155 additions and 0 deletions.
1 change: 1 addition & 0 deletions stereo/image/segmentation/seg_utils/v1_pro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from stereo.image.segmentation.seg_utils.v1_pro.cell_seg_pipeline_v2 import CellSegPipeV1Pro # noqa
176 changes: 176 additions & 0 deletions stereo/image/segmentation/seg_utils/v1_pro/cell_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import multiprocessing as mp
import os
import time

import cv2
import glog
import numpy as np
import torch
from albumentations import Compose
from albumentations.pytorch import ToTensorV2
from skimage import filters
from tqdm import tqdm

from stereo import logger
from .dataset import data_batch2
from .resnet_unet import EpsaResUnet
from .utils import split_preproc

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def get_transforms():
list_transforms = []
list_transforms.extend([])
list_transforms.extend([ToTensorV2()])
list_trfms = Compose(list_transforms)
return list_trfms


def cellInfer(file, size, overlap=100):
# split -> predict -> merge
if isinstance(file, list):
file_list = file
else:
file_list = [file]

result = []

model_path = os.path.join(os.path.split(__file__)[0], 'model')
model_dir = os.path.join(model_path, 'best_model.pth')
logger.info(f'CellCut_model infer path {model_dir}...')
model = EpsaResUnet(out_channels=6)
glog.info('Load model from: {}'.format(model_dir))
model.load_state_dict(torch.load(model_dir, map_location=lambda storage, loc: storage), strict=True)
model.eval()
logger.info('Load model ok.')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
glog.info('GPU type is {}'.format(torch.cuda.get_device_name(0)))
glog.info(f"using device: {device}")
model.to(device)
for idx, image in enumerate(file_list):
logger.info(image.shape)

t1 = time.time()
if torch.cuda.is_available():
import cupy
from utils import cuda_kernel
from cucim.skimage.morphology import disk
logger.info('median filter using gpu')
image_cp = cupy.asarray(image)
# Accelerate median using specific cuda kernel function
median_image = cupy.empty(image.shape, dtype=cupy.uint8)
(height, width) = image.shape
cuda_kernel.median_filter_kernel(
((width + 15) // 16, (height + 15) // 16),
(16, 16),
(image_cp, median_image, width, height, disk(50))
)

median_image = np.asarray(median_image.get())
images = cv2.subtract(image, median_image)
else:
logger.info('median filter using cpu')

image_list, m_x_list, m_y_list = split_preproc(image, 1000, 100)
images = np.zeros(image.shape, dtype=np.uint8)
images.fill(0)
median_filter_in_pool_parallel(image_list, images, m_x_list, m_y_list)

t2 = time.time()
logger.info('median filter: {}'.format(t2 - t1))

# accelerate data loader
overlap = 100
dataset = data_batch2(images, 256, overlap)

merge_label = image
merge_label.fill(0)
x_list, y_list, ori_size = dataset.get_list()
test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=20)
img_idx = 0
for batch in tqdm(test_dataloader, ncols=80):
img = batch
img = img.type(torch.FloatTensor)
img = img.to(device)

pred_mask = model(img)
bacth_size = len(pred_mask)
pred_mask = torch.sigmoid(pred_mask).detach().cpu().numpy()
pred = pred_mask[:, 0, :, :]
pred[:] = (pred[:] < 0.55) * 255
pred1 = pred.astype(np.uint8)
for i in range(bacth_size):
temp_img = pred1[i][:ori_size[i + img_idx][0], :ori_size[i + img_idx][1]]
info = [x_list[i + img_idx], y_list[i + img_idx]]
h, w = temp_img.shape
if int(info[0]) == 0 or int(info[1]) == 0:
x_begin = int(info[0])
y_begin = int(info[1])
temp_data = temp_img[1: - 1, 1: - 1]
merge_label[int(x_begin): int(x_begin) + h - 2, int(y_begin): int(y_begin) + w - 2] = temp_data
else:
x_begin = int(info[0]) + overlap // 2
y_begin = int(info[1]) + overlap // 2
temp_data = temp_img[overlap // 2: - overlap // 2, overlap // 2: - overlap // 2]
merge_label[int(x_begin): int(x_begin) + h - overlap,
int(y_begin): int(y_begin) + w - overlap] = temp_data # noqa
img_idx += 20

result.append(merge_label)

return result


def s_median_filter(image):
from skimage.morphology import disk
m_image = filters.median(image, disk(50))
m_image = cv2.subtract(image, m_image)
return m_image


def median_filter_in_pool(image_list, images):
with mp.Pool(processes=20) as p:
for i in image_list:
median_image = p.apply_async(s_median_filter, (i,))
images.append(median_image)
p.close()
p.join()


def median_filter_in_pool_parallel(image_list, images, x_list, y_list):
import queue
q = queue.Queue()

def worker():
idx = 0
while True:
item = q.get()
if item == 'STOP':
q.task_done()
break
item = item.get()

x = x_list[idx]
y = y_list[idx]
h, w = item.shape
images[x: x + h - 2, y: y + w - 2] = item[1:-1, 1:-1]
idx += 1
# del item

q.task_done()

import threading
threading.Thread(target=worker, daemon=True).start()

with mp.Pool(processes=20) as p:
for i in image_list:
median_image = p.apply_async(s_median_filter, (i,))
q.put(median_image)
p.close()
p.join()
q.put('STOP')

q.join()
122 changes: 122 additions & 0 deletions stereo/image/segmentation/seg_utils/v1_pro/cell_seg_pipeline_v1_pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# import image
import os
import time
from os.path import join

import numpy as np
import tifffile
from skimage import measure

from stereo.image.segmentation.seg_utils.base_cell_seg_pipe.cell_seg_pipeline import CellSegPipe
from stereo.image.segmentation.seg_utils.v1_pro import grade
from stereo.log_manager import logger
from .cell_infer import cellInfer
from .utils import transfer_16bit_to_8bit


class CellSegPipeV1Pro(CellSegPipe):

def tissue_cell_infer(self):
"""cell segmentation in tissue area by neural network"""
self.tissue_cell_label = []
for idx, img in enumerate(self.img_list):
tissue_bbox = self.tissue_bbox[idx]
tissue_img = [img[p[0]: p[2], p[1]: p[3]] for p in tissue_bbox]
label_list = cellInfer(tissue_img, self.deep_crop_size, self.overlap)
self.tissue_cell_label.append(label_list)
return 0

def tissue_label_filter(self, tissue_cell_label):
"""filter cell mask in tissue area"""
tissue_cell_label_filter = []
for idx, label in enumerate(tissue_cell_label):
tissue_bbox = self.tissue_bbox[idx]
label_filter_list = []
for i in range(self.tissue_num[idx]):
if len(self.tissue_mask) != 0:
tiss_bbox_tep = tissue_bbox[i]
label_filter = np.multiply(
label[i],
self.tissue_mask[idx][tiss_bbox_tep[0]: tiss_bbox_tep[2], tiss_bbox_tep[1]: tiss_bbox_tep[3]]
).astype(np.uint8)
label_filter_list.append(label_filter)
else:
label_filter_list.append(label[i])
tissue_cell_label_filter.append(label_filter_list)
return tissue_cell_label_filter

def run(self):
logger.info('Start do cell mask, this will take some minutes.')
t1 = time.time()

self.tissue_cell_infer()
t2 = time.time()
logger.info('Cell inference : %.2f' % (t2 - t1))

# filter by tissue mask
tissue_cell_label_filter = self.tissue_label_filter(self.tissue_cell_label)
t3 = time.time()
logger.info('Filter by tissue mask : %.2f' % (t3 - t2))

# mosaic tissue roi
cell_mask = self.mosaic(tissue_cell_label_filter)
del tissue_cell_label_filter
t4 = time.time()
logger.info('Mosaic tissue roi : %.2f' % (t4 - t3))

# post process
self.watershed_score(cell_mask)
t5 = time.time()
logger.info('Post-processing : %.2f' % (t5 - t4))

self.save_cell_mask()
logger.info('Result saved : %s ' % (self.out_path))

def save_each_file_result(self, file_name, idx):
mask_name = r'_watershed_mask.tif' if self.is_water else r'_mask.tif'
tifffile.imsave(join(self.out_path, file_name + mask_name), self.post_mask_list[idx])

def save_cell_mask(self):
"""save cell mask from network or watershed"""
for idx, file in enumerate(self.file):
file_name, _ = os.path.splitext(file)
self.save_each_file_result(file_name, idx)

def watershed_score(self, cell_mask):
"""watershed and score on cell mask by neural network"""
for idx, cell_mask in enumerate(cell_mask):
post_mask = grade.edgeSmooth(cell_mask)
self.post_mask_list.append(post_mask)

def get_roi(self):
if len(self.tissue_mask) == 0:
self.tissue_num.append(1)
self.tissue_bbox.append([(0, 0, self.img_list[0].shape[0], self.img_list[0].shape[1])])
else:
for idx, tissue_mask in enumerate(self.tissue_mask):
label_image = measure.label(tissue_mask, connectivity=2)
props = measure.regionprops(label_image, intensity_image=self.img_list[idx])

# remove noise tissue mask
filtered_props = props
if len(props) != len(filtered_props):
tissue_mask_filter = np.zeros((tissue_mask.shape), dtype=np.uint8)
for tissue_tile in filtered_props:
bbox = tissue_tile['bbox']
tissue_mask_filter[bbox[0]: bbox[2], bbox[1]: bbox[3]] += tissue_tile['image']
self.tissue_mask[idx] = np.uint8(tissue_mask_filter > 0)
self.tissue_num.append(len(filtered_props))
self.tissue_bbox.append([p['bbox'] for p in filtered_props])

def trans16to8(self):
for idx, img in enumerate(self.img_list):
assert img.dtype in ['uint16', 'uint8']
if img.dtype != 'uint8':
logger.info('%s transfer to 8bit' % self.file[idx])
self.img_list[idx] = transfer_16bit_to_8bit(img)

def get_tissue_mask(self, tissue_seg_model_path, tissue_seg_method):
try:
self.tissue_mask = [tifffile.imread(os.path.join(self.out_path, self.file_name[0] + '_tissue_cut.tif'))]
except Exception:
self.tissue_mask = []
82 changes: 82 additions & 0 deletions stereo/image/segmentation/seg_utils/v1_pro/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import math

import cv2
import numpy as np
import torch
from albumentations import Compose
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset


def get_transforms():
list_transforms = []
list_transforms.extend([])
list_transforms.extend(
[
ToTensorV2(),
])
list_trfms = Compose(list_transforms)
return list_trfms


class data_batch(Dataset):

def __init__(self, img_list):
self.transforms = get_transforms()
self.img_list = img_list

def __len__(self):
return len(self.img_list)

def __getitem__(self, idx):
img = self.img_list[idx]
augmented = self.transforms(image=img)
img = augmented['image']
image = torch.cat((img, img), 0)
return image


class data_batch2(Dataset):

def __init__(self, raw_img, cut_size, overlap):
self.transforms = get_transforms()

shapes = raw_img.shape
x_nums = math.ceil(shapes[0] / (cut_size - overlap))
y_nums = math.ceil(shapes[1] / (cut_size - overlap))
self.x_list = []
self.y_list = []
self.img_list = []
for x_temp in range(x_nums):
for y_temp in range(y_nums):
x_begin = max(0, x_temp * (cut_size - overlap))
y_begin = max(0, y_temp * (cut_size - overlap))
x_end = min(x_begin + cut_size, shapes[0])
y_end = min(y_begin + cut_size, shapes[1])
i = raw_img[x_begin: x_end, y_begin: y_end]
self.x_list.append(x_begin)
self.y_list.append(y_begin)
self.img_list.append(i)

self.ori_size = []

def __len__(self):
return len(self.img_list)

def __getitem__(self, idx):
img = self.img_list[idx]
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
self.ori_size.append([img.shape[0], img.shape[1]])

pad_img = np.full((256, 256, 3), 0, dtype='uint8')
pad_img[:img.shape[0], :img.shape[1], :] = img

augmented = self.transforms(image=pad_img)
pad_img = augmented['image']

image = torch.cat((pad_img, pad_img), 0)

return image

def get_list(self):
return (self.x_list, self.y_list, self.ori_size)
Loading

0 comments on commit fa3ff70

Please sign in to comment.