gridwise_gemm_wmma_cshuffle_v3.hpp Source File

gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/utility/env.hpp"
19
20namespace ck {
21
28// operations that could be applied on each tensor respectively. The CDE_op is an
29// elementwise operation applied to the C and all D tensors.
129template <typename ALayout,
130 typename BLayout,
131 typename DsLayout,
132 typename ELayout,
133 typename AsDataType,
134 typename BsDataType,
135 typename AccDataType,
136 typename CShuffleDataType,
137 typename DsDataType,
138 typename EDataType,
139 typename AElementwiseOperation,
140 typename BElementwiseOperation,
141 typename CDEElementwiseOperation,
143 index_t BlockSize,
144 index_t MPerBlock,
145 index_t NPerBlock,
146 index_t KPerBlock,
147 index_t AK1Value,
148 index_t BK1Value,
149 index_t MPerWmma,
150 index_t NPerWmma,
151 index_t MRepeat,
152 index_t NRepeat,
153 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154 typename ABlockTransferThreadClusterArrangeOrder,
155 typename ABlockTransferSrcAccessOrder,
156 index_t ABlockTransferSrcVectorDim,
157 index_t ABlockTransferSrcScalarPerVector,
158 index_t ABlockTransferDstScalarPerVector_AK1,
159 bool AThreadTransferSrcResetCoordinateAfterRun,
160 index_t ABlockLdsExtraM,
161 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
162 typename BBlockTransferThreadClusterArrangeOrder,
163 typename BBlockTransferSrcAccessOrder,
164 index_t BBlockTransferSrcVectorDim,
165 index_t BBlockTransferSrcScalarPerVector,
166 index_t BBlockTransferDstScalarPerVector_BK1,
167 bool BThreadTransferSrcResetCoordinateAfterRun,
168 index_t BBlockLdsExtraN,
169 index_t CShuffleMRepeatPerShuffle,
170 index_t CShuffleNRepeatPerShuffle,
171 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172 typename CDEShuffleBlockTransferScalarPerVectors,
173 BlockGemmPipelineScheduler BlkGemmPipeSched,
174 BlockGemmPipelineVersion BlkGemmPipelineVer,
175 typename ComputeTypeA,
176 typename ComputeTypeB,
177 bool PermuteA,
178 bool PermuteB,
179 bool ForceThreadTileTransfer = false>
182 ALayout,
183 BLayout,
184 DsLayout,
185 ELayout,
186 AsDataType,
187 BsDataType,
188 AccDataType,
189 CShuffleDataType,
190 DsDataType,
191 EDataType,
192 AElementwiseOperation,
193 BElementwiseOperation,
194 CDEElementwiseOperation,
195 GemmSpec,
196 BlockSize,
197 MPerBlock,
198 NPerBlock,
199 KPerBlock,
200 AK1Value,
201 BK1Value,
202 MPerWmma,
203 NPerWmma,
204 MRepeat,
205 NRepeat,
206 ABlockTransferThreadClusterLengths_AK0_M_AK1,
207 ABlockTransferThreadClusterArrangeOrder,
208 ABlockTransferSrcAccessOrder,
209 ABlockTransferSrcVectorDim,
210 ABlockTransferSrcScalarPerVector,
211 ABlockTransferDstScalarPerVector_AK1,
212 AThreadTransferSrcResetCoordinateAfterRun,
213 ABlockLdsExtraM,
214 BBlockTransferThreadClusterLengths_BK0_N_BK1,
215 BBlockTransferThreadClusterArrangeOrder,
216 BBlockTransferSrcAccessOrder,
217 BBlockTransferSrcVectorDim,
218 BBlockTransferSrcScalarPerVector,
219 BBlockTransferDstScalarPerVector_BK1,
220 BThreadTransferSrcResetCoordinateAfterRun,
221 BBlockLdsExtraN,
222 CShuffleMRepeatPerShuffle,
223 CShuffleNRepeatPerShuffle,
224 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
225 CDEShuffleBlockTransferScalarPerVectors,
226 BlkGemmPipeSched,
227 BlkGemmPipelineVer,
228 ComputeTypeA,
229 ComputeTypeB,
230 PermuteA,
231 PermuteB,
232 ForceThreadTileTransfer>
233{
235 ALayout,
236 BLayout,
237 DsLayout,
238 ELayout,
239 AsDataType,
240 BsDataType,
241 AccDataType,
242 CShuffleDataType,
243 DsDataType,
244 EDataType,
245 AElementwiseOperation,
246 BElementwiseOperation,
247 CDEElementwiseOperation,
248 GemmSpec,
249 BlockSize,
250 MPerBlock,
251 NPerBlock,
252 KPerBlock,
253 AK1Value,
254 BK1Value,
255 MPerWmma,
256 NPerWmma,
257 MRepeat,
258 NRepeat,
259 ABlockTransferThreadClusterLengths_AK0_M_AK1,
260 ABlockTransferThreadClusterArrangeOrder,
261 ABlockTransferSrcAccessOrder,
262 ABlockTransferSrcVectorDim,
263 ABlockTransferSrcScalarPerVector,
264 ABlockTransferDstScalarPerVector_AK1,
265 AThreadTransferSrcResetCoordinateAfterRun,
266 ABlockLdsExtraM,
267 BBlockTransferThreadClusterLengths_BK0_N_BK1,
268 BBlockTransferThreadClusterArrangeOrder,
269 BBlockTransferSrcAccessOrder,
270 BBlockTransferSrcVectorDim,
271 BBlockTransferSrcScalarPerVector,
272 BBlockTransferDstScalarPerVector_BK1,
273 BThreadTransferSrcResetCoordinateAfterRun,
274 BBlockLdsExtraN,
275 CShuffleMRepeatPerShuffle,
276 CShuffleNRepeatPerShuffle,
277 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
278 CDEShuffleBlockTransferScalarPerVectors,
279 BlkGemmPipeSched,
280 BlkGemmPipelineVer,
281 ComputeTypeA,
282 ComputeTypeB,
283 PermuteA,
284 PermuteB,
285 ForceThreadTileTransfer>;
286
287 using Base::I0;
288 using Base::I1;
289 using Base::I2;
290 using Base::I3;
291 using Base::I4;
292 using Base::I5;
293 using Base::I6;
294 using Base::I7;
295
296 using Base::AK0Number;
297 using Base::AK1Number;
298 using Base::BK0Number;
299 using Base::BK1Number;
300
301 using Base::APackedSize;
302 using Base::BPackedSize;
303
317
319
321
322 using Base::NumATensor;
323 using Base::NumBTensor;
324 using Base::NumDTensor;
325 using typename Base::AsGridPointer;
326 using typename Base::BsGridPointer;
327 using typename Base::DsGridPointer;
328 using AsDataType_ = AsDataType;
329 using BsDataType_ = BsDataType;
330
331 struct Problem
332 {
333 __host__ Problem(index_t M_,
334 index_t N_,
335 index_t K_,
336 std::array<index_t, NumATensor> StrideAs_,
337 std::array<index_t, NumBTensor> StrideBs_,
338 std::array<index_t, NumDTensor> StrideDs_,
339 index_t StrideE_,
340 index_t KBatch_)
341 : M{M_},
342 N{N_},
343 K{K_},
344 StrideAs{StrideAs_},
345 StrideBs{StrideBs_},
346 StrideDs{StrideDs_},
347 StrideE{StrideE_},
348 KBatch{KBatch_},
351 KRead{CalculateKRead(K_, KBatch_)},
352 KPadded{CalculateKPadded(K_, KBatch_)},
353 AK0{CalculateAK0Padded(K_, KBatch_)},
354 BK0{CalculateBK0Padded(K_, KBatch_)},
357 {
358 }
359
360 __host__ void Print() const
361 {
362 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
363 << "SAs: {";
364 static_for<0, NumATensor, 1>{}([&](auto i) {
365 std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
366 });
367 std::cout << "}, " << "SBs: {";
368 static_for<0, NumBTensor, 1>{}([&](auto i) {
369 std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
370 });
371 std::cout << "}, ";
372 if constexpr(NumDTensor > 0)
373 {
374 std::cout << "SDs: { ";
375 static_for<0, NumDTensor, 1>{}([&](auto i) {
376 std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
377 });
378 std::cout << " }, ";
379 }
380 std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
381 << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
382 << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
383 << ", " << "NBlock: " << NBlock << "}" << std::endl;
384 }
385
389 std::array<index_t, NumATensor> StrideAs;
390 std::array<index_t, NumBTensor> StrideBs;
391 std::array<index_t, NumDTensor> StrideDs;
402 };
403
404 // Argument
406 {
407 __host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
408 std::array<const void*, NumBTensor> p_bs_grid_,
409 std::array<const void*, NumDTensor> p_ds_grid_,
410 EDataType* p_e_grid_,
411 index_t M_,
412 index_t N_,
413 index_t K_,
414 std::array<index_t, NumATensor> StrideAs_,
415 std::array<index_t, NumBTensor> StrideBs_,
416 std::array<index_t, NumDTensor> StrideDs_,
417 index_t StrideE_,
418 index_t k_batch_,
419 AElementwiseOperation a_element_op_,
420 BElementwiseOperation b_element_op_,
421 CDEElementwiseOperation cde_element_op_,
422 bool is_reduce_ = false)
423 : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
424 p_as_grid{},
425 p_bs_grid{},
426 p_ds_grid{},
427 p_e_grid{p_e_grid_},
428 a_element_op{a_element_op_},
429 b_element_op{b_element_op_},
430 cde_element_op{cde_element_op_},
431 is_reduce(is_reduce_)
432 {
433 // populate pointer, desc for As
434 static_for<0, NumATensor, 1>{}([&](auto i) {
435 using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
436
437 // A pointer
438 p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
439 });
440
441 // populate pointer, desc for Bs
442 static_for<0, NumBTensor, 1>{}([&](auto i) {
443 using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
444
445 // B pointer
446 p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
447 });
448
449 // populate pointer, desc for Ds
450 static_for<0, NumDTensor, 1>{}([&](auto i) {
451 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
452
453 // D pointer
454 p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
455 });
456 }
457
458 __host__ __device__ inline bool IsReduceAdd() const
459 {
460 return (Problem::KBatch > 1) && is_reduce;
461 }
462
463 __host__ __device__ inline bool IsAtomicAdd() const
464 {
465 return (Problem::KBatch > 1) && (!is_reduce);
466 }
467
471 EDataType* p_e_grid;
472
473 const AElementwiseOperation a_element_op;
474 const BElementwiseOperation b_element_op;
475 const CDEElementwiseOperation cde_element_op;
476
477 // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
479 };
480
482 {
483
484 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
485 {
486 // Note: in xdl implementation multiple AB supports one layout
487 // but multiple strides, so we create an array of offsets with
488 // the same values.
489 // It should be fixed later on. Once we will have a thread transfer
490 // more flexible.
492 {
494 [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
495 }
497 {
499 [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
500 }
501
503 {
505 [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
506 }
508 {
509 if constexpr(!PermuteB)
510 {
512 [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
513 }
514 else
515 {
516 const int k0_offset = karg.KRead * karg.N;
518 [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
519 }
520 }
521
522 if(k_id < karg.KBatch - 1)
523 {
524 karg.K = karg.KRead;
525 }
526 else
527 {
528 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
529 }
530
531 if(karg.IsReduceAdd())
532 {
533 c_reduce_offset = k_id * karg.M * karg.N;
534 }
535 else
536 {
537 c_reduce_offset = 0;
538 }
539 }
540
541 std::array<index_t, NumATensor> a_k_split_offset;
542 std::array<index_t, NumBTensor> b_k_split_offset;
544 };
545
547
548 // return block_id to C matrix tile idx (m0, n0) mapping
549 // if arch = gfx942
551 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
552
553 __device__ static index_t GetKBlockPerScale() { return 1; }
554
555 template <bool HasMainKBlockLoop,
556 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
557 TailNumber TailNum,
558 typename EpilogueArgument>
559 __device__ static void Run(AsGridPointer& p_as_grid,
560 BsGridPointer& p_bs_grid,
561 DsGridPointer& p_ds_grid,
562 EDataType* p_e_grid,
563 void* p_shared,
564 const Problem& problem,
565 AElementwiseOperation a_element_op,
566 BElementwiseOperation b_element_op,
567 CDEElementwiseOperation cde_element_op,
568 EpilogueArgument& epilogue_args)
569 {
570 const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
571 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
572 const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
573 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
574 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
575 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
576 const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
577 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
578 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
580 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
581 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
583 e_grid_desc_m_n, problem.MBlock, problem.NBlock);
584
585 // divide block work by [M, N]
586 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
587
588 const auto block_work_idx =
589 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
590
591 if(!block_2_ctile_map.ValidCTileIndex(
592 block_work_idx,
593 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
594 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
595 {
596 return;
597 }
598
599 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
600 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
601
602 // BScale struct (Empty)
603 using BScale = typename BlockwiseGemmPipe::Empty;
604 auto b_scale_struct = BScale{};
605
606 const index_t num_k_block_per_scale = GetKBlockPerScale();
607
608 Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
609 decltype(bs_grid_desc_bk0_n_bk1),
610 decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
611 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
612 decltype(b_scale_struct),
613 decltype(epilogue_args),
614 HasMainKBlockLoop,
615 EGlobalMemoryDataOperation,
616 TailNum>(p_as_grid,
617 p_bs_grid,
618 p_ds_grid,
619 p_e_grid,
620 p_shared,
621 as_grid_desc_ak0_m_ak1,
622 bs_grid_desc_bk0_n_bk1,
623 ds_grid_desc_mblock_mperblock_nblock_nperblock,
624 e_grid_desc_mblock_mperblock_nblock_nperblock,
625 a_element_op,
626 b_element_op,
627 cde_element_op,
628 block_m_id,
629 block_n_id,
630 num_k_block_per_scale,
631 b_scale_struct,
632 epilogue_args);
633 }
634
635 // Wrapper function to have __global__ function in common
636 // between gemm_universal, b_scale, ab_scale, etc.
637 template <bool HasMainKBlockLoop,
638 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
639 TailNumber TailNum,
640 typename EpilogueArgument>
641 __device__ static void Run(void* p_shared,
642 const SplitKBatchOffset& splitk_batch_offset,
643 Argument& karg,
644 EpilogueArgument& epilogue_args)
645 {
646 // shift A matrices pointer for splitk
647 AsGridPointer p_as_grid_splitk;
648 static_for<0, NumATensor, 1>{}([&](auto i) {
649 using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
650 p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
651 splitk_batch_offset.a_k_split_offset[i];
652 });
653
654 // shift B matrices pointer for splitk
655 BsGridPointer p_bs_grid_splitk;
656 static_for<0, NumBTensor, 1>{}([&](auto i) {
657 using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
658 p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
659 splitk_batch_offset.b_k_split_offset[i];
660 });
661
663 p_as_grid_splitk,
664 p_bs_grid_splitk,
665 karg.p_ds_grid,
666 karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
667 p_shared,
668 karg,
669 karg.a_element_op,
670 karg.b_element_op,
671 karg.cde_element_op,
672 epilogue_args);
673 }
674};
675
676} // namespace ck
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
Definition block_to_ctile_map.hpp:271
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:406
const BElementwiseOperation b_element_op
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:474
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:407
const AElementwiseOperation a_element_op
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:473
bool is_reduce
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:478
EDataType * p_e_grid
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:471
const CDEElementwiseOperation cde_element_op
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:475
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:463
BsGridPointer p_bs_grid
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:469
AsGridPointer p_as_grid
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:468
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:458
DsGridPointer p_ds_grid
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:470
std::array< index_t, NumBTensor > StrideBs
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:390
index_t AK0
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:398
index_t N
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:387
index_t K
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:388
index_t BK0
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:399
index_t NPadded
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:395
index_t KPadded
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:397
index_t StrideE
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:392
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:391
index_t NBlock
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:401
std::array< index_t, NumATensor > StrideAs
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:389
__host__ void Print() const
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:360
index_t KBatch
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:393
index_t M
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:386
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:333
index_t MBlock
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:400
index_t MPadded
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:394
index_t KRead
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:396
index_t c_reduce_offset
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:543
std::array< index_t, NumATensor > a_k_split_offset
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:541
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:484
std::array< index_t, NumBTensor > b_k_split_offset
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:542
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:122
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
static constexpr auto I2
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:126
static constexpr auto I3
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:127
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:546
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
decltype(MakeAsGridPointer()) AsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static constexpr auto I1
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:125
static constexpr index_t NumATensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
static constexpr auto AK1Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:151
decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:521
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:446
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
static constexpr auto I6
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:130
static constexpr auto AK0Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:149
static constexpr index_t NumBTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static constexpr auto I0
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
static constexpr index_t APackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:161
static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
static constexpr auto I7
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:131
static constexpr auto I4
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
static __device__ constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static constexpr index_t BPackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:168
static constexpr auto BK1Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr auto BK0Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:150
decltype(MakeBsGridPointer()) BsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:351
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
static constexpr index_t NumDTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:508
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
static constexpr auto I5
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:129
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
static __device__ constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:521
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
static constexpr index_t NumATensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args)
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:641
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, false > Base
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:234
static constexpr index_t NumBTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
static constexpr index_t APackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:161
decltype(MakeBsGridPointer()) BsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:351
static constexpr index_t BPackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:168
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:559
decltype(MakeAsGridPointer()) AsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static __device__ index_t GetKBlockPerScale()
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:553
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
static constexpr index_t NumDTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:508
static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:446
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
Definition functional2.hpp:33
Definition device_base.hpp:197