11#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
12#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
14#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
15#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
18#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
19#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
24__device__
inline float
27#if(defined(__gfx90a__) || defined(__gfx94__)) && \
28 (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
29 CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
31 float result, numerator, denominator;
33 "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
34 "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
35 "v_rcp_f32_e32 %[denominator], %[denominator]\n"
36 "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
37 "v_mul_f32_e32 %[result], %[numerator], %[denominator]"
38 : [numerator]
"=&v"(numerator), [denominator]
"=&v"(denominator), [result]
"=v"(result)
39 : [softmax_scale]
"s"(softmax_scale),
41 [logits_soft_cap_rcp]
"v"(logits_soft_cap_rcp));
44 return softmax_scale * logits *
rcp<float>(1.f +
abs(logits * logits_soft_cap_rcp));
49template <
typename ImplMask>
61template <
typename ImplMask,
bool UseExp2 = false>
101 if constexpr(UseExp2)
110 float logits_soft_cap_,
111 float logits_soft_cap_rcp_)
119 if constexpr(UseExp2)
136 template <
typename Params,
typename T>
144 template <
typename Params,
typename T>
147 [[maybe_unused]]
uint32_t batch_idx,
149 [[maybe_unused]]
uint32_t qo_head_idx,
150 [[maybe_unused]]
uint32_t kv_head_idx)
const
155 template <
typename Params>
156 __device__ __forceinline__
bool LogitsMask(
const Params& params,
157 [[maybe_unused]]
uint32_t batch_idx,
160 [[maybe_unused]]
uint32_t qo_head_idx,
161 [[maybe_unused]]
uint32_t kv_head_idx)
const
163 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
167template <
bool UseExp2 = false>
172 template <
typename Params,
typename T>
175 if constexpr(UseExp2)
187 template <
typename Params,
typename T>
190 [[maybe_unused]]
uint32_t batch_idx,
192 [[maybe_unused]]
uint32_t qo_head_idx,
193 [[maybe_unused]]
uint32_t kv_head_idx)
const
195 if constexpr(UseExp2)
197#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
198 return params.logits_soft_cap *
200#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
207#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
208 return params.logits_soft_cap *
210#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
217 template <
typename Params>
218 __device__ __forceinline__
bool LogitsMask(
const Params& params,
219 [[maybe_unused]]
uint32_t batch_idx,
222 [[maybe_unused]]
uint32_t qo_head_idx,
223 [[maybe_unused]]
uint32_t kv_head_idx)
const
225 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
234template <u
int32_t VARIANT_CODE,
bool UseExp2 = false>
243 template <
typename Params,
typename T>
255 template <
typename Params,
typename T>
258 [[maybe_unused]]
uint32_t batch_idx,
260 [[maybe_unused]]
uint32_t qo_head_idx,
261 [[maybe_unused]]
uint32_t kv_head_idx)
const
265 if constexpr(UseExp2)
267#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
268 return params.logits_soft_cap *
270#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
277#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
278 return params.logits_soft_cap *
280#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
290 template <
typename Params>
291 __device__ __forceinline__
bool LogitsMask(
const Params& params,
292 [[maybe_unused]]
uint32_t batch_idx,
295 [[maybe_unused]]
uint32_t qo_head_idx,
296 [[maybe_unused]]
uint32_t kv_head_idx)
const
298 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
__device__ float exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
Definition variants.hpp:25
Definition tile/core/algorithm/cluster_descriptor.hpp:13
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
constexpr uint32_t ALIBI
Definition variants.hpp:232
CK_TILE_DEVICE float tanh_fast< float >(float x)
Definition tile/core/numeric/math.hpp:1394
constexpr uint32_t LOGITS_SOFT_CAP
Definition variants.hpp:231
constexpr uint32_t CUSTOM_MASK
Definition variants.hpp:229
constexpr T log2e_rcp_v
Definition tile/core/numeric/math.hpp:491
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
constexpr uint32_t SLIDING_WINDOW
Definition variants.hpp:230
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST T rcp(T x)
Definition tile/core/numeric/math.hpp:896
Definition allocators.h:459
unsigned int uint32_t
Definition stdint.h:126
__device__ __host__ ComposedAttention()=default
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:291
static constexpr bool use_exp2
Definition variants.hpp:237
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:256
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition variants.hpp:244
static constexpr bool use_logits_soft_cap
Definition variants.hpp:239
__device__ __host__ LogitsSoftCap()=default
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:218
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:188
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition variants.hpp:173
float logits_soft_cap_rcp
Definition variants.hpp:129
__host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition variants.hpp:87
const ImplMask & impl_mask
Definition variants.hpp:126
__device__ __host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_, float logits_soft_cap_rcp_)
Definition variants.hpp:108
__device__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition variants.hpp:65
float sm_scale
Definition variants.hpp:127
float logits_soft_cap
Definition variants.hpp:128
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition variants.hpp:137
__device__ __host__ StandardAttention()=default
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:145
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:156
const ImplMask & impl_mask
Definition variants.hpp:57
__device__ __host__ StandardAttentionParams(const ImplMask &impl_mask_, float sm_scale_)
Definition variants.hpp:52
float sm_scale
Definition variants.hpp:58