diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index da09ec000..8d32525e1 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -56,7 +56,6 @@ ) from lighteval.utils.utils import ListLike, as_list, download_dataset_worker - if TYPE_CHECKING: from lighteval.logging.evaluation_tracker import EvaluationTracker @@ -97,7 +96,10 @@ class LightevalTaskConfig: # Additional hf dataset config hf_revision: Optional[str] = None hf_filter: Optional[Callable[[dict], bool]] = None - hf_avail_splits: Optional[ListLike[str]] = field(default_factory=lambda: ["train", "validation", "test"]) + hf_avail_splits: Optional[ListLike[str]] = field( + default_factory=lambda: ["train", "validation", "test"] + ) + hf_data_files: Optional[str] = None # We default to false, to reduce security issues trust_dataset: bool = False @@ -123,14 +125,21 @@ class LightevalTaskConfig: def __post_init__(self): # If we got a Metrics enums instead of a Metric, we convert - self.metric = [metric.value if isinstance(metric, Metrics) else metric for metric in self.metric] + self.metric = [ + metric.value if isinstance(metric, Metrics) else metric + for metric in self.metric + ] # Convert list to tuple for hashing self.metric = tuple(self.metric) - self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None + self.hf_avail_splits = ( + tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None + ) self.evaluation_splits = tuple(self.evaluation_splits) self.suite = tuple(self.suite) - self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else () + self.stop_sequence = ( + tuple(self.stop_sequence) if self.stop_sequence is not None else () + ) def print(self): md_writer = MarkdownTableWriter() @@ -143,7 +152,9 @@ def print(self): for ix, metrics in enumerate(v): for metric_k, metric_v in metrics.items(): if inspect.ismethod(metric_v): - values.append([f"{k} {ix}: {metric_k}", metric_v.__qualname__]) + values.append( + [f"{k} {ix}: {metric_k}", metric_v.__qualname__] + ) else: values.append([f"{k} {ix}: {metric_k}", repr(metric_v)]) @@ -182,6 +193,7 @@ def __init__( # noqa: C901 self.dataset_config_name = cfg.hf_subset self.dataset_revision = cfg.hf_revision self.dataset_filter = cfg.hf_filter + self.dataset_files = cfg.hf_data_files self.trust_dataset = cfg.trust_dataset self.dataset: Optional[DatasetDict] = None # Delayed download logger.info(f"{self.dataset_path} {self.dataset_config_name}") @@ -194,19 +206,29 @@ def __init__( # noqa: C901 if cfg.few_shots_split is not None: self.fewshot_split = as_list(cfg.few_shots_split) else: - self.fewshot_split = self.get_first_possible_fewshot_splits(cfg.hf_avail_splits or []) + self.fewshot_split = self.get_first_possible_fewshot_splits( + cfg.hf_avail_splits or [] + ) self.fewshot_selection = cfg.few_shots_select # Metrics self.metrics = as_list(cfg.metric) self.suite = as_list(cfg.suite) - ignored = [metric for metric in self.metrics if metric.category == MetricCategory.IGNORED] + ignored = [ + metric + for metric in self.metrics + if metric.category == MetricCategory.IGNORED + ] if len(ignored) > 0: - logger.warning(f"Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.") + logger.warning( + f"Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}." + ) current_categories = [metric.category for metric in self.metrics] - self.has_metric_category = {category: (category in current_categories) for category in MetricCategory} + self.has_metric_category = { + category: (category in current_categories) for category in MetricCategory + } # We assume num_samples always contains 1 (for base generative evals) self.num_samples = [1] @@ -244,20 +266,26 @@ def get_first_possible_fewshot_splits( list[str]: List of the first available fewshot splits. """ # Possible few shot splits are the available splits not used for evaluation - possible_fewshot_splits = [k for k in available_splits if k not in self.evaluation_split] + possible_fewshot_splits = [ + k for k in available_splits if k not in self.evaluation_split + ] stored_splits = [] # We look at these keys in order (first the training sets, then the validation sets) allowed_splits = ["train", "dev", "valid", "default"] for allowed_split in allowed_splits: # We do a partial match of the allowed splits - available_splits = [k for k in possible_fewshot_splits if allowed_split in k] + available_splits = [ + k for k in possible_fewshot_splits if allowed_split in k + ] stored_splits.extend(available_splits) if len(stored_splits) > 0: return stored_splits[:number_of_splits] - logger.warning(f"Careful, the task {self.name} is using evaluation data to build the few shot examples.") + logger.warning( + f"Careful, the task {self.name} is using evaluation data to build the few shot examples." + ) return None def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]: @@ -279,6 +307,7 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]: self.trust_dataset, self.dataset_filter, self.dataset_revision, + self.dataset_files, ) splits = as_list(splits) @@ -319,9 +348,13 @@ def fewshot_docs(self) -> list[Doc]: # If we have no available few shot split, the few shot data is the eval data! if self.fewshot_split is None: - self._fewshot_docs = self._get_docs_from_split(self.evaluation_split, few_shots=True) + self._fewshot_docs = self._get_docs_from_split( + self.evaluation_split, few_shots=True + ) else: # Normal case - self._fewshot_docs = self._get_docs_from_split(self.fewshot_split, few_shots=True) + self._fewshot_docs = self._get_docs_from_split( + self.fewshot_split, few_shots=True + ) return self._fewshot_docs def eval_docs(self) -> list[Doc]: @@ -338,7 +371,11 @@ def eval_docs(self) -> list[Doc]: return self._docs def construct_requests( - self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str + self, + formatted_doc: Doc, + context: str, + document_id_seed: str, + current_task_name: str, ) -> Dict[RequestType, List[Request]]: """ Constructs a list of requests from the task based on the given parameters. @@ -435,7 +472,10 @@ def construct_requests( choice=choice, metric_categories=[ c - for c in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI] + for c in [ + MetricCategory.MULTICHOICE, + MetricCategory.MULTICHOICE_PMI, + ] if self.has_metric_category[c] ], ) @@ -499,7 +539,9 @@ def construct_requests( def get_metric_method_from_category(self, metric_category): if not self.has_metric_category[metric_category]: - raise ValueError(f"Requested a metric category {metric_category} absent from the task list.") + raise ValueError( + f"Requested a metric category {metric_category} absent from the task list." + ) return LightevalTask._get_metric_method_from_category(metric_category) @@ -507,7 +549,10 @@ def get_metric_method_from_category(self, metric_category): def _get_metric_method_from_category(metric_category): if metric_category == MetricCategory.TARGET_PERPLEXITY: return apply_target_perplexity_metric - if metric_category in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI]: + if metric_category in [ + MetricCategory.MULTICHOICE, + MetricCategory.MULTICHOICE_PMI, + ]: return apply_multichoice_metric if metric_category == MetricCategory.MULTICHOICE_ONE_TOKEN: return apply_multichoice_metric_one_token @@ -519,7 +564,10 @@ def _get_metric_method_from_category(metric_category): MetricCategory.GENERATIVE_LOGPROB, ]: return apply_generative_metric - if metric_category in [MetricCategory.LLM_AS_JUDGE_MULTI_TURN, MetricCategory.LLM_AS_JUDGE]: + if metric_category in [ + MetricCategory.LLM_AS_JUDGE_MULTI_TURN, + MetricCategory.LLM_AS_JUDGE, + ]: return apply_llm_as_judge_metric def aggregation(self): @@ -530,7 +578,9 @@ def aggregation(self): return Metrics.corpus_level_fns(self.metrics) @staticmethod - def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int = 1) -> None: + def load_datasets( + tasks: list["LightevalTask"], dataset_loading_processes: int = 1 + ) -> None: """ Load datasets from the HuggingFace Hub for the given tasks. @@ -550,6 +600,7 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int = task.trust_dataset, task.dataset_filter, task.dataset_revision, + task.dataset_files, ) for task in tasks ] @@ -564,6 +615,7 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int = task.trust_dataset, task.dataset_filter, task.dataset_revision, + task.dataset_files, ) for task in tasks ], @@ -615,13 +667,17 @@ def create_requests_from_tasks( # noqa: C901 requests: dict[RequestType, list[Request]] = collections.defaultdict(list) # Filter out tasks that don't have any docs - task_dict_items = [(name, task) for name, task in task_dict.items() if len(task.eval_docs()) > 0] + task_dict_items = [ + (name, task) for name, task in task_dict.items() if len(task.eval_docs()) > 0 + ] # Get lists of each type of request for task_name, task in task_dict_items: task_docs = list(task.eval_docs()) n_samples = min(max_samples, len(task_docs)) if max_samples else len(task_docs) - evaluation_tracker.task_config_logger.log_num_docs(task_name, len(task_docs), n_samples) + evaluation_tracker.task_config_logger.log_num_docs( + task_name, len(task_docs), n_samples + ) # logs out the different versions of the tasks for every few shot for num_fewshot, _ in fewshot_dict[task_name]: @@ -655,7 +711,9 @@ def create_requests_from_tasks( # noqa: C901 # Constructing the requests cur_task_name = f"{task_name}|{num_fewshot}" docs[SampleUid(cur_task_name, doc_id_seed)] = doc - req_type_reqs_dict = task.construct_requests(doc, doc.ctx, doc_id_seed, cur_task_name) + req_type_reqs_dict = task.construct_requests( + doc, doc.ctx, doc_id_seed, cur_task_name + ) for req_type, reqs in req_type_reqs_dict.items(): requests[req_type].extend(reqs) diff --git a/src/lighteval/utils/utils.py b/src/lighteval/utils/utils.py index c4f2956ea..7ca70c3f9 100644 --- a/src/lighteval/utils/utils.py +++ b/src/lighteval/utils/utils.py @@ -23,7 +23,9 @@ def flatten_dict(nested: dict, sep="/") -> dict: """Flatten dictionary, list, tuple and concatenate nested keys with separator.""" def clean_markdown(v: str) -> str: - return v.replace("|", "_").replace("\n", "_") if isinstance(v, str) else v # Need this for markdown + return ( + v.replace("|", "_").replace("\n", "_") if isinstance(v, str) else v + ) # Need this for markdown def rec(nest: dict, prefix: str, into: dict): for k, v in sorted(nest.items()): @@ -37,9 +39,13 @@ def rec(nest: dict, prefix: str, into: dict): rec(vv, prefix + k + sep + str(i) + sep, into) else: vv = ( - vv.replace("|", "_").replace("\n", "_") if isinstance(vv, str) else vv + vv.replace("|", "_").replace("\n", "_") + if isinstance(vv, str) + else vv ) # Need this for markdown - into[prefix + k + sep + str(i)] = vv.tolist() if isinstance(vv, np.ndarray) else vv + into[prefix + k + sep + str(i)] = ( + vv.tolist() if isinstance(vv, np.ndarray) else vv + ) elif isinstance(v, np.ndarray): into[prefix + k + sep + str(i)] = v.tolist() else: @@ -63,7 +69,9 @@ def clean_s3_links(value: str) -> str: s3_bucket, s3_prefix = str(value).replace("s3://", "").split("/", maxsplit=1) if not s3_prefix.endswith("/"): s3_prefix += "/" - link_str = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?prefix={s3_prefix}" + link_str = ( + f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?prefix={s3_prefix}" + ) value = f' {value} ' return value @@ -151,7 +159,11 @@ def flatten(item: list[Union[list, str]]) -> list[str]: """ flat_item = [] for sub_item in item: - flat_item.extend(sub_item) if isinstance(sub_item, list) else flat_item.append(sub_item) + ( + flat_item.extend(sub_item) + if isinstance(sub_item, list) + else flat_item.append(sub_item) + ) return flat_item @@ -205,6 +217,7 @@ def download_dataset_worker( trust_dataset: bool, dataset_filter: Callable[[dict], bool] | None = None, revision: str | None = None, + data_files: str | None = None, ) -> DatasetDict: """ Worker function to download a dataset from the HuggingFace Hub. @@ -218,6 +231,7 @@ def download_dataset_worker( download_mode=None, trust_remote_code=trust_dataset, revision=revision, + data_files=data_files, ) if dataset_filter is not None: @@ -227,5 +241,7 @@ def download_dataset_worker( return dataset # type: ignore -def safe_divide(numerator: np.ndarray, denominator: float, default_value: float = 0.0) -> np.ndarray: +def safe_divide( + numerator: np.ndarray, denominator: float, default_value: float = 0.0 +) -> np.ndarray: return np.where(denominator != 0, numerator / denominator, default_value)