Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Unreleased

### Added
- New parameter `pruning_params` to `edsnlp.tune` in order to control pruning during tuning.

## v0.19.0 (2025-10-04)

📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.20.0), in October 2025. Please upgrade to Python 3.10 or later.
Expand Down
35 changes: 32 additions & 3 deletions edsnlp/metrics/span_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

import warnings
from collections import defaultdict
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union, Sequence

from edsnlp import registry
from edsnlp.metrics import Examples, average_precision, make_examples, prf
Expand All @@ -57,6 +57,7 @@ def span_attribute_metric(
default_values: Dict = {},
micro_key: str = "micro",
filter_expr: Optional[str] = None,
split_by_values: Union[str, Sequence[str]] = None,
**kwargs: Any,
):
if "qualifiers" in kwargs:
Expand All @@ -80,6 +81,8 @@ def span_attribute_metric(
if filter_expr is not None:
filter_fn = eval(f"lambda doc: {filter_expr}")
examples = [eg for eg in examples if filter_fn(eg.reference)]
if isinstance(split_by_values, str):
split_by_values = [split_by_values]
labels = defaultdict(lambda: (set(), set(), dict()))
labels["micro"] = (set(), set(), dict())
total_pred_count = 0
Expand Down Expand Up @@ -108,9 +111,15 @@ def span_attribute_metric(
if (top_val or include_falsy) and default_values[attr] != top_val:
labels[attr][2][(eg_idx, beg, end, attr, top_val)] = top_p
labels[micro_key][2][(eg_idx, beg, end, attr, top_val)] = top_p
if split_by_values and attr in split_by_values:
key = f"{attr}:{top_val}"
labels[key][2][(eg_idx, beg, end, attr, top_val)] = top_p
if (value or include_falsy) and default_values[attr] != value:
labels[micro_key][0].add((eg_idx, beg, end, attr, value))
labels[attr][0].add((eg_idx, beg, end, attr, value))
if split_by_values and attr in split_by_values:
key = f"{attr}:{value}"
labels[key][0].add((eg_idx, beg, end, attr, value))

doc_spans = get_spans(eg.reference, span_getter)
for span in doc_spans:
Expand All @@ -124,6 +133,9 @@ def span_attribute_metric(
if (value or include_falsy) and default_values[attr] != value:
labels[micro_key][1].add((eg_idx, beg, end, attr, value))
labels[attr][1].add((eg_idx, beg, end, attr, value))
if split_by_values and attr in split_by_values:
key = f"{attr}:{value}"
labels[key][1].add((eg_idx, beg, end, attr, value))

if total_pred_count != total_gold_count:
raise ValueError(
Expand All @@ -133,14 +145,25 @@ def span_attribute_metric(
"predicted by another NER pipe in your model."
)

return {
metrics = {
name: {
**prf(pred, gold),
"ap": average_precision(pred_with_prob, gold),
}
for name, (pred, gold, pred_with_prob) in labels.items()
}

if split_by_values:
for attr in split_by_values:
submetrics = {"micro": metrics[attr]}
for key in list(metrics.keys()):
if key.startswith(f"{attr}:"):
val = key.split(":", 1)[1]
submetrics[val] = metrics.pop(key)
metrics[attr] = submetrics

return metrics


@registry.metrics.register(
"eds.span_attribute",
Expand Down Expand Up @@ -230,7 +253,10 @@ class SpanAttributeMetric:
Key under which to store the micro‐averaged results across all attributes.
filter_expr : Optional[str]
A Python expression (using `doc`) to filter which examples are scored.

split_by_values : Union[str, Sequence[str]] = None
One or more attributes for which metrics should reported separately for each
attribute value. If `None` (default), metrics are computed on the global attribute-level.
Useful when attributes are multiclass.
Returns
-------
Dict[str, Dict[str, float]]
Expand Down Expand Up @@ -258,6 +284,7 @@ def __init__(
include_falsy: bool = False,
micro_key: str = "micro",
filter_expr: Optional[str] = None,
split_by_values: Union[str, Sequence[str]] = None,
):
if qualifiers is not None:
warnings.warn(
Expand All @@ -270,6 +297,7 @@ def __init__(
self.include_falsy = include_falsy
self.micro_key = micro_key
self.filter_expr = filter_expr
self.split_by_values = split_by_values

__init__.__doc__ = span_attribute_metric.__doc__

Expand All @@ -296,6 +324,7 @@ def __call__(self, *examples: Any):
include_falsy=self.include_falsy,
micro_key=self.micro_key,
filter_expr=self.filter_expr,
split_by_values=self.split_by_values,
)


Expand Down
41 changes: 35 additions & 6 deletions edsnlp/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def update_config(
return config


def objective_with_param(config, tuned_parameters, trial, metric):
def objective_with_param(config, tuned_parameters, trial, metric, pruning_params):
kwargs, _ = update_config(config, tuned_parameters, trial=trial)
seed = random.randint(0, 2**32 - 1)
set_seed(seed)
Expand All @@ -282,8 +282,9 @@ def on_validation_callback(all_metrics):
for key in metric:
score = score[key]
trial.report(score, step)
if trial.should_prune():
raise optuna.TrialPruned()
if pruning_params:
if trial.should_prune():
raise optuna.TrialPruned()

try:
nlp = train(**kwargs, on_validation_callback=on_validation_callback)
Expand All @@ -299,15 +300,30 @@ def on_validation_callback(all_metrics):


def optimize(
config_path, tuned_parameters, n_trials, metric, checkpoint_dir, study=None
config_path,
tuned_parameters,
n_trials,
metric,
checkpoint_dir,
pruning_params,
study=None,
):
def objective(trial):
return objective_with_param(config_path, tuned_parameters, trial, metric)
return objective_with_param(
config_path, tuned_parameters, trial, metric, pruning_params
)

if not study:
pruner = None
if pruning_params:
n_startup_trials = pruning_params.get("n_startup_trials", 5)
n_warmup_steps = pruning_params.get("n_warmup_steps", 5)
pruner = MedianPruner(
n_startup_trials=n_startup_trials, n_warmup_steps=n_warmup_steps
)
study = optuna.create_study(
direction="maximize",
pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=2),
pruner=pruner,
sampler=TPESampler(seed=random.randint(0, 2**32 - 1)),
)
study.optimize(
Expand Down Expand Up @@ -444,6 +460,7 @@ def tune_two_phase(
is_fixed_n_trials: bool = False,
gpu_hours: float = 1.0,
skip_phase_1: bool = False,
pruning_params: Dict[str, int] = None,
) -> None:
"""
Perform two-phase hyperparameter tuning using Optuna.
Expand Down Expand Up @@ -505,6 +522,7 @@ def tune_two_phase(
n_trials_1,
metric,
checkpoint_dir,
pruning_params,
study,
)
best_params_phase_1, importances = process_results(
Expand Down Expand Up @@ -551,6 +569,7 @@ def tune_two_phase(
n_trials_2,
metric,
checkpoint_dir,
pruning_params,
study,
)

Expand Down Expand Up @@ -612,6 +631,7 @@ def tune(
seed: int = 42,
metric="ner.micro.f",
keep_checkpoint: bool = False,
pruning_params: Optional[Dict[str, int]] = None,
):
"""
Perform hyperparameter tuning for a model using Optuna.
Expand Down Expand Up @@ -652,6 +672,11 @@ def tune(
Metric used to evaluate trials. Default is "ner.micro.f".
keep_checkpoint : bool, optional
If True, keeps the checkpoint file after tuning. Default is False.
pruning_params : dict, optional
A dictionary specifying pruning parameters:
- "n_startup_trials": Number of startup trials before pruning starts.
- "n_warmup_steps": Number of warmup steps before pruning starts.
Default is None, meaning no pruning.
"""
setup_logging()
viz = is_plotly_install()
Expand Down Expand Up @@ -679,6 +704,7 @@ def tune(
n_trials=1,
metric=metric,
checkpoint_dir=checkpoint_dir,
pruning_params=pruning_params,
)
n_trials = compute_n_trials(gpu_hours, compute_time_per_trial(study)) - 1
else:
Expand Down Expand Up @@ -708,6 +734,7 @@ def tune(
is_fixed_n_trials=is_fixed_n_trials,
gpu_hours=gpu_hours,
skip_phase_1=skip_phase_1,
pruning_params=pruning_params,
)
else:
logger.info("Starting single-phase tuning.")
Expand All @@ -717,6 +744,7 @@ def tune(
n_trials,
metric,
checkpoint_dir,
pruning_params,
study,
)
if not is_fixed_n_trials:
Expand All @@ -732,6 +760,7 @@ def tune(
n_trials,
metric,
checkpoint_dir,
pruning_params,
study,
)
process_results(study, output_dir, viz, config, config_path, hyperparameters)
Expand Down
Loading