1+ import os
2+ import pytest
3+ import torch
4+ import torch .nn as nn
5+ import yaml
6+ from transformers import Trainer
7+ from protein_lm .modeling .getters .data_collator import get_data_collator
8+ from protein_lm .modeling .getters .model import get_model
9+ from protein_lm .modeling .getters .tokenizer import get_tokenizer
10+ from protein_lm .modeling .getters .training_args import get_training_args
11+ from datasets import Dataset , load_dataset
12+ from datasets .dataset_dict import DatasetDict
13+ from pydantic import BaseModel
14+ from protein_lm .modeling .getters .dataset import DatasetConfig ,get_csv_dataset ,set_input_ids ,set_labels ,batch_set_curriculum_learning_column
15+ ##data collator imports
16+ from dataclasses import dataclass
17+ from typing import Dict , Literal ,Any , Callable , Dict , List , NewType , Optional , Tuple , Union
18+ from transformers .tokenization_utils_base import BatchEncoding , PreTrainedTokenizerBase
19+ import pandas as pd
20+ import random
21+
22+ CONFIG_PATH = "protein_lm/configs/train/toy_localcsv.yaml"
23+ strategies = ['ppl' ]
24+ strategy2col = {'ppl' : 'ppl' } #mapping of strategy to the computed column name storing the values of respective strategy
25+ total = 0 #number of batches/steps
26+ unsorted = 0 #number of unsorted batches/steps
27+ InputDataClass = NewType ("InputDataClass" , Any )
28+
29+ global max_value_of_previous_batch
30+ max_value_of_previous_batch = None
31+ global batch_comparison_values
32+ batch_comparison_values = []
33+
34+ def cl_data_collator (features : List [InputDataClass ]) -> Dict [str , torch .Tensor ]:
35+ global total
36+ global unsorted
37+ global max_value_of_previous_batch
38+ global batch_comparison_values
39+ """
40+ Very simple data collator that simply collates batches of dict-like objects and performs special handling for
41+ potential keys named:
42+
43+ - ``label``: handles a single value (int or float) per object
44+ - ``label_ids``: handles a list of values per object
45+
46+ Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
47+ to the model. See glue and ner for example of how it's useful.
48+ """
49+
50+ # In this function we'll make the assumption that all `features` in the batch
51+ # have the same attributes.
52+ # So we will look at the first element as a proxy for what attributes exist
53+ # on the whole batch.
54+ if not isinstance (features [0 ], (dict , BatchEncoding )):
55+ features = [vars (f ) for f in features ]
56+
57+ first = features [0 ]
58+ batch = {}
59+
60+ # Special handling for labels.
61+ # Ensure that tensor is created with the correct type
62+ # (it should be automatically the case, but let's make sure of it.)
63+ if "label" in first and first ["label" ] is not None :
64+ label = first ["label" ].item () if isinstance (first ["label" ], torch .Tensor ) else first ["label" ]
65+ dtype = torch .long if isinstance (label , int ) else torch .float
66+ batch ["labels" ] = torch .tensor ([f ["label" ] for f in features ], dtype = dtype )
67+ elif "label_ids" in first and first ["label_ids" ] is not None :
68+ if isinstance (first ["label_ids" ], torch .Tensor ):
69+ batch ["labels" ] = torch .stack ([f ["label_ids" ] for f in features ])
70+ else :
71+ dtype = torch .long if type (first ["label_ids" ][0 ]) is int else torch .float
72+ batch ["labels" ] = torch .tensor ([f ["label_ids" ] for f in features ], dtype = dtype )
73+
74+ # Handling of all other possible keys.
75+ # Again, we will use the first element to figure out which key/values are not None for this model.
76+ for k , v in first .items ():
77+
78+ if k not in ("label" , "label_ids" ) and v is not None and not isinstance (v , str ):
79+ if isinstance (v , torch .Tensor ):
80+ batch [k ] = torch .stack ([f [k ] for f in features ])
81+ else :
82+ if k == 'ppl' :
83+ batch [k ] = [- f [k ] for f in features ]
84+ else :
85+ batch [k ] = torch .tensor ([f [k ] for f in features ])
86+ lens = batch ['ppl' ]
87+ print ('######lens(cl_data_collator)#########' )
88+ print (lens )
89+ total = total + 1
90+ try :
91+ assert lens == sorted (lens )
92+ except :
93+ unsorted = unsorted + 1
94+ print ('not sorted' )
95+
96+ # Compare between currect batch and previous one
97+ # Append min of current batch and placeholder
98+ batch_comparison_values .append ([lens [0 ], None ])
99+
100+ if max_value_of_previous_batch is not None :
101+ # Append max of the previous batch
102+ batch_comparison_values [- 1 ][1 ] = max_value_of_previous_batch
103+
104+ max_value_of_previous_batch = lens [- 1 ]
105+
106+ return {'input_ids' :batch ['input_ids' ],'labels' : batch ['labels' ]}
107+
108+
109+ def create_random_dataframe (sequence_column_name = 'sequence' ,
110+ curriculum_learning_column_name = 'ppl' ,
111+ curriculum_learning_strategy = 'ppl' ,
112+ max_sequence_length = 30 ,
113+ max_perplexity = 100.0 ,
114+ n = 5000 ):
115+ assert max_sequence_length > 2
116+ random .seed (42 )
117+ df = pd .DataFrame ()
118+ def create_sequence (length ):
119+ seq = '' .join (random .choice ('ACDEFGHIKLMNPQRSTVWY' ) for _ in range (length ))
120+ return seq
121+
122+ if curriculum_learning_strategy == 'ppl' :
123+ df [sequence_column_name ] = [create_sequence (random .randint (2 , max_sequence_length )) for i in range (n )]
124+ df [curriculum_learning_column_name ] = [random .uniform (1.0 , max_perplexity ) for _ in range (n )]
125+ return df
126+
127+ @pytest .mark .parametrize ("strategy" ,strategies )
128+ def test_curriculum_learning (strategy ):
129+
130+ with open (CONFIG_PATH , "r" ) as cf :
131+ print ('loading file.....' )
132+ config_dict = yaml .safe_load (cf )
133+
134+ config_dict ['dataset' ]['max_sequence_length' ] = 40
135+ config_dict ['dataset' ]['do_curriculum_learning' ] = True
136+ config_dict ['dataset' ]['curriculum_learning_column_name' ] = strategy2col [strategy ]
137+ config_dict ['dataset' ]['curriculum_learning_strategy' ] = strategy
138+ config_dict ['dataset' ]['val_size' ] = 100
139+ config_dict ['dataset' ]['test_size' ] = 100
140+ config_dict ['dataset' ]['subsample_size' ] = 500
141+ config_dict ["training_arguments" ]['group_by_length' ] = True
142+ config_dict ["training_arguments" ]['length_column_name' ] = config_dict ['dataset' ]['curriculum_learning_column_name' ]
143+ config_dict ["training_arguments" ]['remove_unused_columns' ] = False # this is necessary to keep curriculum_learning_column_name
144+ config_dict ["training_arguments" ]['per_device_train_batch_size' ] = 20
145+ config_dict ["training_arguments" ]['max_steps' ] = - 1
146+ config_dict ["training_arguments" ]['num_train_epochs' ] = 2
147+
148+ print (config_dict )
149+
150+ tokenizer = get_tokenizer (config_dict = config_dict ["tokenizer" ])
151+ dataset = DatasetDict ()
152+ val_df = create_random_dataframe (sequence_column_name = config_dict ['dataset' ]['sequence_column_name' ],curriculum_learning_column_name = config_dict ['dataset' ]['curriculum_learning_column_name' ],max_sequence_length = config_dict ['dataset' ]['max_sequence_length' ], n = config_dict ['dataset' ]['val_size' ] )
153+ test_df = create_random_dataframe (sequence_column_name = config_dict ['dataset' ]['sequence_column_name' ],curriculum_learning_column_name = config_dict ['dataset' ]["curriculum_learning_column_name" ],max_sequence_length = config_dict ['dataset' ]['max_sequence_length' ], n = config_dict ['dataset' ]['test_size' ] )
154+ train_df = create_random_dataframe (sequence_column_name = config_dict ['dataset' ]['sequence_column_name' ],curriculum_learning_column_name = config_dict ['dataset' ]["curriculum_learning_column_name" ],max_sequence_length = config_dict ['dataset' ]['max_sequence_length' ], n = config_dict ['dataset' ]['subsample_size' ] )
155+
156+ dataset ['train' ] = Dataset .from_pandas (train_df )
157+ dataset ['val' ] = Dataset .from_pandas (val_df )
158+ dataset ['test' ] = Dataset .from_pandas (test_df )
159+ dataset = dataset .map (
160+ lambda e : set_input_ids (
161+ result = e ,
162+ tokenizer = tokenizer ,
163+ sequence_column_name = config_dict ['dataset' ]['sequence_column_name' ],
164+ max_sequence_length = config_dict ['dataset' ]['max_sequence_length' ],
165+ ),
166+ batched = True ,
167+ )
168+ dataset = dataset .map (set_labels , batched = True )
169+ dataset = dataset .map (lambda e : batch_set_curriculum_learning_column (
170+ result = e ,
171+ input_column_name = config_dict ['dataset' ]['sequence_column_name' ],
172+ curriculum_learning_column_name = config_dict ['dataset' ]['curriculum_learning_column_name' ],
173+ strategy = config_dict ['dataset' ]['curriculum_learning_strategy' ]
174+
175+ ),batched = True )
176+ dataset = dataset .select_columns (['input_ids' , 'labels' , strategy2col [strategy ]])
177+ model = get_model (
178+ config_dict = config_dict ["model" ],
179+ )
180+
181+ training_args = get_training_args (
182+ config_dict = config_dict ["training_arguments" ],
183+ )
184+
185+ trainer = Trainer (
186+ model = model ,
187+ args = training_args ,
188+ train_dataset = dataset ["train" ],
189+ eval_dataset = dataset .get ("val" , None ),
190+ data_collator = cl_data_collator ,
191+ )
192+
193+ trainer .train ()
194+
195+ threshold = 10
196+ num = 0
197+ # Iterate over the list
198+ print (batch_comparison_values )
199+ for i in batch_comparison_values :
200+ print (i )
201+ current_min_val , previous_max_val = i
202+ if previous_max_val is not None :
203+ if current_min_val < previous_max_val and previous_max_val - current_min_val <= threshold :
204+ num += 1
205+ assert num == 0
206+ percentage_unsorted = int ((unsorted / total ) * 100 ) #computing the number of times the list in collator was not sorted
207+ #there are sometimes cases where the list is off by a few entries aa the LengthGroupedSampler has a bit of randomness
208+ print (f'percentage_unsorted:{ percentage_unsorted } ' )
209+ assert percentage_unsorted < 10 # just a rough heuristic
0 commit comments