Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions experiments/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import os
import urllib
import urllib.request, urllib.parse, urllib.error
import gzip
import struct
import array
Expand All @@ -13,7 +13,7 @@ def download(url, filename):
os.makedirs('data')
out_file = os.path.join('data', filename)
if not os.path.isfile(out_file):
urllib.urlretrieve(url, out_file)
urllib.request.urlretrieve(url, out_file)


def mnist():
Expand Down
1 change: 1 addition & 0 deletions experiments/gmm_svae_synth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division, print_function

import matplotlib.pyplot as plt
import autograd.numpy as np
import autograd.numpy.random as npr
Expand Down
11 changes: 6 additions & 5 deletions experiments/gmm_svae_synth_plot_pdfs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import division

import numpy as np
import numpy.random as npr
import cPickle as pickle
import pickle as pickle
import matplotlib.pyplot as plt
from itertools import count
from cycler import cycler
Expand All @@ -10,7 +11,7 @@
from operator import itemgetter
from functools import partial

from gmm_svae_synth import decode as gmm_decode, make_pinwheel_data, normalize, \
from .gmm_svae_synth import decode as gmm_decode, make_pinwheel_data, normalize, \
dirichlet, niw, encode_mean, decode_mean
from svae.forward_models import mlp_decode
mlp_decode = partial(mlp_decode, tanh_scale=1000., sigmoid_output=False)
Expand Down Expand Up @@ -95,7 +96,7 @@ def plot_or_update(idx, ax, x, y, alpha=1, **kwargs):

dir_hypers, all_niw_hypers = natparam
weights = normalize(np.exp(dirichlet.expectedstats(dir_hypers)))
components = map(niw.expected_standard_params, all_niw_hypers)
components = list(map(niw.expected_standard_params, all_niw_hypers))

latent_locations = encode_mean(data, natparam, psi)
reconstruction = decode_mean(latent_locations, phi)
Expand Down Expand Up @@ -144,7 +145,7 @@ def make_figure():
def save_figure(fig, filename):
fig.savefig(filename + '.png', dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(filename + '.pdf', bbox_inches='tight', pad_inches=0)
print 'saved {}'.format(filename)
print('saved {}'.format(filename))

def plot_data(data):
fig, ax = make_figure()
Expand Down Expand Up @@ -204,7 +205,7 @@ def load_gmm_svae_params(filename):
try:
for _ in range(20000): gmm_svae_params = pickle.load(f)
except EOFError: pass
else: print 'did not finish loading {}'.format(filename)
else: print('did not finish loading {}'.format(filename))
return gmm_svae_params

gmm_svae_params = load_gmm_svae_params(filename)
Expand Down
21 changes: 11 additions & 10 deletions experiments/load.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import division

import numpy as np
import numpy.random as npr
import cPickle as pickle
import pickle as pickle
import gzip
import h5py
import operator as op
Expand All @@ -10,15 +11,15 @@
partial_flatten = lambda x: np.reshape(x, (x.shape[0], -1))

dmap = lambda dct, f=lambda x: x, keep=lambda x: True: \
{k:f(v) for k, v in dct.iteritems() if keep(k)}
{k:f(v) for k, v in dct.items() if keep(k)}

def standardize(d):
recenter = lambda d: d - np.percentile(d, 0.01)
rescale = lambda d: d / np.percentile(d, 99.99)
return rescale(recenter(d))

def flatten_dict(dct):
data = map(op.itemgetter(1), sorted(dct.items(), key=op.itemgetter(0)))
data = list(map(op.itemgetter(1), sorted(list(dct.items()), key=op.itemgetter(0))))
return np.concatenate(data)

def load(filename):
Expand All @@ -28,7 +29,7 @@ def load(filename):
return datadict

def load_mice(N, file, labelfile=None, addnoise=True, keep=lambda x: True):
print 'loading data from {}...'.format(file)
print('loading data from {}...'.format(file))
if labelfile is None:
data = _load_mice(N, file, keep)
else:
Expand All @@ -37,7 +38,7 @@ def load_mice(N, file, labelfile=None, addnoise=True, keep=lambda x: True):
if addnoise:
data += 1e-3 * npr.normal(size=data.shape)

print '...done loading {} frames!'.format(len(data))
print('...done loading {} frames!'.format(len(data)))
if labelfile:
return data, labels
return data
Expand All @@ -58,8 +59,8 @@ def truncate(a, b):

merged_dict = {name: truncate(datadict[name], stateseqs_dict[name][-1])
for name in stateseqs_dict}
pairs = map(op.itemgetter(1), sorted(merged_dict.items(), key=op.itemgetter(0)))
data, labels = map(np.concatenate, zip(*pairs))
pairs = list(map(op.itemgetter(1), sorted(list(merged_dict.items()), key=op.itemgetter(0))))
data, labels = list(map(np.concatenate, list(zip(*pairs))))
data, labels = partial_flatten(data[:N]), labels[:N]

_, labels = np.unique(labels, return_inverse=True)
Expand All @@ -74,7 +75,7 @@ def load_vae_init(zdim, file, eps=1e-5):
(W_1, b_1), decoder_nnet_params = decoder_params[0], decoder_params[1:]

if zdim < W_h.shape[1]:
raise ValueError, 'initialization zdim must not be greater than svae model zdim'
raise ValueError('initialization zdim must not be greater than svae model zdim')
elif zdim > W_h.shape[1]:
padsize = zdim - W_h.shape[1]
pad = lambda W, b: \
Expand All @@ -86,8 +87,8 @@ def load_vae_init(zdim, file, eps=1e-5):
pad = lambda W, b: (np.vstack((eps*npr.randn(padsize, W.shape[1]), W)), b)
decoder_params = [pad(W_1, b_1)] + decoder_nnet_params

print 'loaded init from {} and padded by {} dimensions'.format(file, padsize)
print('loaded init from {} and padded by {} dimensions'.format(file, padsize))
return encoder_params, decoder_params

print 'loaded init from {}'.format(file)
print('loaded init from {}'.format(file))
return encoder_params, decoder_params
1 change: 1 addition & 0 deletions svae/distributions/categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
from autograd.scipy.misc import logsumexp
from svae.util import softmax
Expand Down
1 change: 1 addition & 0 deletions svae/distributions/dirichlet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
from autograd.scipy.special import digamma, gammaln

Expand Down
1 change: 1 addition & 0 deletions svae/distributions/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
import autograd.numpy.random as npr
from functools import partial
Expand Down
1 change: 1 addition & 0 deletions svae/distributions/mniw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.scipy.special import multigammaln, digamma
Expand Down
5 changes: 3 additions & 2 deletions svae/distributions/niw.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import division

import autograd.numpy as np
from autograd.scipy.special import multigammaln, digamma
from autograd import grad
from autograd.util import make_tuple

from svae.util import symmetrize, outer
from gaussian import pack_dense, unpack_dense
import mniw # niw is a special case of mniw
from .gaussian import pack_dense, unpack_dense
from . import mniw # niw is a special case of mniw

# NOTE: can compute Cholesky then avoid the other cubic computations,
# but numpy/scipy has no dpotri or solve_triangular that broadcasts
Expand Down
3 changes: 2 additions & 1 deletion svae/hmm/hmm_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
from autograd import grad
from autograd import value_and_grad as vgrad
Expand All @@ -20,7 +21,7 @@ def make_grad_hmm_logZ(intermediates, ans, hmm):

def hmm_estep(natparam):
C = lambda x: np.require(x, np.double, 'C')
init_params, pair_params, node_params = map(C, natparam)
init_params, pair_params, node_params = list(map(C, natparam))

# compute messages
alphal = messages_forwards_log(
Expand Down
1 change: 1 addition & 0 deletions svae/lds/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
from numpy.random.mtrand import _rand as rng

Expand Down
1 change: 1 addition & 0 deletions svae/lds/gaussian_nochol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np

from svae.util import solve_symmetric
Expand Down
30 changes: 16 additions & 14 deletions svae/lds/lds_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
Expand All @@ -9,11 +10,11 @@

from svae.util import monad_runner, rand_psd, interleave, depth, uninterleave, \
add, shape, zeros_like
from gaussian import mean_to_natural, pair_mean_to_natural, natural_sample, \
from .gaussian import mean_to_natural, pair_mean_to_natural, natural_sample, \
natural_condition_on, natural_condition_on_general, natural_to_mean, \
natural_rts_backward_step
# from gaussian import natural_predict, natural_lognorm
from gaussian_nochol import natural_predict, natural_lognorm
from .gaussian_nochol import natural_predict, natural_lognorm

from cython_lds_inference import \
natural_filter_forward_general as cython_natural_filter_forward, \
Expand All @@ -22,6 +23,7 @@
natural_sample_backward_grad as cython_natural_sample_backward_grad, \
natural_smoother_general as cython_natural_smoother_general, \
natural_smoother_general_grad as cython_natural_smoother_grad
from functools import reduce

cython_natural_filter_forward = primitive_with_aux(cython_natural_filter_forward)
def make_natural_filter_grad_arg2(intermediates, ans, init_params, pair_params, node_params):
Expand Down Expand Up @@ -55,7 +57,7 @@ def _repeat_param(param, length):
if depth(param) == 1:
param = [param]*length
elif len(param) != length:
param = zip(*param)
param = list(zip(*param))
assert depth(param) == 2 and len(param) == length
return param

Expand Down Expand Up @@ -85,7 +87,7 @@ def is_tuple_of_ndarrays(obj):

def natural_filter_forward_general(init_params, pair_params, node_params):
init_params = _canonical_init_params(init_params)
node_params = zip(*_canonical_node_params(node_params))
node_params = list(zip(*_canonical_node_params(node_params)))
pair_params = _repeat_param(pair_params, len(node_params) - 1)

def unit(J, h, logZ):
Expand All @@ -96,9 +98,9 @@ def bind(result, step):
new_message, term = step(messages[-1])
return messages + [new_message], lognorm + term

condition = lambda node_param: lambda (J, h): natural_condition_on_general(J, h, *node_param)
predict = lambda pair_param: lambda (J, h): natural_predict(J, h, *pair_param)
steps = interleave(map(condition, node_params), map(predict, pair_params))
condition = lambda node_param: lambda J_h: natural_condition_on_general(J_h[0], J_h[1], *node_param)
predict = lambda pair_param: lambda J_h1: natural_predict(J_h1[0], J_h1[1], *pair_param)
steps = interleave(list(map(condition, node_params)), list(map(predict, pair_params)))

messages, lognorm = monad_runner(bind)(unit(*init_params), steps)
lognorm += natural_lognorm(*messages[-1])
Expand All @@ -109,14 +111,14 @@ def bind(result, step):
def natural_sample_backward_general(forward_messages, pair_params, num_samples=None):
filtered_messages = forward_messages[1::2]
pair_params = _repeat_param(pair_params, len(filtered_messages) - 1)
pair_params = map(itemgetter(0, 1), pair_params)
pair_params = list(map(itemgetter(0, 1), pair_params))

unit = lambda sample: [sample]
bind = lambda result, step: [step(result[0])] + result

sample = lambda (J11, J12), (J_filt, h_filt): lambda next_sample: \
natural_sample(*natural_condition_on(J_filt, h_filt, next_sample, J11, J12))
steps = reversed(map(sample, pair_params, filtered_messages[:-1]))
steps = reversed(list(map(sample, pair_params, filtered_messages[:-1])))

last_sample = natural_sample(*filtered_messages[-1], num_samples=num_samples)
samples = monad_runner(bind)(unit(last_sample), steps)
Expand All @@ -129,7 +131,7 @@ def natural_smoother_general(forward_messages, init_params, pair_params, node_pa
inhomog = depth(pair_params) == 2
T = len(prediction_messages)
pair_params, orig_pair_params = _repeat_param(pair_params, T-1), pair_params
node_params = zip(*_canonical_node_params(node_params))
node_params = list(zip(*_canonical_node_params(node_params)))

def unit(filtered_message):
J, h = filtered_message
Expand All @@ -144,7 +146,7 @@ def bind(result, step):

rts = lambda next_pred, filtered, pair_param: lambda next_smooth: \
natural_rts_backward_step(next_smooth, next_pred, filtered, pair_param)
steps = reversed(map(rts, prediction_messages[1:], filter_messages[:-1], pair_params))
steps = reversed(list(map(rts, prediction_messages[1:], filter_messages[:-1], pair_params)))

_, expected_stats = monad_runner(bind)(unit(filter_messages[-1]), steps)

Expand All @@ -167,13 +169,13 @@ def make_node_stats(a):
return ExxT, mu, 1.

E_init_stats = make_init_stats(expected_stats[0])
E_pair_stats = map(make_pair_stats, expected_stats[:-1], expected_stats[1:])
E_node_stats = map(make_node_stats, expected_stats)
E_pair_stats = list(map(make_pair_stats, expected_stats[:-1], expected_stats[1:]))
E_node_stats = list(map(make_node_stats, expected_stats))

if not inhomog:
E_pair_stats = reduce(add, E_pair_stats, zeros_like(orig_pair_params))

E_node_stats = map(np.array, zip(*E_node_stats))
E_node_stats = list(map(np.array, list(zip(*E_node_stats))))

return E_init_stats, E_pair_stats, E_node_stats

Expand Down
3 changes: 2 additions & 1 deletion svae/lds/synthetic_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import numpy as np
import numpy.random as npr

Expand Down Expand Up @@ -37,7 +38,7 @@ def generate_data(T, mu_init, sigma_init, A, sigma_states, C, sigma_obs):
D = np.linalg.cholesky(sigma_obs)

broadcast = lambda X, T: X if X.ndim == 3 else [X]*T
As, Bs, Cs, Ds = map(broadcast, [A, B, C, D], [T-1, T-1, T, T])
As, Bs, Cs, Ds = list(map(broadcast, [A, B, C, D], [T-1, T-1, T, T]))

states[0] = mu_init + np.dot(np.linalg.cholesky(sigma_init), npr.randn(n))
data[0] = np.dot(Cs[0], states[0]) + np.dot(Ds[0], npr.randn(p))
Expand Down
11 changes: 6 additions & 5 deletions svae/models/gmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
import autograd.numpy.random as npr
from itertools import repeat
Expand Down Expand Up @@ -90,7 +91,7 @@ def local_meanfield(global_natparam, node_potentials):
def meanfield_fixed_point(label_global, gaussian_globals, node_potentials, tol=1e-3, max_iter=100):
kl = np.inf
label_stats = initialize_meanfield(label_global, node_potentials)
for i in xrange(max_iter):
for i in range(max_iter):
gaussian_natparam, gaussian_stats, gaussian_kl = \
gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
label_natparam, label_stats, label_kl = \
Expand All @@ -100,7 +101,7 @@ def meanfield_fixed_point(label_global, gaussian_globals, node_potentials, tol=1
if abs(kl - prev_kl) < tol:
break
else:
print 'iteration limit reached'
print('iteration limit reached')

return label_stats

Expand All @@ -126,7 +127,7 @@ def initialize_meanfield(label_global, node_potentials):

def make_plotter_2d(recognize, decode, data, num_clusters, params, plot_every):
import matplotlib.pyplot as plt
if data.shape[1] != 2: raise ValueError, 'make_plotter_2d only works with 2D data'
if data.shape[1] != 2: raise ValueError('make_plotter_2d only works with 2D data')

fig, (observation_axis, latent_axis) = plt.subplots(1, 2, figsize=(8,4))
encode_mean, decode_mean = make_encoder_decoder(recognize, decode)
Expand Down Expand Up @@ -169,13 +170,13 @@ def plot_components(ax, params):
dirichlet_natparams, niw_natparams = pgm_params
normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters)
weights = normalize(np.exp(dirichlet.expectedstats(dirichlet_natparams)))
components = map(get_component, niw.expectedstats(niw_natparams))
components = list(map(get_component, niw.expectedstats(niw_natparams)))
lines = repeat(None) if isinstance(ax, plt.Axes) else ax
for weight, (mu, Sigma), line in zip(weights, components, lines):
plot_ellipse(ax, weight, mu, Sigma, line)

def plot(i, val, params, grad):
print('{}: {}'.format(i, val))
print(('{}: {}'.format(i, val)))
if True or (i % plot_every) == (-1 % plot_every):
plot_encoded_means(latent_axis.lines[0], params)
plot_components(latent_axis.lines[1:], params)
Expand Down
1 change: 1 addition & 0 deletions svae/models/iid_gausisan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
import autograd.numpy.random as npr

Expand Down
1 change: 1 addition & 0 deletions svae/models/lds.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division

import autograd.numpy as np
import autograd.numpy.random as npr

Expand Down
Loading