diff --git a/moba/moba_efficient.py b/moba/moba_efficient.py index 2eec1ba..2767777 100644 --- a/moba/moba_efficient.py +++ b/moba/moba_efficient.py @@ -314,14 +314,14 @@ def moba_attn_varlen( moba_topk = min(moba_topk - 1, num_filtered_chunk) need_moba_attn = moba_topk > 0 + self_attn_cu_seqlen = cu_chunk + # corner case: if no moba attn needed, just return self attn if not need_moba_attn: return flash_attn_varlen_func( - q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=True + q, k, v, self_attn_cu_seqlen, self_attn_cu_seqlen, max_seqlen, max_seqlen, causal=True ) - self_attn_cu_seqlen = cu_chunk - # filtered_kv is a dense matrix that only contains filtered chunk of kv filtered_kv_indices = torch.arange( 0, moba_chunk_size, dtype=torch.int32, device=q.device