universal_gemm_kernel.hpp Source File#
universal_gemm_kernel.hpp
Go to the documentation of this file.
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition tile/host/hip_check_error.hpp:13
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
The Universal GEMM kernel host arguments.
Definition universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition universal_gemm_kernel.hpp:33
const std::array< const void *, NumDTensor > ds_ptr
Definition universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition universal_gemm_kernel.hpp:71
const std::array< const void *, NumBTensor > bs_ptr
Definition universal_gemm_kernel.hpp:61
Definition universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition universal_gemm_kernel.hpp:368
index_t splitted_k
Definition universal_gemm_kernel.hpp:370
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition universal_gemm_kernel.hpp:326
std::array< index_t, NumBTensor > bs_k_split_offset
Definition universal_gemm_kernel.hpp:369
Definition universal_gemm_kernel.hpp:206
static constexpr bool value
Definition universal_gemm_kernel.hpp:210
decltype(T::UsePersistentKernel) has_persistent_type
Definition universal_gemm_kernel.hpp:208
Definition universal_gemm_kernel.hpp:221
decltype(T::GetOutputOffset(std::declval< KernelArgs >(), std::declval< index_t >())) has_get_output_offset_t
Definition universal_gemm_kernel.hpp:223
static constexpr bool value
Definition universal_gemm_kernel.hpp:226
The GEMM kernel device arguments.
Definition universal_gemm_kernel.hpp:86
void * e_ptr
Definition universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
Definition universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
Definition universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
Definition universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
Definition universal_gemm_kernel.hpp:88
index_t k_batch
Definition universal_gemm_kernel.hpp:113
index_t N
Definition universal_gemm_kernel.hpp:98
index_t stride_E
Definition universal_gemm_kernel.hpp:112
index_t K
Definition universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
Definition universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
Definition universal_gemm_kernel.hpp:109
index_t M
Definition universal_gemm_kernel.hpp:96
The Universal GEMM kernel template.
Definition universal_gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition universal_gemm_kernel.hpp:1053
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition universal_gemm_kernel.hpp:156
static CK_TILE_HOST const std::string GetName()
Definition universal_gemm_kernel.hpp:260
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > > > DsLayout
Definition universal_gemm_kernel.hpp:179
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition universal_gemm_kernel.hpp:155
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition universal_gemm_kernel.hpp:1127
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:955
static constexpr bool BDataTypeIsTuple
Definition universal_gemm_kernel.hpp:161
static constexpr bool BLayoutIsTuple
Definition universal_gemm_kernel.hpp:167
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BsDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > > > BsDataType
Definition universal_gemm_kernel.hpp:187
remove_cvref_t< typename GemmPipeline::BElementWise > BElementWise
Definition universal_gemm_kernel.hpp:200
static constexpr index_t NumATensor
Definition universal_gemm_kernel.hpp:241
static constexpr bool ALayoutIsTuple
Definition universal_gemm_kernel.hpp:165
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:1010
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition universal_gemm_kernel.hpp:853
static constexpr bool ADataTypeIsTuple
Definition universal_gemm_kernel.hpp:159
static constexpr bool has_tile_partitioner_output_offset
Definition universal_gemm_kernel.hpp:233
static CK_TILE_HOST constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition universal_gemm_kernel.hpp:267
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > > > DsDataType
Definition universal_gemm_kernel.hpp:191
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition universal_gemm_kernel.hpp:754
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BsLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > > > BsLayout
Definition universal_gemm_kernel.hpp:175
static constexpr index_t NumDTensor
Definition universal_gemm_kernel.hpp:243
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition universal_gemm_kernel.hpp:245
static constexpr bool DDataTypeIsTuple
Definition universal_gemm_kernel.hpp:163
static constexpr bool PersistentKernel
Definition universal_gemm_kernel.hpp:217
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition universal_gemm_kernel.hpp:319
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition universal_gemm_kernel.hpp:196
static CK_TILE_HOST auto BlockSize()
Definition universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition universal_gemm_kernel.hpp:278
static constexpr index_t NumBTensor
Definition universal_gemm_kernel.hpp:242
static CK_TILE_DEVICE auto MakeGemmTensorViews(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const index_t k_size)
Definition universal_gemm_kernel.hpp:580
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::AsLayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > > > AsLayout
Definition universal_gemm_kernel.hpp:172
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
static constexpr bool DLayoutIsTuple
Definition universal_gemm_kernel.hpp:169
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition universal_gemm_kernel.hpp:157
static CK_TILE_HOST constexpr KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition universal_gemm_kernel.hpp:303
remove_cvref_t< typename GemmPipeline::AElementWise > AElementWise
Definition universal_gemm_kernel.hpp:199
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::AsDataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > > > AsDataType
Definition universal_gemm_kernel.hpp:183
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition universal_gemm_kernel.hpp:257
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition universal_gemm_kernel.hpp:197
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition universal_gemm_kernel.hpp:246
Definition type_traits.hpp:115
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition ck_tile/host/stream_config.hpp:30