device_gemm_dl.hpp Source File

device_gemm_dl.hpp Source File#

Composable Kernel: device_gemm_dl.hpp Source File
device_gemm_dl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <
25 typename ADataType,
26 typename BDataType,
27 typename CDataType,
28 typename AccDataType,
29 typename ALayout,
30 typename BLayout,
31 typename CLayout,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CElementwiseOperation,
35 GemmSpecialization GemmSpec,
36 index_t BlockSize,
37 index_t MPerBlock,
38 index_t NPerBlock,
39 index_t K0PerBlock,
40 index_t K1,
41 index_t M1PerThread,
42 index_t N1PerThread,
43 index_t KPerThread,
44 typename M1N1ThreadClusterM1Xs,
45 typename M1N1ThreadClusterN1Xs,
46 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
47 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
48 typename ABlockTransferThreadClusterArrangeOrder,
49 typename ABlockTransferSrcAccessOrder,
50 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
51 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
52 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
53 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
54 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
55 typename BBlockTransferThreadClusterArrangeOrder,
56 typename BBlockTransferSrcAccessOrder,
57 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
58 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
59 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
60 typename CThreadTransferSrcDstAccessOrder,
61 index_t CThreadTransferSrcDstVectorDim,
62 index_t CThreadTransferDstScalarPerVector,
67 bool> = false>
68struct DeviceGemmDl : public DeviceGemm<ALayout,
69 BLayout,
70 CLayout,
71 ADataType,
72 BDataType,
73 CDataType,
74 AElementwiseOperation,
75 BElementwiseOperation,
76 CElementwiseOperation>
77
78{
79 static constexpr auto I0 = Number<0>{};
80 static constexpr auto I1 = Number<1>{};
81 static constexpr auto I2 = Number<2>{};
82 static constexpr auto I3 = Number<3>{};
83 static constexpr auto I4 = Number<4>{};
84 static constexpr auto I5 = Number<5>{};
85
86 static constexpr auto K1Number = Number<K1>{};
87
89 {
90 assert(K % K1 == 0);
91
92 const index_t K0 = K / K1;
93
94 const auto a_grid_desc_m_k = [&]() {
96 {
98 }
100 {
101 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
102 }
103 }();
104
105 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
106 {
107 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
108
110 a_grid_desc_m_k,
112 make_right_pad_transform(M, PadM)),
115 }
116 else
117 {
119 a_grid_desc_m_k,
124 }
125 }
126
128 {
129 assert(K % K1 == 0);
130
131 const index_t K0 = K / K1;
132
133 const auto b_grid_desc_k_n = [&]() {
135 {
136 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
137 }
139 {
140 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
141 }
142 }();
143
144 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
145 {
146 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
147
149 b_grid_desc_k_n,
151 make_right_pad_transform(N, PadN)),
154 }
155 else
156 {
158 b_grid_desc_k_n,
163 }
164 }
165
167 {
168 const auto c_grid_desc_m_n = [&]() {
170 {
171 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
172 }
174 {
175 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
176 }
177 }();
178
179 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
180 {
181 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
182 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
183
185 c_grid_desc_m_n,
189 }
190 else
191 {
192
194 c_grid_desc_m_n,
198 }
199 }
200
203 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
204
205 // GridwiseGemm
208 ADataType,
209 AccDataType,
210 CDataType,
215 MPerBlock,
216 NPerBlock,
217 K0PerBlock,
218 K1,
219 M1PerThread,
220 N1PerThread,
221 KPerThread,
222 M1N1ThreadClusterM1Xs,
223 M1N1ThreadClusterN1Xs,
224 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
225 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
226 ABlockTransferThreadClusterArrangeOrder,
227 ABlockTransferSrcAccessOrder,
228 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
229 ABlockTransferSrcVectorTensorContiguousDimOrder,
230 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
231 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
232 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
233 BBlockTransferThreadClusterArrangeOrder,
234 BBlockTransferSrcAccessOrder,
235 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
236 BBlockTransferSrcVectorTensorContiguousDimOrder,
237 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
238 CThreadTransferSrcDstAccessOrder,
239 CThreadTransferSrcDstVectorDim,
240 CThreadTransferDstScalarPerVector>;
241
250
251 // Argument
252 struct Argument : public BaseArgument
253 {
254 Argument(const ADataType* p_a_grid,
255 const BDataType* p_b_grid,
256 CDataType* p_c_grid,
257 index_t M,
258 index_t N,
259 index_t K,
260 index_t StrideA,
261 index_t StrideB,
262 index_t StrideC,
263 index_t M01,
264 index_t N01,
265 AElementwiseOperation a_element_op,
266 BElementwiseOperation b_element_op,
267 CElementwiseOperation c_element_op)
268 : p_a_grid_{p_a_grid},
269 p_b_grid_{p_b_grid},
270 p_c_grid_{p_c_grid},
275 M01_{M01},
276 N01_{N01},
277 M_raw_{M},
278 N_raw_{N},
279 K_raw_{K},
280 a_element_op_{a_element_op},
281 b_element_op_{b_element_op},
282 c_element_op_{c_element_op}
283 {
287
290 {
297
299 }
300 }
301
302 // private:
303 const ADataType* p_a_grid_;
304 const BDataType* p_b_grid_;
305 CDataType* p_c_grid_;
306
310
314
316
317 // TODO: unused, but may be useful in future.
320
324
325 // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
326 AElementwiseOperation a_element_op_;
327 BElementwiseOperation b_element_op_;
328 CElementwiseOperation c_element_op_;
329 };
330
331 // Invoker
332 struct Invoker : public BaseInvoker
333 {
335
336 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
337 {
338 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
339 {
340 std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
341 << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
342 << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
343 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
344
345 std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
346 << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
347 << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
348 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
349
350 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
351 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
352 }
353
356 {
357 throw std::runtime_error(
358 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting");
359 }
360
362 arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
363
364 const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
365 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
366 const bool has_double_tail_k_block_loop =
368
369 float ave_time = 0;
370
371 if(has_main_k_block_loop && has_double_tail_k_block_loop)
372 {
373 const auto kernel =
374 kernel_gemm_dl_v1r3<GridwiseGemm,
375 ADataType,
376 CDataType,
381 true,
382 true>;
383
384 ave_time = launch_and_time_kernel(stream_config,
385 kernel,
386 dim3(grid_size),
387 dim3(BlockSize),
388 0,
389 arg.p_a_grid_,
390 arg.p_b_grid_,
391 arg.p_c_grid_,
396 }
397 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
398 {
399 const auto kernel =
400 kernel_gemm_dl_v1r3<GridwiseGemm,
401 ADataType,
402 CDataType,
407 true,
408 false>;
409
410 ave_time = launch_and_time_kernel(stream_config,
411 kernel,
412 dim3(grid_size),
413 dim3(BlockSize),
414 0,
415 arg.p_a_grid_,
416 arg.p_b_grid_,
417 arg.p_c_grid_,
422 }
423 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
424 {
425 const auto kernel =
426 kernel_gemm_dl_v1r3<GridwiseGemm,
427 ADataType,
428 CDataType,
433 false,
434 true>;
435
436 ave_time = launch_and_time_kernel(stream_config,
437 kernel,
438 dim3(grid_size),
439 dim3(BlockSize),
440 0,
441 arg.p_a_grid_,
442 arg.p_b_grid_,
443 arg.p_c_grid_,
448 }
449 else
450 {
451 const auto kernel =
453 ADataType,
454 CDataType,
459 false,
460 false>;
461
462 ave_time = launch_and_time_kernel(stream_config,
463 kernel,
464 dim3(grid_size),
465 dim3(BlockSize),
466 0,
467 arg.p_a_grid_,
468 arg.p_b_grid_,
469 arg.p_c_grid_,
474 }
475
476 return ave_time;
477 }
478
479 // polymorphic
480 float Run(const BaseArgument* p_arg,
481 const StreamConfig& stream_config = StreamConfig{}) override
482 {
483 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
484 }
485 };
486
487 static constexpr bool IsValidCompilationParameter()
488 {
489 // TODO: properly implement this check
490 return true;
491 }
492
493 static bool IsSupportedArgument(const Argument& arg)
494 {
495 // Make sure that the M, N, K dimensions before padding are divisible by respective vector
496 // lengths.
498 {
499 constexpr auto A_K_vec_length =
500 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
501 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
502 if(arg.K_raw_ % A_K_vec_length != 0)
503 {
504 return false;
505 }
506 }
507 else
508 {
509 constexpr auto A_M_vec_lenght =
510 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
511 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
512 if(arg.M_raw_ % A_M_vec_lenght != 0)
513 {
514 return false;
515 }
516 }
517
519 {
520 constexpr auto B_N_vec_lenght =
521 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
522 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
523 if(arg.N_raw_ % B_N_vec_lenght != 0)
524 {
525 return false;
526 }
527 }
528 else
529 {
530 constexpr auto B_K_vec_length =
531 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
532 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
533 if(arg.K_raw_ % B_K_vec_length != 0)
534 {
535 return false;
536 }
537 }
538
539 if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
541 {
544 }
545 return false;
546 }
547
548 // polymorphic
549 bool IsSupportedArgument(const BaseArgument* p_arg) override
550 {
551 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
552 }
553
554 static auto MakeArgument(const ADataType* p_a,
555 const BDataType* p_b,
556 CDataType* p_c,
557 index_t M,
558 index_t N,
559 index_t K,
560 index_t StrideA,
561 index_t StrideB,
562 index_t StrideC,
563 AElementwiseOperation a_element_op,
564 BElementwiseOperation b_element_op,
565 CElementwiseOperation c_element_op)
566 {
567 return Argument{p_a,
568 p_b,
569 p_c,
570 M,
571 N,
572 K,
573 StrideA,
574 StrideB,
575 StrideC,
576 1,
577 1,
578 a_element_op,
579 b_element_op,
580 c_element_op};
581 }
582
583 static auto MakeInvoker() { return Invoker{}; }
584
585 // polymorphic
586 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
587 const void* p_b,
588 void* p_c,
589 index_t M,
590 index_t N,
591 index_t K,
592 index_t StrideA,
593 index_t StrideB,
594 index_t StrideC,
595 AElementwiseOperation a_element_op,
596 BElementwiseOperation b_element_op,
597 CElementwiseOperation c_element_op) override
598 {
599 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
600 static_cast<const BDataType*>(p_b),
601 static_cast<CDataType*>(p_c),
602 M,
603 N,
604 K,
605 StrideA,
606 StrideB,
607 StrideC,
608 1,
609 1,
610 a_element_op,
611 b_element_op,
612 c_element_op);
613 }
614
615 // polymorphic
616 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
617 {
618 return std::make_unique<Invoker>(Invoker{});
619 }
620
621 // polymorphic
622 virtual std::string GetTypeString() const override
623 {
624 auto str = std::stringstream();
625
626 // clang-format off
627 str << "DeviceGemmDl"
628 << "<"
629 << BlockSize << ", "
630 << MPerBlock << ", "
631 << NPerBlock << ", "
632 << K0PerBlock << ", "
633 << K1 << ", "
634 << M1PerThread << ", "
635 << N1PerThread << ", "
636 << KPerThread
637 << ">";
638 // clang-format on
639
640 return str.str();
641 }
642};
643
644} // namespace device
645} // namespace tensor_operation
646} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_dl_v1r3.hpp:33
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_v1r3.hpp:93
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
index_t M_raw_
Definition device_gemm_dl.hpp:321
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_gemm_dl.hpp:307
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_dl.hpp:309
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_gemm_dl.hpp:312
index_t M01_
Definition device_gemm_dl.hpp:318
index_t N01_
Definition device_gemm_dl.hpp:319
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_gemm_dl.hpp:313
index_t K_raw_
Definition device_gemm_dl.hpp:323
CDataType * p_c_grid_
Definition device_gemm_dl.hpp:305
index_t N_raw_
Definition device_gemm_dl.hpp:322
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_gemm_dl.hpp:308
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_dl.hpp:254
const BDataType * p_b_grid_
Definition device_gemm_dl.hpp:304
AElementwiseOperation a_element_op_
Definition device_gemm_dl.hpp:326
BElementwiseOperation b_element_op_
Definition device_gemm_dl.hpp:327
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_dl.hpp:315
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_gemm_dl.hpp:311
CElementwiseOperation c_element_op_
Definition device_gemm_dl.hpp:328
const ADataType * p_a_grid_
Definition device_gemm_dl.hpp:303
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_dl.hpp:336
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_dl.hpp:480
DeviceGemmDl::Argument Argument
Definition device_gemm_dl.hpp:334
Definition device_gemm_dl.hpp:78
static constexpr auto I0
Definition device_gemm_dl.hpp:79
static constexpr auto I2
Definition device_gemm_dl.hpp:81
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_gemm_dl.hpp:246
virtual std::string GetTypeString() const override
Definition device_gemm_dl.hpp:622
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_dl.hpp:549
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition device_gemm_dl.hpp:202
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_dl.hpp:493
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition device_gemm_dl.hpp:242
static constexpr auto I3
Definition device_gemm_dl.hpp:82
static auto MakeInvoker()
Definition device_gemm_dl.hpp:583
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition device_gemm_dl.hpp:201
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition device_gemm_dl.hpp:127
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_gemm_dl.hpp:206
static constexpr auto I5
Definition device_gemm_dl.hpp:84
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_dl.hpp:203
static constexpr auto I1
Definition device_gemm_dl.hpp:80
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_dl.hpp:487
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_dl.hpp:554
static constexpr auto I4
Definition device_gemm_dl.hpp:83
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_gemm_dl.hpp:248
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition device_gemm_dl.hpp:244
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_dl.hpp:616
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition device_gemm_dl.hpp:166
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition device_gemm_dl.hpp:88
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_dl.hpp:586
static constexpr auto K1Number
Definition device_gemm_dl.hpp:86
Definition device_gemm.hpp:22
#define CK_ENV(name)
Definition utility/env.hpp:129