Fix dtype leaks that break bfloat16 training#26
Open
tonyzdev wants to merge 1 commit intokyegomez:mainfrom
Open
Fix dtype leaks that break bfloat16 training#26tonyzdev wants to merge 1 commit intokyegomez:mainfrom
tonyzdev wants to merge 1 commit intokyegomez:mainfrom
Conversation
Two small dtype-promotion bugs silently upcast the hidden state from
bf16 to fp32 partway through the forward pass, producing
RuntimeError: expected mat1 and mat2 to have the same dtype,
but got: float != c10::BFloat16
at the next nn.Linear call. Both places omit an explicit `dtype=`
argument and fall back to torch's default float32, poisoning any
tensor they subsequently combine with.
1) `OpenMythos._causal_mask` builds its additive mask with
`torch.full(..., float("-inf"), device=device)`, i.e. fp32. When the
attention adds this mask to a bf16 `attn` tensor, `attn` becomes
fp32, then the subsequent `torch.matmul(attn, v)` crashes because
`v` is still bf16.
2) `RecurrentBlock.forward` allocates the ACT accumulator
`cumulative_p = torch.zeros(B, T, device=h.device)` and uses
`still_running.float()` in the weight update. Both are fp32 regardless
of `h.dtype`, so `h_out = h_out + weight.unsqueeze(-1) * h` silently
upcasts the returned hidden state to fp32. The next Coda layer then
fails at `q_down(h)` for the same dtype-mismatch reason.
The fix threads `h.dtype` / `x.dtype` through both sites. With both
patches applied, `mythos_1b` and a custom 150M MLA/MoE variant train
end-to-end in bf16 on H100 / A40 with no dtype errors, ρ(A) stable at
0.357, and zero impact on fp32 behaviour.
Tested on the latest main (torch 2.8.0+cu128, H100 SXM).
This was referenced Apr 21, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two places in
open_mythos/main.pyallocate tensors without an explicitdtype=argument and therefore fall back to fp32. When the model is loaded in bf16 (model.to(torch.bfloat16)) these fp32 tensors get combined with bf16 ones via PyTorch's type promotion rules, silently upcasting intermediate activations to fp32. The upcasted tensor is then fed into a bf16nn.Linear, which throws:This makes it impossible to train or even run inference of any variant (including
mythos_1b,mythos_3b, etc.) in bf16 on the currentmain.The two bugs
1.
_causal_maskalways builds fp32torch.fulluses the default dtype (fp32) when none is given. InGQAttention/MLAttention:The crash happens on
out = torch.matmul(attn, v)atmain.py:388(MLA) /main.py:246(GQA) on the first prelude layer.2.
RecurrentBlock.forwardACT accumulator is fp32cumulative_pstarts in fp32;still_running.float()forces fp32; soweightbecomes fp32 viatorch.where(cond, remainder_fp32, p_bf16)promotion, thenweight.unsqueeze(-1) * his fp32 * bf16 → fp32. The returnedh_outis fp32, and the following Coda layer crashes atq_down(h)for the same reason as above.(This is a distinct issue from the recently-landed ACT halting / weight-gating fix; that one was about which positions accumulate, this one is about what dtype the accumulation runs in.)
Fix
Thread the surrounding tensor's dtype through both sites:
_causal_mask(seq_len, device)→_causal_mask(seq_len, device, dtype=torch.float32), caller passesdtype=x.dtype.torch.full(..., dtype=dtype).cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype)still_running.float()→still_running.to(h.dtype)(two sites)bf16 -inf is well-defined so the causal mask semantics are unchanged. ACT probabilities accumulating 4–16 sigmoids in bf16 stay well within bf16's precision range.
Test plan
Verified end-to-end on a fresh checkout of latest
main+ this patch:smoke_1b.py(themythos_1bconfig): forward + backward passes without errors in bf16 on H100 SXM. Peak memory 5.5 GB, ρ(A) stable at 0.367.smoke_1b_train.py(3-step AdamW training loop): three real optimizer steps complete, loss decreases, ρ(A) unchanged.generate()withn_loops=8produces outputs.dtype=torch.float32) is unchanged —_causal_maskdefault kwarg is fp32,h.dtypereturns fp32 in that case.Repro without the patch:
Made with Cursor