-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
signalflow_visualisation: Add plot_node_output()
- Loading branch information
Showing
3 changed files
with
114 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,93 +1,2 @@ | ||
import os | ||
import json | ||
import networkx as nx | ||
from IPython.display import SVG | ||
from signalflow import Patch | ||
|
||
def visualise_patch_structure(patch: Patch, filename: str = None, dpi: int = None): | ||
""" | ||
Renders the structure of a patch as a directed graph. | ||
Requires: | ||
- networkx | ||
- pygraphviz (https://github.com/pygraphviz/pygraphviz/issues/11) | ||
Args: | ||
patch (Patch): The patch to diagram. | ||
filename (str): If specified, writes the output to a file (can be .svg, .pdf, .png) | ||
dpi (int): If specified, overwrites the default DPI (which is 72 for screen, 300 for file) | ||
Returns: | ||
An IPython SVG object that can be rendered in a notebook. | ||
TODO: Implement support for cyclical graphs (requires reformulating JSON using JSON pointers) | ||
""" | ||
|
||
G = nx.DiGraph() | ||
|
||
def label_map(label): | ||
lookup = { | ||
"add": "+", | ||
"multiply": "×", | ||
"subtract": "-", | ||
"divide": "÷" | ||
} | ||
if label in lookup.keys(): | ||
return lookup[label] | ||
else: | ||
return label | ||
|
||
spec = patch.to_spec() | ||
structure = json.loads(spec.to_json()) | ||
nodes = structure["nodes"] | ||
for node in nodes: | ||
node_label = node["node"] | ||
node_label = label_map(node_label) | ||
node_label = "<b>%s</b>" % node_label | ||
|
||
node_label += "<font point-size='8'><br />" | ||
for input_key, input_value in node["inputs"].items(): | ||
if not isinstance(input_value, dict): | ||
node_label += "<br /><font point-size='2'><br /></font>%s = %s" % (input_key, round(input_value, 7)) | ||
node_label += "</font>" | ||
|
||
# special graphviz syntax for enabling HTML formatting in node labels | ||
node_label = "<%s>" % node_label | ||
G.add_node(node["id"], label=node_label) | ||
for node in nodes: | ||
for input_key, input_value in node["inputs"].items(): | ||
if isinstance(input_value, dict): | ||
label = "" | ||
if not input_key.startswith("input"): | ||
label = input_key | ||
# white background | ||
# label = "<<table border='0' cellborder='0' cellspacing='0'><tr><td bgcolor='white'>%s</td></tr></table>>" % label | ||
G.add_edge(input_value["id"], node["id"], label=label) | ||
|
||
ag = nx.nx_agraph.to_agraph(G) | ||
ag.graph_attr["splines"] = "polyline" | ||
ag.node_attr["penwidth"] = 0.5 | ||
ag.node_attr["fontname"] = "helvetica" | ||
ag.node_attr["fontsize"] = 9 | ||
ag.node_attr["margin"] = 0.12 | ||
ag.node_attr["height"] = 0.3 | ||
ag.edge_attr["fontname"] = "helvetica" | ||
ag.edge_attr["fontsize"] = 8 | ||
ag.edge_attr["penwidth"] = 0.5 | ||
ag.edge_attr["arrowsize"] = 0.5 | ||
ag.edge_attr["labelfloat"] = False | ||
ag.edge_attr["labeldistance"] = 0 | ||
ag.node_attr["shape"] = "rectangle" | ||
ag.layout(prog='dot') | ||
|
||
if filename is not None: | ||
_, format = os.path.splitext(filename) | ||
format = format[1:] | ||
assert format in ["svg", "pdf", "png"] | ||
ag.graph_attr["dpi"] = dpi if dpi is not None else 300 | ||
ag.draw(format=format, | ||
path=filename) | ||
else: | ||
ag.graph_attr["dpi"] = dpi if dpi is not None else 72 | ||
svg = ag.draw(format='svg') | ||
return SVG(svg) | ||
from .patch_structure import visualise_patch_structure | ||
from .node import plot_node_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from signalflow import Node | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
def plot_node_output(node: Node, duration: float = 1.0): | ||
# TODO: Replace with graph.render_subgraph_to_new_buffer(node) when this is implemented | ||
num_frames = int(duration * node.graph.sample_rate) | ||
output = np.zeros(num_frames) | ||
num_chunks = int(np.ceil(num_frames / node.graph.output_buffer_size)) | ||
for n in range(0, num_chunks): | ||
offset_start = n * node.graph.output_buffer_size | ||
offset_end = (n + 1) * node.graph.output_buffer_size | ||
offset_end = min(num_frames, offset_end) | ||
chunk_length = offset_end - offset_start | ||
node.graph.reset_subgraph(node) | ||
node.graph.render_subgraph(node) | ||
output[offset_start:offset_end] = node.output_buffer[0][:chunk_length] | ||
plt.plot(output) | ||
plt.show() |
93 changes: 93 additions & 0 deletions
93
auxiliary/libs/signalflow_visualisation/patch_structure.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
import json | ||
import networkx as nx | ||
from IPython.display import SVG | ||
from signalflow import Patch | ||
|
||
def visualise_patch_structure(patch: Patch, filename: str = None, dpi: int = None): | ||
""" | ||
Renders the structure of a patch as a directed graph. | ||
Requires: | ||
- networkx | ||
- pygraphviz (https://github.com/pygraphviz/pygraphviz/issues/11) | ||
Args: | ||
patch (Patch): The patch to diagram. | ||
filename (str): If specified, writes the output to a file (can be .svg, .pdf, .png) | ||
dpi (int): If specified, overwrites the default DPI (which is 72 for screen, 300 for file) | ||
Returns: | ||
An IPython SVG object that can be rendered in a notebook. | ||
TODO: Implement support for cyclical graphs (requires reformulating JSON using JSON pointers) | ||
""" | ||
|
||
G = nx.DiGraph() | ||
|
||
def label_map(label): | ||
lookup = { | ||
"add": "+", | ||
"multiply": "×", | ||
"subtract": "-", | ||
"divide": "÷" | ||
} | ||
if label in lookup.keys(): | ||
return lookup[label] | ||
else: | ||
return label | ||
|
||
spec = patch.to_spec() | ||
structure = json.loads(spec.to_json()) | ||
nodes = structure["nodes"] | ||
for node in nodes: | ||
node_label = node["node"] | ||
node_label = label_map(node_label) | ||
node_label = "<b>%s</b>" % node_label | ||
|
||
node_label += "<font point-size='8'><br />" | ||
for input_key, input_value in node["inputs"].items(): | ||
if not isinstance(input_value, dict): | ||
node_label += "<br /><font point-size='2'><br /></font>%s = %s" % (input_key, round(input_value, 7)) | ||
node_label += "</font>" | ||
|
||
# special graphviz syntax for enabling HTML formatting in node labels | ||
node_label = "<%s>" % node_label | ||
G.add_node(node["id"], label=node_label) | ||
for node in nodes: | ||
for input_key, input_value in node["inputs"].items(): | ||
if isinstance(input_value, dict): | ||
label = "" | ||
if not input_key.startswith("input"): | ||
label = input_key | ||
# white background | ||
# label = "<<table border='0' cellborder='0' cellspacing='0'><tr><td bgcolor='white'>%s</td></tr></table>>" % label | ||
G.add_edge(input_value["id"], node["id"], label=label) | ||
|
||
ag = nx.nx_agraph.to_agraph(G) | ||
ag.graph_attr["splines"] = "polyline" | ||
ag.node_attr["penwidth"] = 0.5 | ||
ag.node_attr["fontname"] = "helvetica" | ||
ag.node_attr["fontsize"] = 9 | ||
ag.node_attr["margin"] = 0.12 | ||
ag.node_attr["height"] = 0.3 | ||
ag.edge_attr["fontname"] = "helvetica" | ||
ag.edge_attr["fontsize"] = 8 | ||
ag.edge_attr["penwidth"] = 0.5 | ||
ag.edge_attr["arrowsize"] = 0.5 | ||
ag.edge_attr["labelfloat"] = False | ||
ag.edge_attr["labeldistance"] = 0 | ||
ag.node_attr["shape"] = "rectangle" | ||
ag.layout(prog='dot') | ||
|
||
if filename is not None: | ||
_, format = os.path.splitext(filename) | ||
format = format[1:] | ||
assert format in ["svg", "pdf", "png"] | ||
ag.graph_attr["dpi"] = dpi if dpi is not None else 300 | ||
ag.draw(format=format, | ||
path=filename) | ||
else: | ||
ag.graph_attr["dpi"] = dpi if dpi is not None else 72 | ||
svg = ag.draw(format='svg') | ||
return SVG(svg) |