sweep_tile.hpp Source File

sweep_tile.hpp Source File#

Composable Kernel: sweep_tile.hpp Source File
sweep_tile.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13
14namespace ck_tile {
15
16// sweep over a span of a distribted tile and apply lambda function F
17template <typename TileDistributedSpan_, // tile_distributed_span<...>
18 typename F // signature: F(tile_distributed_index<...>)
19 >
20CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
21{
23
24 static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
25 constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
26
27 f(dstr_idx);
28 });
29}
30
31// unpacked span, this version support span with unpack(multi-arg) functor
32//
33template <
34 typename TileDistributedSpan_, // tile_distributed_span<...>
35 typename F, // signature: F(tile_distributed_index<...>)
36 typename Unpacks = typename uniform_sequence_gen<TileDistributedSpan_::Impl::size(), 1>::type>
37CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks = {})
38{
40
41 static_uford<typename DstrSpan::Impl, Unpacks>{}(
42 [&](auto... dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)...); });
43}
44
45namespace impl {
46
47template <typename, typename, typename>
49
50template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
51struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
52{
53 CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
54 {
55 constexpr auto spans = DistributedTensor::get_distributed_spans();
56 constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
57 constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
58 constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
59 return y_unpacks;
60 }
62 {
63 constexpr auto spans = DistributedTensor::get_distributed_spans();
64 constexpr auto u =
65 static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
66 return u.get_num_of_access() *
67 sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
68 .get_num_of_access();
69 }
70 template <typename F, typename SpanIdx>
71 CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
72 {
73 constexpr auto spans = DistributedTensor::get_distributed_spans();
74
76 spans[number<I>{}],
77 [&](auto... i_idx) {
78 const auto next_span_idx = embed_tuples(
79 [&](auto si) { return make_tuple(concat_tuple(si, make_tuple(i_idx))...); },
80 span_idx);
81 sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
82 f, next_span_idx);
83 },
85 }
86 template <typename F, typename SpanIdx, index_t i_access>
87 CK_TILE_HOST_DEVICE constexpr void
88 operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
89 {
90 constexpr auto spans = DistributedTensor::get_distributed_spans();
91 constexpr auto u =
92 static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
93 constexpr auto access_stride =
94 sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
95 .get_num_of_access();
96 constexpr auto curr_i_access = number<i_access / access_stride>{};
97 constexpr auto next_i_access = number<i_access % access_stride>{};
98 u(
99 [&](auto... i_idx) {
100 const auto next_span_idx = embed_tuples(
101 [&](auto si) {
104 },
105 span_idx);
106 sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
107 f, next_span_idx, next_i_access);
108 },
109 curr_i_access);
110 }
111};
112
113template <typename DistributedTensor, typename UnpacksPerXDim>
114struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<>>
115{
116 CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const { return 1; }
117 template <typename F, typename SpanIdx>
118 CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
119 {
120 unpack(f, span_idx);
121 }
122 template <typename F, typename SpanIdx, index_t i_access>
123 CK_TILE_HOST_DEVICE constexpr void
124 operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
125 {
126 unpack(f, span_idx);
127 }
128};
129
130template <typename, typename, typename>
132
133// TODO: support empty tuple to remove this "entry-point" like function
134template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
135struct sweep_tile_impl_0<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
136{
137 CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
138 {
139 constexpr auto spans = DistributedTensor::get_distributed_spans();
140 constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
141 constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
142 constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
143 return y_unpacks;
144 }
146 {
147 constexpr auto spans = DistributedTensor::get_distributed_spans();
148 constexpr auto u =
149 static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
150 return u.get_num_of_access() *
151 sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
152 .get_num_of_access();
153 }
154 template <typename F>
155 CK_TILE_HOST_DEVICE constexpr void operator()(const F& f) const
156 {
157 constexpr auto spans = DistributedTensor::get_distributed_spans();
159 spans[number<I>{}],
160 [&](auto... i_idx) {
161 constexpr auto next_span_idx = make_tuple(make_tuple(i_idx)...);
162 sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
163 f, next_span_idx);
164 },
165 get_y_unpacks());
166 }
167 template <typename F, index_t i_access>
168 CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, number<i_access>) const
169 {
170 constexpr auto spans = DistributedTensor::get_distributed_spans();
171 constexpr auto u =
172 static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
173 constexpr auto access_stride =
174 sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
175 .get_num_of_access();
176 constexpr auto curr_i_access = number<i_access / access_stride>{};
177 constexpr auto next_i_access = number<i_access % access_stride>{};
178 u(
179 [&](auto... i_idx) {
180 constexpr auto next_span_idx =
182 sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
183 f, next_span_idx, next_i_access);
184 },
185 curr_i_access);
186 }
187};
188
189} // namespace impl
190
191/*
192 * Enhanced sweep-tile utility, can control unpacks along each X-dim
193 * the lambda function argument is the distributed-idx, which can directly
194 * plugged into the distributed tensor as setter/getter
195 *
196 * e.g. below function, y with the type DistributedTensor, r is row scale
197 *
198 * // sweep tile 1 by 1
199 * sweep_tile<DistributedTensor>([&](auto idx) {
200 * constexpr auto row_id = make_tuple(idx[number<0>{}]);
201 * y(idx) = y(idx) * r(row_id);
202 * });
203 *
204 * // sweep tile with 2 pixel from last dim each function call
205 * sweep_tile<DistributedTensor>(
206 * [&](auto idx_0, auto idx_1) {
207 * constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
208 * y(idx_0) = y(idx_0) * r(row_id);
209 * y(idx_1) = y(idx_1) * r(row_id);
210 * },
211 * sequence<1, 2>{});
212 *
213 * // sweep tile with 2x2 pixel each function call
214 * sweep_tile<DistributedTensor>(
215 * [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
216 * constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
217 * constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
218 * y(idx_00) = y(idx_00) * r(row_id0);
219 * y(idx_01) = y(idx_01) * r(row_id0);
220 * y(idx_10) = y(idx_10) * r(row_id1);
221 * y(idx_11) = y(idx_11) * r(row_id1);
222 * },
223 * sequence<2, 2>{});
224 *
225 * TODO: do we need constexpr? lambda function could be non-constexpr
226 */
227template <typename DistributedTensor,
228 typename F,
229 typename UnpacksPerXDim =
230 typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
231CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F& f, UnpacksPerXDim = {})
232{
233 constexpr auto spans = DistributedTensor::get_distributed_spans();
234
235 impl::sweep_tile_impl_0<DistributedTensor,
236 UnpacksPerXDim,
237 typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(f);
238}
239
240template <typename DistributedTensor,
241 typename F,
242 typename UnpacksPerXDim =
243 typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
244CK_TILE_HOST_DEVICE constexpr void
245sweep_tile(const DistributedTensor&, const F& f, UnpacksPerXDim = {})
246{
248}
249
250/*
251 * construct a sweep tile instance, which support issue the lambda one by one
252 * Note that this struct will hold the lambda functor, but will not hold the distributed tensor
253 * the functionality is the same as sweep_tile()
254 */
255template <typename DistributedTensor_,
256 typename F_,
257 typename UnpacksPerXDim_ =
258 typename uniform_sequence_gen<DistributedTensor_::get_num_of_dimension(), 1>::type>
260{
264
267 : f(f_)
268 {
269 }
271 {
272 constexpr auto spans = DistributedTensor::get_distributed_spans();
273 constexpr auto tmp =
276 typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{};
277 return tmp.get_num_of_access();
278 }
279
284
285 template <index_t i_access>
287 {
288 constexpr auto spans = DistributedTensor::get_distributed_spans();
289
292 typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(
294 }
296};
297
298// partial deduction is not allowed
299// template <typename T, typename F, typename U>
300// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
301
302// deduction guide
303template <typename T,
304 typename F,
305 typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
307
308} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE_EXTERN
Definition config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence< Is... >)
Definition tile_distribution.hpp:59
Definition tile/core/arch/amd_buffer_addressing.hpp:110
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F &f, Unpacks={})
Definition sweep_tile.hpp:37
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T &, const F &, U={}) -> tile_sweeper< T, F, U >
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_HOST_DEVICE constexpr auto unpack(F &&f, X &&x)
Definition tile/core/utility/functional.hpp:200
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number< XUnpacks >)
Definition static_distributed_tensor.hpp:197
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple< X... > &tx, const tuple< Y... > &ty)
Definition tile/core/container/tuple.hpp:453
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X &x)
Definition tile/core/container/tuple.hpp:546
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition tile/core/container/sequence.hpp:287
CK_TILE_HOST_DEVICE constexpr void operator()(const F &f, const SpanIdx &span_idx, number< i_access >) const
Definition sweep_tile.hpp:124
CK_TILE_HOST_DEVICE constexpr void operator()(const F &f, const SpanIdx &span_idx) const
Definition sweep_tile.hpp:118
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
Definition sweep_tile.hpp:116
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
Definition sweep_tile.hpp:53
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
Definition sweep_tile.hpp:61
CK_TILE_HOST_DEVICE constexpr void operator()(const F &f, const SpanIdx &span_idx, number< i_access >) const
Definition sweep_tile.hpp:88
CK_TILE_HOST_DEVICE constexpr void operator()(const F &f, const SpanIdx &span_idx) const
Definition sweep_tile.hpp:71
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
Definition sweep_tile.hpp:137
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
Definition sweep_tile.hpp:145
CK_TILE_HOST_DEVICE constexpr void operator()(const F &f) const
Definition sweep_tile.hpp:155
CK_TILE_HOST_DEVICE constexpr void operator()(const F &f, number< i_access >) const
Definition sweep_tile.hpp:168
Definition sweep_tile.hpp:131
Definition sweep_tile.hpp:48
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:141
Definition functional_with_tuple.hpp:129
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access()
Definition functional_with_tuple.hpp:141
remove_cvref_t< DistributedTensor_ > DistributedTensor
Definition sweep_tile.hpp:261
remove_cvref_t< F_ > F
Definition sweep_tile.hpp:262
remove_cvref_t< UnpacksPerXDim_ > UnpacksPerXDim
Definition sweep_tile.hpp:263
CK_TILE_HOST_DEVICE void operator()(number< i_access >) const
Definition sweep_tile.hpp:286
CK_TILE_HOST_DEVICE tile_sweeper(const DistributedTensor &, const F &f_, UnpacksPerXDim={})
Definition sweep_tile.hpp:266
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access()
Definition sweep_tile.hpp:270
F f
Definition sweep_tile.hpp:295
CK_TILE_HOST_DEVICE void operator()() const
Definition sweep_tile.hpp:280
CK_TILE_HOST_DEVICE tile_sweeper(const F &f_, UnpacksPerXDim={})
Definition sweep_tile.hpp:265
Definition tile/core/container/sequence.hpp:314