forked from hliulab/atmtcr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfolders.py
97 lines (68 loc) · 2.48 KB
/
folders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch.utils.data as data
from PIL import Image
import os
import os.path
import random
# import cv2
import scipy.io
import numpy as np
import csv
# from openpyxl import load_workbook
import torchvision
import torch
# 固定随机数种子
seed = 1
random.seed(seed)
class Folder(data.Dataset):
def __init__(self):
sample = []
data_1 = np.load("MHC-encoder/HLA_antigen_encoded_result_test.npy")
data_2 = torch.load("TCR-encoder/train_feature_graph.pt")
data_1 = torch.from_numpy(data_1)
print(data_1.shape)
print(data_2.shape)
p_data_1 = data_1[:int(data_1.shape[0]/2),:]
n_data_1 = data_1[int(data_1.shape[0]/2):,:]
p_data_2 = data_2[:int(data_2.shape[0]/2),:]
n_data_2 = data_2[int(data_2.shape[0] / 2):, :]
temp = torch.zeros_like(n_data_2)
rand_lst = list(range(n_data_2.shape[0]))
random.shuffle(rand_lst)
for i, idex in enumerate(rand_lst):
temp[i,:] = n_data_2[idex,:]
n_data_2 = temp
# print(p_data_1.shape,p_data_2.shape)
# print(n_data_1.shape, n_data_2.shape)
p_data = torch.cat((p_data_1, p_data_2), 1)
n_data = torch.cat((n_data_1, n_data_2), 1)
# print(p_data.shape)
# print(n_data.shape)
for i in range(p_data.shape[0]):
sample.append((p_data[i,:], 1))
for i in range(n_data.shape[0]):
sample.append((n_data[i,:], 0))
random.shuffle(sample)
tr_sample = sample[0:int(round(0.8 * len(sample)))]
te_sample = sample[int(round(0.8 * len(sample))):len(sample)]
self.train_data = ()
self.test_data = ()
self.train_label = ()
self.test_label = ()
for item, (data, num) in enumerate(tr_sample):
self.train_data = self.train_data + (data,)
self.train_label = self.train_label + (num,)
for item, (data, num) in enumerate(te_sample):
self.test_data = self.test_data + (data,)
self.test_label = self.test_label+ (num,)
print("1")
class Dataset(data.Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
def __len__(self):
length = len(self.label)
return length
def __getitem__(self, index):
return self.data[index], self.label[index]
if __name__ == '__main__':
print("1")