Activity 16: Causal Trees

Activity 16: Causal Trees#

2025-04-24


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from econml.dml import CausalForestDML
from econml.cate_interpreter import SingleTreeCateInterpreter

Run the cell below to load in the data. We have the following variables:

  • age: age of the customer

  • income: income of the customer, in $100k

  • has_membership: whether the customer has a membership to the online music platform

  • avg_hours: average number of hours per week the customer has spent on the platform

  • demand: outcome variable – sales of songs on the platform

  • T: treatment (1 if the customer was given a discount, 0 otherwise)

customer_data = pd.read_csv('~/COMSC-341CD/data/customer_data.csv')
customer_data.head()
age income has_membership avg_hours T demand
0 53 0.960863 1 1.834234 0 3.917117
1 54 0.732487 0 7.171411 0 11.585706
2 33 1.130937 0 5.351920 0 24.675960
3 34 0.929197 0 6.723551 0 6.361776
4 30 0.533527 1 2.448247 1 12.624123

A more robust variant of causal tree prediction is to fit multiple trees and take the average of the predictions. This is known as the causal forest, and it just like a random forest model in machine learning. This is implemented in EconML as the CausalForestDML class. Initialize the model as follows:

causal_forest = CausalForestDML(criterion='mse', # use mean squared error as the loss function
                                honest=True, # use honest sample splitting
                                discrete_treatment=True # treatment is binary
                                )

Then call causal_forest.fit() with the following arguments:

  • X: the features given by customer_data[covariates]

  • T: the treatment column given by customer_data['T']

  • Y: the outcome column given by customer_data['demand']

The model may take ~1-2 minutes to fit.

covariates = ['age', 'income', 'has_membership', 'avg_hours']

# TODO initialize causal forest model
causal_forest = None

# TODO fit the model with the appropriate parameters
#causal_forest.fit(TODO)

You can then check the average treatment effect (ATE) via the causal_forest.ate_ attribute.

print("ATE: ", causal_forest.ate_)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 print("ATE: ", causal_forest.ate_)

AttributeError: 'NoneType' object has no attribute 'ate_'

Next, we’ll summarize the causal forest model into a single causal tree using the SingleTreeCateInterpreter class. Initialize the interpreter with the parameter max_depth=2.

Then, call the interpret() method with the following arguments:

  • cate_estimator: our fitted causal forest model

  • X: the features given by customer_data[covariates]

# TODO initialize the causal tree interpreter
#cate_tree = SingleTreeCateInterpreter(TODO)

# TODO fit the interpreter with the causal forest model
#cate_tree.interpret(TODO)

Finally, uncomment and run the cell below to plot the causal tree. According to the tree, which subgroup of customers buys the most product after being given a discount?

Your response: pollev.com/tliu

# resizes the plot to fit the causal tree
# plt.figure(figsize=(25, 20))
# cate_tree.plot(feature_names=covariates)

Acknowledgements#

This activity uses tutorial code and data provided by the EconML package.