@@ -47,29 +47,27 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
4747
4848__global__ void compute_expert_offsets (
4949 const int32_t * __restrict__ problem_sizes1, int32_t * expert_offsets,
50- int32_t * atomic_buffer, const int num_experts, const int topk_length ) {
50+ int32_t * atomic_buffer, const int num_experts, const bool swap_ab ) {
5151 int32_t tot_offset = 0 ;
5252 expert_offsets[0 ] = 0 ;
5353 for (int i = 0 ; i < num_experts; ++i) {
5454 atomic_buffer[i] = tot_offset;
55- tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3 ]
56- : problem_sizes1[i * 3 + 1 ];
55+ tot_offset += swap_ab ? problem_sizes1[i * 3 + 1 ] : problem_sizes1[i * 3 ];
5756 expert_offsets[i + 1 ] = tot_offset;
5857 }
5958}
6059
6160__global__ void compute_expert_blockscale_offsets (
6261 const int32_t * __restrict__ problem_sizes1, int32_t * expert_offsets,
6362 int32_t * blockscale_offsets, int32_t * atomic_buffer, const int num_experts,
64- const int topk_length ) {
63+ const bool swap_ab ) {
6564 int32_t tot_offset = 0 ;
6665 int32_t tot_offset_round = 0 ;
6766 expert_offsets[0 ] = 0 ;
6867 blockscale_offsets[0 ] = 0 ;
6968 for (int i = 0 ; i < num_experts; ++i) {
70- int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
71- ? problem_sizes1[i * 3 ]
72- : problem_sizes1[i * 3 + 1 ];
69+ int32_t cur_offset =
70+ swap_ab ? problem_sizes1[i * 3 + 1 ] : problem_sizes1[i * 3 ];
7371 atomic_buffer[i] = tot_offset;
7472 tot_offset += cur_offset;
7573 expert_offsets[i + 1 ] = tot_offset;
@@ -119,15 +117,19 @@ void get_cutlass_moe_mm_data_caller(
119117
120118 int num_threads = min (THREADS_PER_EXPERT, topk_ids.numel ());
121119
122- if (topk_ids.numel () > SWAP_AB_THRESHOLD) {
123- compute_problem_sizes<false ><<<num_experts, num_threads, 0 , stream>>> (
120+ // Swap-AB should be disabled for FP4 path
121+ bool may_swap_ab = (!blockscale_offsets.has_value ()) &&
122+ (topk_ids.numel () <= SWAP_AB_THRESHOLD);
123+
124+ if (may_swap_ab) {
125+ compute_problem_sizes<true ><<<num_experts, num_threads, 0 , stream>>> (
124126 static_cast <const int32_t *>(topk_ids.data_ptr ()),
125127 static_cast <int32_t *>(problem_sizes1.data_ptr ()),
126128 static_cast <int32_t *>(problem_sizes2.data_ptr ()),
127129 static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n,
128130 k);
129131 } else {
130- compute_problem_sizes<true ><<<num_experts, num_threads, 0 , stream>>> (
132+ compute_problem_sizes<false ><<<num_experts, num_threads, 0 , stream>>> (
131133 static_cast <const int32_t *>(topk_ids.data_ptr ()),
132134 static_cast <int32_t *>(problem_sizes1.data_ptr ()),
133135 static_cast <int32_t *>(problem_sizes2.data_ptr ()),
@@ -136,18 +138,19 @@ void get_cutlass_moe_mm_data_caller(
136138 }
137139
138140 if (blockscale_offsets.has_value ()) {
141+ // fp4 path
139142 compute_expert_blockscale_offsets<<<1 , 1 , 0 , stream>>> (
140143 static_cast <const int32_t *>(problem_sizes1.data_ptr ()),
141144 static_cast <int32_t *>(expert_offsets.data_ptr ()),
142145 static_cast <int32_t *>(blockscale_offsets.value ().data_ptr ()),
143146 static_cast <int32_t *>(atomic_buffer.data_ptr ()), num_experts,
144- topk_ids. numel () );
147+ may_swap_ab );
145148 } else {
146149 compute_expert_offsets<<<1 , 1 , 0 , stream>>> (
147150 static_cast <const int32_t *>(problem_sizes1.data_ptr ()),
148151 static_cast <int32_t *>(expert_offsets.data_ptr ()),
149152 static_cast <int32_t *>(atomic_buffer.data_ptr ()), num_experts,
150- topk_ids. numel () );
153+ may_swap_ab );
151154 }
152155 compute_arg_sorts<<<num_experts, num_threads, 0 , stream>>> (
153156 static_cast <const int32_t *>(topk_ids.data_ptr ()),
0 commit comments