-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_gen.py
105 lines (81 loc) · 3.12 KB
/
data_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from torch.utils.data import Dataset
def generate_points(num=40000,dim=18,type='normal',normal_var=1,radius=1):
'''
If type=='normal', then generate points from N(0,normal_var)
If type=='spherical', then simply divide the points by their norm.'''
X = torch.randn([num,dim]) #coordinates sampled from N(0,1)
if type=='spherical':
norm = torch.norm(X, p=2, dim=1, keepdim=True)
X_spherical = X / norm
return X_spherical
else:
return X
class TreeNode:
'''
This class represents a node in the decision tree.
Each node has a depth, a maximum depth(of the tree), a feature index, and left and right child nodes.
Leaf nodes have a value, which is the predicted class.
'''
def __init__(self, depth, max_depth, feature_index):
self.depth = depth
self.max_depth = max_depth
self.feature = feature_index
self.left = None
self.right = None
self.value = None # This will store the predicted class for leaf nodes
def build_tree(self):
if self.depth == self.max_depth:
self.value = float(self.feature % 2)
return
# Create left and right child nodes
self.left = TreeNode(self.depth + 1, self.max_depth, 2*self.feature+1)
self.right = TreeNode(self.depth + 1, self.max_depth, 2*self.feature+2)
# Recursively build left and right subtrees
self.left.build_tree()
self.right.build_tree()
def predict(self, x):
if self.value is not None:
return self.value
if x[self.feature] > 0:
return self.left.predict(x)
else:
return self.right.predict(x)
def gen_spherical_data(depth, dim_in, type_data, num_points, feat_index_start=0,radius=1):
'''
Generate points uniformly random from a hypersphere. And the label is the prediction of the tree with depth = max_depth.
The node hyperplanes are simply characterised by standard basis vectors(for example, the root node hyperplane is x[0] = 0)
'''
Tree = TreeNode(depth = 0,max_depth=depth,feature_index = feat_index_start)
Tree.build_tree()
X = generate_points(num=num_points,dim=dim_in,type=type_data,radius=radius)
Y=[]
for item in X:
Y.append(Tree.predict(item))
Y = torch.tensor(Y)
return X,Y
# depth = 4
# dim_in = 18
# type_data = 'normal'
# feat_index_start = 0 #the index of the first feature in the tree
# num_points = 40000
# Tree = TreeNode(depth = 0,max_depth=depth,feature_index = feat_index_start)
# Tree.build_tree()
# X = generate_points(num=num_points,dim=dim_in,type=type_data)
# Y=[]
# for item in X:
# Y.append(Tree.predict(item))
# Y = torch.tensor(Y)
class CustomDataset(Dataset):
def __init__(self, x, y, transform=None):
self.x = x
self.y = y
self.transform = transform
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
sample_x = self.x[idx]
sample_y = self.y[idx]
if self.transform:
sample_x, sample_y = self.transform(sample_x, sample_y)
return sample_x, sample_y