Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions open_mythos/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,7 @@ def forward(
k_rope = kv_raw[..., self.kv_lora_rank :] # (B, T, rope_dim)
# expand rope keys across heads and apply RoPE before caching so
# retrieved keys are already positionally encoded
k_rope = (
k_rope.unsqueeze(2)
.expand(B, T, self.n_heads, self.qk_rope_dim)
.contiguous()
)
k_rope = k_rope.unsqueeze(2).repeat(1, 1, self.n_heads, 1)
k_rope = apply_rope(k_rope, freqs_cis) # (B, T, H, rope_dim) ← cached

if kv_cache is not None:
Expand Down Expand Up @@ -517,14 +513,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# routed expert dispatch (token-level scatter)
out = torch.zeros_like(flat)
for i in range(self.topk):
expert_ids = topk_idx[:, i]
token_scores = topk_scores[:, i].unsqueeze(-1)
for eid in range(self.n_experts):
mask = expert_ids == eid
if not mask.any():
continue
out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask])
# More efficient dispatch
for i in range(self.topk):
expert_ids = topk_idx[:, i]
for eid in range(self.n_experts):
mask = expert_ids == eid
if mask.any():
out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask])
# Still O(n_experts * topk) but with fewer Python overheads

# shared experts always fire for every token
for shared in self.shared_experts:
Expand Down Expand Up @@ -821,7 +817,6 @@ def __init__(self, cfg: MythosConfig):
self.loop_dim = (
cfg.dim // 8
) # fraction of channels receiving loop-index embedding

def forward(
self,
h: torch.Tensor,
Expand Down