Skip to content

Commit 44ba438

Browse files
Hangzhiokhat
andauthored
Fix BootstrapFinetune example in index doc and add basic tests for bootstrap_finetune. (#8435)
* done * done * done * format * done * Revert specific files to commit c396217 * fix set_lm * fix * fix testing * lint * link the tutorial to index sessions for clarification * resolve comments * format --------- Co-authored-by: Omar Khattab <[email protected]>
1 parent f875dc5 commit 44ba438

File tree

4 files changed

+99
-4
lines changed

4 files changed

+99
-4
lines changed

docs/docs/index.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,17 +403,20 @@ Given a few tens or hundreds of representative _inputs_ of your task and a _metr
403403

404404
```python linenums="1"
405405
import dspy
406-
dspy.configure(lm=dspy.LM("openai/gpt-4o-mini-2024-07-18"))
406+
lm=dspy.LM('openai/gpt-4o-mini-2024-07-18')
407407

408408
# Define the DSPy module for classification. It will use the hint at training time, if available.
409409
signature = dspy.Signature("text, hint -> label").with_updated_fields("label", type_=Literal[tuple(CLASSES)])
410410
classify = dspy.ChainOfThought(signature)
411+
classify.set_lm(lm)
411412

412413
# Optimize via BootstrapFinetune.
413414
optimizer = dspy.BootstrapFinetune(metric=(lambda x, y, trace=None: x.label == y.label), num_threads=24)
414415
optimized = optimizer.compile(classify, trainset=trainset)
415416

416417
optimized(text="What does a pending cash withdrawal mean?")
418+
419+
# For a complete fine-tuning tutorial, see: https://dspy.ai/tutorials/classification_finetuning/
417420
```
418421

419422
**Possible Output (from the last line):**

docs/docs/learn/optimization/optimizers.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,18 @@ These optimizers produce optimal instructions for the prompt and, in the case of
5858

5959
6. [**`MIPROv2`**](../../api/optimizers/MIPROv2.md): Generates instructions *and* few-shot examples in each step. The instruction generation is data-aware and demonstration-aware. Uses Bayesian Optimization to effectively search over the space of generation instructions/demonstrations across your modules.
6060

61+
7. [**`SIMBA`**](../../api/optimizers/SIMBA.md)
6162

6263
### Automatic Finetuning
6364

6465
This optimizer is used to fine-tune the underlying LLM(s).
6566

66-
7. [**`BootstrapFinetune`**](../../api/optimizers/BootstrapFinetune.md): Distills a prompt-based DSPy program into weight updates. The output is a DSPy program that has the same steps, but where each step is conducted by a finetuned model instead of a prompted LM.
67+
8. [**`BootstrapFinetune`**](/api/optimizers/BootstrapFinetune): Distills a prompt-based DSPy program into weight updates. The output is a DSPy program that has the same steps, but where each step is conducted by a finetuned model instead of a prompted LM. [See the classification fine-tuning tutorial](https://dspy.ai/tutorials/classification_finetuning/) for a complete example.
6768

6869

6970
### Program Transformations
7071

71-
8. [**`Ensemble`**](../../api/optimizers/Ensemble.md): Ensembles a set of DSPy programs and either uses the full set or randomly samples a subset into a single program.
72+
9. [**`Ensemble`**](../../api/optimizers/Ensemble.md): Ensembles a set of DSPy programs and either uses the full set or randomly samples a subset into a single program.
7273

7374

7475
## Which optimizer should I use?
@@ -176,17 +177,20 @@ optimized_program = teleprompter.compile(YOUR_PROGRAM_HERE, trainset=YOUR_TRAINS
176177

177178
```python linenums="1"
178179
import dspy
179-
dspy.configure(lm=dspy.LM('openai/gpt-4o-mini-2024-07-18'))
180+
lm=dspy.LM('openai/gpt-4o-mini-2024-07-18')
180181

181182
# Define the DSPy module for classification. It will use the hint at training time, if available.
182183
signature = dspy.Signature("text, hint -> label").with_updated_fields('label', type_=Literal[tuple(CLASSES)])
183184
classify = dspy.ChainOfThought(signature)
185+
classify.set_lm(lm)
184186

185187
# Optimize via BootstrapFinetune.
186188
optimizer = dspy.BootstrapFinetune(metric=(lambda x, y, trace=None: x.label == y.label), num_threads=24)
187189
optimized = optimizer.compile(classify, trainset=trainset)
188190

189191
optimized(text="What does a pending cash withdrawal mean?")
192+
193+
# For a complete fine-tuning tutorial, see: https://dspy.ai/tutorials/classification_finetuning/
190194
```
191195

192196
**Possible Output (from the last line):**

dspy/teleprompt/bootstrap_finetune.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,14 @@ def compile(
8181
key_to_data = {}
8282
for pred_ind, pred in enumerate(student.predictors()):
8383
data_pred_ind = None if self.multitask else pred_ind
84+
if pred.lm is None:
85+
raise ValueError(
86+
f"Predictor {pred_ind} does not have an LM assigned. "
87+
f"Please ensure the module's predictors have their LM set before fine-tuning. "
88+
f"You can set it using: your_module.set_lm(your_lm)"
89+
)
8490
training_key = (pred.lm, data_pred_ind)
91+
8592
if training_key not in key_to_data:
8693
train_data, data_format = self._prepare_finetune_data(
8794
trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from unittest.mock import patch
2+
3+
import dspy
4+
from dspy import Example
5+
from dspy.predict import Predict
6+
from dspy.teleprompt import BootstrapFinetune
7+
from dspy.utils.dummies import DummyLM
8+
9+
10+
# Define a simple metric function for testing
11+
def simple_metric(example, prediction, trace=None):
12+
return example.output == prediction.output
13+
14+
15+
examples = [
16+
Example(input="What is the color of the sky?", output="blue").with_inputs("input"),
17+
Example(input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!").with_inputs("input"),
18+
]
19+
trainset = [examples[0]]
20+
21+
22+
def test_bootstrap_finetune_initialization():
23+
"""Test BootstrapFinetune initialization with various parameters."""
24+
bootstrap = BootstrapFinetune(metric=simple_metric)
25+
assert bootstrap.metric == simple_metric, "Metric not correctly initialized"
26+
assert bootstrap.multitask == True, "Multitask should default to True"
27+
28+
29+
class SimpleModule(dspy.Module):
30+
def __init__(self, signature):
31+
super().__init__()
32+
self.predictor = Predict(signature)
33+
34+
def forward(self, **kwargs):
35+
return self.predictor(**kwargs)
36+
37+
38+
def test_compile_with_predict_instances():
39+
"""Test BootstrapFinetune compilation with Predict instances."""
40+
# Create SimpleModule instances for student and teacher
41+
student = SimpleModule("input -> output")
42+
teacher = SimpleModule("input -> output")
43+
44+
lm = DummyLM([{"output": "blue"}, {"output": "Ring-ding-ding-ding-dingeringeding!"}])
45+
dspy.settings.configure(lm=lm)
46+
47+
# Set LM for both student and teacher
48+
student.set_lm(lm)
49+
teacher.set_lm(lm)
50+
51+
bootstrap = BootstrapFinetune(metric=simple_metric)
52+
53+
# Mock the fine-tuning process since DummyLM doesn't support it
54+
with patch.object(bootstrap, "finetune_lms") as mock_finetune:
55+
mock_finetune.return_value = {(lm, None): lm}
56+
compiled_student = bootstrap.compile(student, teacher=teacher, trainset=trainset)
57+
58+
assert compiled_student is not None, "Failed to compile student"
59+
assert hasattr(compiled_student, "_compiled") and compiled_student._compiled, "Student compilation flag not set"
60+
61+
mock_finetune.assert_called_once()
62+
63+
64+
def test_error_handling_missing_lm():
65+
"""Test error handling when predictor doesn't have an LM assigned."""
66+
67+
lm = DummyLM([{"output": "test"}])
68+
dspy.settings.configure(lm=lm)
69+
70+
student = SimpleModule("input -> output")
71+
# Intentionally NOT setting LM for the student module
72+
73+
bootstrap = BootstrapFinetune(metric=simple_metric)
74+
75+
# This should raise ValueError about missing LM and hint to use set_lm
76+
try:
77+
bootstrap.compile(student, trainset=trainset)
78+
assert False, "Should have raised ValueError for missing LM"
79+
except ValueError as e:
80+
assert "does not have an LM assigned" in str(e)
81+
assert "set_lm" in str(e)

0 commit comments

Comments
 (0)