gridwise_normalization_welford_variance.hpp Source File

gridwise_normalization_welford_variance.hpp Source File#

Composable Kernel: gridwise_normalization_welford_variance.hpp Source File
gridwise_normalization_welford_variance.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14// Y = Normalization(X, Beta, Gamma)
15template <typename XDataType,
16 typename GammaDataType,
17 typename BetaDataType,
18 typename YDataType,
19 typename SaveMeanInvStdDataType,
20 typename ComputeDataType,
21 typename YElementwiseOperation,
22 typename GridDesc_M_K,
23 typename GridDesc_M,
24 index_t BlockSize,
25 index_t MThreadClusterSize,
26 index_t KThreadClusterSize,
27 index_t MThreadSliceSize,
28 index_t KThreadSliceSize,
29 index_t XSrcVectorDim,
30 index_t XSrcVectorSize,
31 index_t GammaSrcVectorDim,
32 index_t GammaSrcVectorSize,
33 index_t BetaSrcVectorDim,
34 index_t BetaSrcVectorSize,
35 index_t YDstVectorDim,
36 index_t YDstVectorSize,
37 index_t SaveMeanInvStdDstVectorSize,
38 bool SweepOnce>
40{
41 static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
42 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
43 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
44
45 static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
46 (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
47 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
48
49 static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
50 "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
51 "configuration, please check!");
52
53 static_assert(XSrcVectorSize == YDstVectorSize);
54 static_assert(XSrcVectorSize == GammaSrcVectorSize);
55 static_assert(XSrcVectorSize == BetaSrcVectorSize);
56
57 static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
58
60
63
66
67 static constexpr auto thread_cluster_desc =
69
73
75 static constexpr auto thread_buffer_desc_m =
77
82
85
86 using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
87 BlockSize,
90
92
93 static constexpr auto I0 = Number<0>{};
94 static constexpr auto I1 = Number<1>{};
95 static constexpr auto I2 = Number<2>{};
96
97 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
98 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
99 static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
100
101 static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
102
103 __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
104 int thread_k_cluster_id)
105 {
106 // FIXME: Should not hack the transform from deviceOP
107 int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
108 int kPerThread =
109 kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
110 int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
111
112 if(kPerBlockTail > 0)
113 {
115 int thread_max_len =
116 (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
117 int delta = thread_max_len - kPerBlockTail;
118 delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
119 kPerThread += XSrcVectorSize - delta;
120 });
121 }
122
123 return kPerThread;
124 }
125
126 __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
127 const GridDesc_M_K& gamma_grid_desc_m_k,
128 const GridDesc_M_K& beta_grid_desc_m_k,
129 const GridDesc_M_K& y_grid_desc_m_k,
130 const GridDesc_M& save_mean_grid_desc_m,
131 const GridDesc_M& save_inv_std_grid_desc_m,
132 index_t num_k_block_tile_iteration,
133 ComputeDataType epsilon,
134 const XDataType* const __restrict__ p_x_global,
135 const GammaDataType* const __restrict__ p_gamma_global,
136 const BetaDataType* const __restrict__ p_beta_global,
137 YDataType* const __restrict__ p_y_global,
138 SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
139 SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
140 const YElementwiseOperation y_elementwise_op)
141 {
142 auto x_thread_buf = generate_tuple(
143 [&](auto) {
145 ComputeDataType,
146 MThreadSliceSize * XSrcVectorSize,
147 true>{};
148 },
150
151 auto gamma_thread_buf = generate_tuple(
152 [&](auto) {
154 ComputeDataType,
155 MThreadSliceSize * GammaSrcVectorSize,
156 true>{};
157 },
159
160 auto& beta_thread_buf = gamma_thread_buf;
161 auto& y_thread_buf = x_thread_buf;
162
164 mean_thread_buf;
166 var_thread_buf;
167 auto& inv_std_thread_buf = var_thread_buf;
168
169 const index_t thread_local_id = get_thread_local_1d_id();
170 const index_t block_global_id = get_block_1d_id();
171
172 const auto thread_cluster_idx =
173 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
174
175 const auto thread_m_cluster_id = thread_cluster_idx[I0];
176 const auto thread_k_cluster_id = thread_cluster_idx[I1];
177
178 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
179 ComputeDataType,
180 GridDesc_M_K,
181 decltype(thread_buffer_desc_m_k),
184 XSrcVectorDim,
185 XSrcVectorSize,
186 1,
187 true>(
188 x_grid_desc_m_k,
189 make_multi_index(block_global_id * M_BlockTileSize +
190 thread_m_cluster_id * MThreadSliceSize,
191 thread_k_cluster_id * XSrcVectorSize));
192
193 auto threadwise_gamma_load =
195 ComputeDataType,
196 GridDesc_M_K,
197 decltype(thread_buffer_desc_m_k),
200 GammaSrcVectorDim,
201 GammaSrcVectorSize,
202 1,
203 true>(
204 gamma_grid_desc_m_k,
205 make_multi_index(block_global_id * M_BlockTileSize +
206 thread_m_cluster_id * MThreadSliceSize,
207 thread_k_cluster_id * GammaSrcVectorSize));
208
209 auto threadwise_beta_load =
211 ComputeDataType,
212 GridDesc_M_K,
213 decltype(thread_buffer_desc_m_k),
216 BetaSrcVectorDim,
217 BetaSrcVectorSize,
218 1,
219 true>(
220 beta_grid_desc_m_k,
221 make_multi_index(block_global_id * M_BlockTileSize +
222 thread_m_cluster_id * MThreadSliceSize,
223 thread_k_cluster_id * BetaSrcVectorSize));
224
225 auto threadwise_y_store =
227 YDataType,
228 decltype(thread_buffer_desc_m_k),
229 GridDesc_M_K,
230 YElementwiseOperation,
233 YDstVectorDim,
234 YDstVectorSize,
236 1,
237 true>(
238 y_grid_desc_m_k,
239 make_multi_index(block_global_id * M_BlockTileSize +
240 thread_m_cluster_id * MThreadSliceSize,
241 thread_k_cluster_id * YDstVectorSize),
242 y_elementwise_op);
243
244 auto threadwise_mean_store =
246 SaveMeanInvStdDataType,
247 decltype(thread_buffer_desc_m),
248 GridDesc_M,
251 Sequence<0>, // DimAccessOrder
252 0, // SrcVectorDim
253 SaveMeanInvStdDstVectorSize, // ScalarPerVector
255 1,
256 true>(
257 save_mean_grid_desc_m,
258 make_multi_index(block_global_id * M_BlockTileSize +
259 thread_m_cluster_id * MThreadSliceSize),
260 PassThroughOp{});
261
262 auto threadwise_inv_std_store =
264 SaveMeanInvStdDataType,
265 decltype(thread_buffer_desc_m),
266 GridDesc_M,
269 Sequence<0>, // DimAccessOrder
270 0, // SrcVectorDim
271 SaveMeanInvStdDstVectorSize, // ScalarPerVector
273 1,
274 true>(
275 save_inv_std_grid_desc_m,
276 make_multi_index(block_global_id * M_BlockTileSize +
277 thread_m_cluster_id * MThreadSliceSize),
278 PassThroughOp{});
279
280 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
281 constexpr auto thread_copy_bwd_step_m_k =
282 make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
283
284 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
285 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
286
287 const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
288 p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
289
290 const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
291 p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
292
293 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
294 p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
295
296 auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
297 p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
298
299 auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
300 p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
301
302 auto threadwise_welford = ThreadwiseWelford();
303 threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
304
306 mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
307 var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
308 });
309
310 // Separate sweep once and sweep twice pipeline
311 if constexpr(SweepOnce)
312 {
314 threadwise_x_load.Run(x_grid_desc_m_k,
315 x_global_val_buf,
317 make_tuple(I0, I0),
318 x_thread_buf(i));
319
320 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
321 gamma_global_val_buf,
323 make_tuple(I0, I0),
324 gamma_thread_buf(i));
325
326 threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
327
328 if constexpr(i != ThreadBufferNumber - 1)
329 {
330 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
331 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
332 thread_copy_fwd_step_m_k);
333 }
334 });
335
337 if constexpr(I > 0)
339
340 int count = threadwise_welford.cur_count_;
341 BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
342 inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
343 ck::math::sqrt(var_thread_buf(I) + epsilon);
344 });
345
346 // save mean and inverse std for backward (optional)
347 if(thread_k_cluster_id == 0)
348 {
349 if(p_save_mean_global != nullptr)
350 {
351 threadwise_mean_store.Run(thread_buffer_desc_m,
352 make_tuple(I0),
353 mean_thread_buf,
354 save_mean_grid_desc_m,
355 save_mean_global_val_buf);
356 }
357 if(p_save_inv_std_global != nullptr)
358 {
359 threadwise_inv_std_store.Run(thread_buffer_desc_m,
360 make_tuple(I0),
361 inv_std_thread_buf,
362 save_inv_std_grid_desc_m,
363 save_inv_std_global_val_buf);
364 }
365 }
366
367 // normalization
370 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
371 constexpr auto offset_m_k =
372 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
373
374 // normalize
375 y_thread_buf(iK0)(Number<offset_m_k>{}) =
376 (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
377 inv_std_thread_buf(iM);
378
379 // gamma & beta
380 y_thread_buf(iK0)(Number<offset_m_k>{}) =
381 y_thread_buf(iK0)(Number<offset_m_k>{}) *
382 gamma_thread_buf(iK0)(Number<offset_m_k>{});
383 });
384 });
385 });
386
388 threadwise_beta_load.Run(beta_grid_desc_m_k,
389 beta_global_val_buf,
391 make_tuple(I0, I0),
392 beta_thread_buf(i));
393
394 if constexpr(i != ThreadBufferNumber - 1)
395 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
396 thread_copy_fwd_step_m_k);
397 });
398
401 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
402 constexpr auto offset_m_k =
403 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
404
405 // beta
406 y_thread_buf(iK0)(Number<offset_m_k>{}) =
407 y_thread_buf(iK0)(Number<offset_m_k>{}) +
408 beta_thread_buf(iK0)(Number<offset_m_k>{});
409 });
410 });
411 });
412
414 threadwise_y_store.Run(thread_buffer_desc_m_k,
415 make_tuple(I0, I0),
416 y_thread_buf(i),
417 y_grid_desc_m_k,
418 y_global_val_buf);
419
420 if constexpr(i != ThreadBufferNumber - 1)
421 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
422 thread_copy_fwd_step_m_k);
423 });
424 } // end of sweep once
425 else
426 {
427 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
428 {
430 threadwise_x_load.Run(x_grid_desc_m_k,
431 x_global_val_buf,
433 make_tuple(I0, I0),
434 x_thread_buf(i));
435 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
436 threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
437 });
438 }
439
441 if constexpr(I > 0)
443
444 int count = threadwise_welford.cur_count_;
445 BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
446 inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
447 });
448
449 if(thread_k_cluster_id == 0)
450 {
451 if(p_save_mean_global != nullptr)
452 {
453 threadwise_mean_store.Run(thread_buffer_desc_m,
454 make_tuple(I0),
455 mean_thread_buf,
456 save_mean_grid_desc_m,
457 save_mean_global_val_buf);
458 }
459 if(p_save_inv_std_global != nullptr)
460 {
461 threadwise_inv_std_store.Run(thread_buffer_desc_m,
462 make_tuple(I0),
463 inv_std_thread_buf,
464 save_inv_std_grid_desc_m,
465 save_inv_std_global_val_buf);
466 }
467 }
468
469 auto thread_copy_tail_m_k =
470 (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
471
472 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
473 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
474 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
475 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
476
477 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
478 {
480 threadwise_x_load.Run(x_grid_desc_m_k,
481 x_global_val_buf,
483 make_tuple(I0, I0),
484 x_thread_buf(i));
485 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
486 });
487
489 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
490 gamma_global_val_buf,
492 make_tuple(I0, I0),
493 gamma_thread_buf(i));
494
495 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
496 thread_copy_fwd_step_m_k);
497 });
498
501 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
502 constexpr auto offset_m_k =
503 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
504
505 // normalize
506 y_thread_buf(iK0)(Number<offset_m_k>{}) =
507 (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
508 inv_std_thread_buf(iM);
509
510 // gamma
511 y_thread_buf(iK0)(Number<offset_m_k>{}) =
512 y_thread_buf(iK0)(Number<offset_m_k>{}) *
513 gamma_thread_buf(iK0)(Number<offset_m_k>{});
514 });
515 });
516 });
517
519 threadwise_beta_load.Run(beta_grid_desc_m_k,
520 beta_global_val_buf,
522 make_tuple(I0, I0),
523 beta_thread_buf(i));
524 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
525 thread_copy_fwd_step_m_k);
526 });
527
530 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
531 constexpr auto offset_m_k =
532 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
533
534 // beta
535 y_thread_buf(iK0)(Number<offset_m_k>{}) =
536 y_thread_buf(iK0)(Number<offset_m_k>{}) +
537 beta_thread_buf(iK0)(Number<offset_m_k>{});
538 });
539 });
540 });
541
543 threadwise_y_store.Run(thread_buffer_desc_m_k,
544 make_tuple(I0, I0),
545 y_thread_buf(i),
546 y_grid_desc_m_k,
547 y_global_val_buf);
548 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
549 thread_copy_fwd_step_m_k);
550 });
551
552 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
553 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
554 2 * thread_copy_bwd_step_m_k);
555 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
556 2 * thread_copy_bwd_step_m_k);
557 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
558 2 * thread_copy_bwd_step_m_k);
559 }
560 } // end of sweep twice
561 }
562};
563
564} // namespace ck
__host__ __device__ constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition utility/math.hpp:148
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_normalization_welford_variance.hpp:40
static constexpr auto thread_cluster_desc
Definition gridwise_normalization_welford_variance.hpp:67
static __device__ int GetKPerThread(const GridDesc_M_K &x_grid_desc_m_k, int thread_k_cluster_id)
Definition gridwise_normalization_welford_variance.hpp:103
static constexpr bool reorder_thread_cluster
Definition gridwise_normalization_welford_variance.hpp:57
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_normalization_welford_variance.hpp:78
static constexpr auto I0
Definition gridwise_normalization_welford_variance.hpp:93
static constexpr index_t K_BlockTileSize
Definition gridwise_normalization_welford_variance.hpp:98
static constexpr index_t M_BlockTileSize
Definition gridwise_normalization_welford_variance.hpp:97
ThreadwiseWelford< ComputeDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition gridwise_normalization_welford_variance.hpp:83
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition gridwise_normalization_welford_variance.hpp:74
static constexpr index_t K_BlockTileStepSize
Definition gridwise_normalization_welford_variance.hpp:99
static constexpr auto I1
Definition gridwise_normalization_welford_variance.hpp:94
static __device__ void Run(const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &beta_grid_desc_m_k, const GridDesc_M_K &y_grid_desc_m_k, const GridDesc_M &save_mean_grid_desc_m, const GridDesc_M &save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition gridwise_normalization_welford_variance.hpp:126
BlockwiseWelford< ComputeDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition gridwise_normalization_welford_variance.hpp:86
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_normalization_welford_variance.hpp:59
static constexpr auto ThreadBufferNumber
Definition gridwise_normalization_welford_variance.hpp:101
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition gridwise_normalization_welford_variance.hpp:70
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_normalization_welford_variance.hpp:91
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_normalization_welford_variance.hpp:61
static constexpr auto thread_buffer_desc_m
Definition gridwise_normalization_welford_variance.hpp:75
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_normalization_welford_variance.hpp:80
static constexpr auto thread_buffer_desc_m_k
Definition gridwise_normalization_welford_variance.hpp:71
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_normalization_welford_variance.hpp:64
static constexpr auto I2
Definition gridwise_normalization_welford_variance.hpp:95
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:173
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340