Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
cb17765
first commit
d-klee Jul 24, 2019
b9361c6
added dataset and vae that works on dataset
d-klee Jul 24, 2019
083d045
Update README
dmklee Jul 24, 2019
891ed95
removed unnecessary imports
d-klee Jul 24, 2019
766f279
updated README
d-klee Jul 24, 2019
1bb6b52
fixed some out of place comments
d-klee Jul 24, 2019
d9be1ae
Created using Colaboratory
ajeyamk Jul 24, 2019
5920ae7
Created using Colaboratory
ajeyamk Jul 24, 2019
c8ab5d2
Created using Colaboratory
ajeyamk Jul 25, 2019
b025170
Set up 2 remotes
Jul 25, 2019
e0f0713
Created using Colaboratory - Issues with cuda
ajeyamk Jul 28, 2019
da07cde
moved execution to gpu
ajeyamk Aug 1, 2019
5710f46
working version with y handled as one hots
d-klee Aug 1, 2019
5238d58
increased size of encoder, decoder
d-klee Aug 1, 2019
6a10319
Merge branch 'david' of https://github.ccs.neu.edu/dmklee/CausalVAE i…
Aug 2, 2019
67a40f2
Resolved merge conflicts between dev and integration
Aug 2, 2019
aa46ff7
Merge branch 'master' of https://github.com/ajeyamk/causalvae into de…
Aug 2, 2019
5f64aac
Updated code
Aug 2, 2019
e07ea52
Created using Colaboratory - Counterfactual code and plot density
ajeyamk Aug 4, 2019
20d305e
Created using Colaboratory-Fixed data type issue with PyTorch
ajeyamk Aug 6, 2019
cb4e306
Created using Colaboratory
ajeyamk Aug 7, 2019
0d4fb53
Created using Colaboratory - Fixed sampling issue worked on counterfa…
ajeyamk Aug 7, 2019
9894783
refactored code and svi incorporation
ajeyamk Aug 13, 2019
831c7ec
stable code - counterfactual
ajeyamk Aug 14, 2019
370c746
refactored code and added comments
ajeyamk Aug 14, 2019
6a03feb
latest code
ajeyamk Aug 14, 2019
db6ffbb
Final release
ajeyamk Aug 14, 2019
6ff736b
Merge pull request #1 from ajeyamk/dev-ajeya
ajeyamk Aug 15, 2019
dff7fd3
Update README.md
ajeyamk Aug 22, 2019
29677d8
Update README.md
ajeyamk Aug 22, 2019
5f9bb18
added dag file for md
ajeyamk Aug 22, 2019
c3211ae
Update README.md
ajeyamk Aug 22, 2019
1217e1d
Update README.md
ajeyamk Aug 22, 2019
2590414
Update README.md
ajeyamk Aug 22, 2019
96510dd
Added scm image
ajeyamk Aug 22, 2019
96e1870
updated readme
ajeyamk Aug 22, 2019
a0c538e
added deep fakes image
ajeyamk Aug 22, 2019
c914738
Delete dag.png
ajeyamk Aug 22, 2019
858749d
Delete scm.png
ajeyamk Aug 22, 2019
af59920
Added figs folder
ajeyamk Aug 22, 2019
e152dd5
Updated images
ajeyamk Aug 22, 2019
b6dfdba
Add files via upload
ajeyamk Aug 22, 2019
a281fcb
Add files via upload
ajeyamk Aug 22, 2019
a1eff0d
Update README.md
ajeyamk Aug 22, 2019
c04fb2e
Delete vae recons.png
ajeyamk Aug 22, 2019
2a32eb2
Add files via upload
ajeyamk Aug 22, 2019
698de65
Merge pull request #2 from ajeyamk/dev-ajeya
ajeyamk Aug 22, 2019
c4a6425
Added .py file from .ipynb file
ajeyamk Aug 23, 2019
1e37732
Merge pull request #3 from ajeyamk/dev-ajeya
ajeyamk Aug 23, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,61 @@
# causalvae

## Deep Causal Varitional Autoencoder

To train a supervised variational autoencoder using Deepmind's [dSprites](https://github.com/deepmind/dsprites-dataset) dataset.

dSprites is a dataset of sprites, which are 2D shapes procedurally generated from 5 ground truth independent "factors." These factors are color, shape, scale, rotation, x and y positions of a sprite.

All possible combinations of these variables are present exactly once, generating N = 737280 total images.

Factors and their values:

* Shape: 3 values {square, ellipse, heart}
* Scale: 6 values linearly spaced in (0.5, 1)
* Orientation: 40 values in (0, 2pi)
* Position X: 32 values in (0, 1)
* Position Y: 32 values in (0, 1)

There is a sixth factor for color, but it is white for every image in this dataset.

The purpose of this dataset was to evaluate the ability of disentanglement methods. In these methods, you treat these factors as latent and then try to "disentangle" them in the latent representation.

However, in this project, these factors are not treated as latent, but are included as labels in the model training. Further, a causal story is invented that relates these factors and the images in a DAG

![vae_dag](figs/dag.png)

Structural causal model is of the form:

![scm_eq](figs/scm.png)

The image variable will be a 64 x 64 array. The noise term for the image variable will be the traditional Gaussian random variable. The structural assignment *g* for the image variable will be the decoder.


## Work:
* Built a Structural causal model that articulates a causal story relating shape, orientation, scale, X, Y, and the data.
* Resampled the dataset to get a new dataset with an empirical distribution that is faithful to the DAG and is entailed by the SCM
* To implement a causal VAE using [Pyro](http://pyro.ai/) by extending the primitive version of VAE. The VAE is fully supervised.
* Finally used the trained model to answer some counterfactual queries, for example, "given this image of a heart with this orientation, position, and scale, what would it have looked like if it were a square?"

## Optimization:
* The code is made compatible for GPU for faster processing.
* The learned weights are saved to avoid training frequently to enhance development efficiency.

## Results
* Achieved good reconstruction accuracy using vanilla VAE -
![vanilla_vae](figs/vae-recons.png)

* Trained VAE and made sure it recognises changes in the latent dimensions(Manually changed latent variables before training) -
![vae_latent_manual](figs/manual-intervention.png)

* Built structural causal model and verfied for reconstruction accuracy -
![scm_vae](figs/scm-conditioned.png)

* Counterfactual queries (1) - Intervention on shape - Given a oval and certain (x,y) co-ordinates and orientation, how would it look it was a sqaure?
![scm_intervention_vae](figs/intervention.png)

* Counterfactual queries (1) - Intervention on shape, position of (x,y) -
![scm_intervention_vae](figs/intervention-2.png)

## Applications
* DeepFakes :[Structured Disentangled Representations](https://arxiv.org/pdf/1804.02086.pdf)
![dp](deep-fakes.jpg)
Loading