BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ > Struct Template Reference

BlockFmhaPipelineQRKSVSFp8&lt; Problem_, Policy_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ > Struct Template Reference
ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ > Struct Template Reference

#include <block_fmha_pipeline_qr_ks_vs_fp8.hpp>

Public Types

using Problem = remove_cvref_t<Problem_>
using Policy = remove_cvref_t<Policy_>
using QDataType = remove_cvref_t<typename Problem::QDataType>
using KDataType = remove_cvref_t<typename Problem::KDataType>
using VDataType = remove_cvref_t<typename Problem::VDataType>
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>
using PDataType = remove_cvref_t<typename Problem::PDataType>
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>
using ODataType = remove_cvref_t<typename Problem::ODataType>
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>

Public Member Functions

template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding>
CK_TILE_HOST_DEVICE auto operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &, LSEDramBlockWindowTmp &, FmhaMask mask, PositionEncoding, float scale_s, float descale_qk, float descale_sv, void *smem_ptr, BlockDropout &) const

Static Public Member Functions

static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr bool kQLoadOnce = true
static constexpr index_t kBlockSize = Problem::kBlockSize
static constexpr index_t kM0 = BlockFmhaShape::kM0
static constexpr index_t kN0 = BlockFmhaShape::kN0
static constexpr index_t kK0 = BlockFmhaShape::kK0
static constexpr index_t kN1 = BlockFmhaShape::kN1
static constexpr index_t kK1 = BlockFmhaShape::kK1
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim
static constexpr bool kIsGroupMode = Problem::kIsGroupMode
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV
static constexpr auto BiasEnum = Problem::BiasEnum
static constexpr bool kStoreLSE = Problem::kStoreLSE
static constexpr bool kHasDropout = Problem::kHasDropout
static constexpr index_t kAlignmentQ
static constexpr index_t kAlignmentK
static constexpr index_t kAlignmentV
static constexpr index_t kAlignmentO
static constexpr index_t kAlignmentBias
static constexpr index_t kBlockPerCu
static constexpr const char * name = "qr_fp8"

Member Typedef Documentation

◆ BiasDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType>

◆ BlockFmhaShape

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>

◆ FmhaMask

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask>

◆ KDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::KDataType = remove_cvref_t<typename Problem::KDataType>

◆ LSEDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType>

◆ OaccDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType>

◆ ODataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType>

◆ PDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::PDataType = remove_cvref_t<typename Problem::PDataType>

◆ Policy

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_>

◆ Problem

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_>

◆ QDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::QDataType = remove_cvref_t<typename Problem::QDataType>

◆ RandValOutputDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>

◆ SaccDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::SaccDataType = remove_cvref_t<typename Problem::SaccDataType>

◆ SMPLComputeDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>

◆ VDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::VDataType = remove_cvref_t<typename Problem::VDataType>

◆ VLayout

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>

Member Function Documentation

◆ GetSmemSize()

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ operator()()

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding>
CK_TILE_HOST_DEVICE auto ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::operator() ( const QDramBlockWindowTmp & q_dram_block_window_tmp,
const KDramBlockWindowTmp & k_dram_block_window_tmp,
const VDramBlockWindowTmp & v_dram_block_window_tmp,
const BiasDramBlockWindowTmp & bias_dram_block_window_tmp,
RandValDramBlockWindowTmp & ,
LSEDramBlockWindowTmp & ,
FmhaMask mask,
PositionEncoding ,
float scale_s,
float descale_qk,
float descale_sv,
void * smem_ptr,
BlockDropout &  ) const
inline

NOTICE: bias might be materialized mask including -inf values, need consideration

Member Data Documentation

◆ BiasEnum

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
auto ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::BiasEnum = Problem::BiasEnum
staticconstexpr

◆ kAlignmentBias

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kAlignmentBias
staticconstexpr
Initial value:
=
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>()
static constexpr bool kPadSeqLenK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64

◆ kAlignmentK

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kAlignmentK
staticconstexpr
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>()
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52

◆ kAlignmentO

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kAlignmentO
staticconstexpr
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24

◆ kAlignmentQ

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kAlignmentQ
staticconstexpr
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>()

◆ kAlignmentV

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kAlignmentV
staticconstexpr
Initial value:
= []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}()
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:53
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_appendkv_pipeline.hpp:16
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_appendkv_pipeline.hpp:15

◆ kBlockPerCu

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kBlockPerCu
staticconstexpr
Initial value:
= []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
}
}()
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:46
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:51
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:44

◆ kBlockSize

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kBlockSize = Problem::kBlockSize
staticconstexpr

◆ kHasDropout

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kHasDropout = Problem::kHasDropout
staticconstexpr

◆ kIsGroupMode

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kIsGroupMode = Problem::kIsGroupMode
staticconstexpr

◆ kK0

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kK0 = BlockFmhaShape::kK0
staticconstexpr

◆ kK1

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kK1 = BlockFmhaShape::kK1
staticconstexpr

◆ kM0

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kM0 = BlockFmhaShape::kM0
staticconstexpr

◆ kN0

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kN0 = BlockFmhaShape::kN0
staticconstexpr

◆ kN1

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kN1 = BlockFmhaShape::kN1
staticconstexpr

◆ kPadHeadDimQ

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kPadHeadDimQ = Problem::kPadHeadDimQ
staticconstexpr

◆ kPadHeadDimV

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kPadHeadDimV = Problem::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenK

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kPadSeqLenK = Problem::kPadSeqLenK
staticconstexpr

◆ kPadSeqLenQ

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kPadSeqLenQ = Problem::kPadSeqLenQ
staticconstexpr

◆ kQKHeaddim

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kQKHeaddim = BlockFmhaShape::kQKHeaddim
staticconstexpr

◆ kQLoadOnce

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kQLoadOnce = true
staticconstexpr

◆ kStoreLSE

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::kStoreLSE = Problem::kStoreLSE
staticconstexpr

◆ name

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
const char* ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::name = "qr_fp8"
staticconstexpr

The documentation for this struct was generated from the following file: