Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ The goal of this project is to uncover the best approach to scale large protein
## Installing enviroment

```

conda env create -f protein_lm.yml
conda activate protein_lm_env
pip install -e .
Expand All @@ -22,7 +21,9 @@ pip install -e protein_lm/tokenizer/rust_trie

## Training

An example data file is provided in the `protein_lm/dataset/uniref` folder and an example toy training config yaml that uses this dataset is provided: `protein_lm/configs/train/toy_localcsv.yaml`. To use this config, at the root project directory (e.g., `protein_lm_scaling/`), run
### Toy using local dataset

We recommend using a toy tiny dataset for testing and debugging new changes that do not rely on having a large datset. Such a small dataset is provided in the `protein_lm/dataset/uniref` folder and an example toy training config yaml that uses this dataset is provided in `protein_lm/configs/train/toy_localcsv.yaml`. To use this config, at the root project directory (e.g., `protein_lm_scaling/`), run

```
python protein_lm/modeling/scripts/train.py --config-file protein_lm/configs/train/toy_localcsv.yaml
Expand All @@ -34,12 +35,32 @@ This config is actually the default, so the above is equivalent to
python protein_lm/modeling/scripts/train.py
```

An example config yaml of using a dataset from huggingface is `protein_lm/configs/train/toy_hf.yaml`, which you can run with
### Toy using a HuggingFace dataset

For testing with a HuggingFace dataset, we have an example config yaml in `protein_lm/configs/train/toy_hf.yaml`. Note that training with this config is a little more involved than the above `protein_lm/configs/train/toy_localcsv.yaml`:

* When first run, the script will download the [processed uniref50 dataset](https://huggingface.co/datasets/zpn/uniref50), which could take some time.
* This config will log the loss values and other metrics to Weights and Biases. This will require you to create a wandb account.

You can run with this config by:

```
python protein_lm/modeling/scripts/train.py --config-file protein_lm/configs/train/toy_hf.yaml
```

### Running on multiple gpus

We can run on a single node with multiple gpus by

```
torchrun --standalone --nnodes=1 --nproc-per-node <num_gpus> protein_lm/modeling/scripts/train.py --config-file <config_file>
```

For example, to run on a single node with 3 gpus with the provided `protein_lm/configs/train/toy_hf.yaml` config file, we can run with

```
torchrun --standalone --nnodes=1 --nproc-per-node 3 protein_lm/modeling/scripts/train.py --config-file protein_lm/configs/train/toy_hf.yaml
```

## Getting involved
Your involvement is welcome! If you are interested, you can
Expand Down
1 change: 1 addition & 0 deletions protein_lm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- numpy
- pytorch
- pydantic>=2.0
- wandb
- rust
- pip:
- transformers
Expand Down
28 changes: 19 additions & 9 deletions protein_lm/configs/train/toy_hf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,34 @@
dataset:
dataset_type: "huggingface"
dataset_loc: "zpn/uniref50"
train_sample_size: 100
subsample_size: 1000
split_seed: 2
val_size: 10
test_size: 10
sequence_column_name: "sequence"
max_sequence_length: 10

# corresponds to TrainingArgsConfig
training_args:
output_dir: "checkpoints/toy"
max_steps: 1
num_train_epochs: 1
# corresponds to HuggingFace's TrainingArguments
training_arguments:
output_dir: "checkpoints/toy_hf"
num_train_epochs: 2
learning_rate: 0.1
weight_decay: 0.1
save_strategy: "epoch"
per_device_train_batch_size: 1
save_steps: 1
report_to: "none"
per_device_train_batch_size: 10
save_steps: 5
evaluation_strategy: "steps"
eval_steps: 5
report_to: "wandb"
label_names:
- 'labels'
no_cuda: false
ddp_find_unused_parameters: false

# corresponds to WandBConfig
wandb:
name: "toy_hf"
dir: "wandb_files/"

# corresponds to TokenizerConfig
tokenizer:
Expand Down
9 changes: 6 additions & 3 deletions protein_lm/configs/train/toy_localcsv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
dataset:
dataset_type: "csv"
dataset_loc: "protein_lm/dataset/uniref/uniref50_trimmed.csv"
train_sample_size: 100
subsample_size: 100
split_seed: 2
val_size: 10
test_size: 10
sequence_column_name: "sequence"
max_sequence_length: 10

# corresponds to TrainingArgsConfig
training_args:
# corresponds to HuggingFace's TrainingArguments
training_arguments:
output_dir: "checkpoints/toy"
max_steps: 1
num_train_epochs: 1
Expand Down
97 changes: 89 additions & 8 deletions protein_lm/modeling/getters/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Literal, Optional

from datasets import Dataset, load_dataset
from datasets.dataset_dict import DatasetDict
from pydantic import BaseModel


Expand All @@ -10,8 +11,19 @@ class DatasetConfig(BaseModel):
# The path if local or the huggingface dataset name if huggingface
dataset_loc: str

# train sample size to limit to, if any
train_sample_size: Optional[int] = None
# sample size to limit to, if any, usually for debugging
subsample_size: Optional[int] = None

"""
Args for splitting into train, val, test
to be updated once we have more options
"""
# split seed
split_seed: Optional[int] = None
# size of validation dataset
val_size: int
# size of test dataset
test_size: int

# name of the column that contains the sequence
sequence_column_name: str
Expand Down Expand Up @@ -39,20 +51,89 @@ def set_labels(result):
return result


def get_local_dataset(config: DatasetConfig) -> Dataset:
train_ds = load_dataset("csv", data_files=config.dataset_loc)["train"]
return train_ds
def train_val_test_split(
dataset_dict: DatasetDict,
config: DatasetConfig,
) -> DatasetDict:
"""
Given a dictionary of datasets that only contains the split "train",
optionally subsamples it, and then splits it
so that it has potentially 3 splits: "train", "val", "test", where
"val" and "test" splits do not exist if the specified sizes are 0
"""
assert set(dataset_dict.keys()) == {
"train"
}, f"{train_val_test_split.__name__} expects its input to have the keys \
['train'] but the input has keys {list(dataset_dict.keys())}"

dataset = dataset_dict["train"]

val_size = config.val_size
test_size = config.test_size

assert isinstance(
dataset, Dataset
), f"Invalid dataset type {type(dataset)}, only datasets.Dataset allowed"

dataset = dataset.shuffle(seed=config.split_seed)

if config.subsample_size is not None:
dataset = dataset.select(range(config.subsample_size))

valtest_size = val_size + test_size

if valtest_size > 0:
train_valtest = dataset.train_test_split(
test_size=val_size + test_size,
shuffle=False,
)
split_dict = {
"train": train_valtest["train"],
}
if test_size > 0 and val_size > 0:
test_val = train_valtest["test"].train_test_split(
test_size=test_size,
shuffle=False,
)
split_dict["val"] = test_val["train"]
split_dict["test"] = test_val["test"]
elif val_size > 0:
split_dict["val"] = train_valtest["test"]
else:
split_dict["test"] = train_valtest["test"]
else:
split_dict = {
"train": dataset,
}

split_dataset_dict = DatasetDict(split_dict)
return split_dataset_dict


def get_csv_dataset(config: DatasetConfig) -> Dataset:
# note that a csv is read as having just one split "train"
dataset_dict = load_dataset("csv", data_files=config.dataset_loc)
return train_val_test_split(dataset_dict, config)


def get_huggingface_dataset(config: DatasetConfig) -> Dataset:
train_ds = load_dataset(config.dataset_loc, streaming=True, split="train")
return train_ds
dataset_dict = load_dataset(config.dataset_loc)
if set(dataset_dict.keys()) == {"train", "val", "test"}:
return dataset_dict

assert set(dataset_dict.keys()) == {
"train"
}, f"Huggingface DatasetDicts should have the keys {{'train'}} or \
{{'train', 'val', 'split'}} but this DatasetDict has keys \
{set(dataset_dict.keys())}"
return train_val_test_split(dataset_dict, config)


def get_dataset(config_dict: Dict, tokenizer) -> Dataset:
config = DatasetConfig(**config_dict)

if config.dataset_type == "csv":
train_ds = get_local_dataset(config)
train_ds = get_csv_dataset(config)
elif config.dataset_type == "huggingface":
train_ds = get_huggingface_dataset(config)
else:
Expand Down
3 changes: 0 additions & 3 deletions protein_lm/modeling/getters/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,4 @@ def get_model(config_dict: Dict):
config=model_config,
)

if torch.cuda.is_available():
model.cuda()

return model
36 changes: 3 additions & 33 deletions protein_lm/modeling/getters/training_args.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,14 @@
import os
from typing import Dict, List, Union
from typing import Dict

from pydantic import BaseModel, FieldValidationInfo, field_validator
from transformers import TrainingArguments


class TrainingArgsConfig(BaseModel):
per_device_train_batch_size: int
learning_rate: float
weight_decay: float
num_train_epochs: int
max_steps: int
save_steps: int
output_dir: str
save_strategy: str
report_to: str
label_names: List[str]
no_cuda: bool

@field_validator(
"per_device_train_batch_size",
"num_train_epochs",
"weight_decay",
"learning_rate",
"save_steps",
)
@classmethod
def check_gt_zero(cls, v: Union[int, float], info: FieldValidationInfo):
if v <= 0:
raise ValueError(f"trainer.{info.field_name} must be greater than 0")
return v


def get_training_args(config_dict: Dict) -> TrainingArguments:
config = TrainingArgsConfig(**config_dict)
config = TrainingArguments(**config_dict)

if not os.path.isdir(config.output_dir):
print(f"creating checkpoint directory at {config.output_dir}")
os.makedirs(config.output_dir)

return TrainingArguments(
**config_dict,
)
return config
23 changes: 23 additions & 0 deletions protein_lm/modeling/getters/wandb_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import wandb
from pydantic import BaseModel
from typing import Dict, Optional
import os


class WandBConfig(BaseModel):
project: str = "protein_lm_scaling"
name: str
# directory to save to
dir: Optional[str] = None


def setup_wandb(config_dict: Dict) -> None:
config = WandBConfig(**config_dict)
if config.dir is not None:
if not os.path.isdir(config.dir):
print(f"creating wandb directory at {config.dir}")
os.makedirs(config.dir)

os.environ["WANDB_PROJECT"] = config.project
os.environ["WANDB_NAME"] = config.name
os.environ["WANDB_DIR"] = config.dir
27 changes: 13 additions & 14 deletions protein_lm/modeling/scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import math


import yaml
from transformers import Trainer

Expand All @@ -9,6 +10,7 @@
from protein_lm.modeling.getters.model import get_model
from protein_lm.modeling.getters.tokenizer import get_tokenizer
from protein_lm.modeling.getters.training_args import get_training_args
from protein_lm.modeling.getters.wandb_log import setup_wandb


def train(
Expand All @@ -23,7 +25,7 @@ def train(

tokenizer = get_tokenizer(config_dict=config_dict["tokenizer"])

train_ds = get_dataset(
dataset = get_dataset(
config_dict=config_dict["dataset"],
tokenizer=tokenizer,
)
Expand All @@ -38,27 +40,24 @@ def train(
)

training_args = get_training_args(
config_dict=config_dict["training_args"],
config_dict=config_dict["training_arguments"],
)

if "wandb" in training_args.report_to:
setup_wandb(
config_dict["wandb"],
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
train_dataset=dataset["train"],
eval_dataset=dataset.get("val", None),
data_collator=data_collator,
)

train_result = trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
try:
perplexity = math.exp(metrics["train_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
print("metrics:", metrics)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.train()
trainer.save_model()
trainer.save_state()


Expand Down