-
Notifications
You must be signed in to change notification settings - Fork 14
Description
Below is the discussion post I made about being unable to run inference on a ginet.GINet model that I trained. Specifically, I could not separately load and construct a new GraphDataset and run the Trainer.test() method with my pre-trained model due to a lack of clustering in my data even though clustering_method='mcl' was set when constructing the GraphDataset.
I found that this is because Trainer._precluster(GraphDataset) is not called when a pre-trained model is loaded to construct the trainer object. The relevant lines of code are linked in my discussion post pasted below
Hi everyone,
I recently got started using DeepRank2 to see if it can make a capable predictor on some PPI structure data that I have. I successfully installed the GPU version using Conda (and pip), and have been able to train a GINet model following the tutorial pretty closely. The code I used is below. I know the data generation was successful since I am able to visualize it using GraphDataset.hd5topandas()
### DATA GENERATION ###
input_pdbs = input_data_path.glob(pattern = "*grouped.pdb")
queries = QueryCollection()
influence_radius = 8
max_edge_length = 8
print("Loading structures to be processed...")
for structure in input_pdbs:
binary_target = designate_binary_class(structure.stem)
if binary_target is None:
warn(message = f"Class annotation not recognized for {structure.stem}. Excluding structure")
continue
queries.add(
ProteinProteinInterfaceQuery(pdb_path = str(structure),
resolution = 'residue',
chain_ids = ['T', 'B'],
influence_radius = influence_radius,
max_edge_length = max_edge_length,
targets = {'binary': binary_target})
)
print(len(queries), "structures loaded")
print("Generating graph representations...")
graph_file = queries.process(
prefix = str(output_root),
combine_output = True
)
### TRAINING CODE ###
transformations = {'all': {'transform': None, 'standardize': True}}
dataset_train = GraphDataset(
hdf5_path = data_files, #data_files points to the same file path as graph_file above
subset = list(df_train.entry),
features_transform = transformations,
target = target,
task = task,
clustering_method = 'mcl'
)
dataset_val = GraphDataset(
hdf5_path = data_files,
subset = list(df_val.entry),
train_source = dataset_train,
clustering_method = 'mcl'
)
dataset_test = GraphDataset(
hdf5_path = data_files,
subset = list(df_test.entry),
train_source = dataset_train,
clustering_method = 'mcl'
)
model = Trainer(
GINet,
dataset_train,
dataset_val,
dataset_test,
cuda = True,
ngpu = 3,
output_exporters = [HDF5OutputExporter(output_dir)]
)
model.configure_optimizers(torch.optim.Adamax, lr = 0.001, weight_decay = 1e-04)
model.train(
nepoch = 50,
batch_size = 64,
earlystop_patience = 10,
earlystop_maxgap = 0.1,
min_epoch = 5,
validate = True,
)
result = model.test()Not shown: Loading the data and splitting it using train_test_split() basically the same as in the tutorial
My issue is when I try to run an inference script that I wrote to load a totally separate dataset (not derived from the training data) and run a class prediction using the pre-trained GINet. I construct the GraphDataset with only dataset_train and with the clustering_method and train_source set to match my pre-trained model, but I still get the error when running Trainer.test() 'GlobalStorage' object has no attribute 'cluster0' which I know means that there is no clustering on my dataset even though I have that option set. Code below:
queries = QueryCollection()
for pdb in input_files:
queries.add(ProteinProteinInterfaceQuery(
pdb_path = pdb,
resolution = 'residue',
chain_ids = ['B', 'T'],
), verbose = True)
inference_data = queries.process(
prefix = os.path.join(data_output_dir, os.path.basename(input_dir)),
log_error_traceback = True,
combine_output = True
)
input_data = GraphDataset(hdf5_path = inference_data,
clustering_method = 'mcl', # I thought this would address any clustering errors
train_source = "model.pth.tar",
use_tqdm = True)
model = Trainer(
GINet,
dataset_test = input_data,
pretrained_model = "model.pth.tar",
output_exporters = [HDF5OutputExporter("Predictions")]
)
# No errors raised up to this point
results = model.test() # Raises 'GlobalStorage' object has no attribute 'cluster0' errorSo I'm wondering if there's an issue with the way i do the data generation for my new dataset where the clustering operation isn't happening correctly and thus isn't recognized by the model, or if there could be an issue with the data itself which I can't open with GraphDataset.hd5topandas(). Or anything else that could be causing this error.
For reference, I am running version 3.1.0 on Python 3.10.0 with torch 2.1.1 and PyG 2.4.0. OS is Ubuntu 22.04, and I edit and run the code using the interactive code cell feature on VSCode. GPU is A100.
I am only passingly familiar with Torch and PyG (mostly through using this package). Any help is appreciated. Thanks!
UPDATE:
I think I have figured out why and a way around it, but this may warrant an addition to the source code.
deeprank2/deeprank2/trainer.py
Lines 131 to 184 in 78e2773
| if self.pretrained_model is None: | |
| if self.dataset_train is None: | |
| msg = "No training data specified. Training data is required if there is no pretrained model." | |
| raise ValueError(msg) | |
| if self.neuralnet is None: | |
| msg = "No neural network specified. Specifying a model framework is required if there is no pretrained model." | |
| raise ValueError(msg) | |
| self._init_from_dataset(self.dataset_train) | |
| self.optimizer = None | |
| self.class_weights = class_weights | |
| self.subset = self.dataset_train.subset | |
| self.epoch_saved_model = None | |
| if self.target is None: | |
| msg = "No target set. You need to choose a target (set in the dataset) for training." | |
| raise ValueError(msg) | |
| self._load_model() | |
| # clustering the datasets | |
| if self.clustering_method is not None: | |
| if self.clustering_method in ("mcl", "louvain"): | |
| _log.info("Loading clusters") | |
| self._precluster(self.dataset_train) | |
| if self.dataset_val is not None: | |
| self._precluster(self.dataset_val) | |
| else: | |
| _log.warning("No validation dataset given. Randomly splitting training set in training set and validation set.") | |
| self.dataset_train, self.dataset_val = _divide_dataset(self.dataset_train, splitsize=self.val_size) | |
| if self.dataset_test is not None: | |
| self._precluster(self.dataset_test) | |
| else: | |
| msg = f"Invalid node clustering method: {self.clustering_method}. Please set clustering_method to 'mcl', 'louvain' or None." | |
| raise ValueError(msg) | |
| else: | |
| if self.neuralnet is None: | |
| msg = "No neural network class found. Please add it to complete loading the pretrained model." | |
| raise ValueError(msg) | |
| if self.dataset_test is None: | |
| msg = "No dataset_test found. Please add it to evaluate the pretrained model." | |
| raise ValueError(msg) | |
| if self.dataset_train is not None: | |
| self.dataset_train = None | |
| _log.warning("Pretrained model loaded: dataset_train will be ignored.") | |
| if self.dataset_val is not None: | |
| self.dataset_val = None | |
| _log.warning("Pretrained model loaded: dataset_val will be ignored.") | |
| self._init_from_dataset(self.dataset_test) | |
| self._load_params() | |
| self._load_pretrained_model() |
This code block only computes clusters if not loading a pre-trained model. When loading only a dataset_test with a pre-trained model, the self._precluster() method is not run even if clustering_method is set on the GraphDataset. This behaviour is incompatible with ginet.GINet since the architecture requires community pooling to function.
For now, adding this line to my code model._precluster(input_data) (reusing the definitions in my code above) forces the Trainer instance to compute clusters on my dataset. This can be added after constructing the Trainer without issue, if I'm not mistaken, since the clustering information is saved to the .hdf5 file.
I may add an issue referencing this behavior as it'd be nice to have this step added to the source code.
Still I'd like confirmation that this is a valid solution and won't introduce any silent bugs, particularly in my classification results.
Thanks