diff --git a/bergson/collection.py b/bergson/collection.py index 7370f17..bb7050e 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -77,6 +77,8 @@ def callback(name: str, g: torch.Tensor, indices: list[int]): attention_cfgs=attention_cfgs, ) + validate_batch_size(model, cfg.token_batch_size, collector) + # Allocate space ahead of time for the gradients grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()} @@ -252,3 +254,27 @@ def process_preconditioners( preconditioners_eigen[name] = (eigval, eigvec) if rank == 0: processor.preconditioners_eigen = preconditioners_eigen + + +def validate_batch_size( + model: PreTrainedModel, + token_batch_size: int | None, + collector: GradientCollector, +): + """Validate that the specified token batch size fits on device.""" + if token_batch_size is None: + return + + random_tokens = torch.randint( + 0, 10, (1, token_batch_size), device=model.device, dtype=torch.long + ) + try: + with collector: + loss = model(random_tokens).logits[0, 0, 0].float() + loss.backward() + model.zero_grad() + except Exception as e: + raise ValueError( + f"Token batch size {token_batch_size} is too large for the device. " + f"Try reducing the batch size or use --fsdp to shard the model." + ) from e