Skip to content

Commit

Permalink
Remove existing causal mechanisms when creating GCM
Browse files Browse the repository at this point in the history
Before, when a causal graph had causal mechanisms assigned, they were also used when creating a new GCM object based on it. Now, they are removed (from a copied version of the graph).

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Jul 29, 2024
1 parent 57fd684 commit 474c584
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion dowhy/gcm/causal_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@ class ProbabilisticCausalModel:
causal mechanisms can be any general stochastic models."""

def __init__(
self, graph: Optional[DirectedGraph] = None, graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph
self,
graph: Optional[DirectedGraph] = None,
graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph,
remove_existing_mechanisms: bool = False,
):
"""
:param graph: Optional graph object to be used as causal graph.
:param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph
constructor.
:param remove_existing_mechanisms: If True, removes existing causal mechanisms assigned to nodes if they exist.
Otherwise, does not modify graph.
"""
# Todo: Remove after https://github.com/py-why/dowhy/pull/943.
from dowhy.causal_graph import CausalGraph
Expand All @@ -50,6 +55,11 @@ def __init__(
elif isinstance(graph, CausalGraph):
graph = graph_copier(graph._graph)

if remove_existing_mechanisms:
for node in graph.nodes:
if CAUSAL_MECHANISM in graph.nodes[node]:
del graph.nodes[node][CAUSAL_MECHANISM]

self.graph = graph
self.graph_copier = graph_copier

Expand Down

0 comments on commit 474c584

Please sign in to comment.