FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ > Struct Template Reference

FmhaBwdDQDKDVKernel&lt; FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ > Struct Template Reference
ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ > Struct Template Reference

#include <fmha_bwd_kernel.hpp>

Classes

struct  t2s
struct  t2s< float >
struct  t2s< ck_tile::fp16_t >
struct  t2s< ck_tile::bf16_t >
struct  FmhaBwdEmptyKargs
struct  FmhaBwdCommonKargs
struct  FmhaBwdCommonBiasKargs
struct  FmhaBwdBatchModeBiasKargs
struct  FmhaBwdAlibiKargs
struct  FmhaBwdCommonBiasGradKargs
struct  FmhaBwdBatchModeBiasGradKargs
struct  FmhaBwdMaskKargs
struct  FmhaBwdDropoutSeedOffset
struct  FmhaBwdCommonDropoutKargs
struct  FmhaBwdBatchModeDropoutKargs
struct  FmhaBwdDeterministicKargs
struct  FmhaBwdBatchModeKargs
struct  FmhaBwdGroupModeKargs

Public Types

using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>
using QGradEpiloguePipeline = ck_tile::remove_cvref_t<QGradEpiloguePipeline_>
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>
using GemmDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::GemmDataType>
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>
using AccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::AccDataType>
using DDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::DDataType>
using RandValOutputDataType
using OGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::OGradDataType>
using QGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QGradDataType>
using KGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KGradDataType>
using VGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VGradDataType>
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>
using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>

Public Member Functions

CK_TILE_DEVICE void operator() (Kargs kargs) const
CK_TILE_DEVICE void run_ (Kargs kargs) const

Static Public Member Functions

static CK_TILE_HOST std::string GetName ()
template<typename... Ts>
static CK_TILE_HOST constexpr Kargs MakeKargs (Ts... args, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
template<typename... Ts>
static CK_TILE_HOST constexpr Kargs MakeKargs (Ts... args, const std::tuple< const void *, const void * > &drop_seed_offset)
template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *cu_seqlen_q_ptr, const void *cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset)
static CK_TILE_HOST constexpr auto GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
static CK_TILE_DEVICE constexpr auto GetTileIndex ()
static CK_TILE_HOST dim3 BlockSize ()
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu
static constexpr bool kUseQrQtrDorPipeline
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode
static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad
static constexpr bool kHasMask = FmhaMask::IsMasking
static constexpr bool kHasDropout = FmhaDropout::IsDropout
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ
static constexpr bool kIsAvailable = !kUseTrLoad

Member Typedef Documentation

◆ AccDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::AccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::AccDataType>

◆ BiasDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>

◆ BiasGradDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>

◆ DDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::DDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::DDataType>

◆ FmhaDropout

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>

◆ FmhaMask

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>

◆ FmhaPipeline

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>

◆ GemmDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::GemmDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::GemmDataType>

◆ Kargs

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>

◆ KDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>

◆ KGradDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::KGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KGradDataType>

◆ KGradEpiloguePipeline

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>

◆ LSEDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>

◆ OGradDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::OGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::OGradDataType>

◆ QDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>

◆ QGradDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::QGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QGradDataType>

◆ QGradEpiloguePipeline

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::QGradEpiloguePipeline = ck_tile::remove_cvref_t<QGradEpiloguePipeline_>

◆ RandValOutputDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::RandValOutputDataType
Initial value:
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21

◆ VDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>

◆ VGradDataType

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::VGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VGradDataType>

◆ VGradEpiloguePipeline

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
using ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>

Member Function Documentation

◆ BlockSize()

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
CK_TILE_HOST dim3 ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::BlockSize ( )
inlinestatic

◆ GetName()

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
CK_TILE_HOST std::string ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetTileIndex()

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
CK_TILE_DEVICE constexpr auto ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::GetTileIndex ( )
inlinestaticconstexpr

◆ GridSize()

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
CK_TILE_HOST constexpr auto ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::GridSize ( ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_k_ )
inlinestaticconstexpr

◆ MakeKargs() [1/2]

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
template<typename... Ts>
CK_TILE_HOST constexpr Kargs ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::MakeKargs ( Ts... args,
const std::tuple< const void *, const void * > & drop_seed_offset )
inlinestaticconstexpr

◆ MakeKargs() [2/2]

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
template<typename... Ts>
CK_TILE_HOST constexpr Kargs ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::MakeKargs ( Ts... args,
const std::tuple< uint64_t, uint64_t > & drop_seed_offset )
inlinestaticconstexpr

◆ MakeKargsImpl() [1/2]

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::MakeKargsImpl ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
const void * lse_ptr,
const void * do_ptr,
const void * d_ptr,
void * rand_val_ptr,
void * dk_ptr,
void * dv_ptr,
void * dbias_ptr,
void * dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset )
inlinestaticconstexpr

◆ MakeKargsImpl() [2/2]

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::MakeKargsImpl ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
const void * lse_ptr,
const void * do_ptr,
const void * d_ptr,
void * rand_val_ptr,
void * dk_ptr,
void * dv_ptr,
void * dbias_ptr,
void * dq_acc_ptr,
const void * seqstart_q_ptr,
const void * seqstart_k_ptr,
const void * seqlen_q_ptr,
const void * seqlen_k_ptr,
const void * cu_seqlen_q_ptr,
const void * cu_seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset )
inlinestaticconstexpr

◆ operator()()

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
CK_TILE_DEVICE void ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::operator() ( Kargs kargs) const
inline

◆ run_()

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
CK_TILE_DEVICE void ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::run_ ( Kargs kargs) const
inline

FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20

Member Data Documentation

◆ BiasEnum

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
auto ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::BiasEnum = FmhaPipeline::BiasEnum
staticconstexpr

◆ kBlockPerCu

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
ck_tile::index_t ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kBlockPerCu = FmhaPipeline::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
ck_tile::index_t ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kBlockSize = FmhaPipeline::kBlockSize
staticconstexpr

◆ kHasBiasGrad

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kHasBiasGrad = FmhaPipeline::kHasBiasGrad
staticconstexpr

◆ kHasDropout

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kHasDropout = FmhaDropout::IsDropout
staticconstexpr

◆ kHasMask

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kHasMask = FmhaMask::IsMasking
staticconstexpr

◆ kIsAvailable

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kIsAvailable = !kUseTrLoad
staticconstexpr

◆ kIsDeterministic

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kIsDeterministic = FmhaPipeline::kIsDeterministic
staticconstexpr

◆ kIsGroupMode

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kIsGroupMode = FmhaPipeline::kIsGroupMode
staticconstexpr

◆ kIsStoreRandval

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kIsStoreRandval = FmhaDropout::IsStoreRandval
staticconstexpr

◆ kMaxSeqLenQ

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
index_t ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ
staticconstexpr

◆ kPadHeadDimQ

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
index_t ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
staticconstexpr

◆ kPadHeadDimV

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
index_t ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kPadHeadDimV = FmhaPipeline::kPadHeadDimV
staticconstexpr

◆ kUseQrQtrDorPipeline

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kUseQrQtrDorPipeline
staticconstexpr
Initial value:
=
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:777

◆ kUseTrLoad

template<typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_, typename QGradEpiloguePipeline_ = void>
bool ck_tile::FmhaBwdDQDKDVKernel< FmhaPipeline_, KGradEpiloguePipeline_, VGradEpiloguePipeline_, QGradEpiloguePipeline_ >::kUseTrLoad = FmhaPipeline::kUseTrLoad
staticconstexpr

The documentation for this struct was generated from the following file: