grouped_gemm_quant_kernel.hpp Source File

grouped_gemm_quant_kernel.hpp Source File#

Composable Kernel: grouped_gemm_quant_kernel.hpp Source File
grouped_gemm_quant_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13#include "ck_tile/host.hpp"
14
15#include <hip/hip_runtime.h>
16
17namespace ck_tile {
18
27{
29 const void* b_ptr_,
30 void* e_ptr_,
31 const void* aq_ptr_,
32 const void* bq_ptr_,
33 index_t k_batch_,
34 index_t M_,
35 index_t N_,
36 index_t K_,
37 index_t QK_A_,
38 index_t QK_B_,
39 index_t stride_A_,
40 index_t stride_B_,
41 index_t stride_E_,
42 index_t stride_AQ_,
43 index_t stride_BQ_)
44 : a_ptr(a_ptr_),
45 b_ptr(b_ptr_),
46 aq_ptr(aq_ptr_),
47 bq_ptr(bq_ptr_),
48 e_ptr(e_ptr_),
49 M(M_),
50 N(N_),
51 K(K_),
52 QK_A(QK_A_),
53 QK_B(QK_B_),
54 stride_A(stride_A_),
55 stride_B(stride_B_),
56 stride_AQ(stride_AQ_),
57 stride_BQ(stride_BQ_),
58 stride_E(stride_E_),
59 k_batch(k_batch_)
60 {
61 }
62
63 const void* a_ptr;
64 const void* b_ptr;
65 const void* aq_ptr;
66 const void* bq_ptr;
67 union
68 {
69 void* e_ptr;
70 void* c_ptr;
71 };
72
82
83 union
84 {
87 };
88
90};
91
93
111
112template <typename TilePartitioner_,
113 typename GemmPipeline_,
114 typename EpiloguePipeline_,
115 QuantType QuantType_>
117{
121
125
130
136
141
142 static constexpr auto kQuantType = QuantType_;
143
145 static_assert(
147 "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
148
150 static_assert(
152 "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
153
157 "C/ELayout and C/EDataType must be scalars.");
158
160 using Kernel =
162
163 static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
164 static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
165 static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
166
167 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
168 {
169 // clang-format off
170 using P_ = GemmPipeline;
171
172 return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
173 concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
174 concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
175 concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
176 (UsePersistentKernel ? "Persistent" : "NonPersistent"));
177 // clang-format on
178 }
179
180 CK_TILE_HOST static auto
181 GetWorkSpaceSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs) -> std::size_t
182 {
183 return gemm_descs.size() * sizeof(QuantGemmTransKernelArg);
184 }
185
186 CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
187 {
188 return group_count * sizeof(QuantGemmTransKernelArg);
189 }
190
191 CK_TILE_HOST static auto BlockSize() -> dim3
192 {
193 if(is_wave32())
194 {
195 return dim3(kBlockSize / 2);
196 }
197 else
198 {
199 return dim3(kBlockSize);
200 }
201 }
202
209 CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
210 {
211 using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
212 const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>;
213 int occupancy;
215 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_func, kBlockSize, 0));
216 const int grid_size = get_available_compute_units(s) * occupancy;
217 return dim3(grid_size, 1, 1);
218 }
219
220 CK_TILE_HOST static auto GridSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
221 {
222 index_t grid_size = 0;
223 for(const auto& it_desc : gemm_descs)
224 {
225 const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
226 grid_size += local_grid_size * it_desc.k_batch;
227 }
228 return dim3(grid_size, 1, 1);
229 }
230
231 CK_TILE_HOST static auto MakeKargs(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
232 -> std::vector<QuantGemmTransKernelArg>
233 {
234 std::vector<QuantGemmTransKernelArg> gemm_kernel_args_;
235 index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
236 index_t grid_size = 0;
237 gemm_kernel_args_.reserve(group_count);
238
239 for(std::size_t i = 0; i < gemm_descs.size(); ++i)
240 {
241 const index_t M = gemm_descs[i].M;
242 const index_t N = gemm_descs[i].N;
243 const index_t K = gemm_descs[i].K;
244
245 if(M == 0 || N == 0 || K == 0)
246 {
247 continue;
248 }
249
250 const index_t stride_a = gemm_descs[i].stride_A;
251 const index_t stride_b = gemm_descs[i].stride_B;
252 const index_t stride_e = gemm_descs[i].stride_C;
253
254 const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
255
256 const index_t block_start = grid_size;
257 const index_t block_end = grid_size + grid_size_grp;
258
259 grid_size += grid_size_grp;
260
261 auto karg =
263 type_convert<const BDataType*>(gemm_descs[i].b_ptr),
264 type_convert<CDataType*>(gemm_descs[i].e_ptr),
265 type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
266 type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
267 gemm_descs[i].k_batch,
268 M,
269 N,
270 K,
271 gemm_descs[i].QK_A,
272 gemm_descs[i].QK_B,
273 stride_a,
274 stride_b,
275 stride_e,
276 gemm_descs[i].stride_AQ,
277 gemm_descs[i].stride_BQ};
278
279 gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
280 }
281
282 return gemm_kernel_args_;
283 }
284
285 CK_TILE_HOST static bool IsSupportedArgument(const std::vector<QuantGemmTransKernelArg>& kargs)
286 {
287 for(const auto& karg : kargs)
288 {
289 if(!Base::IsSupportedArgument(karg.group_karg))
290 {
291 return false;
292 }
293 }
294 return true;
295 }
296
297 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
298 {
299 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
300 }
301
303 const tuple<index_t, index_t>& block_idx_2d,
304 const index_t block_idx_z) const
305 {
306 const auto [iM, iN] = block_idx_2d;
307
308 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
309 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
310
311 const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
312
313 // options
314 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
315 const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
316 const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
317 const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
318 CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
319
320 // allocate LDS
321 __shared__ char smem_ptr_0[GetSmemSize()];
322
323 // Only for BQuantGrouped DoubleSmemBuffer is supported
324 if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
326 {
327
328 __shared__ char smem_ptr_1[GetSmemSize()];
330 b_ptr,
331 aq_ptr,
332 bq_ptr,
333 c_ptr,
334 smem_ptr_0,
335 smem_ptr_1,
336 kargs,
337 splitk_batch_offset,
338 i_m,
339 i_n);
340 }
341 else
342 {
343
345 b_ptr,
346 aq_ptr,
347 bq_ptr,
348 c_ptr,
349 smem_ptr_0,
350 kargs,
351 splitk_batch_offset,
352 i_m,
353 i_n);
354 }
355 }
356
357 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
358 CK_TILE_DEVICE static void
360 const BDataType* b_ptr,
361 const AQDataType* aq_ptr,
362 const BQDataType* bq_ptr,
363 CDataType* c_ptr,
364 void* smem_ptr_0,
365 void* smem_ptr_1,
366 const QuantGroupedGemmKernelArgs& kargs,
367 const typename Base::SplitKBatchOffset& splitk_batch_offset,
368 const index_t block_idx_m,
369 const index_t block_idx_n)
370 {
371 static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
372 // Create Gemm tensor views, pad views and tile windows
373 const auto& gemm_tensor_views_tuple =
374 Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
375 a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
376
377 const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
378 auto gemm_tile_windows =
379 Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
380
381 const index_t num_loop = __builtin_amdgcn_readfirstlane(
382 TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
383 const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
384
385 // Run GEMM cooperatively by whole workgroup.
386 const auto& a_block_window = gemm_tile_windows.at(Base::I0);
387 const auto& b_block_window = gemm_tile_windows.at(Base::I2);
388
389 const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
390 const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
391 b_block_window,
392 bq_block_window,
393 num_loop,
394 tail_num,
395 smem_ptr_0,
396 smem_ptr_1);
397
398 // Run Epilogue Pipeline
399 auto& c_block_window = gemm_tile_windows.at(Base::I4);
400
401 EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
402 }
403
424 CK_TILE_DEVICE static void
426 const BDataType* b_ptr,
427 const AQDataType* aq_ptr,
428 const BQDataType* bq_ptr,
429 CDataType* c_ptr,
430 void* smem_ptr_0,
431 const QuantGroupedGemmKernelArgs& kargs,
432 const typename Base::SplitKBatchOffset& splitk_batch_offset,
433 const index_t block_idx_m,
434 const index_t block_idx_n)
435 {
436 // Create Gemm tensor views, pad views and tile windows
437 const auto& gemm_tensor_views_tuple =
438 Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
439 a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
440
441 const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
442 auto gemm_tile_windows =
443 Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
444 const auto& a_block_window = gemm_tile_windows.at(Base::I0);
445 const auto& b_block_window = gemm_tile_windows.at(Base::I2);
446
447 // Get hot-loop and tail configuration
448 const index_t num_loop = __builtin_amdgcn_readfirstlane(
449 TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
450 const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
451 const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
452
453 if constexpr(kQuantType == QuantType::BQuantGrouped)
454 {
455 const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
456 // Run GEMM pipeline
457 const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
458 b_block_window,
459 bq_block_window,
460 num_loop,
461 has_hot_loop,
462 tail_num,
463 smem_ptr_0);
464
465 auto& c_block_window = gemm_tile_windows.at(Base::I4);
466
467 // Run Epilogue Pipeline
468 EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
469 }
470 else
471 {
472 // Run GEMM pipeline
473 const auto& c_block_tile = GemmPipeline{}.template operator()(
474 a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
475 // Run Epilogue Pipeline
476 auto& c_block_window = gemm_tile_windows.at(Base::I4);
477 if constexpr(kQuantType == QuantType::RowColQuant)
478 {
479 const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
480 const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
481 EpiloguePipeline{}(c_block_window,
482 c_block_tile,
483 c_block_window,
484 smem_ptr_0,
485 aq_block_window,
486 bq_block_window);
487 }
488 else if constexpr(kQuantType == QuantType::TensorQuant)
489 {
490 const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
491 const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
493 c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
494 }
495 }
496 }
497
498 // For persistent kernels
499 template <bool U = UsePersistentKernel,
500 typename = std::enable_if_t<U>,
501 typename = void> // extra template parameter to avoid redefinition
502 CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
503 const index_t group_count) const
504 {
505 const index_t grid_size = ck_tile::get_grid_size();
506 const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
507 cast_pointer_to_generic_address_space(gemm_descs_const));
508 index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
509 index_t cum_grid_size = 0;
510 for(index_t group_id = 0; group_id < group_count; ++group_id)
511 {
512 const auto& kargs = gemm_desc_ptr[group_id].group_karg;
513 const auto& k_batch = kargs.k_batch;
514 const auto block_start = cum_grid_size;
515 cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
516 while(block_id < cum_grid_size)
517 {
518 const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
519 const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
520 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
521 Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
522 block_id = block_id + grid_size; // advance to next block
523 // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
524 if(block_id >= cum_grid_size)
525 {
526 break; // exit the loop if all blocks are processed
527 }
528 }
529 }
530 }
531};
532
533} // namespace ck_tile
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE index_t get_block_1d_id()
Definition arch.hpp:98
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
QuantType
Definition tile_gemm_quant_traits.hpp:12
@ BQuantGrouped
Definition tile_gemm_quant_traits.hpp:14
@ RowColQuant
Definition tile_gemm_quant_traits.hpp:15
@ TensorQuant
Definition tile_gemm_quant_traits.hpp:16
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition arch.hpp:307
QuantGemmKernelArgs QuantGroupedGemmKernelArgs
Definition grouped_gemm_quant_kernel.hpp:92
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE index_t get_grid_size()
Definition arch.hpp:89
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Struct used to calculate offseted tile indexes.
Definition gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition gemm_tile_partitioner.hpp:192
Definition gemm_quant_kernel.hpp:272
index_t splitted_k
Definition gemm_quant_kernel.hpp:310
Definition gemm_quant_kernel.hpp:171
const void * b_ptr
Definition gemm_quant_kernel.hpp:173
void * c_ptr
Definition gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition gemm_quant_kernel.hpp:174
const void * a_ptr
Definition gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition gemm_quant_kernel.hpp:175
Definition gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition gemm_quant_kernel.hpp:226
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition gemm_quant_kernel.hpp:729
static constexpr auto I0
Definition gemm_quant_kernel.hpp:223
static constexpr auto I1
Definition gemm_quant_kernel.hpp:224
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition gemm_quant_kernel.hpp:313
static constexpr auto I2
Definition gemm_quant_kernel.hpp:225
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition gemm_quant_kernel.hpp:806
Definition grouped_gemm_quant_kernel.hpp:95
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs &&karg, index_t bl_start, index_t bl_end)
Definition grouped_gemm_quant_kernel.hpp:101
ck_tile::index_t block_end
Definition grouped_gemm_quant_kernel.hpp:98
ck_tile::index_t block_start
Definition grouped_gemm_quant_kernel.hpp:97
QuantGroupedGemmKernelArgs group_karg
Definition grouped_gemm_quant_kernel.hpp:96
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs &&karg)
Definition grouped_gemm_quant_kernel.hpp:106
index_t stride_BQ
Definition grouped_gemm_quant_kernel.hpp:81
const void * b_ptr
Definition grouped_gemm_quant_kernel.hpp:64
void * c_ptr
Definition grouped_gemm_quant_kernel.hpp:70
index_t QK_A
Definition grouped_gemm_quant_kernel.hpp:76
index_t M
Definition grouped_gemm_quant_kernel.hpp:73
const void * aq_ptr
Definition grouped_gemm_quant_kernel.hpp:65
index_t stride_B
Definition grouped_gemm_quant_kernel.hpp:79
index_t k_batch
Definition grouped_gemm_quant_kernel.hpp:89
index_t N
Definition grouped_gemm_quant_kernel.hpp:74
index_t stride_AQ
Definition grouped_gemm_quant_kernel.hpp:80
CK_TILE_HOST QuantGroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_E_, index_t stride_AQ_, index_t stride_BQ_)
Definition grouped_gemm_quant_kernel.hpp:28
index_t K
Definition grouped_gemm_quant_kernel.hpp:75
index_t QK_B
Definition grouped_gemm_quant_kernel.hpp:77
void * e_ptr
Definition grouped_gemm_quant_kernel.hpp:69
index_t stride_A
Definition grouped_gemm_quant_kernel.hpp:78
const void * bq_ptr
Definition grouped_gemm_quant_kernel.hpp:66
index_t stride_C
Definition grouped_gemm_quant_kernel.hpp:86
index_t stride_E
Definition grouped_gemm_quant_kernel.hpp:85
const void * a_ptr
Definition grouped_gemm_quant_kernel.hpp:63
Definition grouped_gemm_quant_kernel.hpp:117
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize() -> index_t
Definition grouped_gemm_quant_kernel.hpp:297
QuantGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ > Base
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition grouped_gemm_quant_kernel.hpp:120
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_gemm_quant_kernel.hpp:425
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition grouped_gemm_quant_kernel.hpp:129
static CK_TILE_HOST auto GridSize(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs)
Definition grouped_gemm_quant_kernel.hpp:220
static constexpr index_t kBlockSize
Definition grouped_gemm_quant_kernel.hpp:163
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition grouped_gemm_quant_kernel.hpp:132
static CK_TILE_HOST auto BlockSize() -> dim3
Definition grouped_gemm_quant_kernel.hpp:191
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition grouped_gemm_quant_kernel.hpp:209
CK_TILE_DEVICE void Run(const QuantGroupedGemmKernelArgs &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition grouped_gemm_quant_kernel.hpp:302
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition grouped_gemm_quant_kernel.hpp:128
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition grouped_gemm_quant_kernel.hpp:123
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition grouped_gemm_quant_kernel.hpp:135
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition grouped_gemm_quant_kernel.hpp:127
static CK_TILE_HOST const std::string GetName()
Definition grouped_gemm_quant_kernel.hpp:167
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition grouped_gemm_quant_kernel.hpp:133
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::size_t
Definition grouped_gemm_quant_kernel.hpp:181
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition grouped_gemm_quant_kernel.hpp:186
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_gemm_quant_kernel.hpp:122
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition grouped_gemm_quant_kernel.hpp:502
QuantGroupedGemmKernel< TilePartitioner, GemmPipeline, EpiloguePipeline, kQuantType > Kernel
Definition grouped_gemm_quant_kernel.hpp:160
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, void *smem_ptr_1, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition grouped_gemm_quant_kernel.hpp:359
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition grouped_gemm_quant_kernel.hpp:137
static CK_TILE_HOST auto MakeKargs(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::vector< QuantGemmTransKernelArg >
Definition grouped_gemm_quant_kernel.hpp:231
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_gemm_quant_kernel.hpp:124
static constexpr bool UsePersistentKernel
Definition grouped_gemm_quant_kernel.hpp:164
static constexpr auto kQuantType
Definition grouped_gemm_quant_kernel.hpp:142
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< QuantGemmTransKernelArg > &kargs)
Definition grouped_gemm_quant_kernel.hpp:285
OffsettedTile1DPartitioner< TilePartitioner > OffsetTile1DPartitioner
ALayout and ADataType are expected to be scalars, not a tuple.
Definition grouped_gemm_quant_kernel.hpp:159
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition grouped_gemm_quant_kernel.hpp:139
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition grouped_gemm_quant_kernel.hpp:134
Definition ck_tile/host/stream_config.hpp:30
Definition tile/core/container/tuple.hpp:192