causal_ml module
This module only contains one class (CausalTree). It builds a causal tree based on Reguly (2021). Treatment effects are local treatment effects based on a regression discontinuity estimation within the leaves of the tree.
Usage Steps:
Download the
causal_ml
module into your working directory.Split your data into training and estimation samples.
Create a
CausalTree
object.Grow the tree (both training and estimation samples are used in the process).
Prune the tree.
Estimate unbiased treatment effects with the estimation split.
Print the tree and return the leaf information.
Example:
Your main script may look like this:
import pandas as pd
from causal_ml import CausalTree
from sklearn.model_selection import train_test_split
# Load data
data = pd.read_csv('your_path_to_the_data')
# Split data into training and estimation sets
d_train, d_est = train_test_split(data, test_size=0.5, random_state=42)
# Initialize CausalTree
tree = CausalTree(split_steps=20, max_depth=4, min_leaf_size=100)
# Grow the tree
tree.grow_tree(d_train, d_est, 'wage', 'time', {'age': 'continuous', 'education': 'discrete'})
# Prune the tree
pruned_tree = tree.prune_tree(d_train, d_est, 3)
# Estimate treatment effects
pruned_tree.estimate_tree(d_est)
# Print tree and retrieve leaf information
leaves = pruned_tree.print_tree()
- class causal_ml.CausalTree(depth=0, max_depth=2, split_steps=10, min_leaf_size=50, tol=0.005, alpha=0.0)[source]
Bases:
object
A class for building and pruning regression trees for heterogeneous treatment effects
Treatment effects are estimated through Regression Discontinuity Design (RDD). See Reguly (2021) ‘Heterogeneous Treatment Effects in Regression Discontinuity Design’
- Variables:
depth – The depth of this node from the source node
max_depth – The maximum permissible depth of the tree
is_leaf – Whether the node is a leaf
left – The node branching to the left
right – The node branching to the right
tau – The treatment effect in this node, estimated during training
tau_est – The treatment effect in this node, estimated on estimation sample
v – Variance of tau_est
- estimate_tree(data)[source]
Estimates unbiased treatment effects in the leaves
This method estimates the treatment effect in each leaf of the tree. If you use the estimation sample (as you should), then the treatment effect estimates will be unbiased. These estimates are attached to the tau_est attribute.
- Parameters:
data (pandas.DataFrame) – The estimation sample
>>> my_tree.estimate_tree(data_estimation)
- grow_tree(train_data: DataFrame, est_data: DataFrame, dep_var: str, run_var: str, split_var: list, indep_var: list = [], poly: int = 1, cutoff: float = 0.0)[source]
Grows the full tree
Recursively splits the training sample until the maximum tree depth is reached, or until there are no further information gains from growing the tree deeper
- Parameters:
train_data (pandas.DataFrame) – The training sample
est_data (pandas.DataFrame) – The estimation sample
dep_var (str) – Name of the dependent variable
run_var (str) – Name of the running variable
split_var (dict) – Names of the splitting variables
indep_var (list, optional) – Names of the independent variables, defaults to []
poly (int, optional) – Polynomial order for running variable, defaults to 1
cutoff (int, optional) – Cutoff value for running variable, defaults to 0
>>> my_tree.grow_tree(data_training, data_estimation, dep_var='wage', run_var='time', split_var={'age':'continuous', 'education'='discrete'})
Note
For run_var, pass each variable as a key-value pair where key is the variable name (str) and value is either ‘discrete’ or ‘continuous’ (str).
- print_tree()[source]
Prints out all nodes of the tree along with some of their attributes. Returns a list of the tree leaves if they have attribute
tau_est
.Will print all tree nodes, their depth, boundaries, and estimated treatment effects. It will also return a list of leaves, if those leaves already have an unbiased treatment estimate (
tau_est
) attached to it.- Returns:
Tree leaves with treatment effect, variance, and boundaries for each splitting variable
- Return type:
list or None
>>> my_tree.print_tree() [[0.100, 0.050, [2.1, 2.9], [100, 167]], [0.120, 0.081, [2.1, 2.9], [168, 250]]]
- prune_tree(train_data, est_data, cv_folds=5)[source]
Prune tree using complex-cost pruning
Uses k-fold cross-validation to prune the tree using complexity-cost pruning
- Parameters:
train_data (pandas.DataFrame) – The training sample
est_data (pandas.DataFrame) – The estimation sample
cv_folds (int, optional) – Number of folds for cross-validation, defaults to 5
- Returns:
The pruned tree
- Return type:
>>> my_tree.prune_tree(data_training, data_estimation) <causal_ml.CausalTree at 0x1ff3c9984a0>