Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Generalized Adjustment Criterion #1292

Merged
merged 17 commits into from
Jan 21, 2025

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions dowhy/causal_identifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
AutoIdentifier,
BackdoorAdjustment,
EstimandType,
construct_backdoor_estimand,
GeneralizedAdjustment,
construct_adjustment_estimand,
construct_frontdoor_estimand,
construct_iv_estimand,
identify_effect_auto,
Expand All @@ -16,11 +17,12 @@
"identify_effect_auto",
"identify_effect_id",
"BackdoorAdjustment",
"GeneralizedAdjustment",
"EstimandType",
"IdentifiedEstimand",
"IDIdentifier",
"identify_effect",
"construct_backdoor_estimand",
"construct_adjustment_estimand",
"construct_frontdoor_estimand",
"construct_iv_estimand",
]
29 changes: 29 additions & 0 deletions dowhy/causal_identifier/adjustment_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class AdjustmentSet:
"""Class for storing an adjustment set."""

BACKDOOR = "backdoor"
# General adjustment sets generalize backdoor sets, but we will differentiate
# between the two given the ubiquity of the backdoor criterion.
GENERAL = "general"

def __init__(
self,
adjustment_type,
adjustment_variables,
num_paths_blocked_by_observed_nodes=None,
):
self.adjustment_type = adjustment_type
self.adjustment_variables = adjustment_variables
self.num_paths_blocked_by_observed_nodes = num_paths_blocked_by_observed_nodes

def get_adjustment_type(self):
"""Return the technique associated with this adjustment set (backdoor, etc.)"""
return self.adjustment_type

def get_adjustment_variables(self):
"""Return a list containing the adjustment variables"""
return self.adjustment_variables

def get_num_paths_blocked_by_observed_nodes(self):
"""Return the number of paths blocked by observed nodes (optional)"""
return self.num_paths_blocked_by_observed_nodes
187 changes: 142 additions & 45 deletions dowhy/causal_identifier/auto_identifier.py

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions dowhy/causal_identifier/backdoor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import networkx as nx

from dowhy.causal_identifier.adjustment_set import AdjustmentSet
from dowhy.utils.graph_operations import adjacency_matrix_to_adjacency_list


Expand Down Expand Up @@ -113,11 +114,13 @@ def get_backdoor_vars(self):
self._path_search(adjlist, node1, node2, path_dict)
if len(path_dict) != 0:
obj = HittingSetAlgorithm(path_dict[(node1, node2)].get_condition_vars(), self._colliders)

backdoor_set = {}
backdoor_set["backdoor_set"] = tuple(obj.find_set())
backdoor_set["num_paths_blocked_by_observed_nodes"] = obj.num_sets()
backdoor_sets.append(backdoor_set)
backdoor_sets.append(
AdjustmentSet(
adjustment_type=AdjustmentSet.BACKDOOR,
adjustment_variables=tuple(obj.find_set()),
num_paths_blocked_by_observed_nodes=obj.num_sets(),
)
)

return backdoor_sets

Expand Down
16 changes: 16 additions & 0 deletions dowhy/causal_identifier/identified_estimand.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ def __init__(
estimand_type=None,
estimands=None,
backdoor_variables=None,
general_adjustment_variables=None,
instrumental_variables=None,
frontdoor_variables=None,
mediator_variables=None,
mediation_first_stage_confounders=None,
mediation_second_stage_confounders=None,
default_backdoor_id=None,
default_adjustment_set_id=None,
identifier_method=None,
no_directed_path=False,
):
self.identifier = identifier
self.treatment_variable = parse_state(treatment_variable)
self.outcome_variable = parse_state(outcome_variable)
self.backdoor_variables = backdoor_variables
self.general_adjustment_variables = general_adjustment_variables
self.instrumental_variables = parse_state(instrumental_variables)
self.frontdoor_variables = parse_state(frontdoor_variables)
self.mediator_variables = parse_state(mediator_variables)
Expand All @@ -38,6 +41,7 @@ def __init__(
self.estimand_type = estimand_type
self.estimands = estimands
self.default_backdoor_id = default_backdoor_id
self.default_adjustment_set_id = default_adjustment_set_id
self.identifier_method = identifier_method
self.no_directed_path = no_directed_path

Expand Down Expand Up @@ -78,6 +82,13 @@ def get_instrumental_variables(self):
"""Return a list containing the instrumental variables (if present)"""
return self.instrumental_variables

def get_general_adjustment_variables(self, key: Optional[str] = None):
"""Return a list containing general adjustment variables."""
if key is None:
return self.general_adjustment_variables[self.default_adjustment_set_id]
else:
return self.general_adjustment_variables[key]

def __deepcopy__(self, memo):
return IdentifiedEstimand(
self.identifier, # not deep copied
Expand All @@ -86,10 +97,12 @@ def __deepcopy__(self, memo):
estimand_type=copy.deepcopy(self.estimand_type),
estimands=copy.deepcopy(self.estimands),
backdoor_variables=copy.deepcopy(self.backdoor_variables),
general_adjustment_variables=copy.deepcopy(self.general_adjustment_variables),
instrumental_variables=copy.deepcopy(self.instrumental_variables),
frontdoor_variables=copy.deepcopy(self.frontdoor_variables),
mediator_variables=copy.deepcopy(self.mediator_variables),
default_backdoor_id=copy.deepcopy(self.default_backdoor_id),
default_adjustment_set_id=copy.deepcopy(self.default_adjustment_set_id),
identifier_method=copy.deepcopy(self.identifier_method),
)

Expand All @@ -112,6 +125,9 @@ def __str__(self, only_target_estimand: bool = False, show_all_backdoor_sets: bo
# Just show the default backdoor set
if k.startswith("backdoor") and k != "backdoor":
continue
# Just show the default generalized adjustment set
if k.startswith("general") and k != "general_adjustment":
continue
if only_target_estimand and k != self.identifier_method:
continue
s += "\n### Estimand : {0}\n".format(i)
Expand Down
54 changes: 54 additions & 0 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module defines the fundamental interfaces and functions related to causal graphs."""

import copy
import itertools
import logging
import re
Expand Down Expand Up @@ -187,13 +188,66 @@ def is_blocked(graph: nx.DiGraph, path, conditioned_nodes):
return False


def get_ancestors(graph: nx.DiGraph, nodes):
ancestors = set()
for node_name in nodes:
ancestors = ancestors.union(set(nx.ancestors(graph, node_name)))
return ancestors


def get_descendants(graph: nx.DiGraph, nodes):
descendants = set()
for node_name in nodes:
descendants = descendants.union(set(nx.descendants(graph, node_name)))
return descendants


def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes):
"""Method to get the proper causal path nodes, as described in van der Zander et al. "Constructing Separators and
Adjustment Sets in Ancestral Graphs", Section 4.1. We cannot use do_surgery() since we require deep copies of the given graph.

:param graph: the causal graph in question
:param action_nodes: the action nodes
:param outcome_nodes: the outcome nodes

:returns: the set of nodes that lie on proper causal paths from X to Y
"""

# 1) Create a pair of modified graphs by removing inbound and outbound arrows from the action nodes, respectively.
graph_post_interv = copy.deepcopy(graph) # remove incoming arrows to our action nodes
edges_to_remove = [(u, v) for u, v in graph_post_interv.in_edges(action_nodes)]
graph_post_interv.remove_edges_from(edges_to_remove)
graph_with_action_nodes_as_sinks = copy.deepcopy(graph) # remove outbound arrows from our action nodes
edges_to_remove = [(u, v) for u, v in graph_with_action_nodes_as_sinks.out_edges(action_nodes)]
graph_with_action_nodes_as_sinks.remove_edges_from(edges_to_remove)

# 2) Use the modified graphs to identify the nodes which lie on proper causal paths from the
# action nodes to the outcome nodes.
de_x = get_descendants(graph_post_interv, action_nodes).union(action_nodes)
an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes).union(outcome_nodes)
return (set(de_x) - set(action_nodes)) & an_y


def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes):
"""Method to get the proper backdoor graph from a causal graph, as described in van der Zander et al. "Constructing Separators and
Adjustment Sets in Ancestral Graphs", Section 4.1. We cannot use do_surgery() since we require deep copies of the given graph.

:param graph: the causal graph in question
:param action_nodes: the action nodes
:param outcome_nodes: the outcome nodes

:returns: a new graph which is the proper backdoor graph of the original
"""

# First we can just call get_proper_causal_path_nodes, then
# we remove edges from the action_nodes to the proper causal path nodes.
graph_pbd = copy.deepcopy(graph)
graph_pbd.remove_edges_from(
[(u, v) for u in action_nodes for v in get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)]
)
return graph_pbd


def check_dseparation(graph: nx.DiGraph, nodes1, nodes2, nodes3, new_graph=None, dseparation_algo="default"):
if dseparation_algo == "default":
if new_graph is None:
Expand Down
59 changes: 24 additions & 35 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ pandas = [
{version = "<2.0", python = "<3.9"},
{version = ">1.0", python = ">=3.9"}
]
networkx = ">=2.8.5"
networkx = [
nparent1 marked this conversation as resolved.
Show resolved Hide resolved
{version = ">=3.3", python = ">=3.10"},
{version = ">=2.8.5", python = "<3.10"}
]
sympy = ">=1.10.1"
scikit-learn = ">1.0"
pydot = { version = "^1.4.2", optional = true }
Expand Down
41 changes: 38 additions & 3 deletions tests/causal_identifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from dowhy.graph import build_graph_from_str

from .example_graphs import TEST_FRONTDOOR_GRAPH_SOLUTIONS, TEST_GRAPH_SOLUTIONS
from .example_graphs import (
TEST_FRONTDOOR_GRAPH_SOLUTIONS,
TEST_GRAPH_SOLUTIONS,
TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT,
)


class IdentificationTestGraphSolution(object):
Expand Down Expand Up @@ -34,15 +38,39 @@ def __init__(
observed_variables,
valid_frontdoor_sets,
invalid_frontdoor_sets,
action_nodes=None,
outcome_nodes=None,
):
if outcome_nodes is None:
outcome_nodes = ["Y"]
if action_nodes is None:
action_nodes = ["X"]
self.graph = build_graph_from_str(graph_str)
self.action_nodes = ["X"]
self.outcome_nodes = ["Y"]
self.action_nodes = action_nodes
self.outcome_nodes = outcome_nodes
self.observed_nodes = observed_variables
self.valid_frontdoor_sets = valid_frontdoor_sets
self.invalid_frontdoor_sets = invalid_frontdoor_sets


class IdentificationTestGeneralCovariateAdjustmentGraphSolution(object):
def __init__(
self,
graph_str,
observed_variables,
action_nodes,
outcome_nodes,
minimal_adjustment_sets,
exhaustive_adjustment_sets=None,
):
self.graph = build_graph_from_str(graph_str)
self.action_nodes = action_nodes
self.outcome_nodes = outcome_nodes
self.observed_nodes = observed_variables
self.minimal_adjustment_sets = minimal_adjustment_sets
self.exhaustive_adjustment_sets = exhaustive_adjustment_sets


@pytest.fixture(params=TEST_GRAPH_SOLUTIONS.keys())
def example_graph_solution(request):
return IdentificationTestGraphSolution(**TEST_GRAPH_SOLUTIONS[request.param])
Expand All @@ -51,3 +79,10 @@ def example_graph_solution(request):
@pytest.fixture(params=TEST_FRONTDOOR_GRAPH_SOLUTIONS.keys())
def example_frontdoor_graph_solution(request):
return IdentificationTestFrontdoorGraphSolution(**TEST_FRONTDOOR_GRAPH_SOLUTIONS[request.param])


@pytest.fixture(params=TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT.keys())
def example_complete_adjustment_graph_solution(request):
return IdentificationTestGeneralCovariateAdjustmentGraphSolution(
**TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT[request.param]
)
Loading
Loading