thread_group_tensor_slice_transfer_v7r3_scatter.hpp Source File

thread_group_tensor_slice_transfer_v7r3_scatter.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v7r3_scatter.hpp Source File
thread_group_tensor_slice_transfer_v7r3_scatter.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15// Thread-group level multi-source, multi-destination tensor slice data movement
16// Assume:
17// 1. All sources and destinations are DynamicBuffer
18// 2. Same VectorDim and ScalerPerVector for all sources and destinations
19// 3. DstInMemOps are per destination tensor
20// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
21// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
22//
23// Does following things to avoid scratch memory issue
24// 1. Pass tensor descritpors by reference (or tuple of references)
25// 2. Does not keep reference to tensor descriptor
26// 3. Does not construct new tensor coordinate when call Run()
27template <typename ThreadGroup,
28 typename SrcDatas,
29 typename DstDatas,
30 typename SrcDescs,
31 typename DstDescs,
32 typename ElementwiseOperation,
33 typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
34 typename SliceLengths,
35 typename ThreadClusterLengths,
36 typename ThreadClusterArrangeOrder,
37 typename SrcDimAccessOrder,
38 typename DstDimAccessOrder,
39 index_t SrcVectorDim,
40 index_t DstVectorDim,
41 typename SrcScalarPerVectors,
42 index_t DstScalarPerVector,
43 typename ThreadTransferSrcResetCoordinateAfterRunFlags,
44 typename ThreadTransferDstResetCoordinateAfterRunFlags,
45 typename IndexType,
46 index_t ScatterDim = 1,
47 bool OutputScatter = true,
48 index_t ScatterWeightIdx = 3,
49 index_t NumThreadScratch = 1>
51{
52 static constexpr index_t nDim =
54
55 static constexpr index_t mod_num =
56 ThreadClusterLengths{}.At(Number<3>{}); // Dirty HACK FELIX, TODO fix
59
61
62 static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
64
66 const SrcDescs& src_descs,
67 const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
68 const DstDescs& dst_descs,
69 const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
70 const ElementwiseOperation& element_op)
71 : threadwise_transfer_(src_descs,
73 dst_descs,
75 element_op)
76 {
77 static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
78 nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
79 nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
80 nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
81 "wrong!");
82
83 static_for<0, nSrc, 1>{}([&](auto i) {
84 static_assert(
85 nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
86 "wrong!");
87 });
88
89 static_for<0, nDst, 1>{}([&](auto i) {
90 static_assert(
91 nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
92 "wrong!");
93 });
94
95 static_assert(nDim == ThreadClusterLengths::Size() &&
96 nDim == ThreadClusterArrangeOrder::Size() &&
97 nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
98 "wrong! nDim not consistent");
99
100 static_assert(
101 is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
102 "wrong! threads should be mapped to cover entire slicing window");
103
104 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
105 "wrong! ThreadGroup::GetNumOfThread() too small");
106
107 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
108 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
109 {
110 const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
111 make_multi_index(ThreadGroup::GetThreadId()));
112 const auto src_thread_slice_origins = generate_tuple(
113 [&](auto i) {
114 return src_block_slice_origins[i] +
115 src_thread_cluster_idx * thread_slice_lengths;
116 },
117 Number<nSrc>{});
118
119 const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
120 make_multi_index(OutputScatter ? ThreadGroup::GetThreadId() % mod_num
121 : ThreadGroup::GetThreadId()));
122 const auto dst_thread_slice_origins = generate_tuple(
123 [&](auto i) {
124 return dst_block_slice_origins[i] +
125 dst_thread_cluster_idx * thread_slice_lengths;
126 },
127 Number<nDst>{});
128
129 threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
130 threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
131 }
132 }
133
134 template <typename SrcBuffers, index_t ThreadScratchId = 0>
135 __device__ void RunRead(const SrcDescs& src_descs,
136 const SrcBuffers& src_bufs,
138 {
139 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
140 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
141 {
142 threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
143 }
144 }
145
146 template <typename T>
147 using is_tuple = decltype(std::declval<T&>().IsTuple());
148
149 template <typename DstBuffers, index_t ThreadScratchId = 0>
150 __device__ void RunWrite(const DstDescs& dst_descs,
151 DstBuffers dst_bufs,
154 {
155 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
156 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
157 {
158 if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
159 threadwise_transfer_.RunWrite(
160 dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
161 else
162 threadwise_transfer_.RunWrite(
163 dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
164 }
165 }
166
167 template <typename SrcBuffers, typename DstBuffers>
168 __device__ void Run(const SrcDescs& src_descs,
169 const SrcBuffers& src_bufs,
170 const DstDescs& dst_descs,
171 DstBuffers dst_bufs,
173 {
174 RunRead(src_descs, src_bufs);
175 RunWrite(dst_descs, dst_bufs, scatter_offsets);
176 }
177
178 template <index_t ISrc>
179 __device__ void
180 MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
181 {
182 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
183 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
184 {
185 threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
186 }
187 }
188
189 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
190 {
191 static_for<0, SrcDescs::Size(), 1>{}(
192 [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
193 }
194
195 template <index_t IDst>
196 __device__ void
197 MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
198 {
199 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
200 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
201 {
202 threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
203 }
204 }
205
206 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
207 {
208 static_for<0, DstDescs::Size(), 1>{}(
209 [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
210 }
211
212 private:
213 static constexpr auto thread_cluster_desc_ =
214 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
215
216 using ThreadwiseTransfer =
217 ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas,
218 DstDatas,
219 SrcDescs,
220 DstDescs,
221 ElementwiseOperation,
222 DstInMemOps,
223 decltype(thread_slice_lengths),
224 SrcDimAccessOrder,
225 DstDimAccessOrder,
226 SrcVectorDim,
227 DstVectorDim,
228 SrcScalarPerVectors,
229 DstScalarPerVector,
230 ThreadTransferSrcResetCoordinateAfterRunFlags,
231 ThreadTransferDstResetCoordinateAfterRunFlags,
232 IndexType,
233 ScatterDim,
234 OutputScatter,
235 ScatterWeightIdx,
236 NumThreadScratch>;
237
238 ThreadwiseTransfer threadwise_transfer_;
239};
240
241} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition tuple_helper.hpp:176
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
static constexpr index_t nDst
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:58
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:60
static constexpr index_t mod_num
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:55
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs, StaticallyIndexedArray< IndexType, scatter_num > &scatter_offsets)
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:168
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:206
static constexpr index_t scatter_num
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:63
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:180
static constexpr index_t nSrc
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:57
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:189
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:52
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_block_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_block_slice_origins, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:65
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:197
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, StaticallyIndexedArray< IndexType, scatter_num > &scatter_offsets, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:150
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:62
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:135
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:147
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, StaticallyIndexedArray< IndexType, scatter_num > &scatter_offsets, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:390
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:155
Definition functional2.hpp:33