tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ > Struct Template Reference

tile_distribution_encoding_pattern_bq&lt; BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ &gt; Struct Template Reference#

Composable Kernel: ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ > Struct Template Reference
ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ > Struct Template Reference

#include <gemm_group_quant_utils.hpp>

Inheritance diagram for ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >:
ck_tile::tile_distribution_encoding_pattern

Static Public Member Functions

static CK_TILE_HOST_DEVICE constexpr auto make_2d_static_tile_distribution ()
 Creates a 2D tile distribution for BQ (B-matrix quantization scales).

Static Public Attributes

static constexpr index_t warp_size = get_warp_size()
static constexpr index_t num_warps = BlockSize / get_warp_size()
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{})
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{})
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{})
static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN)

Member Function Documentation

◆ make_2d_static_tile_distribution()

template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
CK_TILE_HOST_DEVICE constexpr auto ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >::make_2d_static_tile_distribution ( )
inlinestaticconstexpr

Creates a 2D tile distribution for BQ (B-matrix quantization scales).

This function determines the optimal thread distribution pattern for loading and applying quantization scales to the B matrix based on the quantization group size (XPerQ) relative to warp dimensions.

Three distinct distribution patterns are handled:

  1. Fine-grained quantization (XPerQ < WarpGemm::kN):
    • Multiple quantization groups exist within a single warp's N-dimension
    • Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
    • Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast
    • Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
  2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps):
    • Each warp handles exactly one quantization scale
    • Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN
    • Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
  3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps):
    • Quantization group spans multiple warps
    • All warps share the same scale value
    • Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
Returns
A static tile distribution encoding for the BQ scale tensor

Member Data Documentation

◆ KWarps

template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
index_t ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >::KWarps = BlockGemmShape::BlockWarps::at(number<2>{})
staticconstexpr

◆ MWarps

template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
index_t ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >::MWarps = BlockGemmShape::BlockWarps::at(number<0>{})
staticconstexpr

◆ NIterPerWarp

template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
index_t ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >::NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN)
staticconstexpr

◆ num_warps

template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
index_t ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >::num_warps = BlockSize / get_warp_size()
staticconstexpr

◆ NWarps

template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
index_t ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >::NWarps = BlockGemmShape::BlockWarps::at(number<1>{})
staticconstexpr

◆ warp_size

template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
index_t ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >::warp_size = get_warp_size()
staticconstexpr

The documentation for this struct was generated from the following file: