1717import sys
1818from tensorflow import compat , config , dtypes
1919
20- import dnngior .NN_Predictor
20+ from dnngior .NN_Predictor import NN
2121
2222# Tensorflow; please consider: https://www.tensorflow.org/api_docs/python/tf/compat/v1/disable_eager_execution
2323compat .v1 .disable_eager_execution ()
@@ -31,7 +31,7 @@ def noise_data(i, noise_0, noise_1, del_p, con_p):
3131 ----------
3232 i : numpy array, required
3333 an array of 0s and 1s which you want to noisify
34- noise_0 : numpy array, required
34+ noise_0 : numpy array, requiredimport dnngior.NN_Predictor
3535 fraction of 0s to change to 1s
3636 noise_1 : numpy array, required
3737 fraction of 1s to change to 0s
@@ -73,7 +73,7 @@ def noise_data(i, noise_0, noise_1, del_p, con_p):
7373 o = temp
7474 return o
7575
76- def generate_training_set (data ,nuplo , min_con , max_con , min_for , max_for , del_p , con_p ):
76+ def generate_feature (data , nuplo , min_con , max_con , min_for , max_for , del_p , con_p ):
7777 """
7878 Function to generate the dataset for training (feature).
7979 PARAMETERS:
@@ -150,7 +150,7 @@ def custom_loss(y_true, y_pred): #y_true is the label, y_pred is the prediction
150150 return bias * (1 - y_true )* loss + (1 - bias )* y_true * loss # return the biased loss y_true are all cases where prediction shouold be 1, 1-y_true all cases where prediction should be one, can scale between these two classes
151151 return custom_loss
152152
153- def train (data , modeltype ,rxn_keys = None ,labels = None ,validation_split = 0.0 ,nuplo = 30 , min_con = 0 , max_con = 0 , min_for = 0.05 , max_for = 0.3 , con_p = None , del_p = None , nlayers = 1 , nnodes = 256 , nepochs = 10 , b_size = 32 , dropout = 0.1 , bias_0 = 0.3 , maskI = True , save = False , name = 'noname ' , output_path = '' , return_history = False ):
153+ def train (data , modeltype ,rxn_keys = None ,labels = None ,validation_split = 0.0 ,nuplo = 30 , min_con = 0 , max_con = 0 , min_for = 0.05 , max_for = 0.3 , con_p = None , del_p = None , nlayers = 1 , nnodes = 256 , nepochs = 10 , b_size = 32 , dropout = 0.1 , bias_0 = 0.3 , maskI = True , save = True , output_path = 'dnngior_predictor.npz ' , return_history = False , return_full_network = False ):
154154 """
155155 Most important function, creates actual NN, there are many optional parameters
156156
@@ -168,7 +168,7 @@ def train(data, modeltype,rxn_keys=None,labels = None,validation_split=0.0,nuplo
168168
169169 TRAINING PARAMETERS:
170170 -------
171- see generate_training_set () ^
171+ see generate_feature () ^
172172
173173 NETWORK PARAMETERS
174174 -------------
@@ -200,39 +200,57 @@ def train(data, modeltype,rxn_keys=None,labels = None,validation_split=0.0,nuplo
200200 SAVING PARAMETERS:
201201
202202 save: boolean, optional
203- Whether you want to save the network, default = False
204- name: string, optional
205- name of your network, default='noname'
203+ Whether you want to save the network, default = True
206204 output_path: string,
207- where output, default=''
205+ Where to save the network, file_extension that work are .h5 and .npz
206+ all other file_extensions defailt to npz (lite network)
207+ default='dnngior_predictor.npz'
208+
209+ OPTIONAL RETURNS:
210+
208211 return_history: boolean, optional
209212 If you want training history
210-
213+ default = False
214+ return_full_network: boolean, optional
215+ if you want to return the lite_network or full tensorflow object
216+ default = False
211217 Returns:
212218 -------------
213219 trainedNN
214220 NN class containing network, rxn_keys and modeltype
215- history: type, if history=True
221+ history: if history=True
216222 history of training
217223 """
218224
219225 print ("Num GPUs Available: " , len (config .list_physical_devices ('GPU' )))
220226
227+ if os .path .exists (output_path ):
228+ print ("# WARNING: overwriting savefile" )
229+ elif os .access (os .path .dirname (output_path ), os .W_OK ):
230+ print ("Saving network at: {}" .format (output_path ))
231+ else :
232+ Exception ("Can not save at: {}" .format (output_path ))
233+
221234 if (isinstance (data , pd .DataFrame )):
222235 rxn_keys = data .index
223236 ndata = np .asarray (data , dtype = np .float32 ).T
224237 elif rxn_keys is None :
225238 raise (Exception ('Provide DataFrame or rxn_keys' ))
226239
227- #create feature and labels from training data
240+ #create feature from training data
228241 if (labels is None ):
229- labels = np .repeat (np .copy (ndata ), nuplo , axis = 0 ).astype (np .float32 )
242+ feature = np .repeat (np .copy (ndata ), nuplo , axis = 0 ).astype (np .float32 )
230243 print ('using data as labels' )
231244 else :
232- labels = np .repeat (np .copy (labels ), nuplo , axis = 0 ).astype (np .float32 )
245+ if (isinstance (labels , pd .DataFrame )):
246+ rxn_keys = data .index
247+ nlabels = np .asarray (labels , dtype = np .float32 ).T
248+ else :
249+ nlabels = labels .astype (np .float32 ).T
250+ feature = np .repeat (np .copy (nlabels ), nuplo , axis = 0 ).astype (np .float32 )
233251 print ("using user provided labels" )
234252
235- train_data = generate_training_set (ndata , nuplo , min_con , max_con , min_for , max_for , del_p , con_p )
253+ train_data = generate_feature (ndata , nuplo , min_con , max_con , min_for , max_for , del_p , con_p )
236254
237255 print ('dataset created' )
238256 nmodels , nreactions = ndata .shape
@@ -253,23 +271,28 @@ def train(data, modeltype,rxn_keys=None,labels = None,validation_split=0.0,nuplo
253271 #print summary of model
254272 network .summary ()
255273 #train model, history can be used to observe training
256- history = network .fit (train_data , labels , validation_split = validation_split , epochs = nepochs , shuffle = True , batch_size = b_size , verbose = 1 )
274+ history = network .fit (train_data , feature , validation_split = validation_split , epochs = nepochs , shuffle = True , batch_size = b_size , verbose = 1 )
257275 pseudo_network = []
258276 for i in range (0 , len (network .layers ),2 ):
259277 pseudo_network .append (network .layers [i ].get_weights ())
260278 pseudo_network = np .asarray (pseudo_network , dtype = object )
261279 #save Network
262280 if (save ):
263- if (save == 'h5' ):
264- network_path = os .path .join (output_path , "{}.h5" .format (name ))
265- with h5py .File (model_path , mode = 'w' ) as f :
281+ if (output_path .endswith ('.h5' )):
282+ with h5py .File (output_path , mode = 'w' ) as f :
266283 network .save (f )
267284 f .attrs ['modeltype' ] = modeltype
268285 f .create_dataset ("rxn_keys" , data = [n .encode ("ascii" , "ignore" ) for n in rxn_keys ])
269286 else :
270- network_path = os .path .join (output_path , "{}.npz" .format (name ))
271- np .savez (network_path ,network = pseudo_network , modeltype = modeltype ,rxn_keys = rxn_keys )
272- trainedNN = NN_Predictor .NN (custom = [pseudo_network ,rxn_keys ,modeltype ])
287+ if not output_path .endswith ('.npz' ):
288+ file_extension = output_path .split ('.' )[- 1 ]
289+ print ('{} not recognized, saving as .npz (lite) instead' .format (file_extension ))
290+ output_path .replace (file_extension , '.npz' )
291+ np .savez (output_path ,network = pseudo_network , modeltype = modeltype ,rxn_keys = rxn_keys )
292+ if return_full_network :
293+ trainedNN = NN (custom = [network ,modeltype ,rxn_keys ])
294+ else :
295+ trainedNN = NN (custom = [pseudo_network ,modeltype ,rxn_keys ])
273296 if return_history :
274297 return trainedNN , history
275298 else :
0 commit comments