Skip to content

Commit 7fd4469

Browse files
committed
pytorch DeepIV WIP
Signed-off-by: Keith Battocchi <[email protected]>
1 parent 06bb009 commit 7fd4469

File tree

14 files changed

+1876
-6
lines changed

14 files changed

+1876
-6
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ jobs:
118118
kind: [except-customer-scenarios, customer-scenarios]
119119
include:
120120
- kind: "except-customer-scenarios"
121-
extras: "[plt,ray]"
121+
extras: "[nn,plt,ray]"
122122
pattern: "(?!CustomerScenarios)"
123123
install_graphviz: true
124124
version: '3.12'
@@ -223,16 +223,16 @@ jobs:
223223
extras: ""
224224
- kind: other
225225
opts: '-m "cate_api and not ray" -n auto'
226-
extras: "[plt]"
226+
extras: "[nn,plt]"
227227
- kind: dml
228228
opts: '-m "dml and not ray"'
229-
extras: "[plt]"
229+
extras: "[nn,plt]"
230230
- kind: main
231231
opts: '-m "not (notebook or automl or dml or serial or cate_api or treatment_featurization or ray)" -n 2'
232-
extras: "[plt,dowhy]"
232+
extras: "[nn,plt,dowhy]"
233233
- kind: treatment
234234
opts: '-m "treatment_featurization and not ray" -n auto'
235-
extras: "[plt]"
235+
extras: "[nn,plt]"
236236
- kind: ray
237237
opts: '-m "ray"'
238238
extras: "[ray]"

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,36 @@ lb, ub = est.effect_interval(X_test, alpha=0.05) # OLS confidence intervals
415415
```
416416
</details>
417417

418+
<details>
419+
<summary>Deep Instrumental Variables (click to expand)</summary>
420+
421+
```Python
422+
import keras
423+
from econml.iv.nnet import DeepIV
424+
425+
treatment_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(2,)),
426+
keras.layers.Dropout(0.17),
427+
keras.layers.Dense(64, activation='relu'),
428+
keras.layers.Dropout(0.17),
429+
keras.layers.Dense(32, activation='relu'),
430+
keras.layers.Dropout(0.17)])
431+
response_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(2,)),
432+
keras.layers.Dropout(0.17),
433+
keras.layers.Dense(64, activation='relu'),
434+
keras.layers.Dropout(0.17),
435+
keras.layers.Dense(32, activation='relu'),
436+
keras.layers.Dropout(0.17),
437+
keras.layers.Dense(1)])
438+
est = DeepIV(n_components=10, # Number of gaussians in the mixture density networks)
439+
m=lambda z, x: treatment_model(keras.layers.concatenate([z, x])), # Treatment model
440+
h=lambda t, x: response_model(keras.layers.concatenate([t, x])), # Response model
441+
n_samples=1 # Number of samples used to estimate the response
442+
)
443+
est.fit(Y, T, X=X, Z=Z) # Z -> instrumental variables
444+
treatment_effects = est.effect(X_test)
445+
```
446+
</details>
447+
418448
See the <a href="#references">References</a> section for more details.
419449

420450
### Interpretability

doc/map.svg

Lines changed: 4 additions & 0 deletions
Loading

doc/reference.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ Doubly Robust (DR) IV
8686
econml.iv.dr.IntentToTreatDRIV
8787
econml.iv.dr.LinearIntentToTreatDRIV
8888

89+
.. _deepiv_api:
90+
91+
DeepIV
92+
^^^^^^
93+
94+
.. autosummary::
95+
:toctree: _autosummary
96+
97+
econml.iv.nnet.DeepIV
98+
8999
.. _tsls_api:
90100

91101
Sieve Methods

doc/spec/comparison.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Detailed estimator comparison
99
+=============================================+==============+==============+==================+=============+=================+============+==============+====================+
1010
| :class:`.SieveTSLS` | Any | Yes | | Yes | Assumed | Yes | Yes | |
1111
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
12+
| :class:`.DeepIV` | Any | Yes | | | | Yes | Yes | |
13+
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
1214
| :class:`.SparseLinearDML` | Any | | Yes | Yes | Assumed | Yes | Yes | Yes |
1315
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
1416
| :class:`.SparseLinearDRLearner` | Categorical | | Yes | | Projected | | Yes | Yes |

doc/spec/estimation/deepiv.rst

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
Deep Instrumental Variables
2+
===========================
3+
4+
Instrumental variables (IV) methods are an approach for estimating causal effects despite the presence of confounding latent variables.
5+
The assumptions made are weaker than the unconfoundedness assumption needed in DML.
6+
The cost is that when unconfoundedness holds, IV estimators will be less efficient than DML estimators.
7+
What is required is a vector of instruments :math:`Z`, assumed to casually affect the distribution of the treatment :math:`T`,
8+
and to have no direct causal effect on the expected value of the outcome :math:`Y`. The package offers two IV methods for
9+
estimating heterogeneous treatment effects: deep instrumental variables [Hartford2017]_
10+
and the two-stage basis expansion approach of [Newey2003]_.
11+
12+
The setup of the model is as follows:
13+
14+
.. math::
15+
16+
Y = g(T, X, W) + \epsilon
17+
18+
where :math:`\E[\varepsilon|X,W,Z] = h(X,W)`, so that the expected value of :math:`Y` depends only on :math:`(T,X,W)`.
19+
This is known as the *exclusion restriction*.
20+
We assume that the conditional distribution :math:`F(T|X,W,Z)` varies with :math:`Z`.
21+
This is known as the *relevance condition*.
22+
We want to learn the heterogeneous treatment effects:
23+
24+
.. math::
25+
26+
\tau(\vec{t}_0, \vec{t}_1, \vec{x}) = \E[g(\vec{t}_1,\vec{x},W) - g(\vec{t}_0,\vec{x},W)]
27+
28+
where the expectation is taken with respect to the conditional distribution of :math:`W|\vec{x}`.
29+
If the function :math:`g` is truly non-parametric, then in the special case where :math:`T`, :math:`Z` and :math:`X` are discrete,
30+
the probability matrix giving the distribution of :math:`T` for each value of :math:`Z` needs to be invertible pointwise at :math:`\vec{x}`
31+
in order for this quantity to be identified for arbitrary :math:`\vec{t}_0` and :math:`\vec{t}_1`.
32+
In practice though we will place some parametric structure on the function :math:`g` which will make learning easier.
33+
In deep IV, this takes the form of assuming :math:`g` is a neural net with a given architecture; in the sieve based approaches,
34+
this amounts to assuming that :math:`g` is a weighted sum of a fixed set of basis functions. [1]_
35+
36+
As explained in [Hartford2017]_, the Deep IV module learns the heterogenous causal effects by minimizing the "reduced-form" prediction error:
37+
38+
.. math::
39+
40+
\hat{g}(T,X,W) \equiv \argmin_{g \in \mathcal{G}} \sum_i \left(y_i - \int g(T,x_i,w_i) dF(T|x_i,w_i,z_i)\right)^2
41+
42+
where the hypothesis class :math:`\mathcal{G}` are neural nets with a given architecture.
43+
The distribution :math:`F(T|x_i,w_i,z_i)` is unknown and so to make the objective feasible it must be replaced by an estimate
44+
:math:`\hat{F}(T|x_i,w_i,z_i)`.
45+
This estimate is obtained by modeling :math:`F` as a mixture of normal distributions, where the parameters of the mixture model are
46+
the output of a "first-stage" neural net whose inputs are :math:`(x_i,w_i,z_i)`.
47+
Optimization of the "first-stage" neural net is done by stochastic gradient descent on the (mixture-of-normals) likelihood,
48+
while optimization of the "second-stage" model for the treatment effects is done by stochastic gradient descent with
49+
three different options for the loss:
50+
51+
* Estimating the two integrals that make up the true gradient calculation by independent averages over
52+
mini-batches of data, which are unbiased estimates of the integral.
53+
* Using the modified objective function
54+
55+
.. math::
56+
57+
\sum_i \sum_d \left(y_i - g(t_d,x_i,w_i)\right)^2
58+
59+
where :math:`t_d \sim \hat{F}(t|x_i,w_i,z_i)` are draws from the estimated first-stage neural net. This modified
60+
objective function is not guaranteed to lead to consistent estimates of :math:`g`, but has the advantage of requiring
61+
only a single set of samples from the distribution, and can be interpreted as regularizing the loss with a
62+
variance penalty. [2]_
63+
* Using a single set of samples to compute the gradient of the loss; this will only be an unbiased estimate of the
64+
gradient in the limit as the number of samples goes to infinity.
65+
66+
Training proceeds by splitting the data into a training and test set, and training is stopped when test set performance
67+
(on the reduced form prediction error) starts to degrade.
68+
69+
The output is an estimated function :math:`\hat{g}`. To obtain an estimate of :math:`\tau`, we difference the estimated
70+
function at :math:`\vec{t}_1` and :math:`\vec{t}_0`, replacing the expectation with the empirical average over all
71+
observations with the specified :math:`\vec{x}`.
72+
73+
74+
.. rubric:: Footnotes
75+
76+
.. [1]
77+
Asymptotic arguments about non-parametric consistency require that the neural net architecture (respectively set of basis functions)
78+
are allowed to grow at some rate so that arbitrary functions can be approximated, but this will not be our concern here.
79+
.. [2]
80+
.. math::
81+
82+
& \int \left(y_i - g(t,x_i,w_i)\right)^2 dt \\
83+
=~& y_i - 2 y_i \int g(t,x_i,w_i)\,dt + \int g(t,x_i,w_i)^2\,dt \\
84+
=~& y_i - 2 y_i \int g(t,x_i,w_i)\,dt + \left(\int g(t,x_i,w_i)\,dt\right)^2 + \int g(t,x_i,w_i)^2\,dt - \left(\int g(t,x_i,w_i)\,dt\right)^2 \\
85+
=~& \left(y_i - \int g(t,x_i,w_i)\,dt\right)^2 + \left(\int g(t,x_i,w_i)^2\,dt - \left(\int g(t,x_i,w_i)\,dt\right)^2\right) \\
86+
=~& \left(y_i - \int g(t,x_i,w_i)\,dt\right)^2 + \Var_t g(t,x_i,w_i)

doc/spec/estimation_iv.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ of [Newey2003]_.
1414
.. toctree::
1515
:maxdepth: 2
1616

17+
estimation/deepiv.rst
1718
estimation/two_sls.rst
1819
estimation/orthoiv.rst

econml/iv/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Copyright (c) PyWhy contributors. All rights reserved.
22
# Licensed under the MIT License.
33

4-
__all__ = ["dml", "dr", "sieve"]
4+
__all__ = ["dml", "dr", "nnet", "sieve"]

econml/iv/nnet/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) PyWhy contributors. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
from ._deepiv import DeepIV, MixtureOfGaussiansModule
5+
6+
__all__ = ["DeepIV, MixtureOfGaussiansModule"]

0 commit comments

Comments
 (0)