Skip to content

Commit 0e48cff

Browse files
talkhanzrich
andauthored
Curriculum Learning for sequence_length (#56)
Co-authored-by: rich <[email protected]>
1 parent 359abff commit 0e48cff

File tree

5 files changed

+212
-3
lines changed

5 files changed

+212
-3
lines changed

protein_lm/configs/train/toy_hf.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ dataset:
88
test_size: 10
99
sequence_column_name: "sequence"
1010
max_sequence_length: 10
11+
do_curriculum_learning: false
12+
curriculum_learning_strategy: 'sequence_length'
13+
curriculum_learning_column_name: 'sequence_length'
1114

1215
# corresponds to HuggingFace's TrainingArguments
1316
training_arguments:

protein_lm/configs/train/toy_localcsv.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ dataset:
88
test_size: 10
99
sequence_column_name: "sequence"
1010
max_sequence_length: 10
11+
do_curriculum_learning: false
12+
curriculum_learning_strategy: 'sequence_length'
13+
curriculum_learning_column_name: 'sequence_length'
1114

1215
# corresponds to HuggingFace's TrainingArguments
1316
training_arguments:

protein_lm/modeling/getters/dataset.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ class DatasetConfig(BaseModel):
2727

2828
# name of the column that contains the sequence
2929
sequence_column_name: str
30-
30+
3131
max_sequence_length: int
32+
do_curriculum_learning: bool
33+
curriculum_learning_strategy: str
34+
curriculum_learning_column_name: str
3235

3336

3437
def set_input_ids(
@@ -45,7 +48,18 @@ def set_input_ids(
4548
)
4649
return result
4750

48-
51+
def batch_set_curriculum_learning_column(
52+
result = None,
53+
input_column_name = 'sequence',
54+
curriculum_learning_column_name = 'sequence_length',
55+
strategy = 'sequence_length'
56+
):
57+
if strategy == 'sequence_length':
58+
#LengthGroupedSampler sorts in descending so we make it ascending by multplying with -1
59+
result[curriculum_learning_column_name] = [-len(x) for x in result[input_column_name]]
60+
return result
61+
else:
62+
raise Exception(f'invalid {strategy} provided. Supported strategy values include sequence_length')
4963
def set_labels(result):
5064
result["labels"] = result["input_ids"].copy()
5165
return result
@@ -149,4 +163,13 @@ def get_dataset(config_dict: Dict, tokenizer) -> Dataset:
149163
batched=True,
150164
)
151165
train_ds = train_ds.map(set_labels, batched=True)
166+
if config.do_curriculum_learning:
167+
train_ds = train_ds.map(lambda e: batch_set_curriculum_learning_column(
168+
result = e,
169+
input_column_name = config.sequence_column_name,
170+
curriculum_learning_column_name = config.curriculum_learning_column_name,
171+
strategy = config.curriculum_learning_strategy
172+
173+
),batched=True)
174+
152175
return train_ds

protein_lm/modeling/scripts/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ def train(
3838
data_collator = get_data_collator(
3939
config_dict=config_dict["data_collator"],
4040
)
41-
41+
if config_dict['dataset']['do_curriculum_learning']:
42+
#groupy_by_length uses the LengthGroupedSampler,
43+
#we have precomputed the lengths (or any discrete column) which can be used as sampling criteria
44+
config_dict["training_arguments"]['group_by_length'] = config_dict['dataset']['do_curriculum_learning']
45+
config_dict["training_arguments"]['length_column_name'] = config_dict['dataset']['curriculum_learning_column_name']
46+
4247
training_args = get_training_args(
4348
config_dict=config_dict["training_arguments"],
4449
)

protein_lm/tests/test_cl.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import os
2+
import pytest
3+
import torch
4+
import torch.nn as nn
5+
import yaml
6+
from transformers import Trainer
7+
from protein_lm.modeling.getters.data_collator import get_data_collator
8+
from protein_lm.modeling.getters.model import get_model
9+
from protein_lm.modeling.getters.tokenizer import get_tokenizer
10+
from protein_lm.modeling.getters.training_args import get_training_args
11+
from datasets import Dataset, load_dataset
12+
from datasets.dataset_dict import DatasetDict
13+
from pydantic import BaseModel
14+
from protein_lm.modeling.getters.dataset import DatasetConfig,get_csv_dataset,set_input_ids,set_labels,batch_set_curriculum_learning_column
15+
##data collator imports
16+
from dataclasses import dataclass
17+
from typing import Dict, Literal,Any, Callable, Dict, List, NewType, Optional, Tuple, Union
18+
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
19+
import pandas as pd
20+
import random
21+
22+
CONFIG_PATH = "protein_lm/configs/train/toy_localcsv.yaml"
23+
strategies = ['sequence_length']
24+
strategy2col = {'sequence_length': 'sequence_length'} #mapping of strategy to the computed column name storing the values of respective strategy
25+
total = 0 #number of batches/steps
26+
unsorted = 0 #number of unsorted batches/steps
27+
InputDataClass = NewType("InputDataClass", Any)
28+
def cl_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
29+
global total
30+
global unsorted
31+
"""
32+
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
33+
potential keys named:
34+
35+
- ``label``: handles a single value (int or float) per object
36+
- ``label_ids``: handles a list of values per object
37+
38+
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
39+
to the model. See glue and ner for example of how it's useful.
40+
"""
41+
42+
# In this function we'll make the assumption that all `features` in the batch
43+
# have the same attributes.
44+
# So we will look at the first element as a proxy for what attributes exist
45+
# on the whole batch.
46+
if not isinstance(features[0], (dict, BatchEncoding)):
47+
features = [vars(f) for f in features]
48+
49+
first = features[0]
50+
batch = {}
51+
52+
# Special handling for labels.
53+
# Ensure that tensor is created with the correct type
54+
# (it should be automatically the case, but let's make sure of it.)
55+
if "label" in first and first["label"] is not None:
56+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
57+
dtype = torch.long if isinstance(label, int) else torch.float
58+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
59+
elif "label_ids" in first and first["label_ids"] is not None:
60+
if isinstance(first["label_ids"], torch.Tensor):
61+
batch["labels"] = torch.stack([f["label_ids"] for f in features])
62+
else:
63+
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
64+
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
65+
66+
# Handling of all other possible keys.
67+
# Again, we will use the first element to figure out which key/values are not None for this model.
68+
for k, v in first.items():
69+
70+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
71+
if isinstance(v, torch.Tensor):
72+
batch[k] = torch.stack([f[k] for f in features])
73+
else:
74+
if k == 'sequence_length':
75+
batch[k] = [-f[k] for f in features]
76+
else:
77+
batch[k] = torch.tensor([f[k] for f in features])
78+
lens = batch['sequence_length']
79+
print('######lens(cl_data_collator)#########')
80+
print(lens)
81+
total = total + 1
82+
try:
83+
assert lens == sorted(lens)
84+
except:
85+
unsorted = unsorted + 1
86+
print('not sorted')
87+
return {'input_ids':batch['input_ids'],'labels': batch['labels']}
88+
89+
90+
def create_random_dataframe(sequence_column_name = 'sequence',curriculum_learning_column_name = 'sequence_length',curriculum_learning_strategy = 'sequence_length',max_sequence_length = 30, n = 5000):
91+
assert max_sequence_length > 2
92+
random.seed(42)
93+
df = pd.DataFrame()
94+
def create_sequence(length):
95+
seq = ''.join(random.choice(['A','T','G','C']) for _ in range(length))
96+
return seq
97+
98+
if curriculum_learning_strategy == 'sequence_length':
99+
df[sequence_column_name] = [create_sequence(random.randint(2, max_sequence_length)) for i in range(n)]
100+
df[curriculum_learning_column_name] = df[sequence_column_name].apply(lambda x: len(x))
101+
return df
102+
103+
@pytest.mark.parametrize("strategy",strategies)
104+
def test_curriculum_learning(strategy):
105+
106+
with open(CONFIG_PATH, "r") as cf:
107+
print('loading file.....')
108+
config_dict = yaml.safe_load(cf)
109+
110+
config_dict['dataset']['max_sequence_length'] = 40
111+
config_dict['dataset']['do_curriculum_learning'] = True
112+
config_dict['dataset']['curriculum_learning_column_name'] = strategy2col[strategy]
113+
config_dict['dataset']['curriculum_learning_strategy'] = strategy
114+
config_dict['dataset']['val_size'] = 100
115+
config_dict['dataset']['test_size'] = 100
116+
config_dict['dataset']['subsample_size'] = 500
117+
config_dict["training_arguments"]['group_by_length'] = True
118+
config_dict["training_arguments"]['length_column_name'] = config_dict['dataset']['curriculum_learning_column_name']
119+
config_dict["training_arguments"]['remove_unused_columns'] = False # this is necessary to keep curriculum_learning_column_name
120+
config_dict["training_arguments"]['per_device_train_batch_size'] = 20
121+
config_dict["training_arguments"]['max_steps'] = -1
122+
config_dict["training_arguments"]['num_train_epochs'] = 2
123+
124+
print(config_dict)
125+
126+
tokenizer = get_tokenizer(config_dict=config_dict["tokenizer"])
127+
dataset = DatasetDict()
128+
val_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']['curriculum_learning_column_name'],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['val_size'] )
129+
test_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']["curriculum_learning_column_name"],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['test_size'] )
130+
train_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']["curriculum_learning_column_name"],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['subsample_size'] )
131+
132+
dataset['train'] = Dataset.from_pandas(train_df)
133+
dataset['val'] = Dataset.from_pandas(val_df)
134+
dataset['test'] = Dataset.from_pandas(test_df)
135+
dataset = dataset.map(
136+
lambda e: set_input_ids(
137+
result=e,
138+
tokenizer=tokenizer,
139+
sequence_column_name=config_dict['dataset']['sequence_column_name'],
140+
max_sequence_length=config_dict['dataset']['max_sequence_length'],
141+
),
142+
batched=True,
143+
)
144+
dataset = dataset.map(set_labels, batched=True)
145+
dataset = dataset.map(lambda e: batch_set_curriculum_learning_column(
146+
result = e,
147+
input_column_name = config_dict['dataset']['sequence_column_name'],
148+
curriculum_learning_column_name = config_dict['dataset']['curriculum_learning_column_name'],
149+
strategy = config_dict['dataset']['curriculum_learning_strategy']
150+
151+
),batched=True)
152+
dataset = dataset.select_columns(['input_ids', 'labels', strategy2col[strategy]])
153+
model = get_model(
154+
config_dict=config_dict["model"],
155+
)
156+
157+
training_args = get_training_args(
158+
config_dict=config_dict["training_arguments"],
159+
)
160+
161+
trainer = Trainer(
162+
model=model,
163+
args=training_args,
164+
train_dataset=dataset["train"],
165+
eval_dataset=dataset.get("val", None),
166+
data_collator=cl_data_collator,
167+
)
168+
169+
trainer.train()
170+
percentage_unsorted = int((unsorted / total) * 100) #computing the number of times the list in collator was not sorted
171+
#there are sometimes cases where the list is off by a few entries aa the LengthGroupedSampler has a bit of randomness
172+
print(f'percentage_unsorted:{percentage_unsorted}')
173+
assert percentage_unsorted < 10 # just a rough heuristic
174+
175+

0 commit comments

Comments
 (0)