inner_product.hpp Source File

inner_product.hpp Source File#

Composable Kernel: inner_product.hpp Source File
inner_product.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5#include "data_type.hpp"
6#include "type_convert.hpp"
7
8namespace ck {
9
10template <typename TA, typename TB, typename TC>
11__device__ void inner_product(const TA& a, const TB& b, TC& c);
12
13template <>
14__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
15{
16#if CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
17 asm volatile("\n \
18 v_mac_f32 %0, %1, %2 \n \
19 "
20 : "=v"(c)
21 : "v"(a), "v"(b), "0"(c));
22#elif CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
23 asm volatile("\n \
24 v_fmac_f32 %0, %1, %2 \n \
25 "
26 : "=v"(c)
27 : "v"(a), "v"(b), "0"(c));
28#else
29 c += a * b;
30#endif
31}
32
33template <>
34__device__ void
36{
37 constexpr auto I0 = Number<0>{};
38 constexpr auto I1 = Number<1>{};
39
40 inner_product(vector_type<float, 2>{a}.AsType<float>()[I0],
41 vector_type<float, 2>{b}.AsType<float>()[I0],
42 c);
43
44 inner_product(vector_type<float, 2>{a}.AsType<float>()[I1],
45 vector_type<float, 2>{b}.AsType<float>()[I1],
46 c);
47}
48
49template <>
50__device__ void
52{
53 constexpr auto I0 = Number<0>{};
54 constexpr auto I1 = Number<1>{};
55 constexpr auto I2 = Number<2>{};
56 constexpr auto I3 = Number<3>{};
57
58 inner_product(vector_type<float, 4>{a}.AsType<float>()[I0],
59 vector_type<float, 4>{b}.AsType<float>()[I0],
60 c);
61
62 inner_product(vector_type<float, 4>{a}.AsType<float>()[I1],
63 vector_type<float, 4>{b}.AsType<float>()[I1],
64 c);
65
66 inner_product(vector_type<float, 4>{a}.AsType<float>()[I2],
67 vector_type<float, 4>{b}.AsType<float>()[I2],
68 c);
69
70 inner_product(vector_type<float, 4>{a}.AsType<float>()[I3],
71 vector_type<float, 4>{b}.AsType<float>()[I3],
72 c);
73}
74
75template <>
76__device__ void inner_product<bhalf_t, bhalf_t, float>(const bhalf_t& a, const bhalf_t& b, float& c)
77{
79}
80
81template <>
82__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
83{
85}
86
87template <>
88__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
89{
90#if defined(CK_USE_AMD_V_DOT2_F32_F16)
91#if CK_USE_AMD_V_DOT_INLINE_ASM
92 // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
93 // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
94 // ) s_nop with parameter 2 is equal to 3 x s_nop
95 asm volatile("\n \
96 v_dot2_f32_f16 %0, %1, %2, %0\n \
97 s_nop 2 \n \
98 "
99 : "=v"(c)
100 : "v"(a), "v"(b), "0"(c));
101#else
102 c = __builtin_amdgcn_fdot2(a, b, c, false);
103#endif
104#else
105 const vector_type<half_t, 2> a_vector{a};
106 const vector_type<half_t, 2> b_vector{b};
107
108 static_for<0, 2, 1>{}([&](auto i) {
109 c += type_convert<float>(a_vector.AsType<half_t>()[i]) *
110 type_convert<float>(b_vector.AsType<half_t>()[i]);
111 });
112#endif
113}
114
115template <>
116__device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
117{
118 constexpr auto I0 = Number<0>{};
119 constexpr auto I1 = Number<1>{};
120
122 vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
123 c);
124
126 vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
127 c);
128}
129
130template <>
131__device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
132{
133 constexpr auto I0 = Number<0>{};
134 constexpr auto I1 = Number<1>{};
135 constexpr auto I2 = Number<2>{};
136 constexpr auto I3 = Number<3>{};
137
139 vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
140 c);
141
143 vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
144 c);
145
147 vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
148 c);
149
151 vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
152 c);
153}
154
155template <>
156__device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c)
157{
159}
160
161template <>
162__device__ void
164{
165 constexpr auto I0 = Number<0>{};
166 constexpr auto I1 = Number<1>{};
167
169 vector_type<int8_t, 2>{b}.AsType<int8_t>()[I0],
170 c);
171
173 vector_type<int8_t, 2>{b}.AsType<int8_t>()[I1],
174 c);
175}
176
177template <>
178__device__ void
180{
181#if defined(CK_USE_AMD_V_DOT4_I32_I8)
182#if CK_USE_AMD_V_DOT_INLINE_ASM
183 // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
184 // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
185 // ) s_nop with parameter 2 is equal to 3 x s_nop
186 asm volatile("\n \
187 v_dot4_i32_i8 %0, %1, %2, %0\n \
188 s_nop 2 \n \
189 "
190 : "=v"(c)
191 : "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
192#else
193 c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
194#endif
195#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
196 c = __builtin_amdgcn_sudot4(true, bit_cast<int32_t>(a), true, bit_cast<int32_t>(b), c, false);
197#else
198 const vector_type<int8_t, 4> a_vector{a};
199 const vector_type<int8_t, 4> b_vector{b};
200
201 static_for<0, 4, 1>{}([&](auto i) {
202 c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) *
203 type_convert<int32_t>(b_vector.AsType<int8_t>()[i]);
204 });
205#endif
206}
207
208template <>
209__device__ void
211{
212 constexpr auto I0 = Number<0>{};
213 constexpr auto I1 = Number<1>{};
214
216 vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
217 c);
218
220 vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
221 c);
222}
223
224template <>
225__device__ void
227{
228 constexpr auto I0 = Number<0>{};
229 constexpr auto I1 = Number<1>{};
230 constexpr auto I2 = Number<2>{};
231 constexpr auto I3 = Number<3>{};
232
234 vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
235 c);
236
238 vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
239 c);
240
242 vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
243 c);
244
246 vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
247 c);
248}
249
250} // namespace ck
Definition ck.hpp:268
__device__ void inner_product< half_t, half_t, float >(const half_t &a, const half_t &b, float &c)
Definition inner_product.hpp:82
ushort bhalf_t
Definition data_type.hpp:30
typename vector_type< int8_t, 8 >::type int8x8_t
Definition dtype_vector.hpp:2178
__device__ void inner_product< float2_t, float2_t, float >(const float2_t &a, const float2_t &b, float &c)
Definition inner_product.hpp:35
typename vector_type< int8_t, 4 >::type int8x4_t
Definition dtype_vector.hpp:2177
__device__ void inner_product< int8x2_t, int8x2_t, int32_t >(const int8x2_t &a, const int8x2_t &b, int32_t &c)
Definition inner_product.hpp:163
_Float16 half_t
Definition data_type.hpp:31
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ void inner_product< half2_t, half2_t, float >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product.hpp:88
__device__ void inner_product< float4_t, float4_t, float >(const float4_t &a, const float4_t &b, float &c)
Definition inner_product.hpp:51
__device__ void inner_product< int8x8_t, int8x8_t, int32_t >(const int8x8_t &a, const int8x8_t &b, int32_t &c)
Definition inner_product.hpp:210
typename vector_type< half_t, 8 >::type half8_t
Definition dtype_vector.hpp:2155
typename vector_type< float, 4 >::type float4_t
Definition dtype_vector.hpp:2146
__device__ void inner_product< half4_t, half4_t, float >(const half4_t &a, const half4_t &b, float &c)
Definition inner_product.hpp:116
typename vector_type< float, 2 >::type float2_t
Definition dtype_vector.hpp:2145
typename vector_type< int8_t, 16 >::type int8x16_t
Definition dtype_vector.hpp:2179
__device__ void inner_product< bhalf_t, bhalf_t, float >(const bhalf_t &a, const bhalf_t &b, float &c)
Definition inner_product.hpp:76
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
typename vector_type< half_t, 2 >::type half2_t
Definition dtype_vector.hpp:2153
__device__ void inner_product< int8x16_t, int8x16_t, int32_t >(const int8x16_t &a, const int8x16_t &b, int32_t &c)
Definition inner_product.hpp:226
__device__ void inner_product< half8_t, half8_t, float >(const half8_t &a, const half8_t &b, float &c)
Definition inner_product.hpp:131
__device__ void inner_product< int8_t, int8_t, int32_t >(const int8_t &a, const int8_t &b, int32_t &c)
Definition inner_product.hpp:156
__device__ void inner_product< int8x4_t, int8x4_t, int32_t >(const int8x4_t &a, const int8x4_t &b, int32_t &c)
Definition inner_product.hpp:179
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void inner_product(const TA &a, const TB &b, TC &c)
__device__ void inner_product< float, float, float >(const float &a, const float &b, float &c)
Definition inner_product.hpp:14
typename vector_type< half_t, 4 >::type half4_t
Definition dtype_vector.hpp:2154
typename vector_type< int8_t, 2 >::type int8x2_t
Definition dtype_vector.hpp:2176
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
Definition functional2.hpp:33
Definition dtype_vector.hpp:10