Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Feb 20, 2024
1 parent 84a7b11 commit 56bf1d6
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 58 deletions.
8 changes: 3 additions & 5 deletions src/treedata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def _validate_tree(self, tree: nx.DiGraph, key: str) -> nx.DiGraph:
for node in tree.nodes:
if tree.in_degree(node) == 0:
root_count += 1
if tree.out_degree(node) == 0:
raise ValueError(f"Value for key {key} must be fully connected")
elif tree.in_degree(node) > 1:
raise ValueError(f"Value for key {key} must be a tree")
if tree.out_degree(node) == 0:
Expand Down Expand Up @@ -196,8 +198,6 @@ def __setitem__(self, key: str, value: nx.DiGraph):
)
with view_update(self.parent, self.attrname, ()) as new_mapping:
new_mapping[key] = value
print("here2")
print(key)

def __delitem__(self, key: str):
if key not in self:
Expand All @@ -222,7 +222,7 @@ def __len__(self) -> int:

@contextmanager
def view_update(tdata_view: TreeData, attr_name: str, keys: tuple[str, ...]):
"""Context manager for updating a view of an AnnData object.
"""Context manager for updating a view of an TreeData object.
Contains logic for "actualizing" a view. Yields the object to be modified in-place.
Expand All @@ -241,8 +241,6 @@ def view_update(tdata_view: TreeData, attr_name: str, keys: tuple[str, ...]):
"""
new = tdata_view.copy()
attr = getattr(new, attr_name)
for key in attr:
print(key)
container = reduce(lambda d, k: d[k], keys, attr)
yield container
tdata_view._init_as_actual(new)
2 changes: 1 addition & 1 deletion src/treedata/_core/treedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _init_as_actual(
self._vart = AxisTrees(self, 1, vals=vart)

def _init_as_view(self, tdata_ref: TreeData, oidx: Index, vidx: Index):
super()._init_as_view(tdata_ref, oidx, vidx)
super()._init_as_view(tdata_ref, oidx=oidx, vidx=vidx)

# view of obst and vart
self._obst = tdata_ref.obst._view(self, (oidx,))
Expand Down
14 changes: 0 additions & 14 deletions src/treedata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,6 @@
import networkx as nx


def get_leaves(tree: nx.DiGraph) -> list[str]:
"""Get the leaves of a tree."""
leaves = [n for n in tree.nodes if tree.out_degree(n) == 0]
return leaves


def get_root(tree: nx.DiGraph) -> str:
"""Get the root of a tree."""
roots = [n for n in tree.nodes if tree.in_degree(n) == 0]
if len(roots) != 1:
raise ValueError(f"Tree must have exactly one root, found {len(roots)}.")
return roots[0]


def subset_tree(tree: nx.DiGraph, leaves: list[str], asview: bool) -> nx.DiGraph:
"""Subset tree."""
keep_nodes = set(leaves)
Expand Down
16 changes: 12 additions & 4 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,17 @@ def test_creation(X, adata, tree):
assert tdata.X is adata.X


@pytest.mark.parametrize("dim", ["obs", "var"])
def test_tree_keys(X, tree, dim):
@pytest.mark.parametrize("axis", [0, 1])
def test_attributes(X, tree, axis):
dim = ["obs", "var"][axis]
tdata = td.TreeData(X, obst={"tree": tree}, vart={"tree": tree}, label=None)
check_graph_equality(getattr(tdata, f"{dim}t")["tree"], tree)
assert getattr(tdata, f"{dim}t").axes == (axis,)
assert getattr(tdata, f"{dim}t").attrname == (f"{dim}t")
assert getattr(tdata, f"{dim}t").dim == dim
assert getattr(tdata, f"{dim}t").parent is tdata
assert list(getattr(tdata, f"{dim}t").dim_names) == ["0", "1", "2"]
assert tdata.allow_overlap is False
assert tdata.label is None


@pytest.mark.parametrize("dim", ["obs", "var"])
Expand Down Expand Up @@ -131,6 +138,7 @@ def test_bad_tree(X):
# Has cycle
has_cycle = nx.DiGraph()
has_cycle.add_edges_from([("0", "1"), ("1", "0")])
has_cycle.add_node("2")
with pytest.raises(ValueError):
_ = td.TreeData(X, obst={"tree": has_cycle})
# Not fully connected
Expand All @@ -145,7 +153,7 @@ def test_bad_tree(X):
_ = td.TreeData(X, obst={"tree": bad_leaves})
# Multiple roots
multi_root = nx.DiGraph()
multi_root.add_edges_from([("root", "0"), ("bad", "0")])
multi_root.add_edges_from([("0", "1"), ("1", "0"), ("2", "3")])
with pytest.raises(ValueError):
_ = td.TreeData(X, obst={"tree": multi_root})

Expand Down
29 changes: 2 additions & 27 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,18 @@
import networkx as nx
import pytest

from treedata._utils import get_leaves, get_root, subset_tree
from treedata._utils import subset_tree


@pytest.fixture
def tree():
tree = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph)
root = get_root(tree)
root = [n for n, d in tree.in_degree() if d == 0][0]
depths = nx.single_source_shortest_path_length(tree, root)
nx.set_node_attributes(tree, values=depths, name="depth")
yield tree


def test_get_leaves():
tree = nx.DiGraph()
tree.add_edges_from([("root", "0"), ("root", "1")])
assert get_leaves(tree) == ["0", "1"]


def test_get_root():
tree = nx.DiGraph()
tree.add_edges_from([("root", "0"), ("root", "1")])
assert get_root(tree) == "root"


def test_get_root_raises():
# Has cycle
has_cycle = nx.DiGraph()
has_cycle.add_edges_from([("root", "0"), ("0", "root")])
with pytest.raises(ValueError):
get_root(has_cycle)
# Multiple roots
multi_root = nx.DiGraph()
multi_root.add_edges_from([("root", "0"), ("bad", "0")])
with pytest.raises(ValueError):
get_root(multi_root)


def test_subset_tree(tree):
# copy
subtree = subset_tree(tree, [7, 8, 9], asview=False)
Expand Down
23 changes: 16 additions & 7 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
import pytest

import treedata as td
from treedata._utils import get_root


@pytest.fixture
def tree():
tree = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph)
tree = nx.relabel_nodes(tree, {i: str(i) for i in tree.nodes})
root = get_root(tree)
depths = nx.single_source_shortest_path_length(tree, root)
depths = nx.single_source_shortest_path_length(tree, "0")
nx.set_node_attributes(tree, values=depths, name="depth")
yield tree

Expand All @@ -36,6 +34,16 @@ def test_views(tdata):
assert tdata_subset.obs["test"].tolist() == list(range(2))


# this test should pass once anndata bug is fixed
# See https://github.com/scverse/anndata/issues/1382
@pytest.mark.xfail
def test_views_creation(tdata):
tdata_view = td.TreeData(tdata, asview=True)
assert tdata_view.is_view
with pytest.raises(ValueError):
_ = td.TreeData(np.zeros(shape=(3, 3)), asview=False)


def test_views_subset_tree(tdata):
expected_edges = [
("0", "1"),
Expand All @@ -58,6 +66,7 @@ def test_views_subset_tree(tdata):
tdata_subset = tdata_subset.copy()
edges = list(tdata_subset.obst["tree"].edges)
assert edges == expected_edges
assert len(tdata.obst["tree"].edges) == 14


def test_views_mutability(tdata):
Expand All @@ -76,7 +85,7 @@ def test_views_mutability(tdata):
tdata_subset.obst["tree"].remove_node("8")


def test_set(tdata):
def test_views_set(tdata):
tdata_subset = tdata[[0, 1, 4], :]
# bad assignment
bad_tree = nx.DiGraph()
Expand All @@ -94,7 +103,7 @@ def test_set(tdata):
assert list(tdata_subset.obst["new_tree"].edges) == [("0", "8")]


def test_del(tdata):
def test_views_del(tdata):
tdata_subset = tdata[[0, 1, 4], :]
# bad deletion
with pytest.raises(KeyError):
Expand All @@ -107,12 +116,12 @@ def test_del(tdata):
assert list(tdata_subset.obst.keys()) == []


def test_contains(tdata):
def test_views_contains(tdata):
tdata_subset = tdata[[0, 1, 4], :]
assert "tree" in tdata_subset.obst
assert "bad" not in tdata_subset.obst


def test_len(tdata):
def test_views_len(tdata):
tdata_subset = tdata[[0, 1, 4], :]
assert len(tdata_subset.obst) == 1

0 comments on commit 56bf1d6

Please sign in to comment.