-
Notifications
You must be signed in to change notification settings - Fork 944
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Effect estimation over timeseries data (#1218)
* Library functions for temporal causal functionality * shifting plotter function * printing graph: best practices * added docstrings * moved datasets * updated tutorial notebook * sphinx documentation * updated shifting columns with 0,1,..,max_lag * support for dot format * tigramite support * updated filter to be a hidden function * black and isort utils * black and isort timeseries * updated notebook text Signed-off-by: Amit Sharma <amit_sharma@live.com> * integer range fix * correction in timestamp : notebook text * time lagged causal estimation * removed cell outputs * find ancestors * include ancestors in notebook * formatting changes * comments : notebook * multiple time lags : csv graph' * multiple time lags * unrolled graph using bfs * cleanup of functions * removed find parents and ancestors * tests for causal graph creation * tests for adding lagged edges * tests for shifting columns * tigramite dependency added --------- Signed-off-by: Amit Sharma <amit_sharma@live.com> Co-authored-by: Amit Sharma <amit_sharma@live.com>
- Loading branch information
1 parent
e783e37
commit becbf7f
Showing
9 changed files
with
1,055 additions
and
0 deletions.
There are no files selected for viewing
15 changes: 15 additions & 0 deletions
15
docs/source/example_notebooks/datasets/temporal_dataset.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
V1,V2,V3,V4,V5,V6,V7 | ||
1,2,3,4,5,6,7 | ||
2,3,4,5,6,7,8 | ||
3,4,5,6,7,8,9 | ||
4,5,6,7,8,9,10 | ||
0,1,5,7,8,9,7 | ||
3,5,4,1,2,6,5 | ||
6,7,1,2,4,5,9 | ||
12,3,5,7,3,8,9 | ||
3,2,1,6,3,8,9 | ||
4,6,3,5,8,9,1 | ||
3,5,9,6,2,1,3 | ||
5,2,6,8,11,3,4 | ||
2,2,4,1,1,4,6 | ||
5,6,4,3,4,6,2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
node1,node2,time_lag | ||
V1,V2,3 | ||
V2,V3,4 | ||
V5,V6,1 | ||
V4,V7,4 | ||
V4,V5,2 | ||
V7,V6,3 | ||
V7,V6,5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
digraph G { | ||
V1 -> V2 [label="(3)"]; | ||
V2 -> V3 [label="(4)"]; | ||
V5 -> V6 [label="(1)"]; | ||
V4 -> V7 [label="(4)"]; | ||
V4 -> V5 [label="(2)"]; | ||
V7 -> V6 [label="(3, 5)"]; | ||
} |
306 changes: 306 additions & 0 deletions
306
docs/source/example_notebooks/timeseries/effect_inference_timeseries_data.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Effect inference with timeseries data\n", | ||
"\n", | ||
"In this notebook, we will look at an example of causal effect inference from timeseries data. We will use DoWhy's functionality to add temporal dependencies to a causal graph and estimate causal effect based on the augmented graph. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import networkx as nx\n", | ||
"import pandas as pd\n", | ||
"from dowhy.utils.timeseries import create_graph_from_csv,create_graph_from_user\n", | ||
"from dowhy.utils.plotting import plot, pretty_print_graph" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Loading timeseries data and causal graph" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset_path=\"../datasets/temporal_dataset.csv\"\n", | ||
"\n", | ||
"dataframe=pd.read_csv(dataset_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In temporal causal inference, accurately estimating causal effects often requires accounting for time lags between nodes in a graph. For instance, if $node_1$ influences $node_2$ with a time lag of 5 timestamps, we represent this dependency as $node_1^{t-5}$ -> $node_2^{t}$.\n", | ||
"\n", | ||
"We can provide the causal graph as a networkx DAG or as a dot file. The edge attributes should mention the exact `time_lag` that is associated with each edge (if any)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dowhy.utils.timeseries import create_graph_from_dot_format\n", | ||
"\n", | ||
"file_path = \"../datasets/temporal_graph.dot\"\n", | ||
"\n", | ||
"graph = create_graph_from_dot_format(file_path)\n", | ||
"plot(graph)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can also create a csv file with the edges in the temporal graph. The columns in the csv are node1, node2, time_lag which represents an directed edge node1 -> node2 with the time lag of time_lag. Let us consider the following graph as the input:\n", | ||
"\n", | ||
"| node1 | node2 | time_lag |\n", | ||
"|--------|--------|----------|\n", | ||
"| V1 | V2 | 3 |\n", | ||
"| V2 | V3 | 4 |\n", | ||
"| V5 | V6 | 1 |\n", | ||
"| V4 | V7 | 4 |\n", | ||
"| V4 | V5 | 2 |\n", | ||
"| V7 | V6 | 3 |\n", | ||
"| V7 | V6 | 5 |" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Input a csv file with the edges in the graph with the columns: node_1,node_2,time_lag\n", | ||
"file_path = \"../datasets/temporal_graph.csv\"\n", | ||
"\n", | ||
"# Create the graph from the CSV file\n", | ||
"graph = create_graph_from_csv(file_path)\n", | ||
"plot(graph)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Dataset Shifting and Filtering\n", | ||
"\n", | ||
"To prepare the dataset for temporal causal inference, we need to shift the columns by the given time lag.\n", | ||
"\n", | ||
"For example, in the causal graph above, $node_1^{t-5}$ -> $node_2^{t}$ with a lag of 5. When considering $node_2$ as the target node, the data for $node_1$ should be shifted down by 5 timestamps. This adjustment ensures that the edge $node_1$ -> $node_2$ accurately represents the lagged dependency. Shifting the data in this manner creates additional columns and allows downstream estimators to acccess the correct values in the same row of a dataframe. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dowhy.timeseries.temporal_shift import shift_columns_by_lag_using_unrolled_graph, add_lagged_edges" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# the outcome node for which effect estimation has to be done, node:6\n", | ||
"target_node = 'V6'\n", | ||
"unrolled_graph = add_lagged_edges(graph, target_node)\n", | ||
"plot(unrolled_graph)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"time_shifted_df = shift_columns_by_lag_using_unrolled_graph(dataframe, unrolled_graph)\n", | ||
"time_shifted_df.head()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Causal Effect Estimation\n", | ||
"\n", | ||
"Once you have the new dataframe, causal effect estimation can be performed on the target node with respect to the action nodes." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"target_node = 'V6_0'\n", | ||
"# include all the treatments\n", | ||
"treatment_columns = list(time_shifted_df.columns)\n", | ||
"treatment_columns.remove(target_node)\n", | ||
"treatment_columns" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# perform causal effect estimation on this new dataset\n", | ||
"import dowhy\n", | ||
"from dowhy import CausalModel\n", | ||
"\n", | ||
"model = CausalModel(\n", | ||
" data=time_shifted_df,\n", | ||
" treatment='V5_-1',\n", | ||
" outcome=target_node,\n", | ||
" graph = unrolled_graph\n", | ||
")\n", | ||
"\n", | ||
"identified_estimand = model.identify_effect()\n", | ||
"\n", | ||
"estimate = model.estimate_effect(identified_estimand,\n", | ||
" method_name=\"backdoor.linear_regression\",\n", | ||
" test_significance=True)\n", | ||
"\n", | ||
"\n", | ||
"print(estimate)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Importing temporal causal graph from Tigramite library\n", | ||
"\n", | ||
"Tigramite is a popular temporal causal discovery library. In this section, we highlight how the causal graph can be obtained by applying PCMCI+ algorithm from tigramite and imported into DoWhy." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install tigramite" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import tigramite\n", | ||
"import tigramite.data_processing as pp\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"dataframe = dataframe.astype(float)\n", | ||
"var_names = dataframe.columns\n", | ||
"# convert the dataframe values to float\n", | ||
"dataframe = pp.DataFrame(dataframe.values, var_names=var_names)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from tigramite import plotting as tp\n", | ||
"tp.plot_timeseries(dataframe, figsize=(15, 5)); plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from tigramite.pcmci import PCMCI\n", | ||
"from tigramite.independence_tests.parcorr import ParCorr\n", | ||
"import numpy as np\n", | ||
"parcorr = ParCorr(significance='analytic')\n", | ||
"pcmci = PCMCI(\n", | ||
" dataframe=dataframe, \n", | ||
" cond_ind_test=parcorr,\n", | ||
" verbosity=1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"correlations = pcmci.run_bivci(tau_max=3, val_only=True)['val_matrix']\n", | ||
"matrix_lags = np.argmax(np.abs(correlations), axis=2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tau_max = 3\n", | ||
"pc_alpha = None\n", | ||
"pcmci.verbosity = 2\n", | ||
"\n", | ||
"results = pcmci.run_pcmciplus(tau_min=0, tau_max=tau_max, pc_alpha=pc_alpha)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dowhy.utils.timeseries import create_graph_from_networkx_array\n", | ||
"\n", | ||
"graph = create_graph_from_networkx_array(results['graph'], var_names)\n", | ||
"\n", | ||
"plot(graph)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.