Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MPIGDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
mode='easgd', sync_every=args.sync_every,
worker_optimizer=args.worker_optimizer,
worker_optimizer_params=args.worker_optimizer_params,
elastic_force=args.elastic_force/(comm.Get_size()-1),
elastic_force=args.elastic_force/(min(1,comm.Get_size()-1)),
elastic_lr=args.elastic_lr,
elastic_momentum=args.elastic_momentum)
else:
Expand Down
8 changes: 6 additions & 2 deletions models/get_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
print ("hum")
import numpy as np
import sys

import keras
def get_data(datafile):
#get data for training
#print ('Loading Data from .....', datafile)
Expand All @@ -20,7 +20,11 @@ def get_data(datafile):
X = X.astype(np.float32)
y = y.astype(np.float32)
y = y/100.
ecal = np.squeeze(np.sum(X, axis=(1, 2, 3)))
if keras.backend.image_data_format() !='channels_last':
X =np.moveaxis(X, -1, 1)
ecal = np.squeeze(np.sum(X, axis=(2, 3, 4)))
else:
ecal = np.squeeze(np.sum(X, axis=(1, 2, 3)))
print (X.shape)
print (y.shape)
print (ecal.shape)
Expand Down
34 changes: 26 additions & 8 deletions mpi_learn/train/GanModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,16 @@ def _Model(**args):
else:
return Model(**args)
def discriminator(fixed_bn = False, discr_drop_out=0.2):
if keras.backend.image_data_format() =='channels_last':
dshape=(25, 25, 25,1)
daxis=(1,2,3)
else:
dshape=(1, 25, 25, 25)
daxis=(2,3,4)

image = Input(shape=dshape, name='image')


image = Input(shape=( 25, 25, 25,1 ), name='image')

bnm=2 if fixed_bn else 0
f=(5,5,5)
Expand Down Expand Up @@ -164,19 +172,21 @@ def discriminator(fixed_bn = False, discr_drop_out=0.2):

fake = _Dense(1, activation='sigmoid', name='classification')(dnn_out)
aux = _Dense(1, activation='linear', name='energy')(dnn_out)
ecal = Lambda(lambda x: K.sum(x, axis=(1, 2, 3)), name='sum_cell')(image)
ecal = Lambda(lambda x: K.sum(x, daxis), name='sum_cell')(image)

return _Model(output=[fake, aux, ecal], input=image, name='discriminator_model')

def generator(latent_size=200, return_intermediate=False, with_bn=True):

if keras.backend.image_data_format() =='channels_last':
dim = (7,7,8,8)
else:
dim = (8, 7, 7,8)
latent = Input(shape=(latent_size, ))

bnm=0
x = _Dense(64 * 7* 7, init='glorot_normal',
name='gen_dense1'
)(latent)
x = Reshape((7, 7,8, 8))(x)
x = Reshape(dim)(x)
x = _Conv3D(64, 6, 6, 8, border_mode='same', init='he_uniform',
name='gen_c1'
)(x)
Expand Down Expand Up @@ -212,9 +222,14 @@ def generator(latent_size=200, return_intermediate=False, with_bn=True):
return _Model(input=[latent], output=fake_image, name='generator_model')

def get_sums(images):
sumsx = np.squeeze(np.sum(images, axis=(2,3)))
sumsy = np.squeeze(np.sum(images, axis=(1,3)))
sumsz = np.squeeze(np.sum(images, axis=(1,2)))
if keras.backend.image_data_format() =='channels_last':
sumsx = np.squeeze(np.sum(images, axis=(2,3)))
sumsy = np.squeeze(np.sum(images, axis=(1,3)))
sumsz = np.squeeze(np.sum(images, axis=(1,2)))
else:
sumsx = np.squeeze(np.sum(images, axis=(3,4)))
sumsy = np.squeeze(np.sum(images, axis=(2,4)))
sumsz = np.squeeze(np.sum(images, axis=(2,3)))
return sumsx, sumsy, sumsz

def get_moments(images, sumsx, sumsy, sumsz, totalE, m):
Expand Down Expand Up @@ -535,6 +550,9 @@ def make_opt(**args):
loss=['binary_crossentropy', 'mean_absolute_percentage_error', 'mean_absolute_percentage_error'],
loss_weights=self.discr_loss_weights
)
if kv2:
self.discriminator.trainable = True #workaround for keras 2 bug

self.combined.metrics_names = self.discriminator.metrics_names
print (self.discriminator.metrics_names)
print (self.combined.metrics_names)
Expand Down