-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
Description
Checkpoints saved by train_binary.py cannot be loaded by inference_PPI_singleGPU.py due to architecture mismatch.
Problem
train_binary.py uses AutoModelForSequenceClassification:
pythonself.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config, **automodel_args)
But inference_PPI_singleGPU.py uses PLMinteract which wraps AutoModelForMaskedLM:
pythonself.esm_mask = AutoModelForMaskedLM.from_pretrained(model_name, config=config)
This causes a RuntimeError when loading checkpoints:
RuntimeError: Error(s) in loading state_dict for PLMinteract:
Missing key(s) in state_dict: "esm_mask.esm.embeddings.word_embeddings.weight", ...
Unexpected key(s) in state_dict: "esm.embeddings.word_embeddings.weight", ...
Question
What is the intended inference script for checkpoints trained with train_binary.py? Is there a separate inference script I should be using, or should I modify inference_PPI_singleGPU.py to handle both checkpoint formats?
Metadata
Metadata
Assignees
Labels
No labels