Skip to content

Commit

Permalink
made obst immutable
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Feb 20, 2024
1 parent 4649137 commit e5affb5
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 38 deletions.
79 changes: 46 additions & 33 deletions src/treedata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import networkx as nx
import pandas as pd

from treedata._utils import get_leaves, get_root, subset_tree
from treedata._utils import subset_tree

if TYPE_CHECKING:
from anndata import AnnData
Expand All @@ -39,34 +39,45 @@ def __repr__(self):
def _ipython_key_completions_(self) -> list[str]:
return list(self.keys())

def _validate_value(self, value: nx.DiGraph, key: str) -> nx.DiGraph:
def _validate_tree(self, tree: nx.DiGraph, key: str) -> nx.DiGraph:
# Check value type
if not isinstance(value, nx.DiGraph):
raise ValueError(f"Tree for key {key} must be a nx.DiGraph")
# Check acyclic
if not nx.is_directed_acyclic_graph(value):
raise ValueError(f"Tree for key {key} cannot have cycles")
# Check fully connected
if not nx.is_weakly_connected(value):
raise ValueError(f"Tree for key {key} must be fully connected")
if not isinstance(tree, nx.DiGraph):
raise ValueError(f"Value for key {key} must be a nx.DiGraph")
# Check tree
if tree.number_of_nodes() != tree.number_of_edges() + 1:
raise ValueError(f"Value for key {key} must be a tree")
root_count = 0
leaves = set()
for node in tree.nodes:
if tree.in_degree(node) == 0:
root_count += 1
elif tree.in_degree(node) > 1:
raise ValueError(f"Value for key {key} must be a tree")
if tree.out_degree(node) == 0:
leaves.add(node)
if root_count != 1:
raise ValueError(f"Value for key {key} must be a tree")
# Check alignment
leaves = get_leaves(value)
if not set(leaves).issubset(self.dim_names):
raise ValueError(f"Leaf nodes of tree for key {key} must be in {self.dim}_names")
# Check root
_ = get_root(value)
if not leaves.issubset(self.dim_names):
raise ValueError(f"Leaf names in must be in {self.dim}_names")
# Check overlap
if not self.parent.allow_overlap:
if set(leaves).intersection(self._membership.keys()):
raise ValueError(f"Leaf nodes of tree for key {key} overlap with other trees")
return value
if key in self._tree_to_leaf:
new_leaves = leaves.difference(self._tree_to_leaf[key])
else:
new_leaves = leaves
if new_leaves.intersection(self._leaf_to_tree.keys()):
raise ValueError(
"Leaf names overlap with leaf names of other trees.", "Set `allow_overlap=True` to allow this"
)
return tree, leaves

def _update_tree_labels(self):
if self.parent._tree_label is not None:
if self.parent.allow_overlap:
mapping = self._membership
mapping = self._leaf_to_tree
else:
mapping = {k: v[0] for k, v in self._membership.items()}
mapping = {k: v[0] for k, v in self._leaf_to_tree.items()}
getattr(self.parent, self.dim)[self.parent._tree_label] = getattr(self.parent, f"{self.dim}_names").map(
mapping
)
Expand Down Expand Up @@ -117,32 +128,35 @@ def __init__(
raise ValueError()
self._axis = axis
self._data = {}
self._membership = defaultdict(list)
self._tree_to_leaf = defaultdict(set)
self._leaf_to_tree = defaultdict(list)
if vals is not None:
self.update(vals)

def __getitem__(self, key: str) -> nx.DiGraph:
return self._data[key]
return nx.graphviews.generic_graph_view(self._data[key])

def __setitem__(self, key: str, value: nx.DiGraph):
value = self._validate_value(value, key)
value, leaves = self._validate_tree(value, key)

leaves = get_leaves(value)
for leaf in leaves:
self._membership[leaf].append(key)
self._leaf_to_tree[leaf].append(key)
self._tree_to_leaf[key] = leaves

if not self.parent.is_view:
self._update_tree_labels()

self._data[key] = value

def __delitem__(self, key: str):
leaves = get_leaves(self._data[key])
for leaf in leaves:
self._membership[leaf].remove(key)
if not self._membership[leaf]:
del self._membership[leaf]
for leaf in self._tree_to_leaf[key]:
self._leaf_to_tree[leaf].remove(key)
if not self._leaf_to_tree[leaf]:
del self._leaf_to_tree[leaf]
del self._tree_to_leaf[key]

self._update_tree_labels()

del self._data[key]

def __len__(self) -> int:
Expand All @@ -169,9 +183,8 @@ def __init__(
self._axis = parent_mapping._axis

def __getitem__(self, key: str) -> nx.DiGraph:
# Consider caching the subset trees
leaves = get_leaves(self.parent_mapping[key])
subset_leaves = set(leaves).intersection(self.dim_names.values)
leaves = self.parent_mapping._tree_to_leaf[key]
subset_leaves = leaves.intersection(self.dim_names.values)
return subset_tree(self.parent_mapping[key], subset_leaves, asview=True)

def __setitem__(self, key: str, value: nx.DiGraph):
Expand Down
32 changes: 27 additions & 5 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@ def tree():
yield tree


def check_graph_equality(g1, g2):
assert g1.nodes == g2.nodes
assert g1.edges == g2.edges


def test_creation(X, adata, tree):
# Test creation with np array
tdata = td.TreeData(X, obst={"tree": tree}, vart={"tree": tree}, label=None)
assert tdata.obst["tree"] == tree
assert tdata.vart["tree"] == tree
print(type(tdata))
check_graph_equality(tdata.obst["tree"], tree)
check_graph_equality(tdata.vart["tree"], tree)
# Test creation with anndata
tdata = td.TreeData(adata)
assert tdata.X is adata.X
Expand All @@ -37,14 +43,14 @@ def test_creation(X, adata, tree):
@pytest.mark.parametrize("dim", ["obs", "var"])
def test_tree_keys(X, tree, dim):
tdata = td.TreeData(X, obst={"tree": tree}, vart={"tree": tree}, label=None)
assert getattr(tdata, f"{dim}t_keys")() == ["tree"]
check_graph_equality(getattr(tdata, f"{dim}t")["tree"], tree)


@pytest.mark.parametrize("dim", ["obs", "var"])
def test_tree_set(X, tree, dim):
tdata = td.TreeData(X)
setattr(tdata, f"{dim}t", {"tree": tree})
assert getattr(tdata, f"{dim}t")["tree"] == tree
check_graph_equality(getattr(tdata, f"{dim}t")["tree"], tree)


@pytest.mark.parametrize("dim", ["obs", "var"])
Expand Down Expand Up @@ -86,7 +92,8 @@ def test_tree_overlap(X, tree):
tdata = td.TreeData(X, obst={"0": tree, "1": second_tree}, allow_overlap=False)
# Test overlap allowed
tdata = td.TreeData(X, obst={"0": tree, "1": second_tree}, allow_overlap=True)
assert tdata.obst == {"0": tree, "1": second_tree}
check_graph_equality(tdata.obst["0"], tree)
check_graph_equality(tdata.obst["1"], second_tree)


def test_repr(X, tree):
Expand All @@ -101,6 +108,21 @@ def test_repr(X, tree):
assert repr(tdata.obst) == expected_repr


def test_mutability(X, tree):
tdata = td.TreeData(X, obst={"tree": tree}, vart={"tree": tree}, label=None)
# Toplogy is immutable
with pytest.raises(nx.NetworkXError):
tdata.obst["tree"].remove_node("0")
# Attributes are mutable
nx.set_node_attributes(tdata.obst["tree"], True, "test")
assert all(tdata.obst["tree"].nodes[node]["test"] for node in tdata.obst["tree"].nodes)
# Topology mutable on copy
tree = tdata.obst["tree"].copy()
tree.remove_node("1")
tdata.obst["tree"] = tree
assert list(tdata.obst["tree"].nodes) == ["root", "0"]


def test_bad_tree(X):
# Not directed graph
not_di_graph = nx.Graph()
Expand Down
4 changes: 4 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def test_views_subset_tree(tdata):
tdata_subset = tdata[["7", "8", "11"], :]
edges = list(tdata_subset.obst["tree"].edges)
assert edges == expected_edges
# now transition to actual object
tdata_subset = tdata_subset.copy()
edges = list(tdata_subset.obst["tree"].edges)
assert edges == expected_edges


def test_views_mutability(tdata):
Expand Down

0 comments on commit e5affb5

Please sign in to comment.