Skip to content

Commit 17e08a8

Browse files
authored
Merge pull request #26 from MGXlab/develop
Develop
2 parents c246436 + 0dd1995 commit 17e08a8

22 files changed

+32793
-2400
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import dnngior.gapfill_class.Gapfill
2727
Gapfill(path_to_model)
2828
```
2929

30-
You may find examples of gap-filling a genome scale reconstruction (GEM) with `dnngior` with a complete or a defined medium in this [example notebook](tutorials/example.ipynb). `dnngior` can gapfill both ModelSEED and BiGG models, to gapfill BiGG models you need to specify modeltype.
30+
You may find examples of gap-filling a genome scale reconstruction (GEM) with `dnngior` with a complete or a defined medium in this [example notebook](tutorials/gapfilling_example.ipynb). `dnngior` can gapfill both ModelSEED and BiGG models, to gapfill BiGG models you need to specify modeltype.
3131

3232
```python
3333
Gapfill(path_to_BiGG_model, modeltype='BiGG')
@@ -48,12 +48,13 @@ Alternatively you can find additional custom Neural Networks for several taxonom
4848

4949
## License
5050

51+
5152
Please see [License](LICENSE)
5253

5354

5455
## Cite
5556

56-
The paper that will accompany the tool is currrently available as preprint:\
57-
https://www.biorxiv.org/content/10.1101/2023.07.10.548314v2
57+
The paper that will accompany the tool is can be found here:\
58+
https://www.cell.com/iscience/fulltext/S2589-0042(24)02574-4
5859

5960

dnngior/MSEED_reactions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def parseStoichOnt(self, stoichiometry):
132132

133133
#For empty reaction
134134
if(stoichiometry == ""):
135-
return rxn_cpds_array
135+
return rxn_cpds_dict
136136

137137
for rgt in stoichiometry.split(";"):
138138
(coeff, cpd, cpt, index, name) = rgt.split(":", 4)

dnngior/NN_Predictor.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,40 @@
77
import pandas as pd
88
import cobra.core.model as cobra_model
99
from dnngior.reaction_class import Reaction
10+
from dnngior.variables import *
1011
import os
1112
import sys
1213
from math import exp
1314
from pathlib import Path
1415

1516
class NN:
16-
def __init__(self, modeltype=None, path=None, custom=None):
17+
def __init__(self, path=None, modeltype=None, custom=None):
1718
'''
1819
Light version of the model, saves space, uses only numpy and cobra and no tensorflow
1920
'''
2021

21-
self.path=path
22-
self.__get_pseudo_network()
22+
if custom:
23+
self.network = custom[0]
24+
self.modeltype = custom[1]
25+
self.rxn_keys = custom[2]
26+
else:
27+
if path:
28+
self.path=path
29+
30+
elif modeltype:
31+
if modeltype == 'ModelSEED':
32+
self.path = TRAINED_NN_MSEED
33+
elif modeltype == 'BiGG':
34+
self.path = TRAINED_NN_BIGG
35+
else:
36+
print("Modeltype: {} not recognized, defaulting to ModelSEED".format(modeltype))
37+
self.path = TRAINED_NN_MSEED
38+
else:
39+
print("No path or modeltype provided, defaulting to ModelSEED")
40+
self.path = TRAINED_NN_MSEED
41+
self.__get_pseudo_network()
42+
43+
2344

2445

2546
#Function that loads the Neural network; path is path to .h5 file
@@ -53,10 +74,10 @@ def predict(self, input):
5374
#check if reaction class
5475
input2 = self.__convert_reaction_list(set(input.reactions))
5576
elif isinstance(input, pd.DataFrame):
56-
input.reindex(self.rxn_keys)
77+
input1b = input.reindex(self.rxn_keys).fillna(0.0)
5778
df_columns = input.columns
5879
#Transpose because rows need to be different models for the network
59-
input2 = np.asarray(input.T)
80+
input2 = np.asarray(input1b.T)
6081
elif isinstance(input, dict):
6182
#check if dictionary, get list of reactions and convert
6283
input2 = self.__convert_reaction_list([i for i in input if input[i]==1])
@@ -75,6 +96,9 @@ def predict(self, input):
7596
else:
7697
single_input=False
7798

99+
if not input2.shape[1] == self.network[0][0].shape[0]:
100+
raise Exception("Input size ({}) does not match network ({})".format(input2.shape[1], len(self.rxn_keys)))
101+
78102
a = input2
79103
for layer in self.network:
80104
a = a.clip(0)
@@ -85,6 +109,8 @@ def predict(self, input):
85109
prediction = dict(zip(self.rxn_keys, np.squeeze(prediction)))
86110
if isinstance(input, pd.DataFrame):
87111
prediction = pd.DataFrame(index=self.rxn_keys, columns=df_columns, data=prediction.T)
112+
if len(prediction.index) != len(input.index):
113+
print('Warning mismatch input vs prediction ({})'.format(len(prediction.index) - len(input.index)))
88114
return prediction
89115

90116
#function that generates a binary input based on a list of reaction ids
@@ -106,7 +132,7 @@ def __convert_reaction_list(self, reaction_set):
106132
b_input.append(1)
107133
else:
108134
b_input.append(0)
109-
print("#reactions not found in keys: ", len(set(reaction_set)) - sum(b_input), '/', len(reaction_set))
135+
print("#reactions not found in NN-keys: ", len(set(reaction_set)) - sum(b_input), '/', len(reaction_set))
110136
except:
111137
raise Exception("Conversion failed")
112138

dnngior/NN_Trainer.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import sys
1818
from tensorflow import compat, config, dtypes
1919

20-
import dnngior.NN_Predictor
20+
from dnngior.NN_Predictor import NN
2121

2222
# Tensorflow; please consider: https://www.tensorflow.org/api_docs/python/tf/compat/v1/disable_eager_execution
2323
compat.v1.disable_eager_execution()
@@ -31,7 +31,7 @@ def noise_data(i, noise_0, noise_1, del_p, con_p):
3131
----------
3232
i : numpy array, required
3333
an array of 0s and 1s which you want to noisify
34-
noise_0 : numpy array, required
34+
noise_0 : numpy array, requiredimport dnngior.NN_Predictor
3535
fraction of 0s to change to 1s
3636
noise_1 : numpy array, required
3737
fraction of 1s to change to 0s
@@ -73,7 +73,7 @@ def noise_data(i, noise_0, noise_1, del_p, con_p):
7373
o = temp
7474
return o
7575

76-
def generate_training_set(data,nuplo, min_con, max_con, min_for, max_for, del_p, con_p):
76+
def generate_feature(data, nuplo, min_con, max_con, min_for, max_for, del_p, con_p):
7777
"""
7878
Function to generate the dataset for training (feature).
7979
PARAMETERS:
@@ -150,7 +150,7 @@ def custom_loss(y_true, y_pred): #y_true is the label, y_pred is the prediction
150150
return bias*(1-y_true)*loss+(1-bias)*y_true*loss # return the biased loss y_true are all cases where prediction shouold be 1, 1-y_true all cases where prediction should be one, can scale between these two classes
151151
return custom_loss
152152

153-
def train(data, modeltype,rxn_keys=None,labels = None,validation_split=0.0,nuplo=30, min_con=0, max_con=0, min_for=0.05, max_for=0.3, con_p=None, del_p = None, nlayers=1, nnodes=256, nepochs=10, b_size=32, dropout=0.1, bias_0=0.3, maskI=True, save=False, name='noname', output_path='', return_history=False):
153+
def train(data, modeltype,rxn_keys=None,labels = None,validation_split=0.0,nuplo=30, min_con=0, max_con=0, min_for=0.05, max_for=0.3, con_p=None, del_p = None, nlayers=1, nnodes=256, nepochs=10, b_size=32, dropout=0.1, bias_0=0.3, maskI=True, save=True, output_path='dnngior_predictor.npz', return_history=False, return_full_network=False):
154154
"""
155155
Most important function, creates actual NN, there are many optional parameters
156156
@@ -168,7 +168,7 @@ def train(data, modeltype,rxn_keys=None,labels = None,validation_split=0.0,nuplo
168168
169169
TRAINING PARAMETERS:
170170
-------
171-
see generate_training_set() ^
171+
see generate_feature() ^
172172
173173
NETWORK PARAMETERS
174174
-------------
@@ -200,39 +200,57 @@ def train(data, modeltype,rxn_keys=None,labels = None,validation_split=0.0,nuplo
200200
SAVING PARAMETERS:
201201
202202
save: boolean, optional
203-
Whether you want to save the network, default = False
204-
name: string, optional
205-
name of your network, default='noname'
203+
Whether you want to save the network, default = True
206204
output_path: string,
207-
where output, default=''
205+
Where to save the network, file_extension that work are .h5 and .npz
206+
all other file_extensions defailt to npz (lite network)
207+
default='dnngior_predictor.npz'
208+
209+
OPTIONAL RETURNS:
210+
208211
return_history: boolean, optional
209212
If you want training history
210-
213+
default = False
214+
return_full_network: boolean, optional
215+
if you want to return the lite_network or full tensorflow object
216+
default = False
211217
Returns:
212218
-------------
213219
trainedNN
214220
NN class containing network, rxn_keys and modeltype
215-
history: type, if history=True
221+
history: if history=True
216222
history of training
217223
"""
218224

219225
print("Num GPUs Available: ", len(config.list_physical_devices('GPU')))
220226

227+
if os.path.exists(output_path):
228+
print("# WARNING: overwriting savefile")
229+
elif os.access(os.path.dirname(output_path), os.W_OK):
230+
print("Saving network at: {}".format(output_path))
231+
else:
232+
Exception("Can not save at: {}".format(output_path))
233+
221234
if(isinstance(data, pd.DataFrame)):
222235
rxn_keys = data.index
223236
ndata = np.asarray(data, dtype=np.float32).T
224237
elif rxn_keys is None:
225238
raise(Exception('Provide DataFrame or rxn_keys'))
226239

227-
#create feature and labels from training data
240+
#create feature from training data
228241
if(labels is None):
229-
labels = np.repeat(np.copy(ndata), nuplo, axis=0).astype(np.float32)
242+
feature = np.repeat(np.copy(ndata), nuplo, axis=0).astype(np.float32)
230243
print('using data as labels')
231244
else:
232-
labels = np.repeat(np.copy(labels), nuplo, axis=0).astype(np.float32)
245+
if(isinstance(labels, pd.DataFrame)):
246+
rxn_keys = data.index
247+
nlabels = np.asarray(labels, dtype=np.float32).T
248+
else:
249+
nlabels = labels.astype(np.float32).T
250+
feature = np.repeat(np.copy(nlabels), nuplo, axis=0).astype(np.float32)
233251
print("using user provided labels")
234252

235-
train_data = generate_training_set(ndata, nuplo, min_con, max_con, min_for, max_for, del_p, con_p)
253+
train_data = generate_feature(ndata, nuplo, min_con, max_con, min_for, max_for, del_p, con_p)
236254

237255
print('dataset created')
238256
nmodels, nreactions = ndata.shape
@@ -253,23 +271,28 @@ def train(data, modeltype,rxn_keys=None,labels = None,validation_split=0.0,nuplo
253271
#print summary of model
254272
network.summary()
255273
#train model, history can be used to observe training
256-
history = network.fit(train_data, labels, validation_split = validation_split, epochs = nepochs, shuffle=True, batch_size = b_size, verbose=1)
274+
history = network.fit(train_data, feature, validation_split = validation_split, epochs = nepochs, shuffle=True, batch_size = b_size, verbose=1)
257275
pseudo_network = []
258276
for i in range(0, len(network.layers),2):
259277
pseudo_network.append(network.layers[i].get_weights())
260278
pseudo_network = np.asarray(pseudo_network, dtype=object)
261279
#save Network
262280
if(save):
263-
if(save == 'h5'):
264-
network_path = os.path.join(output_path, "{}.h5".format(name))
265-
with h5py.File(model_path, mode='w') as f:
281+
if(output_path.endswith('.h5')):
282+
with h5py.File(output_path, mode='w') as f:
266283
network.save(f)
267284
f.attrs['modeltype'] = modeltype
268285
f.create_dataset("rxn_keys", data =[n.encode("ascii", "ignore") for n in rxn_keys])
269286
else:
270-
network_path = os.path.join(output_path, "{}.npz".format(name))
271-
np.savez(network_path,network=pseudo_network, modeltype=modeltype,rxn_keys=rxn_keys)
272-
trainedNN = NN_Predictor.NN(custom=[pseudo_network,rxn_keys,modeltype])
287+
if not output_path.endswith('.npz'):
288+
file_extension = output_path.split('.')[-1]
289+
print('{} not recognized, saving as .npz (lite) instead'.format(file_extension))
290+
output_path.replace(file_extension, '.npz')
291+
np.savez(output_path,network=pseudo_network, modeltype=modeltype,rxn_keys=rxn_keys)
292+
if return_full_network:
293+
trainedNN = NN(custom=[network,modeltype,rxn_keys])
294+
else:
295+
trainedNN = NN(custom=[pseudo_network,modeltype,rxn_keys])
273296
if return_history:
274297
return trainedNN, history
275298
else:

0 commit comments

Comments
 (0)