gridwise_softmax.hpp Source File

gridwise_softmax.hpp Source File#

Composable Kernel: gridwise_softmax.hpp Source File
gridwise_softmax.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
14
15namespace ck {
16
17template <typename GridwiseReduction,
18 typename InDataType,
19 typename OutDataType,
20 typename AccDataType,
21 typename GridDesc_M_K>
22__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k,
23 const GridDesc_M_K out_grid_desc_m_k,
24 index_t block_group_size,
25 index_t num_k_block_tile_iteration,
26 AccDataType alpha,
27 const InDataType* const __restrict__ p_in_value_global,
28 AccDataType beta,
29 OutDataType* const __restrict__ p_out_value_global)
30{
31 GridwiseReduction::Run(in_grid_desc_m_k,
32 out_grid_desc_m_k,
33 block_group_size,
34 num_k_block_tile_iteration,
35 alpha,
36 p_in_value_global,
37 beta,
38 p_out_value_global);
39};
40
41template <typename InDataType,
42 typename OutDataType,
43 typename AccDataType,
44 typename GridDesc_M_K,
45 index_t BlockSize,
46 index_t MThreadClusterSize,
47 index_t KThreadClusterSize,
48 index_t MThreadSliceSize,
49 index_t KThreadSliceSize,
50 index_t InSrcVectorDim,
51 index_t InSrcVectorSize,
52 index_t OutDstVectorSize,
53 bool SweepOnce>
55{
56 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
57 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
58 (KThreadSliceSize % OutDstVectorSize == 0),
59 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
60
61 static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
62
64
67
70
71 static constexpr auto thread_cluster_desc =
73
78
80
81 static constexpr auto I0 = Number<0>{};
82 static constexpr auto I1 = Number<1>{};
83
84 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
85 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
86
87 __device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k,
88 const GridDesc_M_K& out_grid_desc_m_k,
89 index_t block_group_size,
90 index_t num_k_block_tile_iteration,
91 AccDataType alpha,
92 const InDataType* const __restrict__ p_in_value_global,
93 AccDataType beta,
94 OutDataType* const __restrict__ p_out_value_global)
95 {
96 if constexpr(SweepOnce)
97 {
98 num_k_block_tile_iteration = 1;
99 }
100
101 // LDS
102 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
103
104 auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
105 p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize());
106
107 auto reduce_work_buf =
108 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
109
111 in_thread_buf;
112
114 out_thread_buf;
115
117
119 max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
120 });
121
123
125 accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
126 });
127
128 const index_t thread_local_id = get_thread_local_1d_id();
129 const index_t block_global_id = get_block_1d_id();
130 const index_t blkgroup_id = block_global_id / block_group_size;
131 const index_t block_local_id = block_global_id % block_group_size;
132
133 const auto thread_cluster_idx =
134 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
135
136 const auto thread_m_cluster_id = thread_cluster_idx[I0];
137 const auto thread_k_cluster_id = thread_cluster_idx[I1];
138
139 const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
140
141 using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
142 constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
144
145 // Normally, 0 as invalid element value is adequate since 0 makes no contribution to
146 // accumulated result. However, in stable softmax, all values 0s or not are subtracted by
147 // another value_max. As numbers become non-zero, effectively it allows invalid values to
148 // slip through and contribute to the accumulated result.
149 //
150 // The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
151 // propagate NaNs when operands have NaNs involved. By initialiing invalid element value
152 // with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
153 // be identified as an invalid value. We can then discard the invalid values which
154 // originally failed the bound check during accumulation. This allows to ignore values that
155 // failed bound check even after multiple math manipulations.
156 //
157 // NOTE: reset coordinate after every step because the same threadwise copy will sweep
158 // through global memory 3 times back and forth
159 auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
160 AccDataType,
161 GridDesc_M_K,
162 decltype(thread_buffer_desc),
163 ThreadBufferLengths,
165 InSrcVectorDim,
166 InSrcVectorSize,
167 1,
168 true /* ResetCoordAfterRun */,
169 true /* InvalidElementAsNaN */>(
170 in_grid_desc_m_k,
171 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
172 block_local_id * reduceSizePerBlock +
173 thread_k_cluster_id * KThreadSliceSize));
174
175 auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
176 AccDataType,
177 GridDesc_M_K,
178 decltype(thread_buffer_desc),
179 ThreadBufferLengths,
181 InSrcVectorDim,
182 InSrcVectorSize,
183 1,
184 false>(
185 out_grid_desc_m_k,
186 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
187 block_local_id * reduceSizePerBlock +
188 thread_k_cluster_id * KThreadSliceSize));
189
190 auto threadwise_dst_store =
192 OutDataType,
193 decltype(thread_buffer_desc),
194 GridDesc_M_K,
196 ThreadBufferLengths,
198 InSrcVectorDim,
199 OutDstVectorSize,
201 1,
202 true>(
203 out_grid_desc_m_k,
205 blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
206 block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
207 PassThroughOp{});
208
209 constexpr auto in_thread_copy_fwd_step =
210 make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
211 constexpr auto in_thread_copy_bwd_step =
212 make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
213
217 using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
218 AccDataType,
219 BlockSize,
223 false, // param ignored
225
226 using ThreadwiseMaxReduce =
227 ThreadwiseReduction<AccDataType,
231 false, // param ignored
233
234 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
235 p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize());
236
237 index_t reducedTiles = 0;
238 do
239 {
240 threadwise_src_load.Run(in_grid_desc_m_k,
241 in_global_val_buf,
242 thread_buffer_desc,
243 make_tuple(I0, I0),
244 in_thread_buf);
245
246 ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
247
248 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
249
250 reducedTiles++;
251 } while(reducedTiles < num_k_block_tile_iteration);
252
254 BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
256 });
257
258 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
259
263 using BlockwiseSumReduce = PartitionedBlockwiseReduction<
264 AccDataType,
265 BlockSize,
269 false, // ignored
271
272 using ThreadwiseSumReduce =
273 ThreadwiseReduction<AccDataType,
277 false, // ignored
279
280 reducedTiles = 0;
281 do
282 {
283 if constexpr(!SweepOnce)
284 {
285 threadwise_src_load.Run(in_grid_desc_m_k,
286 in_global_val_buf,
287 thread_buffer_desc,
288 make_tuple(I0, I0),
289 in_thread_buf);
290 }
291
292 // do element-wise pre-reduction operation
295 constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
296 out_thread_buf(Number<offset>{}) =
297 math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
298 });
299 });
300
301 ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
302
303 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
304
305 reducedTiles++;
306 } while(reducedTiles < num_k_block_tile_iteration);
307
308 block_sync_lds(); // wait for reading being complete before writing to LDS
310 BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
312 });
313
314 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
315
319 reducedTiles = 0;
320 if(float_equal_zero{}(beta))
321 {
322 do
323 {
324 if constexpr(!SweepOnce)
325 {
326 threadwise_src_load.Run(in_grid_desc_m_k,
327 in_global_val_buf,
328 thread_buffer_desc,
329 make_tuple(I0, I0),
330 in_thread_buf);
331 }
332
334 // out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
336 constexpr auto offset =
337 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
338 out_thread_buf(Number<offset>{}) =
339 alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
340 accu_value_buf(iM);
341 });
342 });
343
344 threadwise_dst_store.Run(thread_buffer_desc,
345 make_tuple(I0, I0),
346 out_thread_buf,
347 out_grid_desc_m_k,
348 out_global_val_buf);
349
350 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
351 threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
352
353 reducedTiles++;
354 } while(reducedTiles < num_k_block_tile_iteration);
355 }
356 else
357 {
359 AccDataType,
360 MThreadSliceSize * KThreadSliceSize,
361 true>
362 in_prior_dst_buf;
363 do
364 {
365 if constexpr(!SweepOnce)
366 {
367 threadwise_src_load.Run(in_grid_desc_m_k,
368 in_global_val_buf,
369 thread_buffer_desc,
370 make_tuple(I0, I0),
371 in_thread_buf);
372 }
373 threadwise_dst_load.Run(out_grid_desc_m_k,
374 out_global_val_buf,
375 thread_buffer_desc,
376 make_tuple(I0, I0),
377 in_prior_dst_buf);
378
380 // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
382 constexpr auto offset =
383 thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
384 out_thread_buf(Number<offset>{}) =
385 alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
386 accu_value_buf(iM) +
387 beta * in_prior_dst_buf(Number<offset>{});
388 });
389 });
390
391 threadwise_dst_store.Run(thread_buffer_desc,
392 make_tuple(I0, I0),
393 out_thread_buf,
394 out_grid_desc_m_k,
395 out_global_val_buf);
396
397 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
398 threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
399 threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
400
401 reducedTiles++;
402 } while(reducedTiles < num_k_block_tile_iteration);
403 }
404 }
405};
406
407} // namespace ck
__host__ T exp(T x)
Definition math_v2.hpp:391
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, const GridDesc_M_K out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_softmax.hpp:22
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_softmax.hpp:55
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_softmax.hpp:74
static __device__ void Run(const GridDesc_M_K &in_grid_desc_m_k, const GridDesc_M_K &out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_softmax.hpp:87
Definition reduction_functions_blockwise.hpp:28
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_functions_accumulate.hpp:17
Definition reduction_common.hpp:20
Definition reduction_operator.hpp:37
Definition reduction_operator.hpp:163
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340