diff --git a/llm_training.py b/llm_training.py index d7ec619..bf2e35c 100644 --- a/llm_training.py +++ b/llm_training.py @@ -60,7 +60,7 @@ def load_across_gpus(gpu_ids, batch_size, seq_length, epochs, learning_rate, cal inputs = {key: value.to(device) for key, value in inputs.items()} labels = inputs['input_ids'] - # Use DataParallel to wrap the model for multi-GGPU usage + # Use DataParallel to wrap the model for multi-GPU usage model = DataParallel(model, device_ids=gpu_ids).to(device) initialize_nvml()