From 56d016ba38825c789515615d189bc4669bb9c913 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sat, 22 Mar 2025 10:59:37 +0000 Subject: [PATCH 1/4] . --- src/lighteval/config/lighteval_config.py | 8 +++ src/lighteval/main_nanotron.py | 51 ++++++++++++------- src/lighteval/models/__init__.py | 21 ++++++++ src/lighteval/models/nanotron/__init__.py | 21 ++++++++ .../models/nanotron/nanotron_model.py | 20 +++++--- src/lighteval/models/nanotron_model.py | 26 ++++++++++ src/lighteval/pipeline.py | 7 ++- 7 files changed, 128 insertions(+), 26 deletions(-) create mode 100644 src/lighteval/models/__init__.py create mode 100644 src/lighteval/models/nanotron/__init__.py create mode 100644 src/lighteval/models/nanotron_model.py diff --git a/src/lighteval/config/lighteval_config.py b/src/lighteval/config/lighteval_config.py index f24a15184..0e8217afe 100644 --- a/src/lighteval/config/lighteval_config.py +++ b/src/lighteval/config/lighteval_config.py @@ -101,3 +101,11 @@ class LightEvalConfig: class FullNanotronConfig: lighteval_config: LightEvalConfig nanotron_config: "Config" + + @property + def generation_parameters(self): + # Return the generation parameters from the lighteval config + # or create default generation parameters if none are set + if self.lighteval_config.generation: + return self.lighteval_config.generation + return GenerationArgs() diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 94004c065..1b973a112 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -42,17 +42,17 @@ def nanotron( checkpoint_config_path: Annotated[ str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.") ], - lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")], + lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")] = None, cache_dir: Annotated[str, Option(help="Cache directory for datasets and models.")] = CACHE_DIR, ): """ Evaluate models using nanotron as backend. """ from nanotron.config import Config, get_config_from_file + from nanotron.config.parallelism_config import ParallelismArgs - from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig + from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig, LightEvalLoggingArgs, LightEvalTasksArgs from lighteval.logging.evaluation_tracker import EvaluationTracker - from lighteval.logging.hierarchical_logger import htrack_block from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available from lighteval.utils.utils import EnvConfig @@ -61,23 +61,38 @@ def nanotron( if not is_nanotron_available(): raise ImportError(NO_NANOTRON_ERROR_MSG) + + # Create nanotron config + if not checkpoint_config_path.endswith(".yaml"): + raise ValueError("The checkpoint path should point to a YAML file") + + model_config = get_config_from_file( + checkpoint_config_path, + config_class=Config, + model_config_class=None, + skip_unused_config_keys=True, + skip_null_keys=True, + ) - with htrack_block("Load nanotron config"): - # Create nanotron config - if not checkpoint_config_path.endswith(".yaml"): - raise ValueError("The checkpoint path should point to a YAML file") - - model_config = get_config_from_file( - checkpoint_config_path, - config_class=Config, - model_config_class=None, - skip_unused_config_keys=True, - skip_null_keys=True, - ) - - # We are getting an type error, because the get_config_from_file is not correctly typed, + # Create or use default lighteval config + if lighteval_config_path is not None: lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore - nanotron_config = FullNanotronConfig(lighteval_config, model_config) + else: + # Create default config with minimal required parameters + default_logging = LightEvalLoggingArgs( + output_dir="./eval_results" + ) + default_tasks = LightEvalTasksArgs( + tasks="lighteval|agieval:aqua-rat|5|0" + ) + default_parallelism = ParallelismArgs(dp=1, pp=1, tp=1) + lighteval_config = LightEvalConfig( + logging=default_logging, + tasks=default_tasks, + parallelism=default_parallelism + ) + + nanotron_config = FullNanotronConfig(lighteval_config, model_config) evaluation_tracker = EvaluationTracker( output_dir=lighteval_config.logging.output_dir, diff --git a/src/lighteval/models/__init__.py b/src/lighteval/models/__init__.py new file mode 100644 index 000000000..064e2842d --- /dev/null +++ b/src/lighteval/models/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. \ No newline at end of file diff --git a/src/lighteval/models/nanotron/__init__.py b/src/lighteval/models/nanotron/__init__.py new file mode 100644 index 000000000..064e2842d --- /dev/null +++ b/src/lighteval/models/nanotron/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. \ No newline at end of file diff --git a/src/lighteval/models/nanotron/nanotron_model.py b/src/lighteval/models/nanotron/nanotron_model.py index 5f139174c..17b027438 100644 --- a/src/lighteval/models/nanotron/nanotron_model.py +++ b/src/lighteval/models/nanotron/nanotron_model.py @@ -343,7 +343,14 @@ def tok_decode(self, tokens: torch.LongTensor) -> List[str]: return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) def _model_call(self, inputs: torch.Tensor) -> torch.Tensor: - return self.model(inputs) + position_ids = ( + torch.arange( + inputs.shape[1], device=inputs.device, dtype=torch.int32 + ) + .unsqueeze(0) + .repeat(inputs.shape[0], 1) + ) + return self.model(inputs, position_ids) def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]: """Ending conditions are submitted in several possible formats. @@ -711,14 +718,14 @@ def _loglikelihood_single_token( inputs, padding_length=max_context, max_context=max_context, full_attention_masks=True ) # batched_inputs, batch_attention, input_lengths, truncated, padded - - out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask) + position_ids = torch.arange(batch_model.input_ids.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0).repeat(batch_model.input_ids.shape[0], 1) + out = self.model(input_ids=batch_model.input_ids, position_ids=position_ids) if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank: # This process got outputs # Gather all the output accross TP - out = out.transpose(0, 1).contiguous() # [batch, seq_length, vocab] + out = out.view(*batch_model.input_ids.shape, -1).contiguous() # [batch, seq_length, vocab] gathered_out = [torch.zeros_like(out) for _ in range(self.parallel_context.tp_pg.size())] dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False) @@ -944,7 +951,8 @@ def _loglikelihood_tokens( ) # batched_inputs, batch_attention, input_lengths, truncated, padded with torch.no_grad(): - out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask) + position_ids = torch.arange(batch_model.input_ids.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0).repeat(batch_model.input_ids.shape[0], 1) + out = self.model(input_ids=batch_model.input_ids, position_ids=position_ids) if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank: # This process got outputs @@ -954,7 +962,7 @@ def _loglikelihood_tokens( dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False) out = torch.cat(gathered_out, dim=-1) - out = out.transpose(0, 1) # [batch, seq_length, vocab] + out = out.view(*batch_model.input_ids.shape, -1) # [batch, seq_length, vocab] multi_logits = F.log_softmax(out, dim=-1) # [batch, padding_length, vocab] logits_sum = [] diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py new file mode 100644 index 000000000..4a1ed72c6 --- /dev/null +++ b/src/lighteval/models/nanotron_model.py @@ -0,0 +1,26 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Import and re-export the NanotronLightevalModel class from the nanotron module +from lighteval.models.nanotron.nanotron_model import NanotronLightevalModel + +__all__ = ["NanotronLightevalModel"] \ No newline at end of file diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index b83403755..fbb97da7f 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -72,7 +72,7 @@ from nanotron.parallel.context import ParallelContext from nanotron.utils import local_ranks_zero_first - from lighteval.models.nanotron_model import NanotronLightevalModel + # from lighteval.models.nanotron import NanotronLightevalModel import logging @@ -188,16 +188,19 @@ def _init_model(self, model_config, model): logger.info("--- LOADING MODEL ---") if model_config is not None: if self.parallel_context: + from lighteval.models.nanotron_model import NanotronLightevalModel + return NanotronLightevalModel( checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path) if self.pipeline_parameters.nanotron_checkpoint_path else "", - nanotron_config=self.model_config, + nanotron_config=model_config, parallel_context=self.parallel_context, debug_one_layer_model=False, model_class=None, env_config=self.pipeline_parameters.env_config, ) + # return None else: return load_model(config=model_config, env_config=self.pipeline_parameters.env_config) if isinstance(model, TransformersModel): From 0bb5d615276143becb6838ead47fd63d593cff6f Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sat, 22 Mar 2025 11:31:06 +0000 Subject: [PATCH 2/4] . --- src/lighteval/main_nanotron.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 1b973a112..22755997a 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -42,7 +42,7 @@ def nanotron( checkpoint_config_path: Annotated[ str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.") ], - lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")] = None, + lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")], cache_dir: Annotated[str, Option(help="Cache directory for datasets and models.")] = CACHE_DIR, ): """ @@ -74,23 +74,8 @@ def nanotron( skip_null_keys=True, ) - # Create or use default lighteval config - if lighteval_config_path is not None: - lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore - else: - # Create default config with minimal required parameters - default_logging = LightEvalLoggingArgs( - output_dir="./eval_results" - ) - default_tasks = LightEvalTasksArgs( - tasks="lighteval|agieval:aqua-rat|5|0" - ) - default_parallelism = ParallelismArgs(dp=1, pp=1, tp=1) - lighteval_config = LightEvalConfig( - logging=default_logging, - tasks=default_tasks, - parallelism=default_parallelism - ) + # Load lighteval config + lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore nanotron_config = FullNanotronConfig(lighteval_config, model_config) From 404c4890697e9c488cf9e057fe3015ad927e6917 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sat, 22 Mar 2025 11:37:11 +0000 Subject: [PATCH 3/4] . --- src/lighteval/models/__init__.py | 21 ------------------ src/lighteval/models/nanotron/__init__.py | 21 ------------------ src/lighteval/models/nanotron_model.py | 26 ----------------------- src/lighteval/pipeline.py | 7 ++---- 4 files changed, 2 insertions(+), 73 deletions(-) delete mode 100644 src/lighteval/models/__init__.py delete mode 100644 src/lighteval/models/nanotron/__init__.py delete mode 100644 src/lighteval/models/nanotron_model.py diff --git a/src/lighteval/models/__init__.py b/src/lighteval/models/__init__.py deleted file mode 100644 index 064e2842d..000000000 --- a/src/lighteval/models/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# MIT License - -# Copyright (c) 2024 The HuggingFace Team - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. \ No newline at end of file diff --git a/src/lighteval/models/nanotron/__init__.py b/src/lighteval/models/nanotron/__init__.py deleted file mode 100644 index 064e2842d..000000000 --- a/src/lighteval/models/nanotron/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# MIT License - -# Copyright (c) 2024 The HuggingFace Team - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. \ No newline at end of file diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py deleted file mode 100644 index 4a1ed72c6..000000000 --- a/src/lighteval/models/nanotron_model.py +++ /dev/null @@ -1,26 +0,0 @@ -# MIT License - -# Copyright (c) 2024 The HuggingFace Team - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -# Import and re-export the NanotronLightevalModel class from the nanotron module -from lighteval.models.nanotron.nanotron_model import NanotronLightevalModel - -__all__ = ["NanotronLightevalModel"] \ No newline at end of file diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index fbb97da7f..dec9c3cd0 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -72,7 +72,7 @@ from nanotron.parallel.context import ParallelContext from nanotron.utils import local_ranks_zero_first - # from lighteval.models.nanotron import NanotronLightevalModel + from lighteval.models.nanotron.nanotron_model import NanotronLightevalModel import logging @@ -187,9 +187,7 @@ def _init_parallelism_manager(self): def _init_model(self, model_config, model): logger.info("--- LOADING MODEL ---") if model_config is not None: - if self.parallel_context: - from lighteval.models.nanotron_model import NanotronLightevalModel - + if self.parallel_context: return NanotronLightevalModel( checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path) if self.pipeline_parameters.nanotron_checkpoint_path @@ -200,7 +198,6 @@ def _init_model(self, model_config, model): model_class=None, env_config=self.pipeline_parameters.env_config, ) - # return None else: return load_model(config=model_config, env_config=self.pipeline_parameters.env_config) if isinstance(model, TransformersModel): From 03c1195780e59003c00c0f1b934179bd0fd583bc Mon Sep 17 00:00:00 2001 From: Jason Stillerman Date: Wed, 26 Mar 2025 00:53:13 +0000 Subject: [PATCH 4/4] allow extra keywords in LightevalTaskConfig --- src/lighteval/tasks/lighteval_task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 3420480b2..e7761a206 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -107,6 +107,7 @@ class LightevalTaskConfig: few_shots_select: Optional[str] = None # Generation args + output_regex: Optional[str] = None generation_size: Optional[int] = None generation_grammar: Optional[TextGenerationInputGrammarType] = None stop_sequence: Optional[ListLike[str]] = None @@ -120,6 +121,7 @@ class LightevalTaskConfig: must_remove_duplicate_docs: bool = False version: int = 0 + frozen: bool = False def __post_init__(self): # If we got a Metrics enums instead of a Metric, we convert