gridwise_tensor_rearrange.hpp Source File

gridwise_tensor_rearrange.hpp Source File#

Composable Kernel: gridwise_tensor_rearrange.hpp Source File
gridwise_tensor_rearrange.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
17namespace ck {
18
19template <typename InputGridDesc,
20 typename InputDataType,
21 typename OutputGridDesc,
22 typename OutputDataType,
23 typename Block2ETileMap,
24 typename ComputePtrOffsetOfStridedBatch,
25 typename GridwiseTensorRearrangeKernel>
26__global__ void
27#if CK_USE_LAUNCH_BOUNDS
29#endif
30 kernel_tensor_rearrange(const InputGridDesc in_grid_desc,
31 const InputDataType* __restrict__ p_in_global,
32 const OutputGridDesc out_grid_desc,
33 OutputDataType* __restrict__ p_out_global,
34 const index_t batch_count,
35 const Block2ETileMap block_2_tile_map,
36 const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
37{
38#if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || \
39 defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
40 GridwiseTensorRearrangeKernel::Run(in_grid_desc,
41 p_in_global,
42 out_grid_desc,
43 p_out_global,
44 batch_count,
45 block_2_tile_map,
46 compute_ptr_offset_of_batch);
47#else
48 ignore = in_grid_desc;
49 ignore = p_in_global;
50 ignore = out_grid_desc;
51 ignore = p_out_global;
52 ignore = batch_count;
53 ignore = block_2_tile_map;
54 ignore = compute_ptr_offset_of_batch;
55#endif
56}
57
58template <typename InputGridDesc,
59 typename InputDataType,
60 typename OutputGridDesc,
61 typename OutputDataType,
62 index_t BlockSize,
63 index_t MPerBlock,
64 index_t KPerBlock,
65 typename ThreadClusterLengths,
66 index_t ScalarPerVector,
68 typename Block2ETileMap,
69 typename ComputePtrOffsetOfStridedBatch>
71{
72
73 static constexpr auto I0 = Number<0>{};
74 static constexpr auto I1 = Number<1>{};
75
77
78 __device__ static void Run(const InputGridDesc& in_grid_desc,
79 const InputDataType* __restrict__ p_in_global,
80 const OutputGridDesc& out_grid_desc,
81 OutputDataType* __restrict__ p_out_global,
82 const index_t batch_count,
83 const Block2ETileMap& block_2_tile_map,
84 const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
85 {
86 const auto block_work_idx =
87 block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
88
89 const index_t m_block_data_idx_on_grid =
90 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
91
92 const index_t k_block_data_idx_on_grid =
93 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
94
95 auto copy_global_to_global =
99 decltype(tie(in_grid_desc)),
100 decltype(tie(out_grid_desc)),
104 ThreadClusterLengths,
107 I1,
108 ScalarPerVector,
111 in_grid_desc,
112 make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
113 out_grid_desc,
114 make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
116
117 const index_t num_blocks_per_batch =
118 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
119 const index_t g_idx =
120 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
121
122 // Global Memory
123 const index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
124 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
125 const index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
126 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
127
128 const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
129 p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
131 p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
132
133 copy_global_to_global.Run(
134 tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
135 }
136
137 __host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc,
138 const OutputGridDesc& out_grid_desc)
139 {
140 if(in_grid_desc.GetLength(I0) % MPerBlock != 0 ||
141 in_grid_desc.GetLength(I1) % KPerBlock != 0)
142 return false;
143 if(out_grid_desc.GetLength(I0) % MPerBlock != 0 ||
144 out_grid_desc.GetLength(I1) % KPerBlock != 0)
145 return false;
146 return true;
147 }
148};
149
150} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__global__ void kernel_tensor_rearrange(const InputGridDesc in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition gridwise_tensor_rearrange.hpp:30
Definition gridwise_tensor_rearrange.hpp:71
static __host__ constexpr bool CheckValidity(const InputGridDesc &in_grid_desc, const OutputGridDesc &out_grid_desc)
Definition gridwise_tensor_rearrange.hpp:137
static __device__ void Run(const InputGridDesc &in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc &out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap &block_2_tile_map, const ComputePtrOffsetOfStridedBatch &compute_ptr_offset_of_batch)
Definition gridwise_tensor_rearrange.hpp:78
Definition utility/sequence.hpp:43
Definition thread_group_tensor_slice_transfer_v7.hpp:42
Definition utility/tuple.hpp:117
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340