@@ -360,12 +360,12 @@ MlasDequantizeBlockwise(
360360 );
361361
362362/* *
363- * @brief Blockwise 2 bits or 4 bits quantization. After quantization, the weights and zero points
364- * are packed row-wise. In terms of the qbits type, dst and src have the same shape, and
365- * scales and zero_points have the same shape .
366- * columns must be multiple of 8 / qbits .
363+ * @brief Blockwise 4 bits quantization. After quantization, the weights and zero points
364+ * are packed row-wise. If zero_points is null, quantized type is int4 with default
365+ * zero point 0, to align with DQ schema. Otherwise, quantized type is uint4 .
366+ * In int4/uint4, dst have the same shape as src, and zero_points have the same shape as scales .
367367 * @tparam Tin
368- * @tparam qbits number of bits used for quantization, 2 or 4
368+ * @tparam qbits number of bits used for quantization, only 4 is supported
369369 * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns]
370370 * @param scales points to the scales matrix, row major
371371 * @param zero_points points to the zero_points matrix, row major
@@ -376,9 +376,10 @@ MlasDequantizeBlockwise(
376376 * @param columns
377377 * @param quant_block_size number of elements in a quantize block
378378 * @param thread_pool
379+ * @return the quantized type is signed.
379380 */
380381template <typename Tin, int qbits>
381- void
382+ bool
382383MlasQDQQuantizeBlockwise (
383384 const Tin* src,
384385 Tin* scales,
@@ -395,8 +396,17 @@ MlasQDQQuantizeBlockwise(
395396 * @brief Transpose blockwise quantized tensors. The src tensors are row major. src weights and zero
396397 * points are packed row-wise. The dst tensors are column major. dst weights and zero points
397398 * are packed column-wise.
399+ * dst_weights and dst_zero_points are in uint4.
400+ * If src_weights is int4 and has src_zero_points, src_weights and src_zero_points are
401+ * converted to uint4 by adding 8.
402+ * If src_weights is int4 and no src_zero_points, src_weights is converted to uint4 by adding 8.
403+ * src_zero_points is 0 and dst_zero_points is 8.
404+ * If src_weights is uint4 and has src_zero_points, just transpose.
405+ * If src_weights is uint4 and no src_zero_points, caller must allocate dst_zero_points with
406+ * 0 values. Otherwise exception is thrown.
398407 * @tparam Tin
399- * @tparam qbits number of bits used for quantization, 2 or 4
408+ * @tparam qbits number of bits used for quantization, only 4 is supported
409+ * @tparam signed_quant true when quantized type is signed, false when quantized type is unsigned
400410 * @param src_weights points to the quantized matrix, row major, shape [rows, columns] in qbits type.
401411 * In uint8_t type, shape is [rows, columns * qbits / 8].
402412 * @param src_scales points to the scales matrix, row major
@@ -410,7 +420,7 @@ MlasQDQQuantizeBlockwise(
410420 * @param quant_block_size number of elements in a quantize block
411421 * @param thread_pool
412422 */
413- template <typename Tin, int qbits>
423+ template <typename Tin, int qbits, bool signed_quant >
414424void
415425MlasQDQTransposeBlockwiseQuantized (
416426 const uint8_t * src_weights,
0 commit comments