Skip to content

Commit d025471

Browse files
committed
feat: enable knowledge distillation
There are many different forms of model training which exist. One popular form of training is knowledge distillation, where a student model learns the output distributions from a teacher model. This commit introduces support for knowledge distillation in the training library. This commit also exposes the `weight_decay` hyperparameter which is often used to help deep learning models generalize. Lastly, this commit changes the useage from `torch.distributed` to just `dist`, as it is a common module used throughout the codebase. Signed-off-by: Oleg S <[email protected]>
1 parent 8e6be0d commit d025471

File tree

2 files changed

+227
-52
lines changed

2 files changed

+227
-52
lines changed

src/instructlab/training/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,16 @@ class DeepSpeedOptions(BaseModel):
121121
save_samples: int | None = None
122122

123123

124+
class DistillationConfig(BaseModel):
125+
"""
126+
Config to use when performing knowledge distillation during training.
127+
"""
128+
129+
temperature: float = Field(1.0, gt=0.0)
130+
alpha: float = Field(1.0, le=1.0, ge=0.0)
131+
teacher_path: str
132+
133+
124134
# public API
125135
class ShardingStrategies(Enum):
126136
FULL_SHARD = "FULL_SHARD"
@@ -179,6 +189,11 @@ class TrainingArgs(BaseModel):
179189
is_padding_free: bool = False # TODO: deprecate
180190
checkpoint_at_epoch: bool = True
181191
accelerate_full_state_at_epoch: bool = True
192+
weight_decay: float = Field(0.0, ge=0.0)
193+
194+
# settings for knowledge distillation
195+
distillation_options: Optional[DistillationConfig] = None
196+
use_distillation: bool = False
182197

183198
mock_data: Optional[bool] = False
184199
mock_data_len: int = 0

0 commit comments

Comments
 (0)