Skip to content

Improve distributed training setup #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,33 @@ The `run_experiments.py` script orchestrates the running of experiments and data
python run_experiments.py
```

### 2. Process Experiment Data
### 2. Distributed Training

The training script uses `torch.nn.parallel.DistributedDataParallel`. To run it manually, launch one process per GPU with `torchrun`:

```bash
torchrun --nproc_per_node=<num_gpus> llm_training.py
```

`run_experiments.py` automatically uses this launch method when calling `llm_training.py`.

### 3. Process Experiment Data

The `process_experiment_data.py` script processes the raw experiment data into a more usable format. It will generate `training_stats.csv` and `inference_stats.csv`.

```bash
python process_experiment_data.py
```

### 3. Generate Plots and Recommendations
### 4. Generate Plots and Recommendations

The `generate_report.py` script generates visualizations and highlights the recommended max power settings based on the provided data.

```bash
python generate_report.py
```

### 4. Viewing the Output
### 5. Viewing the Output

The output image `report.png` contains the following:

Expand Down
92 changes: 65 additions & 27 deletions llm_training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch.nn import DataParallel
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW
import logging
import colorlog
Expand All @@ -10,7 +11,6 @@
from gpu_metrics_utils import initialize_nvml, shutdown_nvml, get_gpu_metrics

# Configuration Constants
GPU_IDS = [0, 1] # List of GPU IDs to use
BATCH_SIZE = 1975 # Batch size for training
SEQ_LENGTH = 2048 # Sequence length for training
EPOCHS = 20 # Number of epochs to train
Expand All @@ -36,16 +36,32 @@ def log_statistics(file_name, headers, data):
writer.writerow(headers)
writer.writerow(data)

def load_across_gpus(gpu_ids, batch_size, seq_length, epochs, learning_rate, callback=None):
def setup(rank, world_size):
"""Initialize the distributed process group."""
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)

def cleanup():
"""Destroy the distributed process group."""
dist.destroy_process_group()

def load_across_gpus(rank, world_size, batch_size, seq_length, epochs, learning_rate, callback=None):
logger.info("Starting LLM training/fine-tuning")
logger.info(f"Using GPUs: {gpu_ids}")
logger.info(f"Rank {rank}/{world_size}")
logger.info(f"Batch size: {batch_size}")
logger.info(f"Sequence length: {seq_length}")
logger.info(f"Epochs: {epochs}")
logger.info(f"Learning rate: {learning_rate}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GPT2LMHeadModel.from_pretrained('gpt2')
setup(rank, world_size)

device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Set padding token
Expand All @@ -60,8 +76,8 @@ def load_across_gpus(gpu_ids, batch_size, seq_length, epochs, learning_rate, cal
inputs = {key: value.to(device) for key, value in inputs.items()}
labels = inputs['input_ids']

# Use DataParallel to wrap the model for multi-GGPU usage
model = DataParallel(model, device_ids=gpu_ids).to(device)
# Wrap the model with DistributedDataParallel
model = DDP(model, device_ids=[rank])

initialize_nvml()

Expand All @@ -71,7 +87,7 @@ def load_across_gpus(gpu_ids, batch_size, seq_length, epochs, learning_rate, cal
log_file = 'training_stats.csv'

# Get sample GPU metrics to dynamically generate headers
sample_metrics = get_gpu_metrics()[0]
sample_metrics = get_gpu_metrics()[rank]
gpu_headers = list(sample_metrics.keys())
headers = ['timestamp', 'epoch', 'iteration', 'batch', 'loss', 'tokens_per_sec'] + gpu_headers + ['max_watt']

Expand All @@ -91,23 +107,37 @@ def load_across_gpus(gpu_ids, batch_size, seq_length, epochs, learning_rate, cal

batch_end_time = time.time()
batch_time = batch_end_time - batch_start_time
total_tokens += batch_labels.numel() # More accurate token count

logger.info(f"Epoch {epoch + 1}/{epochs}, Iteration {iteration}, Batch {i // batch_size + 1} completed, Loss: {loss.item()}")

# Log statistics after each batch
timestamp = datetime.now().isoformat()
tokens_per_sec = total_tokens / (time.time() - start_time)
from gpu_metrics_utils import collect_power_draw_all_gpus
total_power = collect_power_draw_all_gpus()
gpu_metrics = get_gpu_metrics()[0]
data = [timestamp, epoch + 1, iteration, i // batch_size + 1, loss.item(), tokens_per_sec] + list(gpu_metrics.values()) + [MAX_WATT, total_power]
if callback:
data = callback(data)
log_statistics(log_file, headers + ['total_power_draw'], data)
logger.info(f"Logged statistics: {data}")

tokens_this_batch = torch.tensor(batch_labels.numel(), device=device)
dist.all_reduce(tokens_this_batch, op=dist.ReduceOp.SUM)
total_tokens += tokens_this_batch.item()

if rank == 0:
logger.info(
f"Epoch {epoch + 1}/{epochs}, Iteration {iteration}, Batch {i // batch_size + 1} completed, Loss: {loss.item()}"
)

# Log statistics after each batch
timestamp = datetime.now().isoformat()
tokens_per_sec = total_tokens / (time.time() - start_time)
from gpu_metrics_utils import collect_power_draw_all_gpus
total_power = collect_power_draw_all_gpus()
gpu_metrics = get_gpu_metrics()[rank]
data = [
timestamp,
epoch + 1,
iteration,
i // batch_size + 1,
loss.item(),
tokens_per_sec,
] + list(gpu_metrics.values()) + [MAX_WATT, total_power]
if callback:
data = callback(data)
log_statistics(log_file, headers + ["total_power_draw"], data)
logger.info(f"Logged statistics: {data}")

shutdown_nvml()
cleanup()
logger.info("LLM training/fine-tuning completed")

if __name__ == "__main__":
Expand All @@ -116,7 +146,15 @@ def callback(data):
additional_data = {"example": "example"} # Replace with actual data
return data + list(additional_data.values())

if os.getenv('CALLBACK'):
load_across_gpus(GPU_IDS, BATCH_SIZE, SEQ_LENGTH, EPOCHS, LEARNING_RATE, callback)
# torchrun sets these environment variables when launching multiple processes
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))

# Provide defaults for local standalone runs
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")

if os.getenv("CALLBACK"):
load_across_gpus(rank, world_size, BATCH_SIZE, SEQ_LENGTH, EPOCHS, LEARNING_RATE, callback)
else:
load_across_gpus(GPU_IDS, BATCH_SIZE, SEQ_LENGTH, EPOCHS, LEARNING_RATE)
load_across_gpus(rank, world_size, BATCH_SIZE, SEQ_LENGTH, EPOCHS, LEARNING_RATE)
10 changes: 7 additions & 3 deletions run_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@
file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s:%(message)s'))
logger.addHandler(file_handler)

def run_script(script, max_watt):
def run_script(script, max_watt, nprocs=None):
"""Run a specified script with max_watt as an environment variable."""
try:
env = os.environ.copy()
env['MAX_WATT'] = str(max_watt)
subprocess.run(f"python {script}", shell=True, check=True, env=env)
if nprocs:
cmd = f"torchrun --nproc_per_node={nprocs} {script}"
else:
cmd = f"python {script}"
subprocess.run(cmd, shell=True, check=True, env=env)
logger.info(f"Completed running {script}")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to run script {script}: {e}")
Expand All @@ -47,7 +51,7 @@ def main():

# Run the training script
logger.info(f"Running the training script {TRAINING_SCRIPT}")
run_script(TRAINING_SCRIPT, max_watt)
run_script(TRAINING_SCRIPT, max_watt, nprocs=len(GPU_IDS))

# Run the inference script
logger.info(f"Running the inference script {INFERENCE_SCRIPT}")
Expand Down