12template <
typename T,
typename =
void>
22template <
typename T,
typename =
void>
32template <
typename Derived>
35#if defined(__gfx950__)
36 template <
typename Problem>
39 template <
typename Problem>
43 template <
typename Problem>
45 template <
typename Problem>
60 return Derived::ATileAccessPattern;
68 return Derived::BTileAccessPattern;
73 template <
typename Problem>
78 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
79 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
89 return a_lds_block_desc_0;
95 constexpr auto DataTypeSize =
sizeof(ADataType);
96 constexpr auto MLdsLayer =
97 max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
101 number<MPerBlock / MLdsLayer>{},
110 number<KPerBlock / KPack * MLdsLayer>{})),
116 a_lds_block_desc_permuted,
125 a_lds_block_desc_xk0_mnldslayer_mn_xk1,
133 return a_lds_block_desc;
143 template <
typename Problem>
148 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
149 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
160 return b_lds_block_desc_0;
166 constexpr auto BK0 =
number<KPerBlock / KPack>{};
167 constexpr auto DataTypeSize =
sizeof(BDataType);
168 constexpr auto NLdsLayer =
169 max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
187 b_lds_block_desc_permuted,
195 b_lds_block_desc_bk0_nldslayer_n_bk1,
201 return b_lds_block_desc;
206 constexpr index_t BlockSize = Problem::kBlockSize;
208 using TileEncodingPattern =
218 constexpr auto N0 = TileEncodingPattern::X0;
219 constexpr auto N1 = NPerBlock / N0;
221 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
222 constexpr auto NPerXdl =
number<WarpTile::at(
I1)>{};
226 constexpr auto KThreadWrite = TileEncodingPattern::Y2;
227 constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
228 constexpr auto KThreadRead = 64 / NPerXdl;
229 constexpr auto K0PerThreadRead = BK0 / KThreadRead;
231 constexpr auto kfold =
232 (BK1 * N0 *
sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 *
sizeof(BDataType));
233 constexpr auto KThreadReadPerm =
234 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
235 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
239 constexpr auto npair = (BK1 * NPerXdl *
sizeof(BDataType) > 128)
241 : ((128 / (BK1 * NPerXdl *
sizeof(BDataType))) > N0
243 : 128 / (BK1 * NPerXdl *
sizeof(BDataType)));
249 number<kfold * N0 / npair>{},
268 b_lds_block_desc_permuted,
303 b_lds_block_desc_unmerged,
306 number<KThreadWrite / kfold / KThreadReadPerm>{},
316 return b_lds_block_desc_kn;
346 template <
typename Problem,
353 constexpr index_t BlockSize = IsWave32Host ? Problem::kBlockSize / 2 : Problem::kBlockSize;
354 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
355 constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
360 if constexpr(XPerTile % (PackedSize * 32 /
sizeof(DataType)) == 0 &&
361 elements_per_thread % (PackedSize * 32 /
sizeof(DataType)) == 0 &&
364 return (PackedSize * 32 /
sizeof(DataType));
366 else if constexpr(XPerTile % (PackedSize * 16 /
sizeof(DataType)) == 0 &&
367 elements_per_thread % (PackedSize * 16 /
sizeof(DataType)) == 0)
369 return (PackedSize * 16 /
sizeof(DataType));
371 else if constexpr(XPerTile % (PackedSize * 8 /
sizeof(DataType)) == 0 &&
372 elements_per_thread % (PackedSize * 8 /
sizeof(DataType)) == 0)
374 return (PackedSize * 8 /
sizeof(DataType));
376 else if constexpr(
sizeof(DataType) >= PackedSize * 4 &&
377 XPerTile % (PackedSize * 4 /
sizeof(DataType)) == 0 &&
378 elements_per_thread % (PackedSize * 4 /
sizeof(DataType)) == 0)
380 return (PackedSize * 4 /
sizeof(DataType));
382 else if constexpr(
sizeof(DataType) >= PackedSize * 2 &&
383 XPerTile % (PackedSize * 2 /
sizeof(DataType)) == 0 &&
384 elements_per_thread % (PackedSize * 2 /
sizeof(DataType)) == 0)
386 return (PackedSize * 2 /
sizeof(DataType));
394 template <
typename Problem,
bool IsWave32Host = false>
399 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
400 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
405 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
423 template <
typename Problem,
bool IsWave32Host = false>
428 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
429 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
434 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
464 template <
typename Problem>
468 using WG =
typename BlockGemm::WarpGemm;
470 constexpr bool TransposeC = Problem::TransposeC;
471 using CLayout =
typename Problem::CLayout;
472 using CWarpDstr =
typename WG::CWarpDstr;
475 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
477 if constexpr(TransposeC)
481 constexpr index_t NDimY = CWarpDstr::NDimY;
482 constexpr auto c_warp_y_lengths =
483 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
484 static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
491 return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
495 else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
497 if constexpr(TransposeC)
500 return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
506 constexpr index_t NDimY = CWarpDstr::NDimY;
507 constexpr auto c_warp_y_lengths =
508 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
509 static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
516 static_assert(
false,
"Unsupported CLayout!");
520 template <
typename Problem>
523 return Problem::TransposeC;
526 template <
typename Problem>
529 constexpr index_t BlockSize = Problem::kBlockSize;
530 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
531 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
532 constexpr index_t VecLoadSize =
534 constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
539 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
541 using TileEncodingPattern =
548 return TileEncodingPattern::make_2d_static_tile_distribution();
553 using TileEncodingPattern =
560 return TileEncodingPattern::make_2d_static_tile_distribution();
564 template <
typename Problem>
567 constexpr index_t BlockSize = Problem::kBlockSize;
568 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
569 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
570 constexpr index_t VecLoadSize =
572 constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
577 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
579 using TileEncodingPattern =
586 return TileEncodingPattern::make_2d_static_tile_distribution();
591 using TileEncodingPattern =
598 return TileEncodingPattern::make_2d_static_tile_distribution();
602 template <
typename Problem>
607 static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
608 constexpr index_t BlockSize = Problem::kBlockSize;
609 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
610 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
612 constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
620 return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
623 template <
typename Problem>
628 static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
629 constexpr index_t BlockSize = Problem::kBlockSize;
630 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
631 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
633 constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
641 return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
644 template <
typename Problem>
648 constexpr index_t KPack = BlockGemm::Traits::KPack;
652 template <
typename Problem>
656 constexpr index_t KPack = BlockGemm::Traits::KPack;
660 template <
typename Problem>
663 constexpr index_t smem_size_a =
665 Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK,
670 template <
typename Problem>
673 constexpr index_t smem_size_b =
675 Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
680 template <
typename Problem>
686 return smem_size_a + smem_size_b;
694 template <
typename Problem>
697 using BlockWarps =
typename Problem::BlockGemmShape::BlockWarps;
698 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
700 constexpr index_t vector_size =
703 constexpr auto wg_attr_num_access =
711 typename Problem::ComputeDataType,
712 typename Problem::CDataType,
718 Problem::UseStructuredSparsity,
722 typename Problem::BDataType,
723 typename Problem::CDataType,
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Invalid
Definition warp_gemm_attribute_mfma.hpp:17
@ Single
Definition warp_gemm_attribute_mfma.hpp:14
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
@ Quad
Definition warp_gemm_attribute_mfma.hpp:16
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:371
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
@ thread_raked
Thread raked pattern.
Definition static_encoding_pattern.hpp:94
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
constexpr int DS_READ_TR_SIZE()
Definition load_tile_transpose.hpp:20
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_universal_gemm_as_bs_cr.hpp:21
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:34
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledARegTileDistribution()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:603
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static CK_TILE_DEVICE constexpr index_t GetSmemSizeB()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:671
static constexpr auto getATileAccessPattern()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:57
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:645
static constexpr auto DefaultATileAccessPattern
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:54
static constexpr auto DefaultBTileAccessPattern
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:55
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackB()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:653
static CK_TILE_HOST_DEVICE constexpr auto MakeADramTileDistribution()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:527
static CK_TILE_DEVICE constexpr auto MakeBLdsBlockDescriptor()
Create LDS block descriptor for B tensor.
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:144
static constexpr bool is_a_load_tr
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:44
static constexpr bool is_b_load_tr
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:46
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeC()
Get the vector store size for C tensor.
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:465
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeB()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:424
static CK_TILE_HOST_DEVICE constexpr auto GetGlobalVectorLoadSize()
Get the maximum global memory vector load size.
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:351
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBRegTileDistribution()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:624
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto getBTileAccessPattern()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:65
static CK_TILE_DEVICE constexpr index_t GetSmemSizeA()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:661
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeA()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:395
static CK_TILE_HOST_DEVICE constexpr auto MakeBDramTileDistribution()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:565
static CK_TILE_HOST_DEVICE constexpr auto IsTransposeC()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:521
static CK_TILE_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:74
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
static CK_TILE_DEVICE constexpr index_t GetSmemSize()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:681
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:693
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:695
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:14
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:24
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition tile/ops/common/tensor_layout.hpp:22
Definition tile/ops/common/tensor_layout.hpp:17
Class creating 2D static tile distribution with different load/store patterns.
Definition static_encoding_pattern.hpp:130