Skip to content

Commit

Permalink
formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sheiksadique committed Dec 5, 2023
1 parent 0325c80 commit 53109c3
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 31 deletions.
2 changes: 1 addition & 1 deletion nirtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .graph import extract_torch_graph # noqa F401
from .from_nir import load # noqa F401
from .graph import extract_torch_graph # noqa F401
from .to_nir import extract_nir_graph # noqa F401

__version__ = version = "0.2.1"
25 changes: 13 additions & 12 deletions nirtorch/from_nir.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import inspect
from typing import Callable, Dict, List, Optional, Any, Union
from typing import Any, Callable, Dict, List, Optional, Union

import nir
import torch
Expand All @@ -13,8 +13,8 @@

@dataclasses.dataclass
class GraphExecutorState:
"""State for the GraphExecutor that keeps track of both the state of hidden
units and caches the output of previous modules, for use in (future) recurrent
"""State for the GraphExecutor that keeps track of both the state of hidden units
and caches the output of previous modules, for use in (future) recurrent
computations."""

state: Dict[str, Any] = dataclasses.field(default_factory=dict)
Expand All @@ -24,14 +24,14 @@ class GraphExecutorState:
class GraphExecutor(nn.Module):
"""Executes the NIR graph in PyTorch.
By default the graph executor is stateful, since there may be recurrence or
stateful modules in the graph. Specifically, that means accepting and returning a
state object (`GraphExecutorState`). If that is not desired,
By default the graph executor is stateful, since there may be recurrence or
stateful modules in the graph. Specifically, that means accepting and returning a
state object (`GraphExecutorState`). If that is not desired,
set `return_state=False` in the constructor.
Arguments:
graph (Graph): The graph to execute
return_state (bool, optional): Whether to return the state object.
return_state (bool, optional): Whether to return the state object.
Defaults to True.
Raises:
Expand Down Expand Up @@ -92,6 +92,7 @@ def _apply_module(
data: Optional[torch.Tensor] = None,
):
"""Applies a module and keeps track of its state.
TODO: Use pytree to recursively construct the state
"""
inputs = []
Expand Down Expand Up @@ -205,7 +206,7 @@ def load(
"""Load a NIR graph and convert it to a torch module using the given model map.
Because the graph can contain recurrence and stateful modules, the execution accepts
a secondary state argument and returns a tuple of [output, state], instead of just
a secondary state argument and returns a tuple of [output, state], instead of just
the output as follows
>>> executor = nirtorch.load(nir_graph, model_map)
Expand All @@ -216,13 +217,13 @@ def load(
If you do not wish to operate with state, set `return_state=False`.
Args:
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string
representing the path to the NIR object.
model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch
module that corresponds to each NIR node.
return_state (bool): If True, the execution of the loaded graph will return a
tuple of [output, state], where state is a GraphExecutorState object.
If False, only the NIR graph output will be returned. Note that state is
return_state (bool): If True, the execution of the loaded graph will return a
tuple of [output, state], where state is a GraphExecutorState object.
If False, only the NIR graph output will be returned. Note that state is
required for recurrence to work in the graphs.
Returns:
Expand Down
8 changes: 4 additions & 4 deletions nirtorch/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def find_all_ancestors(
# return execution_order
#


def trace_execution(
node: T, edge_fn: Callable[[T], List[T]], visited: Set[T] = None
) -> List[T]:
"""Traces the execution of a node by listing them in order, coloring recursive
nodes to avoid adding the same node twice.
"""
"""Traces the execution of a node by listing them in order, coloring recursive nodes
to avoid adding the same node twice."""
if visited is None:
visited = set()

Expand All @@ -83,4 +83,4 @@ def trace_execution(
for child in edge_fn(node):
if child not in visited:
successors += trace_execution(child, edge_fn, visited)
return [node] + successors
return [node] + successors
9 changes: 5 additions & 4 deletions nirtorch/to_nir.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Callable, Optional, Sequence
import logging
from typing import Any, Callable, Optional, Sequence

import nir
import numpy as np
Expand All @@ -15,13 +15,14 @@ def extract_nir_graph(
model_name: Optional[str] = "model",
ignore_submodules_of=None,
model_fwd_args=[],
ignore_dims: Optional[Sequence[int]]=None,
ignore_dims: Optional[Sequence[int]] = None,
) -> nir.NIRNode:
"""Given a `model`, generate an NIR representation using the specified `model_map`.
Assumptions and known issues:
- Cannot deal with layers like torch.nn.Identity(), since the input tensor and output
tensor will be the same object, and therefore lead to cyclic connections.
- Cannot deal with layers like torch.nn.Identity(), since the input tensor and
output tensor will be the same object, and therefore lead to cyclic
connections.
Args:
model (nn.Module): The model of interest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bidirectional.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import nir
import numpy as np
import torch
import nirtorch

import nirtorch

use_snntorch = False
# use_snntorch = True
Expand Down
2 changes: 1 addition & 1 deletion tests/test_conversion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import nir
import torch
import torch.nn as nn

import nir
import nirtorch


Expand Down
8 changes: 6 additions & 2 deletions tests/test_from_nir.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import nir
import numpy as np
import torch
import pytest
import torch

from nirtorch.from_nir import load

Expand Down Expand Up @@ -56,7 +56,10 @@ def test_extract_empty():


def test_extract_illegal_name():
graph = nir.NIRGraph({"a.b": nir.Input(np.ones(1)), "a.c": nir.Linear(np.array([[1.]]))}, [("a.b", "a.c")])
graph = nir.NIRGraph(
{"a.b": nir.Input(np.ones(1)), "a.c": nir.Linear(np.array([[1.0]]))},
[("a.b", "a.c")],
)
torch_graph = load(graph, _torch_model_map)
assert "a_c" in torch_graph._modules

Expand Down Expand Up @@ -131,6 +134,7 @@ def _map_stateful(node):
m = load(g, _map_stateful, return_state=False)
assert not isinstance(m(torch.ones(10)), tuple)


def test_execute_recurrent():
w = np.ones((1, 1))
g = nir.NIRGraph(
Expand Down
8 changes: 6 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import nir
import pytest
import torch
import torch.nn as nn
from norse.torch import LIBoxCell, LIFCell, SequentialState
from sinabs.layers import Merge

import nir
from nirtorch import extract_nir_graph, extract_torch_graph


Expand Down Expand Up @@ -238,6 +238,7 @@ def test_root_has_no_source():
len(graph.find_source_nodes_of(graph.find_node(my_branched_model.relu1))) == 0
)


@pytest.mark.skip(reason="Root tracing is broken")
def test_get_root():
graph = extract_torch_graph(my_branched_model, sample_data=data, model_name=None)
Expand Down Expand Up @@ -282,7 +283,9 @@ def test_sequential_flatten():
assert tuple(g.nodes["input"].input_type["input"]) == (3, 4)

d = torch.empty(2, 3, 4)
g = extract_nir_graph(torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d, ignore_dims=[0])
g = extract_nir_graph(
torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d, ignore_dims=[0]
)
assert tuple(g.nodes["input"].input_type["input"]) == (3, 4)


Expand Down Expand Up @@ -311,6 +314,7 @@ def forward(self, x, state=None):
assert d.nodes.keys() == {"input", "l", "r", "output"}
assert set(d.edges) == {("input", "r"), ("r", "l"), ("l", "output"), ("r", "r")}


@pytest.mark.skip(reason="Subgraphs are currently flattened")
def test_captures_recurrence_manually():
def export_affine_rec_gru(module):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def from_string(graph):

def __hash__(self) -> int:
return self.name.__hash__()

def __eq__(self, other: object) -> bool:
return self.name == other.name

Expand Down
12 changes: 9 additions & 3 deletions tests/test_to_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def extractor(module: nn.Module):
return nir.Affine(module.weight, module.bias)

raw_input_shape = (1, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0])
g = extract_nir_graph(
model, extractor, torch.ones(raw_input_shape), ignore_dims=[0]
)
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))
assert g.nodes["model"].weight.shape == (1, 3)
Expand All @@ -129,13 +131,17 @@ def extractor(module: nn.Module):
return nir.Affine(module.weight, module.bias)

raw_input_shape = (1, 10, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, -2])
g = extract_nir_graph(
model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, -2]
)
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))
assert g.nodes["model"].weight.shape == (1, 3)

raw_input_shape = (1, 10, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, 1])
g = extract_nir_graph(
model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, 1]
)
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))

Expand Down

0 comments on commit 53109c3

Please sign in to comment.