Skip to content

Commit 9e5dc94

Browse files
committed
docs: fix docs tutorials and warn in training func for mismatched pipe names
1 parent 93552a6 commit 9e5dc94

File tree

3 files changed

+44
-28
lines changed

3 files changed

+44
-28
lines changed

docs/tutorials/training-ner.md

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
233233
import edsnlp
234234
from edsnlp.training import train, ScheduledOptimizer, TrainingData
235235
from edsnlp.metrics.ner import NerExactMetric
236-
from edsnlp.training.loggers import CSVLogger, RichLogger, WandbLogger
236+
from edsnlp.training.loggers import CSVLogger, RichLogger, WandBLogger
237237
import edsnlp.pipes as eds
238238
import torch
239239

@@ -242,6 +242,7 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
242242
nlp.add_pipe(
243243
# The NER pipe will be a CRF model
244244
eds.ner_crf(
245+
name="ner",
245246
mode="joint",
246247
target_span_getter="gold_spans",
247248
# Set spans as both to ents and in separate `ent.label` groups
@@ -280,19 +281,21 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
280281
optim=torch.optim.Adam,
281282
module=nlp,
282283
total_steps=max_steps,
283-
groups={
284-
"^transformer": {
285-
"lr": {"@schedules": "linear", "warmup_rate": 0.1, "start_value": 0 "max_value": 5e-5,},
284+
groups=[
285+
{
286+
"selector": "transformer",
287+
"lr": {"@schedules": "linear", "warmup_rate": 0.1, "start_value": 0, "max_value": 5e-5,},
286288
},
287-
"": {
288-
"lr": {"@schedules": "linear", "warmup_rate": 0.1, "start_value": 3e-4 "max_value": 3e-4,},
289+
{
290+
"selector": ".*",
291+
"lr": {"@schedules": "linear", "warmup_rate": 0.1, "start_value": 3e-4, "max_value": 3e-4,},
289292
},
290-
},
293+
],
291294
)
292295

293296
#
294297
loggers = [
295-
CSVLogger(),
298+
CSVLogger.draft(), # draft as we will let the train function specify the logging_dir
296299
RichLogger(
297300
fields={
298301
"step": {},

docs/tutorials/training-span-classifier.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,24 +265,26 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
265265
# 🎛️ OPTIMIZER (here it will be the same as thedefault one)
266266
optimizer = ScheduledOptimizer.draft( # (2)!
267267
optim=torch.optim.AdamW,
268-
groups={
269-
"biopsy_classifier[.]embedding": {
268+
groups=[
269+
{
270+
"selector": "biopsy_classifier[.]embedding",
270271
"lr": {
271272
"@schedules": "linear",
272273
"warmup_rate": 0.1,
273274
"start_value": 0.,
274275
"max_value": 5e-5,
275276
},
276277
},
277-
".*": {
278+
{
279+
"selector": ".*",
278280
"lr": {
279281
"@schedules": "linear",
280282
"warmup_rate": 0.1,
281283
"start_value": 3e-4,
282284
"max_value": 3e-4,
283285
},
284286
},
285-
}
287+
]
286288
)
287289

288290
# 🚀 TRAIN

edsnlp/training/trainer.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,14 @@ def train(
676676
total_steps=max_steps,
677677
)
678678

679+
for td in train_data:
680+
if not (td.pipe_names is None or td.pipe_names <= trainable_pipe_names):
681+
raise ValueError(
682+
f"Training data pipe names {td.pipe_names} should be a subset of "
683+
f"the trainable pipe names {trainable_pipe_names}, or left to None "
684+
f"use this dataset for all trainable components."
685+
)
686+
679687
for phase_i, pipe_names in enumerate(phases):
680688
trained_pipes_local: Dict[str, TorchComponent] = {
681689
n: nlp.get_pipe(n) for n in pipe_names
@@ -688,6 +696,14 @@ def train(
688696
if td.pipe_names is None or set(td.pipe_names) & set(pipe_names)
689697
]
690698

699+
if len(phase_training_data) == 0:
700+
raise ValueError(
701+
f"No training data found for phase {phase_i + 1} with components "
702+
f"{', '.join(pipe_names)}. Make sure that these components are "
703+
f"listed in the 'pipe_names' attribute of at least one of the "
704+
f"provided training data."
705+
)
706+
691707
with nlp.select_pipes(disable=trainable_pipe_names - set(pipe_names)):
692708
accelerator.print(f"Phase {phase_i + 1}: training {', '.join(pipe_names)}")
693709
set_seed(seed)
@@ -700,37 +716,32 @@ def train(
700716
grad_params.add(param)
701717
param.requires_grad_(has_grad_param)
702718

703-
accelerator.print(
704-
"Optimizing groups:"
705-
+ "".join(
706-
"\n - {} weight tensors ({:,} parameters){}".format(
719+
accelerator.print("Optimizing groups:")
720+
for g in optim.param_groups:
721+
accelerator.print(
722+
" - {} weight tensors ({:,} parameters){}".format(
707723
len([p for p in g["params"] if p in grad_params]),
708724
sum([p.numel() for p in g["params"] if p in grad_params]),
709725
": " + " & ".join(g.get("selectors", "*"))
710726
if "selectors" in g
711727
else "",
712728
)
713-
for g in optim.param_groups
714729
)
715-
)
716730
accelerator.print(
717731
f"Keeping frozen {len(all_params - grad_params):} weight tensors "
718732
f"({sum(p.numel() for p in all_params - grad_params):,} parameters)"
719733
)
720734

721735
nlp.train(True)
722736

723-
iterator = iter(
724-
zip(
725-
*(
726-
td(nlp, device).set_processing(
727-
num_cpu_workers=num_workers,
728-
process_start_method="spawn",
729-
)
730-
for td in phase_training_data
731-
)
737+
phase_datasets = [
738+
td(nlp, device).set_processing(
739+
num_cpu_workers=num_workers,
740+
process_start_method="spawn",
732741
)
733-
)
742+
for td in phase_training_data
743+
]
744+
iterator = iter(zip(*(phase_datasets)))
734745
(accel_optim, trained_pipes) = accelerator.prepare(optim, trained_pipes)
735746
if hasattr(accel_optim.optimizer, "initialize"):
736747
accel_optim.optimizer.initialize()

0 commit comments

Comments
 (0)