device_cgemm_4gemm_xdl_cshuffle.hpp Source File

device_cgemm_4gemm_xdl_cshuffle.hpp Source File#

Composable Kernel: device_cgemm_4gemm_xdl_cshuffle.hpp Source File
device_cgemm_4gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26template <
27 typename ALayout,
28 typename BLayout,
29 typename CLayout,
30 typename ADataType,
31 typename BDataType,
32 typename CDataType,
33 typename GemmAccDataType,
34 typename CShuffleDataType,
35 typename AElementwiseOperation,
36 typename BElementwiseOperation,
37 typename CElementwiseOperation,
38 GemmSpecialization GemmSpec,
39 index_t NumGemmKPrefetchStage,
40 index_t BlockSize,
41 index_t MPerBlock,
42 index_t NPerBlock,
43 index_t KPerBlock,
44 index_t AK1,
45 index_t BK1,
46 index_t MPerXDL,
47 index_t NPerXDL,
48 index_t MXdlPerWave,
49 index_t NXdlPerWave,
50 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
51 typename ABlockTransferThreadClusterArrangeOrder,
52 typename ABlockTransferSrcAccessOrder,
53 index_t ABlockTransferSrcVectorDim,
54 index_t ABlockTransferSrcScalarPerVector,
55 index_t ABlockTransferDstScalarPerVector_AK1,
56 bool ABlockLdsExtraM,
57 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
58 typename BBlockTransferThreadClusterArrangeOrder,
59 typename BBlockTransferSrcAccessOrder,
60 index_t BBlockTransferSrcVectorDim,
61 index_t BBlockTransferSrcScalarPerVector,
62 index_t BBlockTransferDstScalarPerVector_BK1,
63 bool BBlockLdsExtraN,
64 index_t CShuffleMXdlPerWavePerShuffle,
65 index_t CShuffleNXdlPerWavePerShuffle,
66 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73 bool> = false>
75 : public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
76{
79 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
80 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
81
82 static constexpr auto I0 = Number<0>{};
83 static constexpr auto I1 = Number<1>{};
84 static constexpr auto I2 = Number<2>{};
85
86 static constexpr index_t MPerThread =
87 MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
88 static constexpr index_t NPerThread =
89 NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
90
91 static constexpr auto AScalarPerVector = Number<4>{};
92 static constexpr auto BScalarPerVector = Number<4>{};
93 static constexpr auto CScalarPerVector = Number<4>{};
94
95 template <typename Desc_M_N>
96 static auto PadDescriptor_M_N(Desc_M_N desc)
97 {
98 const auto M = desc.GetLength(I0);
99 const auto N = desc.GetLength(I1);
100 const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M;
101 const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N;
102
103 const auto padded_desc = transform_tensor_descriptor(
104 desc,
108
109 return padded_desc;
110 }
111
112 static auto MakeDescriptor_M_N(const std::vector<index_t>& lengths,
113 const std::vector<index_t>& strides)
114 {
115 auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
116 auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
117
118 // nd desc - [s0, s1, s2, ...]
119 const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
120 return PadDescriptor_M_N(desc);
121 }
122
123 // GridwiseGemm
124 template <index_t NXdlPerWave_>
126 ALayout,
127 BLayout,
128 CLayout,
129 ADataType,
130 BDataType,
131 GemmAccDataType,
132 CShuffleDataType,
133 CDataType,
134 AElementwiseOperation,
135 BElementwiseOperation,
136 CElementwiseOperation,
137 GemmSpec,
139 NumGemmKPrefetchStage,
140 BlockSize,
141 MPerBlock,
142 NPerBlock,
143 KPerBlock,
144 AK1,
145 BK1,
146 MPerXDL,
147 NPerXDL,
148 MXdlPerWave,
149 NXdlPerWave_,
150 ABlockTransferThreadClusterLengths_AK0_M_AK1,
151 ABlockTransferThreadClusterArrangeOrder,
152 ABlockTransferSrcAccessOrder,
153 ABlockTransferSrcVectorDim,
154 ABlockTransferSrcScalarPerVector,
155 ABlockTransferDstScalarPerVector_AK1,
156 false,
157 ABlockLdsExtraM,
158 BBlockTransferThreadClusterLengths_BK0_N_BK1,
159 BBlockTransferThreadClusterArrangeOrder,
160 BBlockTransferSrcAccessOrder,
161 BBlockTransferSrcVectorDim,
162 BBlockTransferSrcScalarPerVector,
163 BBlockTransferDstScalarPerVector_BK1,
164 false,
165 BBlockLdsExtraN,
166 CShuffleMXdlPerWavePerShuffle,
167 CShuffleNXdlPerWavePerShuffle,
168 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
169 CShuffleBlockTransferScalarPerVector_NPerBlock,
170 LoopSched>;
173
174 using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1}));
175
176 // Argument
177 struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm64::Problem
178 {
179 using Problem = typename GridwiseGemm64::Problem;
180
181 Argument(const ADataType* p_a_grid_real_,
182 const ADataType* p_a_grid_imag_,
183 const BDataType* p_b_grid_real_,
184 const BDataType* p_b_grid_imag_,
185 CDataType* p_c_grid_real_,
186 CDataType* p_c_grid_imag_,
187 CDataType* p_workspace,
188 index_t M_,
189 index_t N_,
190 index_t K_,
191 index_t StrideA_,
192 index_t StrideB_,
193 index_t StrideC_)
194 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
195 p_a_grid_real{p_a_grid_real_},
196 p_a_grid_imag{p_a_grid_imag_},
197 p_b_grid_real{p_b_grid_real_},
198 p_b_grid_imag{p_b_grid_imag_},
199 p_c_grid_real{p_c_grid_real_},
200 p_c_grid_imag{p_c_grid_imag_},
201 p_aux_grid{p_workspace}
202 {
204 {
205 c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1});
206 }
208 {
209 c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_});
210 }
211
212 p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
213 }
214
215 // private:
216 const ADataType* p_a_grid_real;
217 const ADataType* p_a_grid_imag;
218 const BDataType* p_b_grid_real;
219 const BDataType* p_b_grid_imag;
220 CDataType* p_c_grid_real;
221 CDataType* p_c_grid_imag;
222 CDataType* p_aux_grid;
223 CDataType* p_aux_2_grid;
225 };
226
227 // Invoker
228 struct Invoker : public BaseInvoker
229 {
230 template <typename GridwiseGemm>
231 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
232 {
233 if(stream_config.log_level_ > 0)
234 {
235 arg.Print();
236 }
237
238 typename GridwiseGemm::Problem problem(
239 arg.M, arg.N, arg.K, arg.StrideA, arg.StrideB, arg.StrideC);
240 if(!GridwiseGemm::CheckValidity(problem))
241 {
242 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
243 }
244
245 index_t gdx, gdy, gdz;
246 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
247
248 const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
249
250 float ave_time = 0;
251
254
256
261 Block2TileMap,
262 Add,
263 BlockSize,
264 MPerBlock,
265 NPerBlock,
271 I1,
272 I1>;
273
274 using GridwiseBinSubtract =
279 Block2TileMap,
280 Subtract,
281 BlockSize,
282 MPerBlock,
283 NPerBlock,
289 I1,
290 I1>;
291
292 const index_t M = arg.c_grid_desc_m_n.GetLength(I0);
293 const index_t N = arg.c_grid_desc_m_n.GetLength(I1);
294 const auto block_2_tile_map = Block2TileMap(M, N);
295
296 const auto add_kernel = kernel_elementwise<GridwiseBinAdd,
301 Block2TileMap,
302 Add>;
303
304 const auto subtract_kernel =
305 kernel_elementwise<GridwiseBinSubtract,
310 Block2TileMap,
311 Subtract>;
312
313 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
314 {
315 const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
316 ADataType,
317 BDataType,
318 CDataType,
319 true>;
320
321 ave_time += launch_and_time_kernel(stream_config,
322 kernel,
323 dim3(gdx, gdy, gdz),
324 dim3(BlockSize),
325 0,
326 arg.p_a_grid_real,
327 arg.p_b_grid_real,
328 arg.p_aux_grid,
329 problem);
330
331 ave_time += launch_and_time_kernel(stream_config,
332 kernel,
333 dim3(gdx, gdy, gdz),
334 dim3(BlockSize),
335 0,
336 arg.p_a_grid_imag,
337 arg.p_b_grid_imag,
338 arg.p_aux_2_grid,
339 problem);
340
341 // c_real = aux - aux_2
342 ave_time += launch_and_time_kernel(
343 stream_config,
344 subtract_kernel,
345 dim3(gdx, gdy, gdz),
346 dim3(BlockSize),
347 0,
350 make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
351 const_cast<const CDataType*>(arg.p_aux_2_grid)),
353 block_2_tile_map,
354 Subtract{});
355
356 ave_time += launch_and_time_kernel(stream_config,
357 kernel,
358 dim3(gdx, gdy, gdz),
359 dim3(BlockSize),
360 0,
361 arg.p_a_grid_real,
362 arg.p_b_grid_imag,
363 arg.p_aux_grid,
364 problem);
365
366 ave_time += launch_and_time_kernel(stream_config,
367 kernel,
368 dim3(gdx, gdy, gdz),
369 dim3(BlockSize),
370 0,
371 arg.p_a_grid_imag,
372 arg.p_b_grid_real,
373 arg.p_aux_2_grid,
374 problem);
375
376 // c_imag = aux + aux_2
377 ave_time += launch_and_time_kernel(
378 stream_config,
379 add_kernel,
380 dim3(gdx, gdy, gdz),
381 dim3(BlockSize),
382 0,
385 make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
386 const_cast<const CDataType*>(arg.p_aux_2_grid)),
388 block_2_tile_map,
389 Add{});
390 }
391 else
392 {
393 const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
394 ADataType,
395 BDataType,
396 CDataType,
397 false>;
398
399 ave_time += launch_and_time_kernel(stream_config,
400 kernel,
401 dim3(gdx, gdy, gdz),
402 dim3(BlockSize),
403 0,
404 arg.p_a_grid_real,
405 arg.p_b_grid_real,
406 arg.p_aux_grid,
407 problem);
408
409 ave_time += launch_and_time_kernel(stream_config,
410 kernel,
411 dim3(gdx, gdy, gdz),
412 dim3(BlockSize),
413 0,
414 arg.p_a_grid_imag,
415 arg.p_b_grid_imag,
416 arg.p_aux_2_grid,
417 problem);
418
419 // c_real = aux - aux_2
420 ave_time += launch_and_time_kernel(
421 stream_config,
422 subtract_kernel,
423 dim3(gdx, gdy, gdz),
424 dim3(BlockSize),
425 0,
428 make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
429 const_cast<const CDataType*>(arg.p_aux_2_grid)),
431 block_2_tile_map,
432 Subtract{});
433
434 ave_time += launch_and_time_kernel(stream_config,
435 kernel,
436 dim3(gdx, gdy, gdz),
437 dim3(BlockSize),
438 0,
439 arg.p_a_grid_real,
440 arg.p_b_grid_imag,
441 arg.p_aux_grid,
442 problem);
443
444 ave_time += launch_and_time_kernel(stream_config,
445 kernel,
446 dim3(gdx, gdy, gdz),
447 dim3(BlockSize),
448 0,
449 arg.p_a_grid_imag,
450 arg.p_b_grid_real,
451 arg.p_aux_2_grid,
452 problem);
453
454 // c_imag = aux + aux_2
455 ave_time += launch_and_time_kernel(
456 stream_config,
457 add_kernel,
458 dim3(gdx, gdy, gdz),
459 dim3(BlockSize),
460 0,
463 make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
464 const_cast<const CDataType*>(arg.p_aux_2_grid)),
466 block_2_tile_map,
467 Add{});
468 }
469
470 return ave_time;
471 }
472
474
475 // polymorphic
476 float Run(const BaseArgument* p_arg,
477 const StreamConfig& stream_config = StreamConfig{}) override
478 {
479 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
480 }
481 };
482
483 static constexpr bool IsValidCompilationParameter()
484 {
485 // TODO: properly implement this check
486 return true;
487 }
488
489 static bool IsSupportedArgument(const Argument& arg)
490 {
492 {
493 return false;
494 }
495 if(get_warp_size() == 64)
496 {
497 if constexpr(NXdlPerWave64 > 0)
498 {
500 }
501 }
502 else
503 {
504 if constexpr(NXdlPerWave32 > 0)
505 {
506 typename GridwiseGemm32::Problem problem(
507 arg.M, arg.N, arg.K, arg.StrideA, arg.StrideB, arg.StrideC);
508 return GridwiseGemm32::CheckValidity(problem);
509 }
510 }
511 return false;
512 }
513
514 // polymorphic
515 bool IsSupportedArgument(const BaseArgument* p_arg) override
516 {
517 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
518 }
519
520 static auto MakeArgument(const ADataType* p_a_real,
521 const ADataType* p_a_imag,
522 const BDataType* p_b_real,
523 const BDataType* p_b_imag,
524 CDataType* p_c_real,
525 CDataType* p_c_imag,
526 CDataType* p_workspace,
527 index_t M,
528 index_t N,
529 index_t K,
530 index_t StrideA,
531 index_t StrideB,
532 index_t StrideC,
533 AElementwiseOperation,
534 BElementwiseOperation,
535 CElementwiseOperation)
536 {
537 return Argument{p_a_real,
538 p_a_imag,
539 p_b_real,
540 p_b_imag,
541 p_c_real,
542 p_c_imag,
543 p_workspace,
544 M,
545 N,
546 K,
547 StrideA,
548 StrideB,
549 StrideC};
550 }
551
552 static auto MakeInvoker() { return Invoker{}; }
553
554 // polymorphic
555 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
556 const void* p_a_imag,
557 const void* p_b_real,
558 const void* p_b_imag,
559 void* p_c_real,
560 void* p_c_imag,
561 void* p_workspace,
562 index_t M,
563 index_t N,
564 index_t K,
565 index_t StrideA,
566 index_t StrideB,
567 index_t StrideC,
568 AElementwiseOperation,
569 BElementwiseOperation,
570 CElementwiseOperation,
571 index_t /* KBatch */ = 1) override
572 {
573 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real),
574 static_cast<const ADataType*>(p_a_imag),
575 static_cast<const BDataType*>(p_b_real),
576 static_cast<const BDataType*>(p_b_imag),
577 static_cast<CDataType*>(p_c_real),
578 static_cast<CDataType*>(p_c_imag),
579 static_cast<CDataType*>(p_workspace),
580 M,
581 N,
582 K,
583 StrideA,
584 StrideB,
585 StrideC);
586 }
587
588 // polymorphic
589 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
590 {
591 return std::make_unique<Invoker>(Invoker{});
592 }
593
594 // polymorphic
595 std::string GetTypeString() const override
596 {
597 auto str = std::stringstream();
598
599 // clang-format off
600 str << "DeviceCGemm_4Gemm_Xdl_CShuffle"
601 << "<"
602 << BlockSize << ", "
603 << MPerBlock << ", "
604 << NPerBlock << ", "
605 << KPerBlock << ", "
606 << AK1 << ", "
607 << BK1
608 << ">";
609 // clang-format on
610
611 return str.str();
612 }
613
614 static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
615 {
616 const auto c_grid_desc_m_n =
619 N,
621 StrideC);
622
623 return c_grid_desc_m_n.GetElementSpaceSize();
624 }
625
627 index_t N,
628 [[maybe_unused]] index_t K,
629 [[maybe_unused]] index_t StrideA,
630 [[maybe_unused]] index_t StrideB,
631 index_t StrideC) const override
632 {
633 return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
634 }
635
636 std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
637 {
638 const auto* parg = dynamic_cast<const Argument*>(base_arg);
639
640 if(!parg)
641 {
642 std::ostringstream err;
643 err << "Provided argument pointer is not of an Argument class!" << " In " << __FILE__
644 << ":" << __LINE__ << ", in function: " << __func__;
645 throw std::runtime_error(err.str());
646 }
647
648 return GetWorkspaceSize(
649 parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
650 }
651};
652
653} // namespace device
654} // namespace tensor_operation
655} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
@ Add
Definition ck.hpp:281
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:25
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:121
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:178
CGridDesc_M_N c_grid_desc_m_n
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:224
CDataType * p_c_grid_real
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:220
const BDataType * p_b_grid_imag
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:219
typename GridwiseGemm64::Problem Problem
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:179
const ADataType * p_a_grid_imag
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:217
const ADataType * p_a_grid_real
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:216
CDataType * p_aux_grid
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:222
CDataType * p_aux_2_grid
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:223
CDataType * p_c_grid_imag
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:221
Argument(const ADataType *p_a_grid_real_, const ADataType *p_a_grid_imag_, const BDataType *p_b_grid_real_, const BDataType *p_b_grid_imag_, CDataType *p_c_grid_real_, CDataType *p_c_grid_imag_, CDataType *p_workspace, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:181
const BDataType * p_b_grid_real
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:218
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:229
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:231
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:476
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:76
static constexpr auto I2
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:84
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a_real, const void *p_a_imag, const void *p_b_real, const void *p_b_imag, void *p_c_real, void *p_c_imag, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t=1) override
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:555
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:614
static auto MakeDescriptor_M_N(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:112
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:515
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:125
static constexpr bool IsValidCompilationParameter()
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:483
static auto MakeInvoker()
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:552
static constexpr auto NXdlPerWave32
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:80
std::size_t GetWorkspaceSize(index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC) const override
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:626
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:171
static constexpr index_t MPerThread
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:86
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:79
decltype(MakeDescriptor_M_N({1, 1}, {1, 1})) CGridDesc_M_N
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:174
static constexpr auto I1
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:83
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:172
static constexpr auto I0
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:82
static constexpr auto CScalarPerVector
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:93
std::string GetTypeString() const override
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:595
static auto PadDescriptor_M_N(Desc_M_N desc)
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:96
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:589
static constexpr auto BScalarPerVector
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:92
static auto MakeArgument(const ADataType *p_a_real, const ADataType *p_a_imag, const BDataType *p_b_real, const BDataType *p_b_imag, CDataType *p_c_real, CDataType *p_c_imag, CDataType *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:520
static bool IsSupportedArgument(const Argument &arg)
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:489
std::size_t GetWorkSpaceSize(const BaseArgument *base_arg) const override
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:636
static constexpr index_t NPerThread
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:88
DeviceCGemm_4Gemm_Xdl_CShuffle DeviceOp
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:77
static constexpr auto AScalarPerVector
Definition device_cgemm_4gemm_xdl_cshuffle.hpp:91
Definition device_cgemm.hpp:15
Definition binary_element_wise_operation.hpp:14
Definition binary_element_wise_operation.hpp:247