Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions finetune/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def __init__(self):
# Gradient accumulation to simulate a larger batch size.
self.accumulation_steps = 1

# Mixed-precision training. Set to "bfloat16" on Ampere-class
# or newer GPUs (RTX 30/40-series, A100, H100) to reduce step
# time and activation memory. ``None`` keeps full FP32 training
# and is the default for parity with earlier runs.
self.amp_dtype = None

# AdamW optimizer parameters.
self.adam_beta1 = 0.9
self.adam_beta2 = 0.95
Expand Down
37 changes: 22 additions & 15 deletions finetune/train_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
cleanup_ddp,
set_seed,
get_model_size,
format_time
format_time,
resolve_amp_dtype,
)


Expand Down Expand Up @@ -68,6 +69,10 @@ def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_

train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)

amp_dtype, amp_enabled = resolve_amp_dtype(config.get('amp_dtype'))
if rank == 0 and amp_enabled:
print(f"AMP enabled: autocast on cuda with dtype={amp_dtype}.")

optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['predictor_learning_rate'],
Expand Down Expand Up @@ -96,17 +101,18 @@ def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_
batch_x = batch_x.to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)

# Tokenize input data on-the-fly
with torch.no_grad():
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp_enabled):
# Tokenize input data on-the-fly
with torch.no_grad():
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)

# Prepare inputs and targets for the language model
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
# Prepare inputs and targets for the language model
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]

# Forward pass and loss calculation
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
# Forward pass and loss calculation
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])

# Backward pass and optimization
optimizer.zero_grad()
Expand Down Expand Up @@ -140,12 +146,13 @@ def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_
batch_x = batch_x.to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)

token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp_enabled):
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]

logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])

tot_val_loss_sum_rank += val_loss.item()
val_batches_processed_rank += 1
Expand Down
29 changes: 18 additions & 11 deletions finetune/train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
set_seed,
get_model_size,
format_time,
resolve_amp_dtype,
)


Expand Down Expand Up @@ -95,6 +96,10 @@ def train_model(model, device, config, save_dir, logger, rank, world_size):

train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)

amp_dtype, amp_enabled = resolve_amp_dtype(config.get('amp_dtype'))
if rank == 0 and amp_enabled:
print(f"AMP enabled: autocast on cuda with dtype={amp_dtype}.")

optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['tokenizer_learning_rate'],
Expand Down Expand Up @@ -133,15 +138,16 @@ def train_model(model, device, config, save_dir, logger, rank, world_size):
end_idx = (j + 1) * (ori_batch_x.shape[0] // config['accumulation_steps'])
batch_x = ori_batch_x[start_idx:end_idx]

# Forward pass
zs, bsq_loss, _, _ = model(batch_x)
z_pre, z = zs
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp_enabled):
# Forward pass
zs, bsq_loss, _, _ = model(batch_x)
z_pre, z = zs

# Loss calculation
recon_loss_pre = F.mse_loss(z_pre, batch_x)
recon_loss_all = F.mse_loss(z, batch_x)
recon_loss = recon_loss_pre + recon_loss_all
loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1
# Loss calculation
recon_loss_pre = F.mse_loss(z_pre, batch_x)
recon_loss_all = F.mse_loss(z, batch_x)
recon_loss = recon_loss_pre + recon_loss_all
loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1

loss_scaled = loss / config['accumulation_steps']
current_batch_total_loss += loss.item()
Expand Down Expand Up @@ -177,9 +183,10 @@ def train_model(model, device, config, save_dir, logger, rank, world_size):
with torch.no_grad():
for ori_batch_x, _ in val_loader:
ori_batch_x = ori_batch_x.to(device, non_blocking=True)
zs, _, _, _ = model(ori_batch_x)
_, z = zs
val_loss_item = F.mse_loss(z, ori_batch_x)
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp_enabled):
zs, _, _, _ = model(ori_batch_x)
_, z = zs
val_loss_item = F.mse_loss(z, ori_batch_x)

tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
val_sample_count_rank += ori_batch_x.size(0)
Expand Down
26 changes: 26 additions & 0 deletions finetune/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,29 @@ def format_time(seconds: float) -> str:



def resolve_amp_dtype(amp_dtype):
"""
Resolves the configured AMP dtype string into the arguments expected by
`torch.autocast`.

Currently only "bfloat16" is supported. Passing ``None`` disables
mixed-precision training; the returned dtype is then irrelevant and
autocast becomes a no-op.

Args:
amp_dtype (str | None): "bfloat16" or None.

Returns:
tuple[torch.dtype, bool]: (dtype, enabled) suitable for
``torch.autocast(device_type=..., dtype=dtype, enabled=enabled)``.

Raises:
ValueError: If ``amp_dtype`` is set to an unsupported value.
"""
if amp_dtype is None:
return torch.float32, False
if amp_dtype == "bfloat16":
return torch.bfloat16, True
raise ValueError(
f"Unsupported amp_dtype {amp_dtype!r}; expected 'bfloat16' or None."
)