UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor > Struct Template Reference

UniversalGemmKernelArgs&lt; NumATensor, NumBTensor, NumDTensor &gt; Struct Template Reference#

Composable Kernel: ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor > Struct Template Reference
ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor > Struct Template Reference

The GEMM kernel device arguments. More...

#include <universal_gemm_kernel.hpp>

Inheritance diagram for ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >:
ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BatchedGemmKernelArgs ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::StreamKKernelArgs ck_tile::reboot::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::StreamKKernelArgs

Public Attributes

const std::array< const void *, NumATensor > as_ptr
 The As input tensor's pointer to device memory.
const std::array< const void *, NumBTensor > bs_ptr
 The Bs input tensor's pointer to device memory.
const std::array< const void *, NumDTensor > ds_ptr
 The Ds input tensor's pointer to device memory.
void * e_ptr
 The E output tensor's pointer to device memory.
index_t M
 GEMM's M dimension size.
index_t N
 GEMM's N dimension size.
index_t K
 GEMM's K dimension size.
std::array< index_t, NumATensor > stride_As
 The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
std::array< index_t, NumBTensor > stride_Bs
 The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
std::array< index_t, NumDTensor > stride_Ds
 The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
index_t stride_E
 The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
index_t k_batch

Detailed Description

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >

The GEMM kernel device arguments.

Member Data Documentation

◆ as_ptr

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
const std::array<const void*, NumATensor> ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::as_ptr

The As input tensor's pointer to device memory.

◆ bs_ptr

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
const std::array<const void*, NumBTensor> ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::bs_ptr

The Bs input tensor's pointer to device memory.

◆ ds_ptr

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
const std::array<const void*, NumDTensor> ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::ds_ptr

The Ds input tensor's pointer to device memory.

◆ e_ptr

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
void* ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::e_ptr

The E output tensor's pointer to device memory.

◆ K

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
index_t ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::K

GEMM's K dimension size.

◆ k_batch

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
index_t ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::k_batch

◆ M

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
index_t ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::M

GEMM's M dimension size.

◆ N

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
index_t ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::N

GEMM's N dimension size.

◆ stride_As

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
std::array<index_t, NumATensor> ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::stride_As

The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.

◆ stride_Bs

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
std::array<index_t, NumBTensor> ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::stride_Bs

The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.

◆ stride_Ds

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
std::array<index_t, NumDTensor> ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::stride_Ds

The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.

◆ stride_E

template<index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
index_t ck_tile::UniversalGemmKernelArgs< NumATensor, NumBTensor, NumDTensor >::stride_E

The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.


The documentation for this struct was generated from the following file: