gridwise_sparse_embeddings_forward_layernorm.hpp Source File

gridwise_sparse_embeddings_forward_layernorm.hpp Source File#

Composable Kernel: gridwise_sparse_embeddings_forward_layernorm.hpp Source File
gridwise_sparse_embeddings_forward_layernorm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13template <typename GridwiseSparseEmbedding,
14 typename EmbType,
15 typename IndexType,
16 typename GammaDataType,
17 typename BetaDataType,
18 typename AccDataType,
19 typename OutType,
20 typename OutGridDesc,
21 typename EmbElementwiseOperation,
22 ck::index_t NumEmbeddings>
23#if CK_USE_LAUNCH_BOUNDS
25#endif
27 OutType* p_out,
30 const GammaDataType* p_gamma,
31 const BetaDataType* p_beta,
32 const OutGridDesc out_grid_desc,
33 const AccDataType epsilon,
34 const EmbElementwiseOperation emb_elementwise_op)
35{
36 GridwiseSparseEmbedding::Run(
37 p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op);
38}
39
40template <typename EmbType,
41 typename IndexType,
42 typename GammaDataType,
43 typename BetaDataType,
44 typename AccDataType,
45 typename OutType,
46 typename OutGridDesc,
47 typename EmbElementwiseOperation,
48 ck::index_t BlockSize,
49 ck::index_t DimClusterSize,
50 ck::index_t RowClusterSize,
51 ck::index_t DimPerBlock, // Row x Dim, along Dim
52 ck::index_t RowPerBlock, // Row x Dim, along Row
53 ck::index_t DimThreadSize, // this is actually not vector, but number of registers
54 ck::index_t RowVectorSize,
55 ck::index_t NumEmbeddings>
57{
58 static constexpr auto I0 = Number<0>{};
59 static constexpr auto I1 = Number<1>{};
60 static constexpr auto I2 = Number<2>{};
61 static constexpr auto I3 = Number<3>{};
62 static constexpr index_t WaveSize = 64;
63
64 static_assert(BlockSize == RowClusterSize * DimClusterSize,
65 "Invalid cluster distribution within block");
66 static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise");
67
68 static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, "");
69 static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, "");
70
71 static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize);
72 static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize);
73
74 static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), "");
75 static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks;
76 static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks;
77
80
83
86
88
91
92 __device__ static void Run(OutType* p_out,
95 const GammaDataType* p_gamma,
96 const BetaDataType* p_beta,
97 const OutGridDesc,
98 const AccDataType epsilon,
99 const EmbElementwiseOperation emb_elementwise_op)
100 {
101 const index_t thread_local_id = get_thread_local_1d_id();
102 const index_t block_global_id = get_block_1d_id();
103
104 constexpr auto thread_cluster_desc =
106
107 const auto thread_cluster_idx =
108 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
109
110 const auto thread_dim_cluster_id = thread_cluster_idx[I0];
111 const auto thread_row_cluster_id = thread_cluster_idx[I1];
112
113 const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize);
114
115 const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize;
116
117 auto threadwise_welford = ThreadwiseWelford();
118 threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize;
119
120 constexpr auto thread_buf_size =
121 DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize;
122 constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed(
123 make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize));
124 constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize;
125 constexpr auto mean_var_buf_desc =
127 constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize;
128 constexpr auto gamma_beta_buf_desc =
130
132 NumEmbeddings>
133 in_thread_bufs;
135 index_bufs;
136
138
140 gamma_thread_buf;
142 beta_thread_buf;
143
146
147 auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
149 auto emb_a = emb_vectors[0];
150 using src_vector_t = typename decltype(emb_a)::type;
151 static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
152 constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
153
154 auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
155 sizeof(EmbType) * RowVectorSize;
156 static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
157 IndexType index = index_bufs[i_embedding_][Number<current_dim>{}];
158
160 p_embs[i_embedding_] + index * RowPerBlock);
161 emb_vectors(i_embedding_).template AsType<src_vector_t>()(I0) =
162 amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0);
163 });
164
165 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
166 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
167 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
168 static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
169 in_thread_bufs(i_embedding_)(Number<register_offset>{}) =
171 emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_]);
172 });
173 });
174 });
175 };
176
177 auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
178 static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
179 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
180 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
181 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
182 auto in_data_refs = generate_tie(
183 [&](auto i_embedding_) -> const auto& {
184 return in_thread_bufs(i_embedding_)(Number<register_offset>{});
185 },
187 auto out_data_refs = generate_tie(
188 [&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
189 Number<1>{});
190 unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
191 });
192 });
193 };
194
195 auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
196 static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
197 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
198 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
199 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
200 constexpr auto mean_var_offset =
201 mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
202
203 threadwise_welford.cur_count_++;
204 threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
205 var_thread_buf(Number<mean_var_offset>{}),
206 acc_thread_buf(Number<register_offset>{}));
207 });
208 });
209 };
210
211 auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) {
212 int32x4_t out_res =
213 make_wave_buffer_resource_with_default_range(p_out + index_start * RowPerBlock);
214 static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
216 using dst_vector_t = typename decltype(out_vector)::type;
217
218 constexpr auto mean_var_offset =
219 mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
220 auto divisor =
221 1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
222 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
223 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
224 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
225 constexpr auto gamma_beta_offset =
226 gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
227
228 auto acc_val = acc_thread_buf[Number<register_offset>{}];
229 acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) * divisor;
230 acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
231 beta_thread_buf[Number<gamma_beta_offset>{}];
232
233 out_vector.template AsType<OutType>()(Number<i_row_vec_>{}) =
234 type_convert<OutType>(acc_val);
235 });
236
237 index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
238 sizeof(OutType) * RowVectorSize;
239
241 out_vector.template AsType<dst_vector_t>()[Number<0>{}],
242 out_res,
243 thread_offset,
244 0);
245 });
246 };
247
248 // first load index
249 ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
250 // prefer use s_load
251 ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
252 index_bufs(i_embedding_)(i_idx_) =
253 p_indexes[i_embedding_][index_start + i_idx_.value];
254 });
255 });
256
257 // load gamma/beta
258 static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) {
261
262 index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
263 sizeof(GammaDataType) * RowVectorSize;
264 index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
265 sizeof(BetaDataType) * RowVectorSize;
266
269
270 gamma_vector.template AsType<typename decltype(gamma_vector)::type>()(I0) =
272 gamma_res, thread_offset_gamma, 0);
273 beta_vector.template AsType<typename decltype(beta_vector)::type>()(I0) =
274 amd_buffer_load_impl<BetaDataType, RowVectorSize>(beta_res, thread_offset_beta, 0);
275
276 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
277 constexpr auto offset =
278 gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
279 gamma_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
280 gamma_vector.template AsType<GammaDataType>()[Number<i_row_vec_>{}]);
281 beta_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
282 beta_vector.template AsType<BetaDataType>()[Number<i_row_vec_>{}]);
283 });
284 });
285
287 [&](auto I) { acc_thread_buf(I) = type_convert<AccDataType>(0.0f); });
288
290 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
291 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
292 });
293
294 static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) {
295 load_current_sub_row(i_dim_sub, Number<0>{});
296 static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) {
297 load_current_sub_row(i_dim_sub, Number<1>{} + i_row);
298 accumulate_current_sub_row(i_dim_sub, i_row);
299 threadwise_welford_sub_row(i_dim_sub, i_row);
300 });
301 accumulate_current_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
302 threadwise_welford_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
303
304 // blockwise welford
306 if constexpr(I > 0)
309 mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
310 });
311
312 // store
314 [&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); });
315 });
316 }
317};
318
319} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
Definition ck.hpp:268
__device__ int32x4_t make_wave_buffer_resource_with_default_range(T *p_wave)
Definition utility/amd_buffer_addressing.hpp:38
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__device__ void amd_buffer_store_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:544
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
typename vector_type< int32_t, 4 >::type int32x4_t
Definition dtype_vector.hpp:2168
__global__ void kernel_sparse_embeddings_forward_layernorm(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc out_grid_desc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:26
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ vector_type< T, N >::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:419
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
Definition utility/array.hpp:14
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:57
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< DimSubBlocks *DimThreadSize >{}, Number< RowSubBlocks *RowVectorSize >{}))) ThreadwiseWolfordDesc2D
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:78
static __device__ void Run(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:92
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition functional2.hpp:33