Skip to content

Commit

Permalink
【Fix】 fix cell_segment_v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenbin24 committed Dec 4, 2023
1 parent 007fb7f commit ec99a42
Show file tree
Hide file tree
Showing 16 changed files with 2,366 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def convert_gray(self):
logger.info('Image %s convert to gray!' % self.file[idx])
self.img_list[idx] = img[:, :, 0]

def get_img_filter(self):
"""get tissue image by tissue mask"""
for img, tissue_mask in zip(self.img_list, self.tissue_mask):
img_filter = np.multiply(img, tissue_mask).astype(np.uint8)
self.img_filter.append(img_filter)

@staticmethod
def transfer_32bit_to_8bit(image_32bit):
min_32bit = np.min(image_32bit)
Expand Down
1 change: 1 addition & 0 deletions stereo/image/segmentation/seg_utils/v1_pro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from stereo.image.segmentation.seg_utils.v1_pro.cell_seg_pipeline_v1_pro import CellSegPipeV1Pro # noqa
159 changes: 159 additions & 0 deletions stereo/image/segmentation/seg_utils/v1_pro/cell_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import multiprocessing as mp
import os
import time

import cv2
import glog
import numpy as np
# import tensorflow as tf
import torch
from albumentations import Compose
from albumentations.pytorch import ToTensorV2
from skimage import filters
from tqdm import tqdm

from stereo import logger
from .dataset import data_batch2
from .resnet_unet import EpsaResUnet
from .utils import (
split_preproc
)

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def get_transforms():
list_transforms = []

list_transforms.extend([])
list_transforms.extend([ToTensorV2(), ])
list_trfms = Compose(list_transforms)
return list_trfms


def cellInfer(file, size, overlap=100):
if isinstance(file, list):
file_list = file
else:
file_list = [file]

result = []

model_path = os.path.join(os.path.split(__file__)[0], 'model')
model_dir = os.path.join(model_path, 'best_model.pth')
logger.info(f'CellCut_model infer path {model_dir}...')
model = EpsaResUnet(out_channels=6)
glog.info('Load model from: {}'.format(model_dir))
model.load_state_dict(torch.load(model_dir, map_location=lambda storage, loc: storage), strict=True)
model.eval()
logger.info('Load model ok.')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
glog.info('GPU type is {}'.format(torch.cuda.get_device_name(0)))
glog.info(f"using device: {device}")
model.to(device)
for idx, image in enumerate(file_list):
logger.info(image.shape)

t1 = time.time()
logger.info('median filter using cpu')

image_list, m_x_list, m_y_list = split_preproc(image, 1000, 100)
images = np.zeros(image.shape, dtype=np.uint8)
images.fill(0)
median_filter_in_pool_parallel(image_list, images, m_x_list, m_y_list)

t2 = time.time()
logger.info('median filter: {}'.format(t2 - t1))

# accelerate data loader
overlap = 100
dataset = data_batch2(images, 256, overlap)

merge_label = image
merge_label.fill(0)
x_list, y_list, ori_size = dataset.get_list()
test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=20)
img_idx = 0
for batch in tqdm(test_dataloader, ncols=80):
img = batch
img = img.type(torch.FloatTensor)
img = img.to(device)

pred_mask = model(img)
bacth_size = len(pred_mask)
pred_mask = torch.sigmoid(pred_mask).detach().cpu().numpy()
pred = pred_mask[:, 0, :, :]
pred[:] = (pred[:] < 0.55) * 255
pred1 = pred.astype(np.uint8)
for i in range(bacth_size):
temp_img = pred1[i][:ori_size[i + img_idx][0], :ori_size[i + img_idx][1]]
info = [x_list[i + img_idx], y_list[i + img_idx]]
h, w = temp_img.shape
if int(info[0]) == 0 or int(info[1]) == 0:
x_begin = int(info[0])
y_begin = int(info[1])
merge_label[int(x_begin): int(x_begin) + h - 2, int(y_begin): int(y_begin) + w - 2] = \
temp_img[1: - 1, 1: - 1]
else:
x_begin = int(info[0]) + overlap // 2
y_begin = int(info[1]) + overlap // 2
merge_label[int(x_begin): int(x_begin) + h - overlap, int(y_begin): int(y_begin) + w - overlap] = \
temp_img[overlap // 2: - overlap // 2, overlap // 2: - overlap // 2]
img_idx += 20

result.append(merge_label)

return result


def s_median_filter(image):
from skimage.morphology import disk
m_image = filters.median(image, disk(50))
m_image = cv2.subtract(image, m_image)
return m_image


def median_filter_in_pool(image_list, images):
with mp.Pool(processes=20) as p:
for i in image_list:
median_image = p.apply_async(s_median_filter, (i,))
images.append(median_image)
p.close()
p.join()


def median_filter_in_pool_parallel(image_list, images, x_list, y_list):
import queue
q = queue.Queue()

def worker():
idx = 0
while True:
item = q.get()
if item == 'STOP':
q.task_done()
break
item = item.get()

x = x_list[idx]
y = y_list[idx]
h, w = item.shape
images[x: x + h - 2, y: y + w - 2] = item[1:-1, 1:-1]
idx += 1

q.task_done()

import threading
threading.Thread(target=worker, daemon=True).start()

with mp.Pool(processes=20) as p:
for i in image_list:
median_image = p.apply_async(s_median_filter, (i,))
q.put(median_image)
p.close()
p.join()
q.put('STOP')

q.join()
108 changes: 108 additions & 0 deletions stereo/image/segmentation/seg_utils/v1_pro/cell_seg_pipeline_v1_pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# import image
import os
import time
from os.path import join

import numpy as np
import tifffile
from skimage import measure

from stereo.image.segmentation.seg_utils.base_cell_seg_pipe.cell_seg_pipeline import CellSegPipe
from stereo.image.segmentation.seg_utils.v1_pro import grade
from stereo.log_manager import logger
from .cell_infer import cellInfer


class CellSegPipeV1Pro(CellSegPipe):

def tissue_cell_infer(self):

"""cell segmentation in tissue area by neural network"""
self.tissue_cell_label = []
for idx, img in enumerate(self.img_filter):
tissue_bbox = self.tissue_bbox[idx]
tissue_img = [img[p[0]: p[2], p[1]: p[3]] for p in tissue_bbox]

label_list = cellInfer(tissue_img, self.deep_crop_size, self.overlap)
self.tissue_cell_label.append(label_list)
return 0

def tissue_label_filter(self, tissue_cell_label):
"""filter cell mask in tissue area"""
tissue_cell_label_filter = []
for idx, label in enumerate(tissue_cell_label):
tissue_bbox = self.tissue_bbox[idx]
label_filter_list = []
for i in range(self.tissue_num[idx]):
if len(self.tissue_mask) != 0:
tiss_bbox_tep = tissue_bbox[i]
label_filter = np.multiply(
label[i],
self.tissue_mask[idx][tiss_bbox_tep[0]: tiss_bbox_tep[2], tiss_bbox_tep[1]: tiss_bbox_tep[3]]
).astype(np.uint8)
label_filter_list.append(label_filter)
else:
label_filter_list.append(label[i])
tissue_cell_label_filter.append(label_filter_list)
return tissue_cell_label_filter

def run(self):
logger.info('Start do cell mask, the method is v1_pro, this will take some minutes.')
self.get_img_filter()

t1 = time.time()

self.tissue_cell_infer()
t2 = time.time()
logger.info('Cell inference : %.2f' % (t2 - t1))

# filter by tissue mask
tissue_cell_label_filter = self.tissue_label_filter(self.tissue_cell_label)
t3 = time.time()
logger.info('Filter by tissue mask : %.2f' % (t3 - t2))

# mosaic tissue roi
cell_mask = self.mosaic(tissue_cell_label_filter)
del tissue_cell_label_filter
t4 = time.time()
logger.info('Mosaic tissue roi : %.2f' % (t4 - t3))

# post process
self.watershed_score(cell_mask)
t5 = time.time()
logger.info('Post-processing : %.2f' % (t5 - t4))

self.save_cell_mask()
logger.info('Result saved : %s ' % (self.out_path))

def save_each_file_result(self, file_name, idx):
mask_name = r'_watershed_mask.tif' if self.is_water else r'_mask.tif'
tifffile.imsave(join(self.out_path, file_name + mask_name), self.post_mask_list[idx])

def save_cell_mask(self):
"""save cell mask from network or watershed"""
for idx, file in enumerate(self.file):
file_name, _ = os.path.splitext(file)
self.save_each_file_result(file_name, idx)

def watershed_score(self, cell_mask):
"""watershed and score on cell mask by neural network"""
for idx, cell_mask in enumerate(cell_mask):
post_mask = grade.edgeSmooth(cell_mask)
self.post_mask_list.append(post_mask)

def get_roi(self):
for idx, tissue_mask in enumerate(self.tissue_mask):
label_image = measure.label(tissue_mask, connectivity=2)
props = measure.regionprops(label_image, intensity_image=self.img_list[idx])

# remove noise tissue mask
filtered_props = props # self.__filter_roi(props)
if len(props) != len(filtered_props):
tissue_mask_filter = np.zeros((tissue_mask.shape), dtype=np.uint8)
for tissue_tile in filtered_props:
bbox = tissue_tile['bbox']
tissue_mask_filter[bbox[0]: bbox[2], bbox[1]: bbox[3]] += tissue_tile['image']
self.tissue_mask[idx] = np.uint8(tissue_mask_filter > 0)
self.tissue_num.append(len(filtered_props))
self.tissue_bbox.append([p['bbox'] for p in filtered_props])
86 changes: 86 additions & 0 deletions stereo/image/segmentation/seg_utils/v1_pro/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import math

import cv2
import numpy as np
import torch
from albumentations import (
Compose
)
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset


def get_transforms():
list_transforms = []

list_transforms.extend([])

list_transforms.extend([ToTensorV2(), ])
list_trfms = Compose(list_transforms)
return list_trfms


class data_batch(Dataset):

def __init__(self, img_list):
self.transforms = get_transforms()
self.img_list = img_list

def __len__(self):
return len(self.img_list)

def __getitem__(self, idx):
img = self.img_list[idx]

augmented = self.transforms(image=img)
img = augmented['image']

image = torch.cat((img, img), 0)

return image


class data_batch2(Dataset):

def __init__(self, raw_img, cut_size, overlap):
self.transforms = get_transforms()

shapes = raw_img.shape
x_nums = math.ceil(shapes[0] / (cut_size - overlap))
y_nums = math.ceil(shapes[1] / (cut_size - overlap))
self.x_list = []
self.y_list = []
self.img_list = []
for x_temp in range(x_nums):
for y_temp in range(y_nums):
x_begin = max(0, x_temp * (cut_size - overlap))
y_begin = max(0, y_temp * (cut_size - overlap))
x_end = min(x_begin + cut_size, shapes[0])
y_end = min(y_begin + cut_size, shapes[1])
i = raw_img[x_begin: x_end, y_begin: y_end]
self.x_list.append(x_begin)
self.y_list.append(y_begin)
self.img_list.append(i)

self.ori_size = []

def __len__(self):
return len(self.img_list)

def __getitem__(self, idx):
img = self.img_list[idx]
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
self.ori_size.append([img.shape[0], img.shape[1]])

pad_img = np.full((256, 256, 3), 0, dtype='uint8')
pad_img[:img.shape[0], :img.shape[1], :] = img

augmented = self.transforms(image=pad_img)
pad_img = augmented['image']

image = torch.cat((pad_img, pad_img), 0)

return image

def get_list(self):
return (self.x_list, self.y_list, self.ori_size)
Loading

0 comments on commit ec99a42

Please sign in to comment.