You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This will train a system on some test data and calculate an average treatment effect (ATE).
Description
As input this system expects data where each row consists of:
Freeform text
A categorical variable (numerically coded) representing a confound
A binary treatment variable
A binary outcome variable
Then the system will give the text to BERT, and use the BERT embeddings + confound to predict
P(T | C, text)
P(Y | T = 1, C, text)
P(Y | T = 0, C, text)
The original masked language modeling objective of BERT.
Once trained the resulting BERT embeddings will be sufficient for some causal inferences.
Example
df = pd.read_csv('testdata.csv')
cb = CausalBertWrapper(batch_size=2, # init a model wrapper
g_weight=0.1, Q_weight=0.1, mlm_weight=1)
cb.train(df['text'], df['C'], df['T'], df['Y'], epochs=1) # train the model
print(cb.ATE(df['C'], df['text'], platt_scaling=True)) # use the model to get an average treatment effect
Usage
Initialize the model wrapper (handles training and inference):
cb = CausalBertWrapper(
batch_size=2, # batch size for training
g_weight=1.0, # loss weight for P(T | C, text) prediction head
Q_weight=0.1, # loss weight for P(Y | T, C, text) prediction heads
mlm_weight=1) # loss weight for original MLM objective
Then train
cb.train(
df['text'], # list of texts
df['C'], # list of confounds
df['T'], # list of treatments
df['Y'], # list of outcomes
epochs=1) # training epochs
Perform inference
( ( P(Y=1|T=1), P(Y=0|T=1)), ( P(Y=1|T=0), P(Y=0|T=0) ), ... = cb.inference(
df['text'], # list of texts
df['C']) # list of confounds
Or estimate an average treatment effect
ATE = cb.ate(
df['text'], # list of texts
df['C'], # list of confounds
platt_scailing=False) # https://en.wikipedia.org/wiki/Platt_scaling
About
Pytorch implementation of "Adapting Text Embeddings for Causal Inference"