-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset.py
171 lines (141 loc) · 6.76 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import numpy as np
import pandas as pd
from copy import deepcopy
import torch
from torch.utils import data
from torchvision import transforms
from torchtoolbox.transform import Cutout
import cv2
import scanpy as sc
from utils import load_ST_file, adata_preprocess_pca, adata_preprocess_hvg, extract_wash_patches, build_her2st_data
class Dataset(data.Dataset):
def __init__(self, dataset, path, name, gene_preprocess='pca', n_genes=3000,
prob_mask=0.5, pct_mask=0.2, prob_noise=0.5, pct_noise=0.8, sigma_noise=0.5,
prob_swap=0.5, pct_swap=0.1, img_size=112, train=True):
super(Dataset, self).__init__()
# adata, label, image
if dataset == "SpatialLIBD":
adata = load_ST_file(os.path.join(path, name))
df_meta = pd.read_csv(os.path.join(path, name, 'metadata.tsv'), sep='\t')
self.label = pd.Categorical(df_meta['layer_guess']).codes
# image
full_image = cv2.imread(os.path.join(path, name, f'{name}_full_image.tif'))
full_image = cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB)
patches = []
for x, y in adata.obsm['spatial']:
patches.append(full_image[y-img_size:y+img_size, x-img_size:x+img_size])
patches = np.array(patches)
self.image = patches
elif dataset == "Her2st":
adata, patches = build_her2st_data(path, name, img_size)
self.label = adata.obs['label']
self.image = patches
elif dataset in ["Mouse_brain_anterior"]:
adata = sc.read_h5ad(os.path.join(path, f"{name}.h5ad"))
adata.X = adata.X.A
x = adata.obs['x4'].values
y = adata.obs['x5'].values
adata.obsm['spatial'] = np.stack((y, x), 1)
# label
self.label = np.zeros(adata.shape[0], dtype=int)
# image
full_image = cv2.imread(os.path.join(path, f"{name}_histology.tif"))
full_image = cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB)
patches = []
for x, y in adata.obsm['spatial']:
patches.append(full_image[y-img_size:y+img_size, x-img_size:x+img_size])
patches = np.array(patches)
self.image = patches
if dataset == "IDC":
adata = load_ST_file(os.path.join(path, name))
adata.X = adata.X.A
self.label = np.zeros(adata.shape[0], dtype=int)
# image
full_image = cv2.imread(os.path.join(path, name, f'{name}.tif'))
full_image = cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB)
patches = []
for x, y in adata.obsm['spatial']:
patches.append(full_image[y - img_size:y + img_size, x - img_size:x + img_size])
patches = np.array(patches)
self.image = patches
self.n_clusters = self.label.max() + 1
self.spatial = adata.obsm['spatial']
self.n_pos = self.spatial.max() + 1
# preprocess
if gene_preprocess == 'pca':
self.gene = adata_preprocess_pca(adata, pca_n_comps=n_genes).astype(np.float32)
elif gene_preprocess == 'hvg':
self.gene = np.array(adata_preprocess_hvg(adata, n_top_genes=n_genes)).astype(np.float32)
self.train = train
self.img_train_transform = transforms.Compose([
Cutout(0.5),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.img_test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.gene_train_transform = GeneTransforms(n_genes,
prob_mask=0.5, pct_mask=0.2,
prob_noise=0.5, pct_noise=0.8, sigma_noise=0.5,
prob_swap=0.5, pct_swap=0.1)
self.gene = self.gene[self.label != -1]
self.image = self.image[self.label != -1]
self.label = self.label[self.label != -1]
def __len__(self):
return len(self.label)
def __getitem__(self, idx):
spatial = torch.from_numpy(self.spatial[idx])
y = self.label[idx]
if self.train:
xg = self.gene[idx]
xg_u = self.gene_train_transform(deepcopy(xg))
xg_v = self.gene_train_transform(deepcopy(xg))
xg = torch.from_numpy(xg)
xg_u = torch.from_numpy(xg_u)
xg_v = torch.from_numpy(xg_v)
xi_u = self.img_train_transform(self.image[idx])
xi_v = self.img_train_transform(self.image[idx])
return xg, xg_u, xg_v, xi_u, xi_v, spatial, y, idx
else:
xg = self.gene[idx]
xg = torch.from_numpy(xg)
xi = self.img_test_transform(self.image[idx])
return xg, xi, spatial, y, idx
class GeneTransforms(torch.nn.Module):
def __init__(self, n_genes,
prob_mask, pct_mask,
prob_noise, pct_noise, sigma_noise,
prob_swap, pct_swap):
super(GeneTransforms, self).__init__()
self.n_genes = n_genes
self.prob_mask = prob_mask
self.pct_mask = pct_mask
self.prob_noise = prob_noise
self.pct_noise = pct_noise
self.sigma_noise = sigma_noise
self.prob_swap = prob_swap
self.pct_swap = pct_swap
def build_mask(self, pct_mask):
mask = np.concatenate([np.ones(int(self.n_genes * pct_mask), dtype=bool),
np.zeros(self.n_genes - int(self.n_genes * pct_mask), dtype=bool)])
np.random.shuffle(mask)
return mask
def forward(self, xg):
if np.random.uniform(0, 1) < self.prob_mask:
mask = self.build_mask(self.pct_mask)
xg[mask] = 0
if np.random.uniform(0, 1) < self.prob_noise:
mask = self.build_mask(self.pct_noise)
noise = np.random.normal(0, self.sigma_noise, int(self.n_genes * self.pct_noise))
xg[mask] += noise
if np.random.uniform(0, 1) < self.prob_swap:
swap_pairs = np.random.randint(self.n_genes, size=(int(self.n_genes * self.pct_swap / 2), 2))
xg[swap_pairs[:, 0]], xg[swap_pairs[:, 1]] = xg[swap_pairs[:, 1]], xg[swap_pairs[:, 0]]
return xg