Skip to content

Commit 31dcd8f

Browse files
Leo-T-ZangTianlai ChenTianlai ChenTianlai Chen
authored
Add continous value (plDDT and PPL) support for Curriculum Learning (#58)
Co-authored-by: Tianlai Chen <[email protected]> Co-authored-by: Tianlai Chen <[email protected]> Co-authored-by: Tianlai Chen <[email protected]>
1 parent c92083b commit 31dcd8f

File tree

5 files changed

+241
-13
lines changed

5 files changed

+241
-13
lines changed

protein_lm/configs/train/toy_hf.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@ dataset:
99
sequence_column_name: "sequence"
1010
max_sequence_length: 10
1111
do_curriculum_learning: false
12-
curriculum_learning_strategy: 'sequence_length'
13-
curriculum_learning_column_name: 'sequence_length'
12+
curriculum_learning_strategy:
13+
- 'sequence_length'
14+
- 'ppl'
15+
- 'plddt'
16+
curriculum_learning_column_name:
17+
- 'sequence_length'
18+
- 'ppl'
19+
- 'plddt'
1420

1521
# corresponds to HuggingFace's TrainingArguments
1622
training_arguments:

protein_lm/configs/train/toy_localcsv.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@ dataset:
99
sequence_column_name: "sequence"
1010
max_sequence_length: 10
1111
do_curriculum_learning: false
12-
curriculum_learning_strategy: 'sequence_length'
13-
curriculum_learning_column_name: 'sequence_length'
12+
curriculum_learning_strategy:
13+
- 'sequence_length'
14+
- 'ppl'
15+
- 'plddt'
16+
curriculum_learning_column_name:
17+
- 'sequence_length'
18+
- 'ppl'
19+
- 'plddt'
1420

1521
# corresponds to HuggingFace's TrainingArguments
1622
training_arguments:

protein_lm/modeling/getters/dataset.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,24 @@ def set_input_ids(
4949
return result
5050

5151
def batch_set_curriculum_learning_column(
52-
result = None,
53-
input_column_name = 'sequence',
54-
curriculum_learning_column_name = 'sequence_length',
55-
strategy = 'sequence_length'
52+
result=None,
53+
input_column_name='sequence',
54+
curriculum_learning_column_name='sequence_length',
55+
strategy='sequence_length'
5656
):
57+
supported_strategies = ['sequence_length', 'ppl', 'plddt']
58+
59+
if strategy not in supported_strategies:
60+
raise Exception(f'Invalid {strategy} provided. Supported strategy values include {", ".join(supported_strategies)}')
61+
5762
if strategy == 'sequence_length':
58-
#LengthGroupedSampler sorts in descending so we make it ascending by multplying with -1
63+
# LengthGroupedSampler sorts in descending so we make it ascending by multiplying with -1
5964
result[curriculum_learning_column_name] = [-len(x) for x in result[input_column_name]]
60-
return result
61-
else:
62-
raise Exception(f'invalid {strategy} provided. Supported strategy values include sequence_length')
65+
elif strategy in ['ppl', 'plddt']:
66+
result[curriculum_learning_column_name] = [-x for x in result[strategy]]
67+
68+
return result
69+
6370
def set_labels(result):
6471
result["labels"] = result["input_ids"].copy()
6572
return result

protein_lm/tests/test_cl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def create_random_dataframe(sequence_column_name = 'sequence',curriculum_learnin
9292
random.seed(42)
9393
df = pd.DataFrame()
9494
def create_sequence(length):
95-
seq = ''.join(random.choice(['A','T','G','C']) for _ in range(length))
95+
seq = ''.join(random.choice('ACDEFGHIKLMNPQRSTVWY') for _ in range(length))
9696
return seq
9797

9898
if curriculum_learning_strategy == 'sequence_length':
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import os
2+
import pytest
3+
import torch
4+
import torch.nn as nn
5+
import yaml
6+
from transformers import Trainer
7+
from protein_lm.modeling.getters.data_collator import get_data_collator
8+
from protein_lm.modeling.getters.model import get_model
9+
from protein_lm.modeling.getters.tokenizer import get_tokenizer
10+
from protein_lm.modeling.getters.training_args import get_training_args
11+
from datasets import Dataset, load_dataset
12+
from datasets.dataset_dict import DatasetDict
13+
from pydantic import BaseModel
14+
from protein_lm.modeling.getters.dataset import DatasetConfig,get_csv_dataset,set_input_ids,set_labels,batch_set_curriculum_learning_column
15+
##data collator imports
16+
from dataclasses import dataclass
17+
from typing import Dict, Literal,Any, Callable, Dict, List, NewType, Optional, Tuple, Union
18+
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
19+
import pandas as pd
20+
import random
21+
22+
CONFIG_PATH = "protein_lm/configs/train/toy_localcsv.yaml"
23+
strategies = ['ppl']
24+
strategy2col = {'ppl': 'ppl'} #mapping of strategy to the computed column name storing the values of respective strategy
25+
total = 0 #number of batches/steps
26+
unsorted = 0 #number of unsorted batches/steps
27+
InputDataClass = NewType("InputDataClass", Any)
28+
29+
global max_value_of_previous_batch
30+
max_value_of_previous_batch = None
31+
global batch_comparison_values
32+
batch_comparison_values = []
33+
34+
def cl_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
35+
global total
36+
global unsorted
37+
global max_value_of_previous_batch
38+
global batch_comparison_values
39+
"""
40+
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
41+
potential keys named:
42+
43+
- ``label``: handles a single value (int or float) per object
44+
- ``label_ids``: handles a list of values per object
45+
46+
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
47+
to the model. See glue and ner for example of how it's useful.
48+
"""
49+
50+
# In this function we'll make the assumption that all `features` in the batch
51+
# have the same attributes.
52+
# So we will look at the first element as a proxy for what attributes exist
53+
# on the whole batch.
54+
if not isinstance(features[0], (dict, BatchEncoding)):
55+
features = [vars(f) for f in features]
56+
57+
first = features[0]
58+
batch = {}
59+
60+
# Special handling for labels.
61+
# Ensure that tensor is created with the correct type
62+
# (it should be automatically the case, but let's make sure of it.)
63+
if "label" in first and first["label"] is not None:
64+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
65+
dtype = torch.long if isinstance(label, int) else torch.float
66+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
67+
elif "label_ids" in first and first["label_ids"] is not None:
68+
if isinstance(first["label_ids"], torch.Tensor):
69+
batch["labels"] = torch.stack([f["label_ids"] for f in features])
70+
else:
71+
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
72+
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
73+
74+
# Handling of all other possible keys.
75+
# Again, we will use the first element to figure out which key/values are not None for this model.
76+
for k, v in first.items():
77+
78+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
79+
if isinstance(v, torch.Tensor):
80+
batch[k] = torch.stack([f[k] for f in features])
81+
else:
82+
if k == 'ppl':
83+
batch[k] = [-f[k] for f in features]
84+
else:
85+
batch[k] = torch.tensor([f[k] for f in features])
86+
lens = batch['ppl']
87+
print('######lens(cl_data_collator)#########')
88+
print(lens)
89+
total = total + 1
90+
try:
91+
assert lens == sorted(lens)
92+
except:
93+
unsorted = unsorted + 1
94+
print('not sorted')
95+
96+
# Compare between currect batch and previous one
97+
# Append min of current batch and placeholder
98+
batch_comparison_values.append([lens[0], None])
99+
100+
if max_value_of_previous_batch is not None:
101+
# Append max of the previous batch
102+
batch_comparison_values[-1][1] = max_value_of_previous_batch
103+
104+
max_value_of_previous_batch = lens[-1]
105+
106+
return {'input_ids':batch['input_ids'],'labels': batch['labels']}
107+
108+
109+
def create_random_dataframe(sequence_column_name = 'sequence',
110+
curriculum_learning_column_name = 'ppl',
111+
curriculum_learning_strategy = 'ppl',
112+
max_sequence_length = 30,
113+
max_perplexity = 100.0,
114+
n = 5000):
115+
assert max_sequence_length > 2
116+
random.seed(42)
117+
df = pd.DataFrame()
118+
def create_sequence(length):
119+
seq = ''.join(random.choice('ACDEFGHIKLMNPQRSTVWY') for _ in range(length))
120+
return seq
121+
122+
if curriculum_learning_strategy == 'ppl':
123+
df[sequence_column_name] = [create_sequence(random.randint(2, max_sequence_length)) for i in range(n)]
124+
df[curriculum_learning_column_name] = [random.uniform(1.0, max_perplexity) for _ in range(n)]
125+
return df
126+
127+
@pytest.mark.parametrize("strategy",strategies)
128+
def test_curriculum_learning(strategy):
129+
130+
with open(CONFIG_PATH, "r") as cf:
131+
print('loading file.....')
132+
config_dict = yaml.safe_load(cf)
133+
134+
config_dict['dataset']['max_sequence_length'] = 40
135+
config_dict['dataset']['do_curriculum_learning'] = True
136+
config_dict['dataset']['curriculum_learning_column_name'] = strategy2col[strategy]
137+
config_dict['dataset']['curriculum_learning_strategy'] = strategy
138+
config_dict['dataset']['val_size'] = 100
139+
config_dict['dataset']['test_size'] = 100
140+
config_dict['dataset']['subsample_size'] = 500
141+
config_dict["training_arguments"]['group_by_length'] = True
142+
config_dict["training_arguments"]['length_column_name'] = config_dict['dataset']['curriculum_learning_column_name']
143+
config_dict["training_arguments"]['remove_unused_columns'] = False # this is necessary to keep curriculum_learning_column_name
144+
config_dict["training_arguments"]['per_device_train_batch_size'] = 20
145+
config_dict["training_arguments"]['max_steps'] = -1
146+
config_dict["training_arguments"]['num_train_epochs'] = 2
147+
148+
print(config_dict)
149+
150+
tokenizer = get_tokenizer(config_dict=config_dict["tokenizer"])
151+
dataset = DatasetDict()
152+
val_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']['curriculum_learning_column_name'],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['val_size'] )
153+
test_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']["curriculum_learning_column_name"],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['test_size'] )
154+
train_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']["curriculum_learning_column_name"],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['subsample_size'] )
155+
156+
dataset['train'] = Dataset.from_pandas(train_df)
157+
dataset['val'] = Dataset.from_pandas(val_df)
158+
dataset['test'] = Dataset.from_pandas(test_df)
159+
dataset = dataset.map(
160+
lambda e: set_input_ids(
161+
result=e,
162+
tokenizer=tokenizer,
163+
sequence_column_name=config_dict['dataset']['sequence_column_name'],
164+
max_sequence_length=config_dict['dataset']['max_sequence_length'],
165+
),
166+
batched=True,
167+
)
168+
dataset = dataset.map(set_labels, batched=True)
169+
dataset = dataset.map(lambda e: batch_set_curriculum_learning_column(
170+
result = e,
171+
input_column_name = config_dict['dataset']['sequence_column_name'],
172+
curriculum_learning_column_name = config_dict['dataset']['curriculum_learning_column_name'],
173+
strategy = config_dict['dataset']['curriculum_learning_strategy']
174+
175+
),batched=True)
176+
dataset = dataset.select_columns(['input_ids', 'labels', strategy2col[strategy]])
177+
model = get_model(
178+
config_dict=config_dict["model"],
179+
)
180+
181+
training_args = get_training_args(
182+
config_dict=config_dict["training_arguments"],
183+
)
184+
185+
trainer = Trainer(
186+
model=model,
187+
args=training_args,
188+
train_dataset=dataset["train"],
189+
eval_dataset=dataset.get("val", None),
190+
data_collator=cl_data_collator,
191+
)
192+
193+
trainer.train()
194+
195+
threshold = 10
196+
num = 0
197+
# Iterate over the list
198+
print(batch_comparison_values)
199+
for i in batch_comparison_values:
200+
print(i)
201+
current_min_val, previous_max_val = i
202+
if previous_max_val is not None:
203+
if current_min_val < previous_max_val and previous_max_val - current_min_val <= threshold:
204+
num += 1
205+
assert num == 0
206+
percentage_unsorted = int((unsorted / total) * 100) #computing the number of times the list in collator was not sorted
207+
#there are sometimes cases where the list is off by a few entries aa the LengthGroupedSampler has a bit of randomness
208+
print(f'percentage_unsorted:{percentage_unsorted}')
209+
assert percentage_unsorted < 10 # just a rough heuristic

0 commit comments

Comments
 (0)