layernorm2d_fwd_pipeline_two_pass.hpp Source File

layernorm2d_fwd_pipeline_two_pass.hpp Source File#

Composable Kernel: layernorm2d_fwd_pipeline_two_pass.hpp Source File
layernorm2d_fwd_pipeline_two_pass.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <string>
9#include <type_traits>
10
11namespace ck_tile {
12
13template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
15{
18
27
30
31 static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
32 static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
33 static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
34 static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
35
36 static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
37 static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
38 static constexpr bool kPadN = Problem::Traits::kPadN;
39 static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
40 static constexpr bool kWelford = Problem::Traits::kWelford;
41 static constexpr auto kXbias = Problem::Traits::kXbias;
42 static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
43 static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
44
45 static constexpr const char* name = []() {
46 if constexpr(kNeedCrossWarpSync)
47 return "bpr_2p"; // block per row
48 else
49 return "wpr_2p"; // warp per row
50 }();
51
53 {
54 return Policy::template GetSmemSize<Problem>();
55 }
56
57 template <typename XWindow,
58 typename XResidualWindow,
59 typename XBiasWindow,
60 typename GammaWindow,
61 typename BetaWindow,
62 typename YWindow,
63 typename YResidualWindow,
64 typename MeanWindow,
65 typename InvStdWindow,
66 typename SmoothScaleWindow,
67 typename YScaleWindow,
68 typename Epilogue>
69 CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
70 const XResidualWindow& x_residual_window_,
71 const XBiasWindow& x_bias_window_,
72 const GammaWindow& gamma_window_,
73 const BetaWindow& beta_window_,
74 YWindow& y_window,
75 const YResidualWindow& y_residual_window_,
76 MeanWindow& mean_window,
77 InvStdWindow& inv_std_window,
78 const SmoothScaleWindow& /*sm_scale_window*/,
79 YScaleWindow& /*y_scale_window*/,
80 ComputeDataType epsilon,
81 ck_tile::index_t row_size,
82 void* smem,
83 Epilogue) const
84 {
85 static_assert(kWelford == true, "2 pass only supports welford merge");
86 auto x_window =
87 make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
88 auto x_bias_window = make_tile_window(
89 x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
90 auto gamma_window = make_tile_window(
91 gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
92 auto beta_window = make_tile_window(
93 beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
94 auto x_residual_window = make_tile_window(
95 x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
96 auto y_residual_window = make_tile_window(
97 y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
98
99 // Problem::BlockShape
100 static constexpr index_t Block_N = Problem::BlockShape::Block_N;
101 index_t num_n_tile_iteration =
103
104 // total number of count assume current iter have no pad(only last iter has pad)
105 constexpr index_t count_per_iter =
106 Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
107 const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
108
109 int cur_count = 0;
110 int max_count =
111 (num_n_tile_iteration - 1) * count_per_iter +
113 auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
114 auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
115 auto block_norm_reduce_cross_warp_sync =
116 Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
117
118 using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
119 auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
120 auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
121
122 for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
123 {
124 auto x = load_tile(x_window);
125 auto x_resi = load_tile(x_residual_window);
126 const auto x_bias = load_tile(x_bias_window);
127
128 move_tile_window(x_window, {0, Block_N});
129 move_tile_window(x_residual_window, {0, Block_N});
130 move_tile_window(x_bias_window, {Block_N});
131 auto acc = cast_tile<ComputeDataType>(x);
132
134 {
135 sweep_tile(x, [&](auto idx) {
136 // compute x = bias + x
137 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
138 acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
139 });
140 }
141
144 {
145 sweep_tile(x_resi, [&](auto idx) {
146 // compute x = x_resi + x
147 acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
148 });
150 {
151 store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
152 move_tile_window(y_residual_window, {0, Block_N});
153 }
154 }
155 block_norm_reduce(acc, mean, var, cur_count, max_count);
156 }
157
158 block_norm_reduce_sync(mean, var, cur_count);
159 block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
161
162 // compute inv-std
163 auto inv_std = tile_elementwise_in(
164 [&](const auto& v_) {
165 if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
166 {
167 return type_convert<ComputeDataType>(1.0f) *
168 __builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
169 }
170 else
171 {
172 return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
173 }
174 },
175 var);
176 if constexpr(kSaveMean)
177 store_tile(mean_window, cast_tile<MeanDataType>(mean));
178 if constexpr(kSaveInvStd)
179 store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
180
181 // reverse read x to reuse cache
182 ck_tile::index_t stride_to_right_most_window =
183 row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
184
186 {
187 move_tile_window(y_residual_window, {0, -Block_N});
188 }
189 else
190 {
191 move_tile_window(x_window, {0, -Block_N});
192 move_tile_window(x_residual_window, {0, -Block_N});
193 move_tile_window(x_bias_window, {-Block_N});
194 }
195 move_tile_window(gamma_window, {stride_to_right_most_window});
196 move_tile_window(beta_window, {stride_to_right_most_window});
197 move_tile_window(y_window, {0, stride_to_right_most_window});
198
199 // layernorm computation
200 for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
201 {
203 decltype(load_tile(x_window))::get_tile_distribution());
204
206 {
207 acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
208 move_tile_window(y_residual_window, {0, -Block_N});
209 }
210 else
211 {
212 acc = cast_tile<ComputeDataType>(load_tile(x_window));
213 move_tile_window(x_window, {0, -Block_N});
214
216 {
217 const auto x_bias = load_tile(x_bias_window);
218 move_tile_window(x_bias_window, {-Block_N});
219
220 sweep_tile(acc, [&](auto idx) {
221 // compute x = bias + x
222 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
223 acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
224 });
225 }
226
228 {
229 auto x_resi = load_tile(x_residual_window);
230 move_tile_window(x_residual_window, {0, -Block_N});
231
232 sweep_tile(x_resi, [&](auto idx) {
233 // compute x = x_resi + x
234 acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
235 });
236 }
237 }
238
239 // load gamma/beta (TODO: support no gamma/beta?)
240 const auto gamma = load_tile(gamma_window);
241 const auto beta = load_tile(beta_window);
242
243 auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
244
245 sweep_tile(ln, [&, mean_ = mean](auto idx) {
246 constexpr auto i_idx = make_tuple(idx[number<0>{}]);
247 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
248
249 const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
250 const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
251
252 auto ln_ = (acc(idx) - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
253
254 ln(idx) = ln_;
255 });
256
258 Epilogue{}(y_window, ln, nullptr);
259
260 move_tile_window(gamma_window, {-Block_N});
261 move_tile_window(beta_window, {-Block_N});
262 move_tile_window(y_window, {0, -Block_N});
263 }
264 }
265};
266} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
Definition block_norm_reduce.hpp:361
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition block_norm_reduce.hpp:393
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
@ DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:43
@ ADD_BIAS
Definition layernorm2d_fwd_traits.hpp:14
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
@ PRE_ADD_STORE
Definition layernorm2d_fwd_traits.hpp:27
@ PRE_ADD
Definition layernorm2d_fwd_traits.hpp:29
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition layernorm2d_fwd_pipeline_two_pass.hpp:15
static constexpr auto kXbias
Definition layernorm2d_fwd_pipeline_two_pass.hpp:41
static constexpr bool kHasBeta
Definition layernorm2d_fwd_pipeline_two_pass.hpp:32
static constexpr bool kWelford
Definition layernorm2d_fwd_pipeline_two_pass.hpp:40
ck_tile::remove_cvref_t< typename Problem::XBiasDataType > XBiasDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:20
XDataType XResidualDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:28
static constexpr bool kSaveInvStd
Definition layernorm2d_fwd_pipeline_two_pass.hpp:34
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const XResidualWindow &x_residual_window_, const XBiasWindow &x_bias_window_, const GammaWindow &gamma_window_, const BetaWindow &beta_window_, YWindow &y_window, const YResidualWindow &y_residual_window_, MeanWindow &mean_window, InvStdWindow &inv_std_window, const SmoothScaleWindow &, YScaleWindow &, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const
Definition layernorm2d_fwd_pipeline_two_pass.hpp:69
ck_tile::remove_cvref_t< typename Problem::BetaDataType > BetaDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:22
static constexpr auto kFusedQuant
Definition layernorm2d_fwd_pipeline_two_pass.hpp:43
static constexpr bool kHasGamma
Definition layernorm2d_fwd_pipeline_two_pass.hpp:31
ck_tile::remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:21
ck_tile::remove_cvref_t< Problem_ > Problem
Definition layernorm2d_fwd_pipeline_two_pass.hpp:16
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:23
static constexpr bool kFastFDiv
Definition layernorm2d_fwd_pipeline_two_pass.hpp:39
XDataType YResidualDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:29
ck_tile::remove_cvref_t< typename Problem::InvStdDataType > InvStdDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:26
static constexpr bool kSaveMean
Definition layernorm2d_fwd_pipeline_two_pass.hpp:33
static constexpr bool kPadM
Definition layernorm2d_fwd_pipeline_two_pass.hpp:37
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:24
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition layernorm2d_fwd_pipeline_two_pass.hpp:52
static constexpr bool kPadN
Definition layernorm2d_fwd_pipeline_two_pass.hpp:38
static constexpr auto kFusedAdd
Definition layernorm2d_fwd_pipeline_two_pass.hpp:42
ck_tile::remove_cvref_t< Policy_ > Policy
Definition layernorm2d_fwd_pipeline_two_pass.hpp:17
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:19
static constexpr bool kNeedCrossWarpSync
Definition layernorm2d_fwd_pipeline_two_pass.hpp:36
static constexpr const char * name
Definition layernorm2d_fwd_pipeline_two_pass.hpp:45
ck_tile::remove_cvref_t< typename Problem::MeanDataType > MeanDataType
Definition layernorm2d_fwd_pipeline_two_pass.hpp:25
Definition tile/core/numeric/integral_constant.hpp:13