device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File

device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File#

Composable Kernel: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26template <typename GridwiseGemm,
27 typename FloatAB,
28 typename FloatC,
29 typename D0sPointer,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename C0DEElementwiseOperation,
33 typename B1ElementwiseOperation,
34 typename C1DEElementwiseOperation,
35 typename AGridDesc_AK0_M_AK1,
36 typename BGridDesc_BK0_N_BK1,
37 typename B1GridDesc_BK0_N_BK1,
38 typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
39 typename D0sGridDesc_M_N,
40 typename Block2CTileMap,
41 typename ComputeBasePtrOfStridedBatch,
42 typename C0MatrixMask,
43 bool HasMainKBlockLoop>
44__global__ void
45#if CK_USE_LAUNCH_BOUNDS
47#endif
49 const FloatAB* __restrict__ p_a_grid,
50 const FloatAB* __restrict__ p_b_grid,
51 const FloatAB* __restrict__ p_b1_grid,
52 FloatC* __restrict__ p_c_grid,
53 D0sPointer p_d0s_grid,
54 const AElementwiseOperation a_element_op,
55 const BElementwiseOperation b_element_op,
56 const C0DEElementwiseOperation c0de_element_op,
57 const B1ElementwiseOperation b1_element_op,
58 const C1DEElementwiseOperation c1de_element_op,
59 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
60 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
61 const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
62 const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
63 c1_grid_desc_mblock_mperblock_nblock_nperblock,
64 const D0sGridDesc_M_N d0s_griddesc_m_n,
65 const Block2CTileMap block_2_ctile_map,
66 const index_t batch_count,
67 const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
68 const C0MatrixMask c0_matrix_mask)
69{
70#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
71 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
72 {
73 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
74 const index_t num_blocks_per_batch =
75 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
76 const index_t g_idx =
77 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
78
79 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
80 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
81 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
82 static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
83 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
84 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
85 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
86 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
87
88 static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
89 const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
90 static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
91 p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
92 });
93
94 GridwiseGemm::template Run<HasMainKBlockLoop>(
95 p_a_grid + a_batch_offset,
96 p_b_grid + b_batch_offset,
97 p_b1_grid + b1_batch_offset,
98 p_c_grid + c_batch_offset,
99 p_d0s_grid,
100 p_shared,
101 a_element_op,
102 b_element_op,
103 c0de_element_op,
104 b1_element_op,
105 c1de_element_op,
106 a_grid_desc_ak0_m_ak1,
107 b_grid_desc_bk0_n_bk1,
108 b1_grid_desc_bk0_n_bk1,
109 c1_grid_desc_mblock_mperblock_nblock_nperblock,
110 d0s_griddesc_m_n,
111 block_2_ctile_map,
112 c0_matrix_mask);
113 }
114#else
115 ignore = p_a_grid;
116 ignore = p_b_grid;
117 ignore = p_b1_grid;
118 ignore = p_c_grid;
119 ignore = p_d0s_grid;
120 ignore = a_element_op;
121 ignore = b_element_op;
122 ignore = c0de_element_op;
123 ignore = b1_element_op;
124 ignore = c1de_element_op;
125 ignore = a_grid_desc_ak0_m_ak1;
126 ignore = b_grid_desc_bk0_n_bk1;
127 ignore = b1_grid_desc_bk0_n_bk1;
128 ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
129 ignore = d0s_griddesc_m_n;
130 ignore = block_2_ctile_map;
131 ignore = batch_count;
132 ignore = compute_base_ptr_of_batch;
133 ignore = c0_matrix_mask;
134#endif // end of if (defined(__gfx9__))
135}
136
137// Computes C = A * B0 * B1
138// ^^^^^^ (Acc0)
139// ^^^^^^^^^^^ (Acc1)
140template <index_t NumDimG,
141 index_t NumDimM,
142 index_t NumDimN,
143 index_t NumDimK,
144 index_t NumDimO, // NumDimGemm1N
145 typename ADataType,
146 typename BDataType,
147 typename B1DataType,
148 typename CDataType,
149 typename D0sDataType,
150 typename D1sDataType,
151 typename GemmAccDataType,
152 typename CShuffleDataType,
153 typename AElementwiseOperation,
154 typename BElementwiseOperation,
155 typename C0DEElementwiseOperation,
156 typename B1ElementwiseOperation,
157 typename C1DEElementwiseOperation,
158 GemmSpecialization GemmSpec,
163 index_t NumGemmKPrefetchStage,
164 index_t BlockSize,
165 index_t MPerBlock,
166 index_t NPerBlock, // Gemm0NPerBlock
167 index_t KPerBlock, // Gemm0KPerBlock
168 index_t Gemm1NPerBlock,
169 index_t Gemm1KPerBlock,
170 index_t AK1,
171 index_t BK1,
172 index_t B1K1,
173 index_t MPerXDL,
174 index_t NPerXDL,
175 index_t MXdlPerWave,
176 index_t NXdlPerWave,
177 index_t Gemm1NXdlPerWave,
178 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
179 typename ABlockTransferThreadClusterArrangeOrder,
180 typename ABlockTransferSrcAccessOrder,
181 index_t ABlockTransferSrcVectorDim,
182 index_t ABlockTransferSrcScalarPerVector,
183 index_t ABlockTransferDstScalarPerVector_AK1,
184 bool ABlockLdsExtraM,
185 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
186 typename BBlockTransferThreadClusterArrangeOrder,
187 typename BBlockTransferSrcAccessOrder,
188 index_t BBlockTransferSrcVectorDim,
189 index_t BBlockTransferSrcScalarPerVector,
190 index_t BBlockTransferDstScalarPerVector_BK1,
191 bool BBlockLdsExtraN,
192 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
193 typename B1BlockTransferThreadClusterArrangeOrder,
194 typename B1BlockTransferSrcAccessOrder,
195 index_t B1BlockTransferSrcVectorDim,
196 index_t B1BlockTransferSrcScalarPerVector,
197 index_t B1BlockTransferDstScalarPerVector_BK1,
198 bool B1BlockLdsExtraN,
199 index_t CShuffleMXdlPerWavePerShuffle,
200 index_t CShuffleNXdlPerWavePerShuffle,
201 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
202 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
203 MaskingSpecialization MaskingSpec,
204 int D0sTransferSrcScalarPerVector = 4,
208 NumDimM,
209 NumDimN,
210 NumDimK,
211 NumDimO,
212 ADataType,
213 BDataType,
214 B1DataType,
215 CDataType,
216 D0sDataType,
217 D1sDataType,
218 AElementwiseOperation,
219 BElementwiseOperation,
220 C0DEElementwiseOperation,
221 B1ElementwiseOperation,
222 C1DEElementwiseOperation,
223 MaskingSpec>
224{
225 static constexpr auto MXdlPerWave64 =
226 GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
227 static constexpr auto MXdlPerWave32 =
228 GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
229
230 static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
231 "Number of dimension must be greater than 0");
232
233 static constexpr index_t NumD0Tensor = D0sDataType::Size();
234 static constexpr index_t NumD1Tensor = D1sDataType::Size();
235
236 // TODO ANT: implement bias combination
237 static_assert(NumD1Tensor == 0, "Gemm1 Bias addition is unimplemented");
238
239#if 0
240 // TODO ANT: use alias
241 static constexpr index_t NumDimGemm0M = NumDimM;
242 static constexpr index_t NumDimGemm0N = NumDimN;
243 static constexpr index_t NumDimGemm0K = NumDimK;
244 static constexpr index_t NumDimGemm1M = NumDimM;
245 static constexpr index_t NumDimGemm1N = NumDimO;
246 static constexpr index_t NumDimGemm1K = NumDimN;
247#endif
248
250
251 static constexpr auto I0 = Number<0>{};
252 static constexpr auto I1 = Number<1>{};
253 static constexpr auto I2 = Number<2>{};
254
258 GemmSpec,
259 ASpec,
260 BSpec,
261 B1Spec,
262 CSpec>;
263
264 static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
265 const std::vector<index_t>& a_gs_ms_ks_strides_vec)
266 {
268 Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
269 Number<AK1>{});
270 }
271
272 static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
273 const std::vector<index_t>& b_gs_ns_ks_strides_vec)
274 {
276 Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
277 Number<BK1>{});
278 }
279
280 static auto
281 MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
282 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
283 {
285 Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
286 b1_gs_gemm1ns_gemm1ks_strides_vec),
287 Number<B1K1>{});
288 }
289
291 const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
292 const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
293 {
294 return generate_tuple(
295 [&](auto i) {
296 return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
297 acc0_biases_gs_ms_ns_strides[i]);
298 },
300 }
301
303 const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
304 const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
305 {
306 return generate_tuple(
307 [&](auto i) {
308 return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
309 acc0_biases_gs_ms_ns_strides[i]);
310 },
312 }
313
324
325 constexpr static auto make_MaskOutPredicate()
326 {
327 if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
328 {
329 return MaskDisabledPredicate{};
330 }
331 else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
332 {
334 }
335 }
337
339 {
341 const BGridDesc_G_N_K& b_grid_desc_g_n_k,
342 const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
343 const C1GridDesc_G_M_N& c1_grid_desc_g_m_n,
344 const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n)
345 : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
346 b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
347 b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
348 c1_grid_desc_g_m_n_(c1_grid_desc_g_m_n),
349 d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n)
350 {
351 }
352
353 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
354 {
355 return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
356 }
357
358 __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
359 {
360 return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
361 }
362
363 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
364 {
365 return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
366 }
367
368 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
369 {
370 return c1_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
371 }
372
373 template <index_t I>
374 __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
375 Number<I> d0_idx) const
376 {
377 return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
378 }
379
380 private:
381 AGridDesc_G_M_K a_grid_desc_g_m_k_;
382 BGridDesc_G_N_K b_grid_desc_g_n_k_;
383 B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
384 C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
385 D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
386 };
387
388 // GridwiseGemm
389 template <index_t MXdlPerWave_>
391 ADataType, // TODO: distinguish A/B datatype
392 GemmAccDataType,
393 CShuffleDataType,
394 CDataType,
395 D0sDataType,
396 AElementwiseOperation,
397 BElementwiseOperation,
398 C0DEElementwiseOperation,
399 B1ElementwiseOperation,
400 C1DEElementwiseOperation,
407 NumGemmKPrefetchStage,
408 BlockSize,
409 MPerBlock,
410 NPerBlock,
411 KPerBlock,
412 Gemm1NPerBlock,
413 Gemm1KPerBlock,
414 AK1,
415 BK1,
416 B1K1,
417 MPerXDL,
418 NPerXDL,
419 MXdlPerWave_,
420 NXdlPerWave,
421 Gemm1NXdlPerWave,
422 ABlockTransferThreadClusterLengths_AK0_M_AK1,
423 ABlockTransferThreadClusterArrangeOrder,
424 ABlockTransferSrcAccessOrder,
425 ABlockTransferSrcVectorDim,
426 ABlockTransferSrcScalarPerVector,
427 ABlockTransferDstScalarPerVector_AK1,
428 true,
429 ABlockLdsExtraM,
430 BBlockTransferThreadClusterLengths_BK0_N_BK1,
431 BBlockTransferThreadClusterArrangeOrder,
432 BBlockTransferSrcAccessOrder,
433 BBlockTransferSrcVectorDim,
434 BBlockTransferSrcScalarPerVector,
435 BBlockTransferDstScalarPerVector_BK1,
436 true,
437 BBlockLdsExtraN,
438 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
439 B1BlockTransferThreadClusterArrangeOrder,
440 B1BlockTransferSrcAccessOrder,
441 B1BlockTransferSrcVectorDim,
442 B1BlockTransferSrcScalarPerVector,
443 B1BlockTransferDstScalarPerVector_BK1,
444 false,
445 B1BlockLdsExtraN,
446 CShuffleMXdlPerWavePerShuffle,
447 CShuffleNXdlPerWavePerShuffle,
448 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
449 CShuffleBlockTransferScalarPerVector_NPerBlock,
450 LoopSched,
453 D0sTransferSrcScalarPerVector>;
456
457 // Argument
458 // FIXME: constness
459 struct Argument : public BaseArgument
460 {
462 const ADataType* p_a_grid,
463 const BDataType* p_b_grid,
464 const B1DataType* p_b1_grid,
465 CDataType* p_c_grid,
466 const std::array<void*, NumD0Tensor> p_acc0_biases,
467 const std::array<void*, NumD1Tensor> p_acc1_biases,
468 const std::vector<index_t>& a_gs_ms_ks_lengths,
469 const std::vector<index_t>& a_gs_ms_ks_strides,
470 const std::vector<index_t>& b_gs_ns_ks_lengths,
471 const std::vector<index_t>& b_gs_ns_ks_strides,
472 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
473 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
474 const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
475 const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
476 const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
477 const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
478 const std::array<std::vector<ck::index_t>, NumD1Tensor>&
479 acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
480 const std::array<std::vector<ck::index_t>, NumD1Tensor>&
481 acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
482 AElementwiseOperation a_element_op,
483 BElementwiseOperation b_element_op,
484 C0DEElementwiseOperation c0de_element_op,
485 B1ElementwiseOperation b1_element_op,
486 C1DEElementwiseOperation c1de_element_op)
487 : p_a_grid_{p_a_grid},
488 p_b_grid_{p_b_grid},
489 p_b1_grid_{p_b1_grid},
490 p_c_grid_{p_c_grid},
491 p_d0s_grid_{},
493 DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
495 DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
497 b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
498 c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
499 c_gs_ms_gemm1ns_strides)},
500 d0s_grid_desc_m_n_{DeviceOp::MakeD0sGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
501 acc0_biases_gs_ms_ns_strides)},
503 Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
505 Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
506 b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
507 b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
508 c1_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
509 c_gs_ms_gemm1ns_strides)},
511 acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
512 block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)},
513 a_element_op_{a_element_op},
514 b_element_op_{b_element_op},
515 c0de_element_op_{c0de_element_op},
516 b1_element_op_{b1_element_op},
517 c1de_element_op_{c1de_element_op},
519 raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
520 b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
521 b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
522 b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
523 a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
524 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
525 b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
526 b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
527 b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
528 b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
529 c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
530 c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
537 {
538 // TODO ANT: implement bias addition
539 ignore = p_acc1_biases;
540 ignore = acc1_biases_gs_ms_gemm1ns_lengths;
541 ignore = acc1_biases_gs_ms_gemm1ns_strides;
542
543 static_for<0, NumD0Tensor, 1>{}([&](auto i) {
544 using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
545 // D0 pointer
546 p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
547 // for check
548 d0s_nl_ns_lengths_strides_[i].push_back(
549 acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
550 d0s_nl_ns_lengths_strides_[i].push_back(
551 acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
552 });
553 }
554
555 void Print() const
556 {
557 std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
558 << a_grid_desc_g_m_k_.GetLength(I1) << ", "
559 << a_grid_desc_g_m_k_.GetLength(I2) << '\n';
560 std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
561 << b_grid_desc_g_n_k_.GetLength(I1) << ", "
562 << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
563 std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
564 << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
565 << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
566 std::cout << "c1_grid_desc_g_m_n_: " << c1_grid_desc_g_m_n_.GetLength(I0) << ", "
567 << c1_grid_desc_g_m_n_.GetLength(I1) << ", "
568 << c1_grid_desc_g_m_n_.GetLength(I2) << '\n';
569 }
570
571 // pointers
572 const ADataType* p_a_grid_;
573 const BDataType* p_b_grid_;
574 const B1DataType* p_b1_grid_;
575 CDataType* p_c_grid_;
577
578 // tensor descriptor
589
590 // block-to-c-tile map
592
593 // element-wise op
594 AElementwiseOperation a_element_op_;
595 BElementwiseOperation b_element_op_;
596 C0DEElementwiseOperation c0de_element_op_;
597 B1ElementwiseOperation b1_element_op_;
598 C1DEElementwiseOperation c1de_element_op_;
599
600 // check C0 masking and padding
602
603 // For robust IsSupportedArgument() check
604 std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
605 std::vector<index_t> a_mz_kz_strides_;
606 std::vector<index_t> b_nz_kz_strides_;
607 std::vector<index_t> b1_nz_kz_strides_;
608 std::vector<index_t> c_mz_gemm1nz_strides_;
609 std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
610
613 };
614
615 // Invoker
616 struct Invoker : public BaseInvoker
617 {
619
620 template <typename GridwiseGemm>
621 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
622 {
624 {
625 throw std::runtime_error("wrong! unsupported argument");
626 }
627 auto c1_grid_desc_mblock_mperblock_nblock_nperblock =
628 GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
630
631 const index_t grid_size =
632 arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_;
633
634 // Gemm0_K
635 const auto K =
636 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
637
638 float ave_time = 0;
639
640 auto launch_kernel = [&](auto has_main_k_block_loop_) {
642 GridwiseGemm,
643 ADataType, // TODO: distiguish A/B datatype
644 CDataType,
645 typename GridwiseGemm::D0sGridPointer,
646 AElementwiseOperation,
647 BElementwiseOperation,
648 C0DEElementwiseOperation,
649 B1ElementwiseOperation,
650 C1DEElementwiseOperation,
654 typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
656 typename GridwiseGemm::DefaultBlock2CTileMap,
657 ComputeBasePtrOfStridedBatch,
659 has_main_k_block_loop_>;
660
661 return launch_and_time_kernel(stream_config,
662 kernel,
663 dim3(grid_size),
664 dim3(BlockSize),
665 0,
666 arg.p_a_grid_,
667 arg.p_b_grid_,
668 arg.p_b1_grid_,
669 arg.p_c_grid_,
670 arg.p_d0s_grid_,
671 arg.a_element_op_,
672 arg.b_element_op_,
674 arg.b1_element_op_,
679 c1_grid_desc_mblock_mperblock_nblock_nperblock,
682 arg.batch_count_,
684 arg.c0_matrix_mask_);
685 };
686
687 // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
688 // to concern Gemm0's loop
689 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
690 {
691 ave_time = launch_kernel(integral_constant<bool, true>{});
692 }
693 else
694 {
695 ave_time = launch_kernel(integral_constant<bool, false>{});
696 }
697
698 return ave_time;
699 }
700
701 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
702 {
703 if(get_warp_size() == 64)
704 {
705 if constexpr(MXdlPerWave64 > 0)
706 {
707 return RunImp<GridwiseGemm64>(arg, stream_config);
708 }
709 }
710 else
711 {
712 if constexpr(MXdlPerWave32 > 0)
713 {
714 return RunImp<GridwiseGemm32>(arg, stream_config);
715 }
716 }
717 return 0;
718 }
719
720 // polymorphic
721 float Run(const BaseArgument* p_arg,
722 const StreamConfig& stream_config = StreamConfig{}) override
723 {
724 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
725 }
726 };
727
728 static constexpr bool IsValidCompilationParameter()
729 {
730 // TODO: properly implement this check
731 return true;
732 }
733
734 static bool IsSupportedArgument(const Argument& arg)
735 {
736 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
737 {
738 arg.Print();
739 }
740
742 {
743 return false;
744 }
745 // TODO ANT: Check if tensor specialization & strides mismatch
746
747 // Check if C permute dimension matches GEMM + GEMM shape
748 const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded
749 const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0);
750 const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1);
751 const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
752 const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
753
754 if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
755 {
756 return false;
757 }
758
759 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
760 // vector is out of bounds
761 // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
762 const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
763 const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
764 const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
765 const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
766
767 // Check scalar per vector requirement
768 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
769 const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
770 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
771 const auto c_extent_lowest = Gemm1NzRaw;
772
773 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
774 b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
775 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
776 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
777 {
778 return false;
779 }
780
781 // Check vector load/store requirement
782 const auto a_stride_lowest =
783 ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
784 const auto b_stride_lowest =
785 BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
786 const auto b1_stride_lowest =
787 B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
788 const auto c_stride_lowest =
789 arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
790
791 if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
792 c_stride_lowest == 1))
793 {
794 return false;
795 }
796 for(int i = 0; i < NumD0Tensor; i++)
797 {
798 if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
799 arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0)
800 {
801 return false;
802 }
803 if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1)
804 {
805 return false;
806 }
807 }
808 if(get_warp_size() == 64)
809 {
810 if constexpr(MXdlPerWave64 > 0)
811 {
817 }
818 }
819 else
820 {
821 if constexpr(MXdlPerWave32 > 0)
822 {
828 }
829 }
830 return false;
831 }
832
833 // polymorphic
834 bool IsSupportedArgument(const BaseArgument* p_arg) override
835 {
836 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
837 }
838
839 static auto MakeArgument(
840 const ADataType* p_a,
841 const BDataType* p_b,
842 const B1DataType* p_b1,
843 CDataType* p_c,
844 const std::array<void*, NumD0Tensor> p_acc0_biases,
845 const std::array<void*, NumD1Tensor> p_acc1_biases,
846 const std::vector<index_t>& a_gs_ms_ks_lengths,
847 const std::vector<index_t>& a_gs_ms_ks_strides,
848 const std::vector<index_t>& b_gs_ns_ks_lengths,
849 const std::vector<index_t>& b_gs_ns_ks_strides,
850 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
851 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
852 const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
853 const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
854 const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
855 const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
856 const std::array<std::vector<ck::index_t>, NumD1Tensor>
857 acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
858 const std::array<std::vector<ck::index_t>, NumD1Tensor>
859 acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
860 AElementwiseOperation a_element_op,
861 BElementwiseOperation b_element_op,
862 C0DEElementwiseOperation c0de_element_op,
863 B1ElementwiseOperation b1_element_op,
864 C1DEElementwiseOperation c1de_element_op)
865 {
866 return Argument{p_a,
867 p_b,
868 p_b1,
869 p_c,
870 p_acc0_biases,
871 p_acc1_biases,
872 a_gs_ms_ks_lengths,
873 a_gs_ms_ks_strides,
874 b_gs_ns_ks_lengths,
875 b_gs_ns_ks_strides,
876 b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
877 b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
878 c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
879 c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
880 acc0_biases_gs_ms_ns_lengths,
881 acc0_biases_gs_ms_ns_strides,
882 acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
883 acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
884 a_element_op,
885 b_element_op,
886 c0de_element_op,
887 b1_element_op,
888 c1de_element_op};
889 }
890
891 static auto MakeInvoker() { return Invoker{}; }
892
893 // polymorphic
894 // FIXME: constness
895 std::unique_ptr<BaseArgument> MakeArgumentPointer(
896 const void* p_a,
897 const void* p_b,
898 const void* p_b1,
899 void* p_c,
900 const std::array<void*, NumD0Tensor> p_acc0_biases,
901 const std::array<void*, NumD1Tensor> p_acc1_biases,
902 const std::vector<index_t>& a_gs_ms_ks_lengths,
903 const std::vector<index_t>& a_gs_ms_ks_strides,
904 const std::vector<index_t>& b_gs_ns_ks_lengths,
905 const std::vector<index_t>& b_gs_ns_ks_strides,
906 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
907 const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
908 const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
909 const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
910 const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
911 const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
912 const std::array<std::vector<ck::index_t>, NumD1Tensor>
913 acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
914 const std::array<std::vector<ck::index_t>, NumD1Tensor>
915 acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
916 AElementwiseOperation a_element_op,
917 BElementwiseOperation b_element_op,
918 C0DEElementwiseOperation c0de_element_op,
919 B1ElementwiseOperation b1_element_op,
920 C1DEElementwiseOperation c1de_element_op) override
921 {
922 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
923 static_cast<const BDataType*>(p_b),
924 static_cast<const B1DataType*>(p_b1),
925 static_cast<CDataType*>(p_c),
926 p_acc0_biases, // cast in struct Argument
927 p_acc1_biases, // cast in struct Argument
928 a_gs_ms_ks_lengths,
929 a_gs_ms_ks_strides,
930 b_gs_ns_ks_lengths,
931 b_gs_ns_ks_strides,
932 b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
933 b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
934 c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
935 c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
936 acc0_biases_gs_ms_ns_lengths,
937 acc0_biases_gs_ms_ns_strides,
938 acc1_biases_gs_ms_gemm1ns_lengths,
939 acc1_biases_gs_ms_gemm1ns_strides,
940 a_element_op,
941 b_element_op,
942 c0de_element_op,
943 b1_element_op,
944 c1de_element_op);
945 }
946
947 // polymorphic
948 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
949 {
950 return std::make_unique<Invoker>(Invoker{});
951 }
952
953 // polymorphic
954 std::string GetTypeString() const override
955 {
956 auto str = std::stringstream();
957
958 // clang-format off
959 str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle"
960 << "<"
961 << BlockSize << ", "
962 << MPerBlock << ", "
963 << NPerBlock << ", "
964 << KPerBlock << ", "
965 << AK1 << ", "
966 << BK1 << ", "
967 << MPerBlock << ", "
968 << Gemm1NPerBlock << ", "
969 << Gemm1KPerBlock << ", "
970 << B1K1 << ", "
971 << getGemmSpecializationString(GemmSpec) << ", "
972 << "ASpec" << getTensorSpecializationString(ASpec) << ", "
973 << "B0Spec" << getTensorSpecializationString(BSpec) << ", "
974 << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
975 << "CSpec" << getTensorSpecializationString(CSpec) << ", "
976 << getMaskingSpecializationString(MaskingSpec) << ">";
977 // clang-format on
978
979 return str.str();
980 }
981};
982
983} // namespace device
984} // namespace tensor_operation
985} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
GemmSpecialization
Definition gemm_specialization.hpp:11
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
__global__ void kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, D0sPointer p_d0s_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const C0DEElementwiseOperation c0de_element_op, const B1ElementwiseOperation b1_element_op, const C1DEElementwiseOperation c1de_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const D0sGridDesc_M_N d0s_griddesc_m_n, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:48
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
int64_t long_index_t
Definition ck.hpp:300
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:86
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::D0sGridPointer
decltype(MakeD0sGridPointer()) D0sGridPointer
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:399
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::CheckValidity
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const C1GridDesc_M_N &c1_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:232
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:408
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
static auto MakeB0GridDescriptor_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:198
static auto MakeAGridDescriptor_G_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:154
static auto MakeB0GridDescriptor_G_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:193
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm.hpp:248
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:159
static auto MakeCGridDescriptor_G_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:274
static auto MakeB1GridDescriptor_G_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:233
static auto MakeB1GridDescriptor_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:238
static auto MakeCGridDescriptor_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:279
Definition device_base.hpp:197
Definition masking_specialization.hpp:57
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:339
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:353
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:358
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, Number< I > d0_idx) const
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:374
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const BGridDesc_G_N_K &b_grid_desc_g_n_k, const B1GridDesc_G_N_K &b1_grid_desc_g_n_k, const C1GridDesc_G_M_N &c1_grid_desc_g_m_n, const D0sGridDesc_G_M_N &d0s_grid_desc_g_m_n)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:340
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:368
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:363
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:460
const B1DataType * p_b1_grid_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:574
D0sGridDesc_M_N d0s_grid_desc_m_n_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:583
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:591
AElementwiseOperation a_element_op_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:594
C1DEElementwiseOperation c1de_element_op_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:598
std::vector< index_t > c_mz_gemm1nz_strides_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:608
std::vector< index_t > b_nz_kz_strides_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:606
BElementwiseOperation b_element_op_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:595
B1GridDesc_G_N_K b1_grid_desc_g_n_k_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:586
index_t batch_count_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:611
C0DEElementwiseOperation c0de_element_op_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:596
C1GridDesc_M_N c1_grid_desc_m_n_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:582
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:579
std::vector< index_t > raw_lengths_mz_nz_kz_gemm1nz_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:604
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:588
std::vector< index_t > b1_nz_kz_strides_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:607
B1ElementwiseOperation b1_element_op_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:597
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:612
C1GridDesc_G_M_N c1_grid_desc_g_m_n_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:587
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:584
const BDataType * p_b_grid_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:573
std::vector< index_t > a_mz_kz_strides_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:605
BGridDesc_G_N_K b_grid_desc_g_n_k_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:585
void Print() const
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:555
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > &acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > &acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:461
const ADataType * p_a_grid_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:572
C0MatrixMask c0_matrix_mask_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:601
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:580
GridwiseGemm64::D0sGridPointer p_d0s_grid_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:576
CDataType * p_c_grid_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:575
std::array< std::vector< ck::index_t >, NumD0Tensor > d0s_nl_ns_lengths_strides_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:609
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:581
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:617
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:721
DeviceOp::Argument Argument
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:618
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:701
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:621
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:224
decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})) BGridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:315
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:834
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_K
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:320
static auto MakeD0sGridDescriptor_G_M_N(const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:302
static constexpr auto MXdlPerWave64
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:225
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:728
GridwiseGemmBase< math::max(MXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:454
static constexpr auto make_MaskOutPredicate()
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:325
static constexpr index_t NumD1Tensor
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:234
static constexpr index_t NumD0Tensor
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:233
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_b1, void *p_c, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op) override
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:895
static auto MakeB1GridDescriptor_BK0_N_BK1(const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths_vec, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:281
TransformBatchedContractionContractionToBatchedGemmGemm< Sequence< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO >, Sequence< MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock >, GemmSpec, ASpec, BSpec, B1Spec, CSpec > Transform
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:255
std::string GetTypeString() const override
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:954
static constexpr auto I1
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:252
decltype(MakeD0sGridDescriptor_M_N({}, {})) D0sGridDesc_M_N
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:322
static auto MakeInvoker()
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:891
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle DeviceOp
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:249
static constexpr auto I2
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:253
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const B1DataType *p_b1, CDataType *p_c, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:839
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:264
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:272
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:948
static constexpr auto I0
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:251
GridwiseGemmBase< MXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:455
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:318
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) BGridDesc_G_N_K
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:319
static auto MakeD0sGridDescriptor_M_N(const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:290
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:734
decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})) AGridDesc_AK0_M_AK1
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:314
GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector > GridwiseGemmBase
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:390
static constexpr auto MXdlPerWave32
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:227
decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})) B1GridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:316
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) C1GridDesc_G_M_N
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:321
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:336
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) C1GridDesc_M_N
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:317
decltype(MakeD0sGridDescriptor_G_M_N({}, {})) D0sGridDesc_G_M_N
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:323
Definition device_batched_gemm_softmax_gemm_permute.hpp:34
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43
#define CK_ENV(name)
Definition utility/env.hpp:129