-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpyroexample.py
More file actions
28 lines (24 loc) · 881 Bytes
/
pyroexample.py
File metadata and controls
28 lines (24 loc) · 881 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
from functools import partial
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
# for CI testing
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')
pyro.set_rng_seed(1)
plt.style.use('default')
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
df["cont_africa_x_rugged"] = df["cont_africa"] * df["rugged"]
data = torch.tensor(df[["cont_africa", "rugged", "cont_africa_x_rugged", "rgdppc_2000"]].values,
dtype=torch.float)
x_data, y_data = data[:, :-1], data[:, -1]
import IPython as ip
ip.embed()