Skip to content

Remediation: router-bias wiring, EOS packing, numerical stability, trainer hardening, tests#41

Open
supernavyl wants to merge 6 commits intokyegomez:mainfrom
supernavyl:fixes/scan-remediation
Open

Remediation: router-bias wiring, EOS packing, numerical stability, trainer hardening, tests#41
supernavyl wants to merge 6 commits intokyegomez:mainfrom
supernavyl:fixes/scan-remediation

Conversation

@supernavyl
Copy link
Copy Markdown

Summary

Comprehensive remediation of issues found during a multi-dimensional audit of the codebase. Six logical commits that can be reviewed independently:

  1. fix(moe+attn) — vectorized dispatch, wired router bias, numerical stability
  2. fix(tokenizer) — correct vocab sizing + EOS-aware encode_with_eos
  3. refactor — move moda.py to open_mythos.experimental
  4. feat(training) — harden FineWeb-Edu pretraining script
  5. chore(deps) — reconcile pyproject / requirements.txt / training/requirements.txt
  6. test — cover config validation, router bias update, vectorized dispatch

Headline fixes

1. router_bias was silently inert (correctness bug)

Upstream defines MoEFFN.router_bias as a buffer of zeros but no code path ever updated it — the advertised DeepSeek-V3 aux-loss-free load balancing was not happening. This PR adds:

  • MoEFFN.expert_load buffer + bincount accumulation on each training forward.
  • MoEFFN.update_bias(speed) with a sign-based update (magnitude bounded by speed regardless of imbalance, per DeepSeek-V3 Eq. 16). Flushes load after each call.
  • OpenMythos.update_router_biases(speed=None, ddp=False) that walks every MoE layer and all-reduces load across FSDP ranks before updating.
  • Trainer call site: model.update_router_biases(ddp=ddp) after every successful optimizer.step().

2. No EOS between packed documents (training-quality bug)

The FineWebEduDataset packer used encode() which does not append EOS, so concatenated documents flowed into each other with no boundary marker — the model learned spurious cross-document attention.

  • New MythosTokenizer.encode_with_eos() appends EOS (with eos -> bos -> all_special_ids[0] -> None fallback).
  • Dataset now uses it.

3. Vectorized MoE dispatch (performance)

The naive per-token loop was O(topk * n_experts) Python iterations per forward pass. Replaced with a stable argsort + bincount + index_add_ pipeline — O(n_experts) dense ops. Semantic equivalence verified against a reference loop implementation at 1e-5 tolerance (see tests/test_moe_router.py::test_vectorized_matches_naive_dispatch).

4. Numerical stability

  • Softmax upcast to fp32 in GQAttention, MLAttention, and MoEFFN routers. bf16/fp16 softmax quantizes the tail and collapses attention to one-hot or uniform at long context.
  • clamp_min(1e-9) on the gate-weight renormalization denominator so a fully underflowed set of topk scores does not divide by zero.
  • LTI clamp tightened: log_dt + log_A clamp lower bound raised from -20 to -10. At -20, exp(-exp(-20)) rounds to exactly 1.0 in fp32 and breaks the strict ρ(A) < 1 invariant. At -10 it saturates at ~1 - 4.5e-5, safely below 1 in fp32.

5. vocab_size trap (correctness bug)

MythosTokenizer.vocab_size was self.tokenizer.vocab_size, which for HF tokenizers is the base vocab excluding added specials. Any added special token would index past the model's embedding matrix and trigger a CUDA device-side assert deep into pretraining.

  • Now returns len(self.tokenizer), rounded up to vocab_multiple_of (default 128 for tensor-core-friendly widths).

6. Config validation at construction time

MythosConfig.__post_init__ now validates 12+ invariants (attn_type enum, dim/n_heads divisibility, GQA grouping, MLA rope even, loop_dim even, MoE sizing, ACT/dropout ranges, LoRA positivity, vocab_size/max_seq_len positivity). Bad configs fail at MythosConfig(...) instead of mid-step hours into a run.

7. Trainer hardening for long runs

  • Per-microstep NaN/Inf loss guard (poisoning Adam moments is unrecoverable).
  • Post-clip non-finite grad_norm guard.
  • ShardedGradScaler on the fp16 path.
  • Full RNG seeding (python/numpy/torch/cuda), persisted through the checkpoint.
  • Checkpoint also persists scaler state, torch/cuda versions, seed.
  • SIGTERM / SIGINT cooperative shutdown with final atomic save, then exit 130.
  • loguru rotating file sink (100 MB / 7-day retention, gz compressed, per-rank).
  • On-device loss accumulation (single .item() per log window).
  • Directory fsync after atomic rename for power-loss durability.
  • cfg.__post_init__() re-run after mutating vocab_size / max_seq_len.

8. moda.py moved to open_mythos.experimental

MoDA (Mixture-of-Depths Attention) is a parallel research line, not part of the canonical Prelude/Recurrent/Coda model. Subpackage boundary makes stability guarantees explicit. Commented-out smoke-test block deleted (dead scaffold, not a test).

9. Defensive freqs_cis[:T] slice in attention modules

Standalone callers (tests, ad-hoc scripts) no longer have to pre-slice the full precomputed RoPE buffer to match the current T. Fixes a pre-existing test_main.py::TestGQAttention::test_output_shape crash.

Test plan

  • pytest test_main.py tests/103 passed (0 failed).
  • test_moe_router.py::test_vectorized_matches_naive_dispatch verifies the fast dispatch path matches a reference loop at 1e-5.
  • test_moe_router.py::test_update_bias_magnitude_bounded_by_speed verifies the sign-based update even on 1M-imbalance load.
  • test_config_validation.py covers 11 distinct __post_init__ guards.
  • test_tokenizer.py::test_encode_ids_within_vocab — embedding-safety invariant.
  • python -m py_compile on every changed file.
  • Multi-node FSDP smoke test — not run (no cluster access on review author's side).
  • Full pretraining run — not run (30B-token scale).

Breaking changes

  • MythosConfig now validates in __post_init__. Configs that relied on silently invalid hyperparameter combinations will now raise ValueError at construction.
  • MythosTokenizer.vocab_size may return a larger value than before (base vocab + added specials, rounded to 128). Any code that hardcoded the old value will need to rebuild its embedding matrix.
  • open_mythos.moda import path → open_mythos.experimental.moda (or from open_mythos.experimental import MoDAModel). The old import path no longer works.

Honest caveats

  • Tests run CPU-only. I did not have GPU access for a real FSDP run, so the trainer changes are validated by static review + unit tests on the model code, not by a live pretraining run.
  • The vectorized MoE dispatch is verified equivalent at 1e-5 on small shapes in eval mode; at train time expert_load bookkeeping differs (intentionally, since the fast path accumulates it and the naive reference does not). That's tested separately in test_expert_load_accumulates_on_forward.

…ility

MoEFFN
------
- Vectorized routed-expert dispatch via stable argsort + bincount offsets +
  index_add_. Replaces the O(topk * n_experts) Python loop with one matmul
  per expert (O(n_experts) dense ops). Preserves exact semantics - tested
  against a naive reference loop with 1e-5 tolerance.
- Added router_bias (non-persistent buffer) and expert_load (non-persistent
  buffer) to implement DeepSeek-V3 aux-loss-free load balancing. Upstream
  had router_bias but nothing ever updated it - balancing was silently inert.
- update_bias(speed) applies a sign-based update so per-step delta magnitude
  is bounded by speed regardless of how skewed the load is. Flushes
  expert_load after each call.
- Softmax upcast to fp32 inside fp16/bf16 autocast; clamp_min(1e-9) on the
  gate-weight renormalization denominator prevents division by underflow.

OpenMythos
----------
- New update_router_biases(speed=None, ddp=False) method walks every MoEFFN
  submodule, optionally all-reduces expert_load across ranks, and applies
  the bias update. Must be called AFTER optimizer.step().

GQAttention / MLAttention
-------------------------
- Softmax upcast to fp32 before cast back to attention dtype. Long-sequence
  bf16 softmax quantizes the tail and collapses attention toward one-hot or
  uniform.
- Defensive freqs_cis[:T] slice so standalone callers do not have to
  pre-slice before passing the full precomputed RoPE buffer.

LTIInjection
------------
- Lower clamp tightened from -20 to -10 in get_A(). At -20, exp(-exp(-20))
  rounds to exactly 1.0 in float32, breaking the strict spectral-radius<1
  guarantee under adversarial gradient steps. -10 gives a 4.5e-5 margin
  below 1.0 that comfortably survives fp32 rounding.

MythosConfig
------------
- __post_init__ now validates every hyperparameter at construction time.
  Bad configs fail now instead of mid-step in a pretraining run.
- Added fields: bias_update_speed, loop_rope_theta, lti_b_init, init_std
  with sensible defaults.
- _init_weights uses cfg.init_std; router init scaled by 0.1.
MythosTokenizer changes that the trainer and model both depend on:

- vocab_size now returns len(self.tokenizer) (base vocab + added specials),
  rounded up to vocab_multiple_of (default 128 for tensor-core-friendly
  embedding widths). HF's tokenizer.vocab_size silently excludes added
  specials, so a token in that excluded range caused a CUDA device-side
  assert deep into pretraining. Any nn.Embedding sized from the new property
  cannot index out of range.
- eos_token_id property with a fallback chain: eos -> bos -> all_special_ids[0]
  -> None. Used by the FineWeb-Edu packer to inject an explicit boundary
  token between concatenated documents, so the model never sees cross-document
  attention without a marker.
- encode() now silently rejects None and non-str inputs (returns []) and
  truncates at MAX_CHARS_PER_DOC = 4_000_000 before tokenizing. FineWeb-Edu
  has pathological outliers that stalled DataLoader workers and OOM'd the
  tokenizer.
- encode_with_eos() method appends the EOS id when defined. Intended for
  the document packer path.
- trust_remote_code=False pinned explicitly so future transformers versions
  cannot silently start running remote Python.
MoDA (Mixture-of-Depths Attention) is a parallel research-line architecture,
not part of the canonical OpenMythos Prelude/Recurrent/Coda model. Moving
it to open_mythos.experimental/ makes that boundary explicit:

- Public API at the package root stays the canonical architecture.
- Experimental components (MoDAConfig, MoDAModel, MoDAAttention, DeepSeekMoE,
  DeepSeekGate, DeepSeekExpert, RMSNorm, RotaryEmbedding) are importable
  from open_mythos.experimental with a loud docstring stating no stability
  guarantees.
- The commented-out smoke test block at the bottom of moda.py is deleted -
  it was a dead debug scaffold, not a test, and encouraging smoke tests to
  live as commented-out __main__ blocks teaches the wrong pattern.
Comprehensive hardening of training/3b_fine_web_edu.py for long multi-day
FSDP runs where the cost of a crash at step 50k is days of wasted compute.

Correctness and numerics
------------------------
- Per-microstep NaN/Inf loss guard: non-finite micro-losses are skipped
  (no backward), so one bad sample cannot poison Adam moment buffers.
  If every microstep in the accumulation window is non-finite, the whole
  optimizer.step() is skipped but the step counter still ticks so LR
  schedule stays monotonic.
- Non-finite grad_norm guard after clipping (ShardedGradScaler handles
  this for the fp16 path but we enforce it uniformly for bf16 too).
- Aux-loss-free load balancing is finally driven: after every successful
  optimizer.step() we call model.update_router_biases(ddp=ddp), which
  all-reduces expert_load across ranks and applies the DeepSeek-V3 bias
  update. Without this call the balancing mechanism was silently inert.
- EOS injection: tokenization uses encoding.encode_with_eos() so packed
  documents get a boundary token instead of flowing into each other.
- Micro-batch loss accumulated on-device; single .item() per logging
  window instead of every microstep.

Mixed precision
---------------
- ShardedGradScaler wired up on the fp16 path (Volta/Pascal). bf16 path
  runs with FSDP MixedPrecision and no scaler, which is the officially
  supported combination.
- Scaler state round-trips through checkpoints.

Reproducibility
---------------
- All RNGs seeded (python / numpy / torch / cuda) with per-rank offset
  for in-process uniqueness. Seed persists through the checkpoint so a
  resume on a different node draws the same data stream (given the
  shard is still at the same position).
- Checkpoint carries RNG state, scaler state, torch and cuda versions.

Graceful shutdown
-----------------
- SIGTERM / SIGINT handler marks a cooperative shutdown flag. Main loop
  polls it between microbatches, breaks cleanly, writes a final atomic
  checkpoint, barriers, and exits 130. A second signal falls through to
  default handling so a stuck rank can always be force-killed.

Logging
-------
- loguru rotating file sink: 100 MB per file, 7-day retention, gz
  compressed, per-rank file. Non-master ranks silence stderr to avoid
  interleaving chaos but still log to file for post-mortem. Main rank
  keeps the default stderr sink.
- Exception path captures tracebacks via logger.exception; final save
  runs in a finally block so a crash still writes the latest state.

Misc
----
- cfg.__post_init__() re-run after mutating vocab_size and max_seq_len
  so an operator who edits them at the CLI gets a clean error early.
- Directory fsync after atomic rename so the checkpoint is durable
  across power loss.
- persistent_workers=True on the DataLoader so workers survive between
  epoch boundaries instead of respawning and re-opening the stream.
- zero_grad(set_to_none=True) to decouple grad lifetime from param
  lifetime under FSDP.
…ents.txt

Three dep manifests existed with conflicting constraints (torch "2.11.0"
exact vs >=2.1.0 vs >=2.11.0) and missing entries (numpy / loguru used by
the trainer but declared nowhere).

pyproject.toml (Poetry, library-facing)
  - torch >=2.3.0,<3.0.0 (floor set by ShardedGradScaler import path +
    torch.amp.autocast device_type= signature)
  - transformers >=4.40.0,<5.0.0
  - datasets >=2.18.0,<4.0.0
  - New [tool.poetry.group.training] group with numpy and loguru for
    users who want the pretraining scripts but not a minimal inference
    install.

requirements.txt (inference / library use)
  - Same ranges as pyproject but in pip-compatible syntax.

training/requirements.txt (pretraining runs)
  - Exact pins (torch==2.11.0, transformers==4.46.3, datasets==3.2.0,
    loguru==0.7.3) for node-to-node reproducibility.
  - Includes the CUDA 12.4 wheel index.
  - Documents: when bumping torch here, bump pyproject too and confirm
    the FSDP / autocast APIs the trainer uses still exist.
New pytest modules under tests/ that pin the behavior introduced in the
accompanying fix commits so regressions surface in CI instead of at
step 50k of pretraining.

tests/test_config_validation.py (11 tests)
  - Every MythosConfig.__post_init__ guard exercised one axis at a time.
  - Baseline config is a known-good fixture other tests can override.

tests/test_moe_router.py (9 tests)
  - router_bias and expert_load are buffers, not Parameters.
  - expert_load accumulates on training forward; not on eval forward.
  - update_bias shifts bias toward underused experts and away from
    overused ones.
  - update_bias flushes expert_load.
  - update_bias with speed=0 is a no-op on bias, still flushes load.
  - Sign-based update magnitude is bounded by `speed` even for massive
    imbalance (spec invariant).
  - OpenMythos.update_router_biases walks every MoE layer.
  - Vectorized dispatch matches a naive per-token loop at 1e-5 tolerance -
    this is the safety net for the argsort + index_add_ optimization.

tests/test_tokenizer.py (rewritten, 16 tests)
  - vocab_size >= len(tokenizer) invariant.
  - vocab_size rounded to multiple of 128 by default; configurable.
  - encode rejects None/non-str; truncates oversized inputs at max_chars.
  - encode_with_eos appends EOS when defined, plain encode otherwise.
  - encode_with_eos on empty returns empty (no lone EOS emitted).
  - Every emitted id is < vocab_size (embedding safety invariant).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant