diff --git a/csrc/ops.h b/csrc/ops.h index f8bdc61aaa8e..928d434a37b6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -103,13 +103,16 @@ void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& output_mask, const torch::Tensor& repetition_penalties); -void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, - const torch::Tensor& rowEnds, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1); +void top_k_per_row_prefill(const torch::Tensor& logits, + const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK); void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, - const torch::Tensor& seq_lens, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1); + const torch::Tensor& seqLens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK); void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, diff --git a/csrc/sampler.cu b/csrc/sampler.cu index 410b8988f493..30d968027a69 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -44,41 +44,301 @@ __global__ void apply_repetition_penalties_kernel( } } -static inline __device__ uint16_t extractBinIdx(float x) { - union { - __half h; - uint16_t u16; - } tmp; - tmp.h = __float2half_rn(x); - tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000); - return 511 - (tmp.u16 >> 7); +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; } -template -__device__ void topKPerRowJob(const float* logits, const int rowStart, - const int rowEnd, const int rowIdx, - int* outIndices, int stride0, int stride1) { - // The number of elements per thread for the final top-k sort. - static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; - // The class to sort the elements during the final top-k sort. - using TopKSort = cub::BlockRadixSort; +template +static inline __device__ uint32_t extractBinIdx(float x) { + if constexpr (step == 0) { + __half hx = __float2half(x); + uint16_t bits = __half_as_ushort(hx); + bits = (bits & 0x8000) ? bits : ~bits & 0x7fff; + return bits >> 5; + } else { + return bits & 0x3ff; + uint32_t bits = __float_as_uint(x); + bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; + + if constexpr (step == 1) { + return bits >> 21; + } else if constexpr (step == 2) { + return (bits >> 10) & 0x7ff; + } else if constexpr (step == 3) { + return bits & 0x3ff; + } + } +} + +template +static inline __device__ bool isPartialMatch(float x, uint32_t pattern) { + if constexpr (shift == 0) { + return true; + } + uint32_t bits = __float_as_uint(x); + bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; + return (bits ^ pattern) >> shift == 0; +} + +/** + * Map a Func over the input data, using vectorized load instructions if + * possible. + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ +template +__device__ void vectorized_process(size_t thread_rank, size_t num_threads, + const T* in, idxT len, Func f) { + constexpr int WARP_SIZE = 32; + using WideT = float4; + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = thread_rank; i < len; i += num_threads) { + f(in[i], i); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + // TODO: it's UB + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / + sizeof(T)) + : 0; + if (skip_cnt > len) { + skip_cnt = len; + } + const WideT* in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + for (idxT i = thread_rank; i < len_cast; i += num_threads) { + wide.scalar = in_cast[i]; + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + // no need to use loop + if (thread_rank < skip_cnt) { + f(in[thread_rank], thread_rank); + } + // because len_cast = (len - skip_cnt) / items_per_scalar, + // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; + // and so + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= + // WARP_SIZE no need to use loop + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; + if (remain_i < len) { + f(in[remain_i], remain_i); + } + } +} + +template +__device__ bool processHistogramStep( + const int* indices, const float* logits, int rowEnd, uint32_t& logitPattern, + int& thresholdBinIdx, SmemOutputType& smemOutput, int* smemThresholdBinIdx, + int* smemFinalDstIdx, int* smemFinalBinSize, int* smemFoundTopKValues, + SmemFinalType& smemFinal, int stride1, int rowStart, int topK) { + // Clear the histogram. +#pragma unroll + for (int idx = threadIdx.x; idx < kNumBins; idx += kNumThreadsPerBlock) { + smemFinal.histo.data[idx] = 0; + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Update pattern + constexpr auto patternShift = step < 2 ? 0 : step == 2 ? 21 : 10; + if constexpr (step == 2) { + logitPattern = static_cast(thresholdBinIdx & 0x7ff) + << patternShift; + } else if constexpr (step == 3) { + logitPattern |= static_cast(thresholdBinIdx & 0x7ff) + << patternShift; + } + + auto distributeToBins = [&](float logit, int /* idx */ = 0) { + if (isPartialMatch(logit, logitPattern)) { + uint32_t binIdx = extractBinIdx(logit); + atomicAdd(&smemFinal.histo.data[binIdx], 1); + } + }; + // Distribute the elements to the histogram bins. + if (stride1 == 1) { + vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart, + rowEnd - rowStart, distributeToBins); + } else { + for (int idx = rowStart + threadIdx.x; idx < rowEnd; + idx += kNumThreadsPerBlock) { + float logit = logits[idx * stride1]; + distributeToBins(logit, idx); + } + } + // Make sure the histogram is ready. + __syncthreads(); + + // Reads the value of the starting position in the smemOutput array + int lastValue = smemFoundTopKValues[0]; + + for (int round = 0; round < kNumBins / kNumThreadsPerBlock; round++) { + // Read the values from SMEM. + int idx = threadIdx.x + kNumThreadsPerBlock * round; + int binCount{0}; + binCount = smemFinal.histo.data[idx]; + + // Make sure each thread has read its value. + __syncthreads(); + + // Compute the prefix sum. + int prefixSum{0}, totalSum{0}; + using Scan = cub::BlockScan; + Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum); + + // Update the histogram with the prefix sums. + prefixSum += lastValue; + totalSum += lastValue; + smemFinal.histo.data[idx] = prefixSum; + + // Make sure the data is in shared memory. + __syncthreads(); + + // Find the last valid bin. + bool foundThreshold = false; + if (prefixSum < topK) { + int nextPrefixSum = threadIdx.x == kNumThreadsPerBlock - 1 + ? totalSum + : smemFinal.histo.data[idx + 1]; + + if (nextPrefixSum >= topK) { + smemThresholdBinIdx[0] = idx; + smemFinalBinSize[0] = nextPrefixSum - prefixSum; + foundThreshold = true; + } + } + + // Early exit: if any thread found the threshold, we can skip remaining + // rounds + if (__syncthreads_or(foundThreshold)) { + break; + } + + lastValue = totalSum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The threshold bin. + thresholdBinIdx = smemThresholdBinIdx[0]; + + auto processBins = [&](float logit, int idx) { + if (isPartialMatch(logit, logitPattern)) { + uint32_t binIdx = extractBinIdx(logit); + if (binIdx < thresholdBinIdx) { + // The element is part of the top-k selection + int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1); + + if constexpr (mergeBlocks) { + smemOutput[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemOutput[dstIdx] = idx + rowStart; + reinterpret_cast(smemOutput + topK)[dstIdx] = logit; + } else { + smemOutput[dstIdx] = idx; + } + } + if constexpr (step < 3) { + // Only fill the final items for sorting if the threshold bin fits + if (binIdx == thresholdBinIdx && + smemFinalBinSize[0] <= kNumFinalItems) { + int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); + smemFinal.items.logits[dstIdx] = logit; + if constexpr (mergeBlocks) { + smemFinal.items.indices[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemFinal.items.indices[dstIdx] = idx + rowStart; + } else { + smemFinal.items.indices[dstIdx] = idx; + } + } + } else { + if (binIdx == thresholdBinIdx) { + // The elements in the threshold bin share the same 32 bits at step 3 + int dstIdx = atomicAdd(&smemFinal.histo.data[binIdx], 1); + if (dstIdx < topK) { + if constexpr (mergeBlocks) { + smemOutput[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemOutput[dstIdx] = idx + rowStart; + reinterpret_cast(smemOutput + topK)[dstIdx] = logit; + } else { + smemOutput[dstIdx] = idx; + } + } + } + } + } + }; + + if (stride1 == 1) { + vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart, + rowEnd - rowStart, processBins); + } else { + for (int idx = rowStart + threadIdx.x; idx < rowEnd; + idx += kNumThreadsPerBlock) { + float logit = logits[idx * stride1]; + processBins(logit, idx); + } + } + + // Make sure the elements are in shared memory. + __syncthreads(); + + // Check if we should continue to next step + return smemFinalBinSize[0] > kNumFinalItems; +} + +// Follows half - 11 - 11 - 10 bit iterations +template +static __device__ void topKPerRowJob(const int* indices, const float* logits, + int rowStart, int rowEnd, int* outIndices, + float* outLogits, int stride1, int topK) { // The number of slots for the final pass. - static constexpr int kNumFinalItems = 3072; + static constexpr int kNumFinalItems = 2048; // The number of elements per thread for the final sort. static constexpr int kNumFinalItemsPerThread = kNumFinalItems / kNumThreadsPerBlock; // The class to sort the elements during the final pass. using FinalSort = cub::BlockRadixSort; - + using FinalSortTempStorage = + std::conditional_t; // The class to compute the inclusive prefix-sum over the histogram. using Scan = cub::BlockScan; - // Shared memory to compute the block scan. - __shared__ typename Scan::TempStorage smemScan; - // The structure to store the final items (for the final pass). struct FinalItems { // Shared memory to store the indices for the final pass. @@ -87,200 +347,219 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart, float logits[kNumFinalItems]; }; + struct Histogram { + typename Scan::TempStorage scan; + int data[kNumBins]; + }; + // Shared memory to compute the block sort. __shared__ union { FinalItems items; - typename FinalSort::TempStorage finalSort; - typename TopKSort::TempStorage topKSort; + FinalSortTempStorage finalSort; + Histogram histo; } smemFinal; - // Shared memory to store the histogram. - __shared__ int smemHistogram[kNumBins]; // Shared memory to store the selected indices. - __shared__ int smemIndices[kTopK]; + // If we are processing using multiple blocks, we need to store the logits and + // indices. + extern __shared__ int32_t smemOutput[]; + // Shared memory to store the threshold bin. __shared__ int smemThresholdBinIdx[1]; // Shared memory counter to register the candidates for the final phase. __shared__ int smemFinalDstIdx[1]; + // Shared memory to determine if the threshold bin fits in the final items. + __shared__ int smemFinalBinSize[1]; + // Shared memory to keep track of the top-k values found so far by the + // previous iterations + __shared__ int smemFoundTopKValues[1]; // The length of the row. int rowLen = rowEnd - rowStart; // Shortcut if the length of the row is smaller than Top-K. Indices are not // sorted by their corresponding logit. - if (rowLen <= kTopK) { + if (rowLen <= topK) { for (int rowIt = threadIdx.x; rowIt < rowLen; rowIt += kNumThreadsPerBlock) { - int idx = rowStart + rowIt; - outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; + if constexpr (multipleBlocksPerRow) { + outIndices[rowIt] = rowIt + rowStart; + outLogits[rowIt] = logits[rowIt + rowStart]; + } else { + outIndices[rowIt] = rowIt; + } } - for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; + for (int rowIt = rowLen + threadIdx.x; rowIt < topK; rowIt += kNumThreadsPerBlock) { - outIndices[rowIdx * kTopK + rowIt] = -1; + outIndices[rowIt] = -1; + if constexpr (multipleBlocksPerRow) { + outLogits[rowIt] = -FLT_MAX; + } } - return; - } - - // Clear the histogram. - if (threadIdx.x < kNumBins) { - smemHistogram[threadIdx.x] = 0; - } - - // Make sure the histogram is ready. - __syncthreads(); - - // Fetch elements one-by-one. - for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; - rowIt += kNumThreadsPerBlock) { - uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]); - atomicAdd(&smemHistogram[idx], 1); - } - // Make sure the histogram is ready. - __syncthreads(); - - // Read the values from SMEM. - int binCount{0}; - if (threadIdx.x < kNumBins) { - binCount = smemHistogram[threadIdx.x]; - } - - // Make sure each thread has read its value. - __syncthreads(); - - // Compute the prefix sum. - int prefixSum{0}, totalSum{0}; - Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum); - - // Update the histogram with the prefix sums. - if (threadIdx.x < kNumBins) { - smemHistogram[threadIdx.x] = prefixSum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Find the last valid bin. - if (threadIdx.x < kNumBins) { - int nextPrefixSum = - threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1]; - if (prefixSum < kTopK && nextPrefixSum >= kTopK) { - smemThresholdBinIdx[0] = threadIdx.x; - } + return; } - - // Clear the counter to store the items for the final phase. + // Initialize values if (threadIdx.x == 0) { smemFinalDstIdx[0] = 0; + smemFoundTopKValues[0] = 0; } - - // Make sure the data is in shared memory. __syncthreads(); + int thresholdBinIdx = -1; + uint32_t logitPattern = 0; + + // Step 0: Process first 11 bits of half representation + bool continueToNextStep = + processHistogramStep<0, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + + if (continueToNextStep) { + // Step 1: Process next 11 bits + continueToNextStep = + processHistogramStep<1, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + } - // The threshold bin. - int thresholdBinIdx = smemThresholdBinIdx[0]; - - // Fetch elements one-by-one and populate the shared memory buffers. - for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; - rowIt += kNumThreadsPerBlock) { - float logit = logits[rowIdx * stride0 + rowIt * stride1]; - uint16_t idx = extractBinIdx(logit); - if (idx < thresholdBinIdx) { - int dstIdx = atomicAdd(&smemHistogram[idx], 1); - smemIndices[dstIdx] = rowIt; - } else if (idx == thresholdBinIdx) { - int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); - if (dstIdx < kNumFinalItems) { - smemFinal.items.logits[dstIdx] = logit; - smemFinal.items.indices[dstIdx] = rowIt; - } - } + if (continueToNextStep) { + // Step 2: Process next 11 bits + continueToNextStep = + processHistogramStep<2, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); } - // Make sure the elements are in shared memory. - __syncthreads(); + if (continueToNextStep) { + // Step 3: Process last 10 bits + processHistogramStep<3, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + } - // The logits of the elements to be sorted in the final pass. - float finalLogits[kNumFinalItemsPerThread]; - // The indices of the elements to be sorted in the final pass. - int finalIndices[kNumFinalItemsPerThread]; + if (!continueToNextStep) { + // The histogram did not proceed to the final 10 bits, therefore we need to + // sort the final items The logits of the elements to be sorted in the final + // pass. + if constexpr (useRadixSort) { + // Sorting with radix sort + float finalLogits[kNumFinalItemsPerThread]; + // The indices of the elements to be sorted in the final pass. + int finalIndices[kNumFinalItemsPerThread]; -// Init. #pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - finalLogits[ii] = -FLT_MAX; - } + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + finalLogits[ii] = -FLT_MAX; + } -// Read the elements from SMEM. + // Read the elements from SMEM. #pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; - if (srcIdx < smemFinalDstIdx[0]) { - finalLogits[ii] = smemFinal.items.logits[srcIdx]; - finalIndices[ii] = smemFinal.items.indices[srcIdx]; - } - } + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + if (srcIdx < smemFinalDstIdx[0]) { + finalLogits[ii] = smemFinal.items.logits[srcIdx]; + finalIndices[ii] = smemFinal.items.indices[srcIdx]; + } + } + // Make sure the shared memory has been read. + __syncthreads(); - // Make sure the shared memory has been read. - __syncthreads(); + // Sort the elements. + FinalSort(smemFinal.finalSort) + .SortDescendingBlockedToStriped(finalLogits, finalIndices); - // Sort the elements. - FinalSort(smemFinal.finalSort) - .SortDescendingBlockedToStriped(finalLogits, finalIndices); + // Copy the data back to the shared memory storage. + int baseIdx = smemFoundTopKValues[0]; - // Copy the data back to the shared memory storage. - int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0; #pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; - int dstIdx = baseIdx + srcIdx; - if (dstIdx < kTopK) { - smemIndices[dstIdx] = finalIndices[ii]; + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + int dstIdx = baseIdx + srcIdx; + + if (dstIdx < topK) { + smemOutput[dstIdx] = finalIndices[ii]; + if constexpr (multipleBlocksPerRow) { + reinterpret_cast(smemOutput + topK)[dstIdx] = + finalLogits[ii]; + } + } + } + } else { + // Sorting with insertion sort + auto baseIdx = smemFoundTopKValues[0]; + for (int i = threadIdx.x; i < smemFinalDstIdx[0]; + i += kNumThreadsPerBlock) { + int outIndex = 0; + auto logit = smemFinal.items.logits[i]; + for (int j = 0; j < smemFinalDstIdx[0]; j++) { + auto otherLogit = smemFinal.items.logits[j]; + if (logit < otherLogit || (logit == otherLogit && i < j)) { + outIndex++; + } + } + // Store if outIndex is in bounds + if (outIndex + baseIdx < topK) { + smemOutput[outIndex + baseIdx] = smemFinal.items.indices[i]; + if constexpr (multipleBlocksPerRow) { + reinterpret_cast(smemOutput + topK)[outIndex + baseIdx] = + smemFinal.items.logits[i]; + } + } + } } + __syncthreads(); } - // Make sure the data is in shared memory. - __syncthreads(); - -// Store to global memory. -#pragma unroll - for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { - int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; - outIndices[offset] = - smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart; + // Store to global memory. + for (int i = threadIdx.x; i < topK; i += kNumThreadsPerBlock) { + if constexpr (multipleBlocksPerRow) { + outIndices[i] = smemOutput[i]; + outLogits[i] = reinterpret_cast(smemOutput + topK)[i]; + } else { + outIndices[i] = smemOutput[i] - rowStart; + } } } -template -static __global__ void topKPerRow(const float* logits, const int* rowStarts, - const int* rowEnds, int* outIndices, - int stride0, int stride1) { +template +static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill( + const float* logits, const int* rowStarts, const int* rowEnds, + int* outIndices, int stride0, int stride1, const int topK, + const int offsetIndex) { // The number of bins in the histogram. - static constexpr int kNumBins = 512; - - // The top-k width. - static constexpr int kTopK = 2048; + static constexpr int kNumBins = 2048; // The row computed by this block. - int rowIdx = blockIdx.x; + int rowIdx = blockIdx.x + offsetIndex; // The range of logits within the row. int rowStart = rowStarts[rowIdx]; int rowEnd = rowEnds[rowIdx]; - topKPerRowJob( - logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); + // Local pointers to this block + outIndices += rowIdx * topK; + logits += rowIdx * stride0; + + topKPerRowJob( + nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); } -template -static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, - int* outIndices, int stride0, - int stride1, int next_n) { +template +static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( + const float* logits, const int* seqLens, int* outIndices, int stride0, + int stride1, const int topK, int next_n, float* outLogits = nullptr, + const int numBlocksToMerge = 0, const int* indices = nullptr) { // The number of bins in the histogram. - static constexpr int kNumBins = 512; - - // The top-k width. - static constexpr int kTopK = 2048; + static constexpr int kNumBins = 2048; // The row computed by this block. int rowIdx = blockIdx.x; @@ -290,8 +569,25 @@ static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, int seq_len = seqLens[rowIdx / next_n]; int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; - topKPerRowJob( - logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); + // Local pointers to this block + if constexpr (!multipleBlocksPerRow && !mergeBlocks) { + outIndices += rowIdx * topK; + } else if constexpr (multipleBlocksPerRow) { + const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 + rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 + rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; + outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; + outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; + } else if constexpr (mergeBlocks) { + rowEnd = numBlocksToMerge * topK; + indices += rowIdx * numBlocksToMerge * topK; + outIndices += rowIdx * topK; + } + logits += rowIdx * stride0; + + topKPerRowJob( + indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK); } } // namespace vllm @@ -339,28 +635,84 @@ void apply_repetition_penalties_( void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, const torch::Tensor& seqLens, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1) { - // Compute the results on the device. + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK) { + constexpr int kSortingAlgorithmThreshold = 12288; + constexpr int kSplitWorkThreshold = 200 * 1000; constexpr int kNumThreadsPerBlock = 512; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - vllm::topKPerRowDecode - <<>>( - logits.data_ptr(), seqLens.data_ptr(), - indices.data_ptr(), static_cast(stride0), - static_cast(stride1), static_cast(next_n)); + const auto numColumns = logits.size(1); + + if (numColumns < kSortingAlgorithmThreshold) { + // Use insertion sort + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n)); + } else if (numColumns < kSplitWorkThreshold) { + // From this threshold, use radix sort instead + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n)); + } else { + // Long sequences are run in two steps + constexpr auto multipleBlocksPerRowConfig = 10; + + const auto outIndicesAux = + torch::empty({numRows, multipleBlocksPerRowConfig, topK}, + torch::dtype(torch::kInt32).device(logits.device())); + const auto outLogitsAux = + torch::empty({numRows, multipleBlocksPerRowConfig, topK}, + torch::dtype(torch::kFloat).device(logits.device())); + + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + outIndicesAux.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n), outLogitsAux.data_ptr()); + + constexpr int kNumThreadsPerBlockMerge = 1024; + vllm::topKPerRowDecode + <<>>( + outLogitsAux.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), multipleBlocksPerRowConfig * topK, 1, + static_cast(topK), static_cast(next_n), nullptr, + multipleBlocksPerRowConfig, outIndicesAux.data_ptr()); + } } -void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, - const torch::Tensor& rowEnds, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1) { - // Compute the results on the device. +void top_k_per_row_prefill(const torch::Tensor& logits, + const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK) { + constexpr int kSortingAlgorithmThreshold = 12288; constexpr int kNumThreadsPerBlock = 512; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::topKPerRow - <<>>( - logits.data_ptr(), rowStarts.data_ptr(), - rowEnds.data_ptr(), indices.data_ptr(), - static_cast(stride0), static_cast(stride1)); + int numInsertionBlocks = + std::min(static_cast(numRows), kSortingAlgorithmThreshold); + vllm::topKPerRowPrefill + <<>>(logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + static_cast(stride0), static_cast(stride1), + static_cast(topK), 0); + + if (numRows > kSortingAlgorithmThreshold) { + int numRadixBlocks = numRows - kSortingAlgorithmThreshold; + vllm::topKPerRowPrefill + <<>>(logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + static_cast(stride0), static_cast(stride1), + static_cast(topK), kSortingAlgorithmThreshold); + } } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 5af74c2c2a6b..b6e34658f7c6 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -180,15 +180,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Optimized top-k per row operation ops.def( - "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " + "top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, " "Tensor! indices, int numRows, int stride0, " - "int stride1) -> ()"); - ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); + "int stride1, int topK) -> ()"); + ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill); ops.def( "top_k_per_row_decode(Tensor logits, int next_n, " - "Tensor seq_lens, Tensor! indices, int numRows, " - "int stride0, int stride1) -> ()"); + "Tensor seq_lens, Tensor! indices, " + "int numRows, int stride0, int stride1, int topK) -> ()"); ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); // Layernorm-quant diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index cadda27b49e9..3bf69389753e 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -9,23 +9,45 @@ # Test parameters NUM_ROWS = [1, 32, 2050] -TOP_K_VALUES = [2048] -BATCH_SIZE = [1, 2, 4, 2048, 4096] -NEXT_N = [1, 2, 4, 8] +TOP_K_VALUES = [2048, 3000] +BATCH_SIZE = [1, 2, 2048] +NEXT_N = [1, 8] +DATA_GENERATION = ["random", "10LSBits"] def create_random_logits( row_starts: torch.Tensor, row_ends: torch.Tensor, - vocab_size: int, dtype: torch.dtype, seed: int, + data_generation: str, ) -> torch.Tensor: """Create random logits tensor for testing.""" torch.manual_seed(seed) np.random.seed(seed) # Generate logits with some structure to make testing more meaningful - logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda") + if data_generation == "random": + logits = torch.randn( + row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda" + ) + elif data_generation == "10LSBits": + top_22_bits_mask = 0xFFFFFC00 + last_10_bits_mask = 0x000003FF + fixed_top_22_bits = 0x3F900000 + # Generate random bits for the last 10 bits + random_bottom_bits = torch.randint( + 0, + 2**10, + (row_starts.shape[0], max(row_ends)), + dtype=torch.int32, + device="cuda", + ) + # Combine: fixed top 22 bits with random last 10 bits + logits_bits = (fixed_top_22_bits & top_22_bits_mask) | ( + random_bottom_bits & last_10_bits_mask + ) + logits = logits_bits.view(dtype) + for i, end in enumerate(row_ends): logits[i, end:] = float("-inf") return logits @@ -113,13 +135,13 @@ def test_top_k_per_row( # Create test data vocab_size = 20000 row_starts, row_ends = create_row_boundaries(num_rows, vocab_size) - logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) + logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random") # Create output tensors indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") # Run CUDA implementation - torch.ops._C.top_k_per_row( + torch.ops._C.top_k_per_row_prefill( logits, row_starts, row_ends, @@ -127,6 +149,7 @@ def test_top_k_per_row( num_rows, logits.stride(0), logits.stride(1), + top_k, ) # Run reference implementation @@ -139,27 +162,23 @@ def test_top_k_per_row( # Compare results assert compare_top_k_results( logits, indices, torch_indices, row_starts, row_ends, top_k - ), "CUDA top_k_per_row results don't match torch.topk" + ), "CUDA top_k_per_row_prefill results don't match torch.topk" -@pytest.mark.parametrize("top_k", TOP_K_VALUES) -@pytest.mark.parametrize("batch_size", BATCH_SIZE) -@pytest.mark.parametrize("next_n", NEXT_N) -@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") -@torch.inference_mode() -def test_top_k_per_row_decode( +def _run_top_k_per_row_decode_test( top_k: int, batch_size: int, next_n: int, + vocab_size: int, + data_generation: str, ) -> None: """ - Test top_k_per_row with seq_lens tensor. + Helper function to run top_k_per_row_decode test with given parameters. """ torch.set_default_device("cuda:0") # Create test data num_rows = batch_size * next_n - vocab_size = 20000 seq_lens = torch.randint( vocab_size, (batch_size,), dtype=torch.int32, device="cuda" ) @@ -167,7 +186,9 @@ def test_top_k_per_row_decode( row_indices = torch.arange(num_rows, device="cuda") // next_n next_n_offset = torch.arange(num_rows, device="cuda") % next_n row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1 - logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) + logits = create_random_logits( + row_starts, row_ends, torch.float32, 42, data_generation + ) # Create output tensors indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") @@ -181,6 +202,7 @@ def test_top_k_per_row_decode( num_rows, logits.stride(0), logits.stride(1), + top_k, ) torch.cuda.synchronize() @@ -195,4 +217,41 @@ def test_top_k_per_row_decode( # Compare results assert compare_top_k_results( logits, indices, torch_indices, row_starts, row_ends, top_k - ), "CUDA top_k_per_row results don't match torch.topk" + ), "CUDA top_k_per_row_decode results don't match torch.topk" + + +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("next_n", NEXT_N) +@pytest.mark.parametrize("data_generation", DATA_GENERATION) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_top_k_per_row_decode( + top_k: int, + batch_size: int, + next_n: int, + data_generation: str, +) -> None: + """ + Test top_k_per_row with seq_lens tensor. + """ + vocab_size = 20000 + _run_top_k_per_row_decode_test( + top_k, batch_size, next_n, vocab_size, data_generation + ) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_top_k_per_row_decode_large_vocab_size() -> None: + """ + Test top_k_per_row_decode with large vocabulary size. + """ + top_k = 2048 + batch_size = 2 + next_n = 2 + vocab_size = 300000 + data_generation = "random" + _run_top_k_per_row_decode_test( + top_k, batch_size, next_n, vocab_size, data_generation + ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7cfd381592b4..b2a5d89c6488 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -661,11 +661,10 @@ def sparse_attn_indexer( chunk.cu_seqlen_ke, ) num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - torch.ops._C.top_k_per_row( + torch.ops._C.top_k_per_row_prefill( logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, @@ -673,6 +672,7 @@ def sparse_attn_indexer( num_rows, logits.stride(0), logits.stride(1), + topk_tokens, ) if has_decode: @@ -715,7 +715,6 @@ def sparse_attn_indexer( max_model_len=max_model_len, ) num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] torch.ops._C.top_k_per_row_decode( @@ -726,6 +725,7 @@ def sparse_attn_indexer( num_rows, logits.stride(0), logits.stride(1), + topk_tokens, ) if decode_metadata.requires_padding: # if padded, we need to unpack