Skip to content

Commit cf73146

Browse files
NehanthUbuntuashwinb
authored
feat: Enable DPO training with HuggingFace inline provider (#2825)
What does this PR do? This PR adds support for Direct Preference Optimization (DPO) training via the existing HuggingFace inline provider. It introduces a new DPO training recipe, config schema updates, dataset integration, and end-to-end testing to support preference-based fine-tuning with TRL. Test Plan Added integration test: tests/integration/post_training/test_post_training.py::TestPostTraining::test_preference_optimize Ran tests on both CPU and CUDA environments --------- Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Ashwin Bharambe <[email protected]>
1 parent 2665f00 commit cf73146

File tree

7 files changed

+913
-215
lines changed

7 files changed

+913
-215
lines changed

docs/source/providers/post_training/inline_huggingface.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin
2424
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
2525
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
2626
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
27+
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
28+
| `use_reference_model` | `<class 'bool'>` | No | True | |
29+
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
30+
| `dpo_output_dir` | `<class 'str'>` | No | ./checkpoints/dpo | |
2731

2832
## Sample Configuration
2933

llama_stack/providers/inline/post_training/huggingface/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ class HuggingFacePostTrainingConfig(BaseModel):
6767
# Can improve data transfer speed to GPU but uses more memory
6868
dataloader_pin_memory: bool = True
6969

70+
# DPO-specific parameters
71+
dpo_beta: float = 0.1
72+
use_reference_model: bool = True
73+
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
74+
dpo_output_dir: str = "./checkpoints/dpo"
75+
7076
@classmethod
7177
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
7278
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}

llama_stack/providers/inline/post_training/huggingface/post_training.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
2626
HFFinetuningSingleDevice,
2727
)
28+
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
29+
HFDPOAlignmentSingleDevice,
30+
)
2831
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
2932
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
3033
from llama_stack.schema_utils import webmethod
@@ -36,6 +39,7 @@ class TrainingArtifactType(Enum):
3639

3740

3841
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
42+
_JOB_TYPE_DPO_TRAINING = "dpo-training"
3943

4044

4145
class HuggingFacePostTrainingImpl:
@@ -119,12 +123,37 @@ async def preference_optimize(
119123
hyperparam_search_config: dict[str, Any],
120124
logger_config: dict[str, Any],
121125
) -> PostTrainingJob:
122-
raise NotImplementedError("DPO alignment is not implemented yet")
126+
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
127+
on_log_message_cb("Starting HF DPO alignment")
123128

124-
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
125-
return ListPostTrainingJobsResponse(
126-
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
127-
)
129+
recipe = HFDPOAlignmentSingleDevice(
130+
job_uuid=job_uuid,
131+
datasetio_api=self.datasetio_api,
132+
datasets_api=self.datasets_api,
133+
)
134+
135+
resources_allocated, checkpoints = await recipe.train(
136+
model=finetuned_model,
137+
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
138+
job_uuid=job_uuid,
139+
dpo_config=algorithm_config,
140+
config=training_config,
141+
provider_config=self.config,
142+
)
143+
144+
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
145+
if checkpoints:
146+
for checkpoint in checkpoints:
147+
artifact = self._checkpoint_to_artifact(checkpoint)
148+
on_artifact_collected_cb(artifact)
149+
else:
150+
on_log_message_cb("Warning: No checkpoints were saved during DPO training")
151+
152+
on_status_change_cb(SchedulerJobStatus.completed)
153+
on_log_message_cb("HF DPO alignment completed")
154+
155+
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
156+
return PostTrainingJob(job_uuid=job_uuid)
128157

129158
@staticmethod
130159
def _get_artifacts_metadata_by_type(job, artifact_type):
@@ -174,3 +203,9 @@ async def cancel_training_job(self, job_uuid: str) -> None:
174203
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
175204
job = self._scheduler.get_job(job_uuid)
176205
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
206+
207+
@webmethod(route="/post-training/jobs", method="GET")
208+
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
209+
return ListPostTrainingJobsResponse(
210+
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
211+
)

0 commit comments

Comments
 (0)