Skip to content

Commit

Permalink
fixed issue with actualizing view
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Feb 20, 2024
1 parent e5affb5 commit 84a7b11
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 18 deletions.
19 changes: 13 additions & 6 deletions src/treedata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import abc as cabc
from collections import defaultdict
from collections.abc import Iterator, Mapping, Sequence
from contextlib import contextmanager
from functools import reduce
from typing import (
TYPE_CHECKING,
Expand All @@ -22,7 +23,8 @@
from anndata import AnnData
from anndata.raw import Raw

from .treedata import TreeData
from treedata._core.treedata import TreeData


OneDIdx = Union[Sequence[int], Sequence[bool], slice]
TwoDIdx = tuple[OneDIdx, OneDIdx]
Expand Down Expand Up @@ -98,7 +100,7 @@ def parent(self) -> AnnData | Raw:

@property
def attrname(self) -> str:
return f"{self.dim}m"
return f"{self.dim}t"

@property
def axes(self) -> tuple[Literal[0, 1]]:
Expand Down Expand Up @@ -188,20 +190,22 @@ def __getitem__(self, key: str) -> nx.DiGraph:
return subset_tree(self.parent_mapping[key], subset_leaves, asview=True)

def __setitem__(self, key: str, value: nx.DiGraph):
value = self._validate_value(value, key) # Validate before mutating
value, _ = self._validate_tree(value, key) # Validate before mutating
warnings.warn(
f"Setting element `.{self.attrname}['{key}']` of view, " "initializing view as actual.", stacklevel=2
f"Setting element `.{self.attrname}['{key}']` of view, initializing view as actual.", stacklevel=2
)
with view_update(self.parent, self.attrname, ()) as new_mapping:
new_mapping[key] = value
print("here2")
print(key)

def __delitem__(self, key: str):
if key not in self:
raise KeyError(
"'{key!r}' not found in view of {self.attrname}"
) # Make sure it exists before bothering with a copy
warnings.warn(
f"Removing element `.{self.attrname}['{key}']` of view, " "initializing view as actual.", stacklevel=2
f"Removing element `.{self.attrname}['{key}']` of view, initializing view as actual.", stacklevel=2
)
with view_update(self.parent, self.attrname, ()) as new_mapping:
del new_mapping[key]
Expand All @@ -216,6 +220,7 @@ def __len__(self) -> int:
return len(self.parent_mapping)


@contextmanager
def view_update(tdata_view: TreeData, attr_name: str, keys: tuple[str, ...]):
"""Context manager for updating a view of an AnnData object.
Expand All @@ -234,8 +239,10 @@ def view_update(tdata_view: TreeData, attr_name: str, keys: tuple[str, ...]):
------
`adata.attr[key1][key2][keyn]...`
"""
new = TreeData.copy()
new = tdata_view.copy()
attr = getattr(new, attr_name)
for key in attr:
print(key)
container = reduce(lambda d, k: d[k], keys, attr)
yield container
tdata_view._init_as_actual(new)
29 changes: 18 additions & 11 deletions src/treedata/_core/treedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,24 @@ def _init_as_actual(
filemode=filemode,
)

if label is not None:
for attr in ["obs", "var"]:
if label in getattr(self, attr).columns:
warnings.warn(f"label {label} already present in .{attr} overwriting it", stacklevel=2)
getattr(self, attr)[label] = pd.NA
self._tree_label = label

self._allow_overlap = allow_overlap

self._obst = AxisTrees(self, 0, vals=obst)
self._vart = AxisTrees(self, 1, vals=vart)
# init from TreeData
if isinstance(X, TreeData):
self._tree_label = X.label
self._allow_overlap = X.allow_overlap
self._obst = X.obst
self._vart = X.vart

# init from scratch
else:
if label is not None:
for attr in ["obs", "var"]:
if label in getattr(self, attr).columns:
warnings.warn(f"label {label} already present in .{attr} overwriting it", stacklevel=2)
getattr(self, attr)[label] = pd.NA
self._tree_label = label
self._allow_overlap = allow_overlap
self._obst = AxisTrees(self, 0, vals=obst)
self._vart = AxisTrees(self, 1, vals=vart)

def _init_as_view(self, tdata_ref: TreeData, oidx: Index, vidx: Index):
super()._init_as_view(tdata_ref, oidx, vidx)
Expand Down
44 changes: 43 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def tree():
@pytest.fixture
def tdata(tree):
df = pd.DataFrame({"anno": range(8)}, index=[str(i) for i in range(7, 15)])
yield td.TreeData(X=np.zeros((8, 8)), obst={"tree": tree}, vart={"tree": tree}, obs=df, var=df)
yield td.TreeData(X=np.zeros((8, 8)), obst={"tree": tree}, vart={"tree": tree}, obs=df, var=df, allow_overlap=True)


def test_views(tdata):
Expand Down Expand Up @@ -74,3 +74,45 @@ def test_views_mutability(tdata):
# cannot mutate structure of graph
with pytest.raises(nx.NetworkXError):
tdata_subset.obst["tree"].remove_node("8")


def test_set(tdata):
tdata_subset = tdata[[0, 1, 4], :]
# bad assignment
bad_tree = nx.DiGraph()
bad_tree.add_edge("0", "bad")
with pytest.raises(ValueError):
tdata_subset.obst["new_tree"] = bad_tree
assert tdata_subset.is_view
# good assignment actualizes object
new_tree = nx.DiGraph()
new_tree.add_edge("0", "8")
with pytest.warns(UserWarning):
tdata_subset.obst["new_tree"] = new_tree
assert not tdata_subset.is_view
assert list(tdata_subset.obst.keys()) == ["tree", "new_tree"]
assert list(tdata_subset.obst["new_tree"].edges) == [("0", "8")]


def test_del(tdata):
tdata_subset = tdata[[0, 1, 4], :]
# bad deletion
with pytest.raises(KeyError):
del tdata_subset.obst["bad"]
assert tdata_subset.is_view
# good deletion actualizes object
with pytest.warns(UserWarning):
del tdata_subset.obst["tree"]
assert not tdata_subset.is_view
assert list(tdata_subset.obst.keys()) == []


def test_contains(tdata):
tdata_subset = tdata[[0, 1, 4], :]
assert "tree" in tdata_subset.obst
assert "bad" not in tdata_subset.obst


def test_len(tdata):
tdata_subset = tdata[[0, 1, 4], :]
assert len(tdata_subset.obst) == 1

0 comments on commit 84a7b11

Please sign in to comment.