Skip to content

Commit

Permalink
clean dataset ops
Browse files Browse the repository at this point in the history
  • Loading branch information
SevenLJY committed Apr 4, 2024
1 parent ffb9ebb commit 8b1063b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
21 changes: 10 additions & 11 deletions datamodules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,8 @@ def _random_permute(self, graph, nodes):
'''
N = len(nodes)
order = np.random.permutation(N)
mapping = {i: order[i] for i in range(N)}
mapping.update({i: i for i in range(len(nodes), self.hparams.K)})
graph_permuted = self._reorder_nodes(graph, mapping)
graph_permuted = self._reorder_nodes(graph, order)
nodes_permuted = nodes[order, :]
graph_permuted['parents'] = graph['parents'][order]
return graph_permuted, nodes_permuted

def _build_graph(self, nodes):
Expand All @@ -49,7 +46,6 @@ def _build_graph(self, nodes):
Args:
nodes: list of nodes
K: size of the adjacency matrix
Returns:
adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes
edge_list: list of edges, for visualization
Expand Down Expand Up @@ -77,24 +73,27 @@ def _build_graph(self, nodes):
'parents': np.array(parents, dtype=np.int8)
}

def _reorder_nodes(self, graph, mapping):
def _reorder_nodes(self, graph, order):
'''
Function to reorder nodes in the graph and
update the adjacency matrix, edge list, and root node.
Args:
graph: a dictionary containing the adjacency matrix, edge list, and root node
mapping: a dictionary mapping the old node id to the new node id
order: a list of indices for reordering
Returns:
new_graph: a dictionary containing the updated adjacency matrix, edge list, and root node
'''
N = len(order)
mapping = {i: order[i] for i in range(N)}
mapping.update({i: i for i in range(N, self.hparams.K)})
G = nx.from_numpy_array(graph['adj'], create_using=nx.Graph)
G_ = nx.relabel_nodes(G, mapping)
new_adj = nx.adjacency_matrix(G_, G.nodes).todense()
return {
'adj': new_adj.astype(np.float32),
'root': mapping[graph['root']],
'parents': graph['parents'],
'parents': graph['parents'][order],
}

def _prepare_node_data(self, node):
Expand All @@ -109,9 +108,9 @@ def _prepare_node_data(self, node):
aabb_min = aabb_center - aabb_size / 2
# joint axis and range
if node['joint']['type'] == 'fixed':
axis_dir = np.zeros((3,))
axis_ori = np.zeros((3,))
joint_range = np.zeros((2,))
axis_dir = np.zeros((3,), dtype=np.float32)
axis_ori = np.zeros((3,), dtype=np.float32)
joint_range = np.zeros((2,), dtype=np.float32)
else:
if node['joint']['type'] == 'revolute' or node['joint']['type'] == 'continuous':
joint_range = np.array([node['joint']['range'][1]], dtype=np.float32) / 360.
Expand Down
2 changes: 1 addition & 1 deletion datamodules/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _load_graph_cat(self, idx):
return data, cond

def __getitem__(self, idx):
if self.hparams.pred_mode == 'uncond' or self.hparams.pred_mode == 'cond_graph':
if self.hparams.pred_mode in ('uncond', 'cond_graph'):
data, cond = self._load_graph_cat(idx)
else: # conditional on node attributes
data, cond = self._prepare_item(idx)
Expand Down

0 comments on commit 8b1063b

Please sign in to comment.