Skip to content

Fix dtype leaks that break bfloat16 training#26

Open
tonyzdev wants to merge 1 commit intokyegomez:mainfrom
tonyzdev:fix/bf16-dtype-leak
Open

Fix dtype leaks that break bfloat16 training#26
tonyzdev wants to merge 1 commit intokyegomez:mainfrom
tonyzdev:fix/bf16-dtype-leak

Conversation

@tonyzdev
Copy link
Copy Markdown

Summary

Two places in open_mythos/main.py allocate tensors without an explicit dtype= argument and therefore fall back to fp32. When the model is loaded in bf16 (model.to(torch.bfloat16)) these fp32 tensors get combined with bf16 ones via PyTorch's type promotion rules, silently upcasting intermediate activations to fp32. The upcasted tensor is then fed into a bf16 nn.Linear, which throws:

RuntimeError: expected mat1 and mat2 to have the same dtype,
              but got: float != c10::BFloat16

This makes it impossible to train or even run inference of any variant (including mythos_1b, mythos_3b, etc.) in bf16 on the current main.

The two bugs

1. _causal_mask always builds fp32

# current
mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device)

torch.full uses the default dtype (fp32) when none is given. In GQAttention/MLAttention:

attn = torch.matmul(q, k.transpose(-2, -1)) * scale   # bf16
attn = attn + mask                                    # bf16 + fp32 -> fp32 (!)
attn = F.softmax(attn, dim=-1)                        # fp32
out = torch.matmul(attn, v)                           # fp32 @ bf16 -> CRASH

The crash happens on out = torch.matmul(attn, v) at main.py:388 (MLA) / main.py:246 (GQA) on the first prelude layer.

2. RecurrentBlock.forward ACT accumulator is fp32

# current
cumulative_p = torch.zeros(B, T, device=h.device)
...
weight = weight * still_running.float()
h_out = h_out + weight.unsqueeze(-1) * h
cumulative_p = cumulative_p + p * still_running.float()

cumulative_p starts in fp32; still_running.float() forces fp32; so weight becomes fp32 via torch.where(cond, remainder_fp32, p_bf16) promotion, then weight.unsqueeze(-1) * h is fp32 * bf16 → fp32. The returned h_out is fp32, and the following Coda layer crashes at q_down(h) for the same reason as above.

(This is a distinct issue from the recently-landed ACT halting / weight-gating fix; that one was about which positions accumulate, this one is about what dtype the accumulation runs in.)

Fix

Thread the surrounding tensor's dtype through both sites:

  • _causal_mask(seq_len, device)_causal_mask(seq_len, device, dtype=torch.float32), caller passes dtype=x.dtype. torch.full(..., dtype=dtype).
  • cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype)
  • still_running.float()still_running.to(h.dtype) (two sites)

bf16 -inf is well-defined so the causal mask semantics are unchanged. ACT probabilities accumulating 4–16 sigmoids in bf16 stay well within bf16's precision range.

Test plan

Verified end-to-end on a fresh checkout of latest main + this patch:

  1. smoke_1b.py (the mythos_1b config): forward + backward passes without errors in bf16 on H100 SXM. Peak memory 5.5 GB, ρ(A) stable at 0.367.
  2. smoke_1b_train.py (3-step AdamW training loop): three real optimizer steps complete, loss decreases, ρ(A) unchanged.
  3. generate() with n_loops=8 produces outputs.
  4. Full 15,000-step bf16 training of a 117M MLA+MoE variant on FineWeb-Edu converges cleanly (loss 11 → 4.0) with ρ(A) stable at 0.357 throughout; no dtype errors across ~10h of training.
  5. fp32 path (default dtype=torch.float32) is unchanged — _causal_mask default kwarg is fp32, h.dtype returns fp32 in that case.

Repro without the patch:

import torch
from open_mythos import OpenMythos
from open_mythos.variants import mythos_1b

cfg = mythos_1b(); cfg.max_loop_iters = 2
m = OpenMythos(cfg).to("cuda", torch.bfloat16)
ids = torch.randint(0, cfg.vocab_size, (1, 64), device="cuda")
m(ids, n_loops=2)   # -> RuntimeError: expected mat1 and mat2 to have the same dtype

Made with Cursor

Two small dtype-promotion bugs silently upcast the hidden state from
bf16 to fp32 partway through the forward pass, producing

    RuntimeError: expected mat1 and mat2 to have the same dtype,
                  but got: float != c10::BFloat16

at the next nn.Linear call. Both places omit an explicit `dtype=`
argument and fall back to torch's default float32, poisoning any
tensor they subsequently combine with.

1) `OpenMythos._causal_mask` builds its additive mask with
   `torch.full(..., float("-inf"), device=device)`, i.e. fp32. When the
   attention adds this mask to a bf16 `attn` tensor, `attn` becomes
   fp32, then the subsequent `torch.matmul(attn, v)` crashes because
   `v` is still bf16.

2) `RecurrentBlock.forward` allocates the ACT accumulator
   `cumulative_p = torch.zeros(B, T, device=h.device)` and uses
   `still_running.float()` in the weight update. Both are fp32 regardless
   of `h.dtype`, so `h_out = h_out + weight.unsqueeze(-1) * h` silently
   upcasts the returned hidden state to fp32. The next Coda layer then
   fails at `q_down(h)` for the same dtype-mismatch reason.

The fix threads `h.dtype` / `x.dtype` through both sites. With both
patches applied, `mythos_1b` and a custom 150M MLA/MoE variant train
end-to-end in bf16 on H100 / A40 with no dtype errors, ρ(A) stable at
0.357, and zero impact on fp32 behaviour.

Tested on the latest main (torch 2.8.0+cu128, H100 SXM).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant