FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

FmhaFwdV3Kernel&lt; FmhaPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <fmha_fwd_v3_kernel.hpp>

Classes

struct  FmhaFwdEmptyKargs
struct  FmhaFwdCommonKargs
struct  FmhaFwdMaskKargs
struct  FmhaFwdCommonLSEKargs
struct  FmhaFwdBatchModeKargs
struct  FmhaFwdGroupModeKargs

Public Types

using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>

Public Member Functions

CK_TILE_DEVICE void operator() (Kargs kargs) const

Static Public Member Functions

template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const ck_tile::index_t *cu_seqlen_q_ptr=nullptr, const ck_tile::index_t *cu_seqlen_kv_ptr=nullptr)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *seqstart_padded_q_ptr=nullptr, const void *seqstart_padded_k_ptr=nullptr)
static CK_TILE_HOST constexpr auto GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
static CK_TILE_DEVICE constexpr auto RemapTileIndices (int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
static CK_TILE_DEVICE constexpr auto GetTileIndex (const Kargs &)
static CK_TILE_HOST constexpr auto BlockSize ()
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE
static constexpr bool kHasMask = FmhaMask::IsMasking

Member Typedef Documentation

◆ EpiloguePipeline

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>

◆ FmhaMask

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>

◆ FmhaPipeline

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>

◆ Kargs

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>

◆ KDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>

◆ LSEDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>

◆ ODataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>

◆ QDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>

◆ SaccDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>

◆ VDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>

Member Function Documentation

◆ BlockSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr auto ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::BlockSize ( )
inlinestaticconstexpr

◆ GetSmemSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetTileIndex()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE constexpr auto ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::GetTileIndex ( const Kargs & )
inlinestaticconstexpr

◆ GridSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr auto ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::GridSize ( ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_ )
inlinestaticconstexpr

◆ MakeKargs() [1/2]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
void * lse_ptr,
void * o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt,
const ck_tile::index_t * cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t * cu_seqlen_kv_ptr = nullptr )
inlinestaticconstexpr

◆ MakeKargs() [2/2]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
void * lse_ptr,
void * o_ptr,
const void * seqstart_q_ptr,
const void * seqstart_k_ptr,
const void * seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt,
const void * seqstart_padded_q_ptr = nullptr,
const void * seqstart_padded_k_ptr = nullptr )
inlinestaticconstexpr

◆ operator()()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE void ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::operator() ( Kargs kargs) const
inline

◆ RemapTileIndices()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE constexpr auto ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::RemapTileIndices ( int32_t tg_idx,
int32_t tg_idy,
int32_t remap_option )
inlinestaticconstexpr

Member Data Documentation

◆ kBlockPerCu

template<typename FmhaPipeline_, typename EpiloguePipeline_>
ck_tile::index_t ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCu = FmhaPipeline::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename FmhaPipeline_, typename EpiloguePipeline_>
ck_tile::index_t ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockSize = FmhaPipeline::kBlockSize
staticconstexpr

◆ kHasMask

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kHasMask = FmhaMask::IsMasking
staticconstexpr

◆ kIsGroupMode

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kIsGroupMode = FmhaPipeline::kIsGroupMode
staticconstexpr

◆ kPadHeadDimQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
staticconstexpr

◆ kPadHeadDimV

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimV = FmhaPipeline::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenK

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenK = FmhaPipeline::kPadSeqLenK
staticconstexpr

◆ kPadSeqLenQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
staticconstexpr

◆ kStoreLSE

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::kStoreLSE = FmhaPipeline::kStoreLSE
staticconstexpr

The documentation for this struct was generated from the following file: