26template <
typename GridwiseGemm,
27 bool HasMainKBlockLoop,
32#if CK_USE_LAUNCH_BOUNDS
38#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
39 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
41 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
43 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
45 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
47 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
48 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
57template <
typename GridwiseGemm,
58 bool HasMainKBlockLoop,
63#if CK_USE_LAUNCH_BOUNDS
69#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
72 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
74 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
77 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
79 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
80 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
81 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
82 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
197template <
typename ALayout,
202 typename AccDataType,
203 typename CShuffleDataType,
205 typename AElementwiseOperation,
206 typename BElementwiseOperation,
207 typename CElementwiseOperation,
219 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
220 typename ABlockTransferThreadClusterArrangeOrder,
221 typename ABlockTransferSrcAccessOrder,
222 index_t ABlockTransferSrcVectorDim,
223 index_t ABlockTransferSrcScalarPerVector,
224 index_t ABlockTransferDstScalarPerVector_AK1,
225 bool AThreadTransferSrcResetCoordinateAfterRun,
227 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
228 typename BBlockTransferThreadClusterArrangeOrder,
229 typename BBlockTransferSrcAccessOrder,
230 index_t BBlockTransferSrcVectorDim,
231 index_t BBlockTransferSrcScalarPerVector,
232 index_t BBlockTransferDstScalarPerVector_BK1,
233 bool BThreadTransferSrcResetCoordinateAfterRun,
235 index_t CShuffleMXdlPerWavePerShuffle,
236 index_t CShuffleNXdlPerWavePerShuffle,
237 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
238 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
241 typename ComputeTypeA = CDataType,
242 typename ComputeTypeB = ComputeTypeA,
243 bool PermuteA =
false,
244 bool PermuteB =
false,
245 bool DoElementwiseBeforeCShuffle =
false>
270 KPerBlock < 128 && MPerXdl == 16))
321 auto K_t = K_Batch * KPerBlock;
322 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
327 auto K_t = K_Batch * KPerBlock;
328 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
333 auto K_t = K_Batch * KPerBlock;
334 return (K + K_t - 1) / K_t * KPerBlock;
340 auto K_t = K_Batch * KReadVec;
341 return (K + K_t - 1) / K_t * KReadVec;
354 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
372 const auto a_grid_desc_mraw_kraw = [&]() {
385 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386 GemmSpec == GemmSpecialization::MNKPadding)
389 const auto a_grid_desc_m_k =
403 return a_grid_desc_ak0_m_ak1;
405 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406 GemmSpec == GemmSpecialization::MNPadding)
410 a_grid_desc_mraw_kraw,
416 return a_grid_desc_ak0_m_ak1;
418 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419 GemmSpec == GemmSpecialization::NKPadding)
423 a_grid_desc_mraw_kraw,
435 return a_grid_desc_ak0_m_ak1;
441 a_grid_desc_mraw_kraw,
447 return a_grid_desc_ak0_m_ak1;
454 const auto b_grid_desc_nraw_kraw = [&]() {
468 GemmSpec != GemmSpecialization::Default),
469 "pk_i4_t does not support padding");
471 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
472 GemmSpec == GemmSpecialization::MNKPadding)
475 const auto b_grid_desc_n_k =
489 return b_grid_desc_bk0_n_bk1;
491 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
492 GemmSpec == GemmSpecialization::MNPadding)
496 b_grid_desc_nraw_kraw,
502 return b_grid_desc_bk0_n_bk1;
504 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
505 GemmSpec == GemmSpecialization::MKPadding)
509 b_grid_desc_nraw_kraw,
521 return b_grid_desc_bk0_n_bk1;
525 if constexpr(!PermuteB)
529 b_grid_desc_nraw_kraw,
535 return b_grid_desc_bk0_n_bk1;
541 constexpr index_t BK01 = KPerBlock / BK1Value;
542 const index_t BK0_ = StrideB / BK1Value;
543 const index_t BK00 = BK0_ / BK01;
545 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
549 b_grid_desc_bk00_n_bk01_bk1_permute,
556 return b_grid_desc_bk0_n_bk1_permute;
561 template <
typename ABlockDesc_AK0_M_AK1>
562 __host__ __device__
static constexpr auto
565 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
570 template <
typename BBlockDesc_BK0_N_BK1>
571 __host__ __device__
static constexpr auto
574 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
579 __host__ __device__
static auto
582 const auto c_grid_desc_mraw_nraw = [&]() {
602 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
603 GemmSpec == GemmSpecialization::MNKPadding)
612 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
613 GemmSpec == GemmSpecialization::MKPadding)
617 c_grid_desc_mraw_nraw,
622 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
623 GemmSpec == GemmSpecialization::NKPadding)
627 c_grid_desc_mraw_nraw,
635 return c_grid_desc_mraw_nraw;
649 AElementwiseOperation a_element_op,
650 BElementwiseOperation b_element_op,
651 CElementwiseOperation c_element_op)
676 std::cout <<
"problem {"
685 <<
"KRead:" <<
KRead <<
", "
687 <<
"AK0:" <<
AK0 <<
", "
688 <<
"BK0:" <<
BK0 <<
", "
689 <<
"MBlock: " <<
MBlock <<
", "
690 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
718 const BDataType* p_b_grid_,
719 CDataType* p_c_grid_,
727 bool is_reduce_ =
false,
728 AElementwiseOperation
a_element_op = AElementwiseOperation{},
729 BElementwiseOperation
b_element_op = BElementwiseOperation{},
730 CElementwiseOperation
c_element_op = CElementwiseOperation{})
784 if constexpr(!PermuteB)
790 const int k0_offset = karg.
KRead * karg.
N;
821 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
822 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
823 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
838 constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
853 a_lds_block_desc_permuted,
861 a_lds_block_desc_ak0_mldslayer_m_ak1,
869 return a_lds_block_desc_ak0_m_ak1;
876 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
877 constexpr auto M1 = MPerBlock / M0;
879 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
880 constexpr auto K0PerThreadWrite =
AK0Number / KThreadWrite;
881 constexpr auto KThreadRead = WaveSize / MPerXdl;
882 constexpr auto K0PerThreadRead =
AK0Number / KThreadRead;
884 constexpr auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
886 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
887 constexpr auto KThreadReadPerm =
888 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
889 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
893 constexpr auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
895 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
897 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
903 Number<kfold * M0 / mpair>{},
922 a_lds_block_desc_permuted,
944 a_lds_block_desc_unmerged,
947 Number<KThreadWrite / kfold / KThreadReadPerm>{},
956 return a_lds_block_desc_ak0_m_ak1;
962 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
963 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
964 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
978 constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
993 b_lds_block_desc_permuted,
1001 b_lds_block_desc_bk0_nldslayer_n_bk1,
1009 return b_lds_block_desc_bk0_n_bk1;
1013 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
1014 constexpr auto N1 = NPerBlock / N0;
1016 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
1017 constexpr auto K0PerThreadWrite =
BK0Number / KThreadWrite;
1018 constexpr auto KThreadRead = WaveSize / NPerXdl;
1019 constexpr auto K0PerThreadRead =
BK0Number / KThreadRead;
1021 constexpr auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
1023 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
1024 constexpr auto KThreadReadPerm =
1025 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1026 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1030 constexpr auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
1032 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
1034 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
1040 Number<kfold * N0 / npair>{},
1059 b_lds_block_desc_permuted,
1081 b_lds_block_desc_unmerged,
1084 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1093 return b_lds_block_desc_bk0_n_bk1;
1099 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1100 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1102 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1109 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1127 ABlockTransferSrcScalarPerVector,
1128 BBlockTransferSrcScalarPerVector,
1148 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1151 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1154 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1157 constexpr auto c_block_size =
1158 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1161 b_block_space_size_aligned *
sizeof(BDataType) /
BPackedSize),
1162 c_block_size *
sizeof(CShuffleDataType));
1165 template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
1168 enum struct Arch :
bool
1170#if defined(__gfx950__)
1171 is_gfx950_build =
true,
1173 is_gfx950_build =
false,
1178 if constexpr(
static_cast<bool>(Arch::is_gfx950_build) ||
1190 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
1199 CGlobalMemoryDataOperation>();
1204 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1205 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1206 "Invalid tuning param!");
1214 if(!(karg.M % MPerBlock == 0))
1218 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1219 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1232 if(!(karg.N % NPerBlock == 0))
1236 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1237 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1250 auto K_t = karg.KBatch * KPerBlock;
1251 if(!(karg.K % K_t == 0))
1255 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1256 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1257 <<
", in function: " << __func__ << std::endl;
1265 auto K_t = karg.KBatch * KReadVec;
1267 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1275 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1279 std::cout <<
"Arg K (" << karg.K
1280 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1281 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1282 << __LINE__ <<
", in function: " << __func__ << std::endl;
1289 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1293 std::cout <<
"Arg M (" << karg.M
1294 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1295 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1296 << __LINE__ <<
", in function: " << __func__ << std::endl;
1304 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1308 std::cout <<
"Arg N (" << karg.N
1309 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1310 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1311 << __LINE__ <<
", in function: " << __func__ << std::endl;
1318 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1322 std::cout <<
"Arg K (" << karg.K
1323 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1324 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1325 << __LINE__ <<
", in function: " << __func__ << std::endl;
1333 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1337 std::cout <<
"Arg N (" << karg.N
1338 <<
") value is not a multiple of "
1339 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1340 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1341 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1349 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1353 std::cout <<
"Arg M (" << karg.M
1354 <<
") value is not a multiple of "
1355 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1356 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1357 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1369 if(!karg.IsReduceAdd())
1373 std::cout <<
" KBatch: " << karg.KBatch <<
" > 1 is not support yet" << __FILE__
1374 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1384 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1388 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1400 const index_t num_loop = K / KPerBlock;
1402 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1407 const index_t num_loop = K / KPerBlock;
1409 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1412 template <
typename CGr
idDesc>
1414 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1423 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1431 template <
typename AGridDesc_AK0_M_K1,
1432 typename BGridDesc_BK0_N_K1,
1433 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1434 bool HasMainKBlockLoop,
1437 __device__
static void Run(
const ADataType* p_a_grid,
1438 const BDataType* p_b_grid,
1439 CDataType* p_c_grid,
1441 const Problem& problem,
1442 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1443 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1444 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1445 c_grid_desc_mblock_mperblock_nblock_nperblock)
1448 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1450 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1452 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1455 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1457 const auto block_work_idx =
1460 if(!block_2_ctile_map.ValidCTileIndex(
1462 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1463 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1468 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1469 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1472 const index_t m_block_data_idx_on_grid =
1473 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1475 const index_t n_block_data_idx_on_grid =
1476 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1488 auto a_blockwise_copy =
1490 AElementwiseOperation,
1494 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1495 ABlockTransferThreadClusterArrangeOrder,
1498 decltype(a_grid_desc_ak0_m_ak1),
1499 decltype(a_block_desc_ak0_m_ak1),
1500 ABlockTransferSrcAccessOrder,
1502 ABlockTransferSrcVectorDim,
1504 ABlockTransferSrcScalarPerVector,
1505 ABlockTransferDstScalarPerVector_AK1,
1508 AThreadTransferSrcResetCoordinateAfterRun,
1510 BlockwiseGemmPipe::GlobalBufferNum>(
1511 a_grid_desc_ak0_m_ak1,
1513 problem.a_element_op_,
1514 a_block_desc_ak0_m_ak1,
1519 auto b_blockwise_copy =
1521 BElementwiseOperation,
1525 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1526 BBlockTransferThreadClusterArrangeOrder,
1529 decltype(b_grid_desc_bk0_n_bk1),
1530 decltype(b_block_desc_bk0_n_bk1),
1531 BBlockTransferSrcAccessOrder,
1533 BBlockTransferSrcVectorDim,
1535 BBlockTransferSrcScalarPerVector,
1536 BBlockTransferDstScalarPerVector_BK1,
1539 BThreadTransferSrcResetCoordinateAfterRun,
1541 BlockwiseGemmPipe::GlobalBufferNum>(
1542 b_grid_desc_bk0_n_bk1,
1544 problem.b_element_op_,
1545 b_block_desc_bk0_n_bk1,
1551 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1555 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1558 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
1561 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1567 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1569 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1571 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1572 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1576 a_block_desc_ak0_m_ak1,
1580 a_block_slice_copy_step,
1581 b_grid_desc_bk0_n_bk1,
1582 b_block_desc_bk0_n_bk1,
1586 b_block_slice_copy_step,
1588 num_k_block_main_loop);
1592 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1593 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1596 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1597 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1600 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1601 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1605 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1606 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1608 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1609 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1610 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1611 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1612 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1613 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1614 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1615 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1617 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1621 static_cast<CShuffleDataType*
>(p_shared),
1622 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1625 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1645 const auto c_thread_mtx_on_block =
1646 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1648 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1649 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1651 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1657 const auto m_thread_data_on_block_idx =
1658 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1661 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1667 const auto n_thread_data_on_block_idx =
1668 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1672 const auto& vpgr_to_lds_element_op = [&] {
1673 if constexpr(DoElementwiseBeforeCShuffle)
1675 return problem.c_element_op_;
1679 return pass_through;
1682 const auto& lds_to_global_element_op = [&] {
1683 if constexpr(!DoElementwiseBeforeCShuffle)
1685 return problem.c_element_op_;
1689 return pass_through;
1697 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1698 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1700 CElementwiseOperation,
1702 Sequence<CShuffleMXdlPerWavePerShuffle,
1703 CShuffleNXdlPerWavePerShuffle,
1715 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1718 m_thread_data_on_block_idx[
I1],
1719 n_thread_data_on_block_idx[
I1],
1720 m_thread_data_on_block_idx[
I2],
1721 m_thread_data_on_block_idx[
I3],
1722 m_thread_data_on_block_idx[
I4],
1723 n_thread_data_on_block_idx[
I2]),
1724 vpgr_to_lds_element_op()};
1730 CElementwiseOperation,
1732 CGlobalMemoryDataOperation,
1734 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1736 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1737 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1741 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1742 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1745 CShuffleBlockTransferScalarPerVector_NPerBlock,
1748 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1750 c_grid_desc_mblock_mperblock_nblock_nperblock,
1752 lds_to_global_element_op()};
1755 constexpr auto sfc_c_vgpr =
1758 Sequence<CShuffleMXdlPerWavePerShuffle,
1759 CShuffleNXdlPerWavePerShuffle,
1768 constexpr auto sfc_c_global =
1772 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1774 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1776 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1778 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1785 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1786 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1788 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1789 c_shuffle_block_buf);
1795 c_shuffle_block_copy_lds_to_global.Run(
1796 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1797 c_shuffle_block_buf,
1798 c_grid_desc_mblock_mperblock_nblock_nperblock,
1801 if constexpr(access_id < num_access - 1)
1803 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1806 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1807 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1813 template <
bool HasMainKBlockLoop,
1816 __device__
static void Run(
const ADataType* p_a_grid,
1817 const BDataType* p_b_grid,
1818 CDataType* p_c_grid,
1820 const Problem& problem)
1823 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1825 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1827 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1828 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1830 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1832 Run<
decltype(a_grid_desc_ak0_m_ak1),
1833 decltype(b_grid_desc_bk0_n_bk1),
1834 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1836 CGlobalMemoryDataOperation,
1842 a_grid_desc_ak0_m_ak1,
1843 b_grid_desc_bk0_n_bk1,
1844 c_grid_desc_mblock_mperblock_nblock_nperblock);
1847 template <
typename AGridDesc_AK0_M_K1,
1848 typename BGridDesc_BK0_N_K1,
1849 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1850 bool HasMainKBlockLoop,
1853 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1854 const BDataType* p_b_grid,
1855 CDataType* p_c_grid,
1858 const Problem& problem,
1859 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1860 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1861 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1862 c_grid_desc_mblock_mperblock_nblock_nperblock)
1865 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1867 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1869 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1872 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1874 const auto block_work_idx =
1877 if(!block_2_ctile_map.ValidCTileIndex(
1879 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1880 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1885 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1886 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1889 const index_t m_block_data_idx_on_grid =
1890 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1892 const index_t n_block_data_idx_on_grid =
1893 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1905 auto a_blockwise_copy =
1907 AElementwiseOperation,
1911 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1912 ABlockTransferThreadClusterArrangeOrder,
1915 decltype(a_grid_desc_ak0_m_ak1),
1916 decltype(a_block_desc_ak0_m_ak1),
1917 ABlockTransferSrcAccessOrder,
1919 ABlockTransferSrcVectorDim,
1921 ABlockTransferSrcScalarPerVector,
1922 ABlockTransferDstScalarPerVector_AK1,
1925 AThreadTransferSrcResetCoordinateAfterRun,
1927 BlockwiseGemmPipe::GlobalBufferNum>(
1928 a_grid_desc_ak0_m_ak1,
1930 problem.a_element_op_,
1931 a_block_desc_ak0_m_ak1,
1936 auto b_blockwise_copy =
1938 BElementwiseOperation,
1942 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1943 BBlockTransferThreadClusterArrangeOrder,
1946 decltype(b_grid_desc_bk0_n_bk1),
1947 decltype(b_block_desc_bk0_n_bk1),
1948 BBlockTransferSrcAccessOrder,
1950 BBlockTransferSrcVectorDim,
1952 BBlockTransferSrcScalarPerVector,
1953 BBlockTransferDstScalarPerVector_BK1,
1956 BThreadTransferSrcResetCoordinateAfterRun,
1958 BlockwiseGemmPipe::GlobalBufferNum>(
1959 b_grid_desc_bk0_n_bk1,
1961 problem.b_element_op_,
1962 b_block_desc_bk0_n_bk1,
1968 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1971 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1975 a_block_space_size_aligned *
sizeof(ADataType)),
1976 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1979 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1983 a_block_space_size_aligned *
sizeof(ADataType)),
1984 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1986 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
1987 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1993 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1995 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1997 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1998 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2002 a_block_desc_ak0_m_ak1,
2006 a_block_slice_copy_step,
2007 b_grid_desc_bk0_n_bk1,
2008 b_block_desc_bk0_n_bk1,
2012 b_block_slice_copy_step,
2014 num_k_block_main_loop);
2018 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2019 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2022 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2023 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2026 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2027 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2031 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2032 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2034 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2035 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2036 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2037 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2038 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2039 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2040 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2041 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2043 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2047 static_cast<CShuffleDataType*
>(p_shared_0),
2048 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2051 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2071 const auto c_thread_mtx_on_block =
2072 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2074 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2075 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2077 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2083 const auto m_thread_data_on_block_idx =
2084 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2087 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2093 const auto n_thread_data_on_block_idx =
2094 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2098 auto c_thread_copy_vgpr_to_lds =
2101 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2102 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2104 Sequence<CShuffleMXdlPerWavePerShuffle,
2105 CShuffleNXdlPerWavePerShuffle,
2118 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2121 m_thread_data_on_block_idx[
I1],
2122 n_thread_data_on_block_idx[
I1],
2123 m_thread_data_on_block_idx[
I2],
2124 m_thread_data_on_block_idx[
I3],
2125 m_thread_data_on_block_idx[
I4],
2126 n_thread_data_on_block_idx[
I2]),
2132 CElementwiseOperation,
2133 CGlobalMemoryDataOperation,
2135 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2137 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2138 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2142 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2143 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2146 CShuffleBlockTransferScalarPerVector_NPerBlock,
2149 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2151 c_grid_desc_mblock_mperblock_nblock_nperblock,
2153 problem.c_element_op_};
2156 constexpr auto sfc_c_vgpr =
2159 Sequence<CShuffleMXdlPerWavePerShuffle,
2160 CShuffleNXdlPerWavePerShuffle,
2169 constexpr auto sfc_c_global =
2173 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2175 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2177 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2179 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2186 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2187 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2189 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2190 c_shuffle_block_buf);
2196 c_shuffle_block_copy_lds_to_global.Run(
2197 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2198 c_shuffle_block_buf,
2199 c_grid_desc_mblock_mperblock_nblock_nperblock,
2202 if constexpr(access_id < num_access - 1)
2204 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2207 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2208 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2214 template <
bool HasMainKBlockLoop,
2217 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2218 const BDataType* p_b_grid,
2219 CDataType* p_c_grid,
2222 const Problem& problem)
2225 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2227 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2229 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2231 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2233 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2235 Run_2Lds<
decltype(a_grid_desc_ak0_m_ak1),
2236 decltype(b_grid_desc_bk0_n_bk1),
2237 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2239 CGlobalMemoryDataOperation,
2246 a_grid_desc_ak0_m_ak1,
2247 b_grid_desc_bk0_n_bk1,
2248 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:717
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:759
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:748
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:753
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:758
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:761
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:697
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:706
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:704
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:709
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:673
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:814
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:815
__device__ SplitKBatchOffset(Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:767
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:816
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBGridDescriptor_BK0_N_BK1 __host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:451
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKRead static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:337
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::is_scale_mfma static constexpr auto is_scale_mfma
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:273
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeGemmMmaTileDescriptor __host__ static __device__ constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:355
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKPadded static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:314
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateMPadded static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:304
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::BK1Number static constexpr auto BK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:261
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::APackedSize static constexpr index_t APackedSize
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:285
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::is_single_rate_mfma static constexpr bool is_single_rate_mfma
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:264
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetSharedMemoryNumberOfByte static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1138
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::BlockwiseGemmPipe remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, BlkGemmPipeSched, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1112
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1 static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:819
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::ThisThreadBlock ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:283
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateAK0Padded static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:319
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I2 static constexpr auto I2
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:250
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::KPack static constexpr index_t KPack
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:274
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::lcm_AK1_BK1 static constexpr auto lcm_AK1_BK1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:263
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::IsValidCompilationParameter static __device__ bool constexpr IsValidCompilationParameter()
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1166
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I7 static constexpr auto I7
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:255
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKBlockLoopTailNum static __host__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1405
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateHasMainKBlockLoop static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1398
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I5 static constexpr auto I5
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:253
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::AK1Number static constexpr auto AK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:260
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1097
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run_2Lds static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1853
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:299
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateMBlock static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:344
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAMmaTileDescriptor_M0_M1_M2_K __host__ static __device__ constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:563
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1413
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run_2Lds static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:2217
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateNPadded static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:309
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1 static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:960
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateBK0Padded static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:325
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::BPackedSize static constexpr index_t BPackedSize
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:292
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1437
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBMmaTileDescriptor_N0_N1_N2_K __host__ static __device__ constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:572
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I6 static constexpr auto I6
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:254
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I1 static constexpr auto I1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:249
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I0 static constexpr auto I0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:248
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I3 static constexpr auto I3
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:251
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1202
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1816
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::Block2CTileMap BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1428
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I4 static constexpr auto I4
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:252
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKPadded static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:331
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAGridDescriptor_AK0_M_AK1 __host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:369
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateNBlock static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:349
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::BK0Number static constexpr auto BK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:259
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::AK0Number static constexpr auto AK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:258
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeCGridDescriptor_M_N __host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:580
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129