grouped_convolution_utils.hpp Source File

grouped_convolution_utils.hpp Source File#

Composable Kernel: grouped_convolution_utils.hpp Source File
grouped_convolution_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
18template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
20{
23 InPtr in_ptr_,
24 WeiPtr wei_ptr_,
25 const std::vector<const void*> ds_ptr_,
26 OutPtr out_ptr_,
27 index_t k_batch_,
28 CDElementwise elfunc_ = CDElementwise{})
29 : conv::ConvParam(conv_param),
30 in_ptr(in_ptr_),
31 wei_ptr(wei_ptr_),
32 ds_ptr(ds_ptr_),
33 out_ptr(out_ptr_),
34 k_batch(k_batch_),
35 elfunc(elfunc_)
36 {
37 }
38
39 InPtr in_ptr;
40 WeiPtr wei_ptr;
41 const std::vector<const void*> ds_ptr;
42 OutPtr out_ptr;
44 const CDElementwise elfunc;
45};
46
48
49template <typename CDElementwise = PassThrough>
55
56template <index_t NDimSpatial_,
57 ConvolutionSpecialization ConvSpecialization_,
58 typename InLayout_,
59 typename WeiLayout_,
60 typename DsLayout_,
61 typename OutLayout_,
62 index_t VectorSizeA_ = 1,
63 index_t VectorSizeB_ = 1,
64 index_t VectorSizeC_ = 1,
65 index_t NumGroupsToMerge_ = 1,
66 bool EnableSplitImage_ = false>
68{
69 private:
70 static constexpr auto generate_implicit_gemm_layout()
71 {
72 return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
73 number<DsLayout_::size()>{});
74 }
75
76 public:
77 // Fixed values for Implicit GEMM
79 {
82 static constexpr bool kPadM = true;
83 static constexpr bool kPadN = true;
84 static constexpr bool kPadK = true;
85 static constexpr bool TransposeC = false;
86 static constexpr bool FixedVectorSize = true;
87 static constexpr bool UseStructuredSparsity = false;
88 static constexpr bool Persistent = false;
90 };
91 // Compile time parameters
92 static constexpr bool EnableSplitImage = EnableSplitImage_;
93 static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
94 static constexpr index_t NDimSpatial = NDimSpatial_;
95 static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
96 using InLayout = InLayout_;
97 using WeiLayout = WeiLayout_;
98 using DsLayout = DsLayout_;
99 using OutLayout = OutLayout_;
100
101 // Forward Gemm Layouts
105 // Backward Data Gemm Layouts
109 // Backward Weight Gemm Layouts
113
114 template <ck_tile::index_t NumWaveGroups = 1>
117 template <ck_tile::index_t NumWaveGroups = 1>
119 true,
120 true,
124 NumWaveGroups>;
125 template <ck_tile::index_t NumWaveGroups = 1>
127 true,
128 true,
132 NumWaveGroups>;
133 static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
134 static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
135 static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
136 static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
137 using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
138};
139
151
176template <typename TilePartitioner>
178 ck_tile::index_t num_d_pieces,
179 ck_tile::index_t num_h_pieces,
180 ck_tile::index_t num_w_pieces,
181 ck_tile::index_t base_piece_d,
182 ck_tile::index_t base_piece_h,
183 ck_tile::index_t base_piece_w,
184 ck_tile::index_t total_d,
185 ck_tile::index_t total_h,
186 ck_tile::index_t total_w,
189 ck_tile::index_t total_blocks)
190{
191 // Unflatten piece index into 3D coordinates (W-major, then H, then D)
192 const ck_tile::index_t w_idx = piece_idx % num_w_pieces;
193 const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces;
194 const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces);
195
196 // Calculate spatial start positions
197 const ck_tile::index_t w_start = w_idx * base_piece_w;
198 const ck_tile::index_t h_start = h_idx * base_piece_h;
199 const ck_tile::index_t d_start = d_idx * base_piece_d;
200
201 // Calculate piece sizes (last piece may be larger to cover remainder)
202 const ck_tile::index_t w_size =
203 (w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w;
204 const ck_tile::index_t h_size =
205 (h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h;
206 const ck_tile::index_t d_size =
207 (d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d;
208
209 // Calculate GEMM dimensions for this piece
210 const ck_tile::index_t piece_gemm_m = N * d_size * h_size * w_size;
211 const ck_tile::index_t piece_gemm_n = K;
212
213 // Calculate GPU grid size for this piece
214 const ck_tile::index_t piece_grid =
215 ((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) *
216 ((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock);
217
218 return {
219 total_blocks, total_blocks + piece_grid, d_start, h_start, w_start, d_size, h_size, w_size};
220}
221
222} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/host/convolution_host_tensor_descriptor_helper.hpp:11
Definition tile/core/algorithm/cluster_descriptor.hpp:13
ConvolutionSpecialization
Definition convolution_specialization.hpp:11
GroupedConvHostArgs< const void *, const void *, void *, CDElementwise > GroupedConvFwdHostArgs
Definition grouped_convolution_utils.hpp:50
ck_tile::element_wise::PassThrough PassThrough
Definition grouped_convolution_utils.hpp:47
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST SplitImagePieceInfo calculate_spatial_piece(ck_tile::index_t piece_idx, ck_tile::index_t num_d_pieces, ck_tile::index_t num_h_pieces, ck_tile::index_t num_w_pieces, ck_tile::index_t base_piece_d, ck_tile::index_t base_piece_h, ck_tile::index_t base_piece_w, ck_tile::index_t total_d, ck_tile::index_t total_h, ck_tile::index_t total_w, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t total_blocks)
Calculate piece information for split-image convolution.
Definition grouped_convolution_utils.hpp:177
GroupedConvHostArgs< const void *, void *, const void *, PassThrough > GroupedConvBwdWeightHostArgs
Definition grouped_convolution_utils.hpp:51
GroupedConvHostArgs< void *, const void *, const void *, PassThrough > GroupedConvBwdDataHostArgs
Definition grouped_convolution_utils.hpp:53
int32_t index_t
Definition integer.hpp:9
The Grouped Conv kernel host arguments.
Definition grouped_convolution_utils.hpp:20
InPtr in_ptr
Definition grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition grouped_convolution_utils.hpp:40
index_t k_batch
Definition grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition grouped_convolution_utils.hpp:41
const CDElementwise elfunc
Definition grouped_convolution_utils.hpp:44
CK_TILE_HOST GroupedConvHostArgs()=delete
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, InPtr in_ptr_, WeiPtr wei_ptr_, const std::vector< const void * > ds_ptr_, OutPtr out_ptr_, index_t k_batch_, CDElementwise elfunc_=CDElementwise{})
Definition grouped_convolution_utils.hpp:22
Definition grouped_convolution_utils.hpp:79
static constexpr bool UseStructuredSparsity
Definition grouped_convolution_utils.hpp:87
static constexpr ck_tile::index_t TilePartitionerM01
Definition grouped_convolution_utils.hpp:81
static constexpr bool kPadN
Definition grouped_convolution_utils.hpp:83
static constexpr bool Persistent
Definition grouped_convolution_utils.hpp:88
static constexpr bool FixedVectorSize
Definition grouped_convolution_utils.hpp:86
static constexpr bool kPadM
Definition grouped_convolution_utils.hpp:82
static constexpr ck_tile::index_t TilePartitionerGroupNum
Definition grouped_convolution_utils.hpp:80
static constexpr bool kPadK
Definition grouped_convolution_utils.hpp:84
static constexpr bool TransposeC
Definition grouped_convolution_utils.hpp:85
ck_tile::tensor_layout::gemm::RowMajor ELayout
Definition grouped_convolution_utils.hpp:89
Definition grouped_convolution_utils.hpp:68
ck_tile::tensor_layout::gemm::RowMajor AsLayoutFwd
Definition grouped_convolution_utils.hpp:102
ck_tile::tensor_layout::gemm::RowMajor CLayoutBwdData
Definition grouped_convolution_utils.hpp:108
TileGemmTraits< true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups > GroupedConvImplicitGemmTraitsFwd
Definition grouped_convolution_utils.hpp:115
static constexpr index_t NDimSpatial
Definition grouped_convolution_utils.hpp:94
TileGemmTraits< true, true, true, AsLayoutBwdData, BsLayoutBwdData, CLayoutBwdData, NumWaveGroups > GroupedConvImplicitGemmTraitsBwdData
Definition grouped_convolution_utils.hpp:118
WeiLayout_ WeiLayout
Definition grouped_convolution_utils.hpp:97
ck_tile::tensor_layout::gemm::RowMajor BsLayoutBwdWeight
Definition grouped_convolution_utils.hpp:111
TileGemmTraits< true, true, true, AsLayoutBwdWeight, BsLayoutBwdWeight, CLayoutBwdWeight, NumWaveGroups > GroupedConvImplicitGemmTraitsBwdWeight
Definition grouped_convolution_utils.hpp:126
ck_tile::tensor_layout::gemm::RowMajor AsLayoutBwdData
Definition grouped_convolution_utils.hpp:106
InLayout_ InLayout
Definition grouped_convolution_utils.hpp:96
static constexpr ck_tile::index_t VectorSizeA
Definition grouped_convolution_utils.hpp:133
static constexpr ck_tile::index_t VectorSizeB
Definition grouped_convolution_utils.hpp:134
ck_tile::tensor_layout::gemm::ColumnMajor BsLayoutFwd
Definition grouped_convolution_utils.hpp:103
static constexpr ck_tile::index_t VectorSizeC
Definition grouped_convolution_utils.hpp:135
static constexpr bool EnableSplitImage
Definition grouped_convolution_utils.hpp:92
ck_tile::tensor_layout::gemm::RowMajor CLayoutFwd
Definition grouped_convolution_utils.hpp:104
OutLayout_ OutLayout
Definition grouped_convolution_utils.hpp:99
static constexpr ck_tile::index_t NumDTensor
Definition grouped_convolution_utils.hpp:136
static constexpr index_t NumGroupsToMerge
Definition grouped_convolution_utils.hpp:93
DsLayout_ DsLayout
Definition grouped_convolution_utils.hpp:98
ck_tile::tensor_layout::gemm::ColumnMajor AsLayoutBwdWeight
Definition grouped_convolution_utils.hpp:110
decltype(generate_implicit_gemm_layout()) ImplicitGemmDsLayout
Definition grouped_convolution_utils.hpp:137
static constexpr ConvolutionSpecialization ConvSpecialization
Definition grouped_convolution_utils.hpp:95
ck_tile::tensor_layout::gemm::RowMajor CLayoutBwdWeight
Definition grouped_convolution_utils.hpp:112
ck_tile::tensor_layout::gemm::RowMajor BsLayoutBwdData
Definition grouped_convolution_utils.hpp:107
Helper struct for split-image piece information.
Definition grouped_convolution_utils.hpp:146
ck_tile::index_t block_end
GPU block range for this piece.
Definition grouped_convolution_utils.hpp:147
ck_tile::index_t d_size
Definition grouped_convolution_utils.hpp:149
ck_tile::index_t d_start
Definition grouped_convolution_utils.hpp:148
ck_tile::index_t w_start
Spatial start coordinates (output space).
Definition grouped_convolution_utils.hpp:148
ck_tile::index_t h_size
Definition grouped_convolution_utils.hpp:149
ck_tile::index_t h_start
Definition grouped_convolution_utils.hpp:148
ck_tile::index_t w_size
Spatial dimensions of this piece.
Definition grouped_convolution_utils.hpp:149
ck_tile::index_t block_start
Definition grouped_convolution_utils.hpp:147
Definition tile_gemm_traits.hpp:18
Definition tile/host/convolution_parameter.hpp:15
ConvParam(ck_tile::index_t n_dim, ck_tile::index_t group_count, ck_tile::index_t n_batch, ck_tile::index_t n_out_channels, ck_tile::index_t n_in_channels, const std::vector< ck_tile::index_t > &filters_len, const std::vector< ck_tile::index_t > &input_len, const std::vector< ck_tile::index_t > &strides, const std::vector< ck_tile::index_t > &dilations, const std::vector< ck_tile::index_t > &left_pads, const std::vector< ck_tile::index_t > &right_pads)
Definition tile/host/convolution_parameter.hpp:16
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:437
Definition tile/ops/common/tensor_layout.hpp:22
Definition tile/ops/common/tensor_layout.hpp:17