15 static constexpr const char*
name =
"Add";
17 template <
typename Y,
typename X0,
typename X1>
18 __host__ __device__
constexpr void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
21 __host__ __device__
constexpr void
22 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
28 __host__ __device__
constexpr void
29 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
35 __host__ __device__
constexpr void
36 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
42 __host__ __device__
constexpr void
43 operator()<
half_t>(
half_t& y,
const float& x0,
const float& x1)
const
49 __host__ __device__
constexpr void
56 __host__ __device__
constexpr void
63 __host__ __device__
constexpr void
64 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
71 __host__ __device__
constexpr void
76 const float y_tmp = x1_tmp + x2_tmp;
81 __host__ __device__
constexpr void
85 const float y_tmp = x0 + x2_tmp;
90 __host__ __device__
constexpr void
99 static constexpr const char*
name =
"Max";
101 template <
typename Y,
typename X0,
typename X1>
102 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
112 static constexpr const char*
name =
"Min";
114 template <
typename Y,
typename X0,
typename X1>
115 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
125 static constexpr const char*
name =
"Multiply";
127 template <
typename Y,
typename X0,
typename X1>
128 __host__ __device__
constexpr void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
131 __host__ __device__
constexpr void
132 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
138 __host__ __device__
constexpr void
139 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
145 __host__ __device__
constexpr void
146 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
152 __host__ __device__
constexpr void
159 __host__ __device__
constexpr void
166 __host__ __device__
constexpr void
173 __host__ __device__
constexpr void
174 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
181 __host__ __device__
constexpr void
186 const float y_tmp = x1_tmp * x2_tmp;
191 __host__ __device__
constexpr void
196 const float y_tmp = x1_tmp * x2_tmp;
201 __host__ __device__
constexpr void
205 const float y_tmp = x0 * x2_tmp;
210 __host__ __device__
constexpr void
219 static constexpr const char*
name =
"ScaleAdd";
223 template <
typename Y,
typename X0,
typename X1>
224 __host__ __device__
constexpr void operator()(Y& y,
const X0& x0,
const X1& x1)
const
230 __host__ __device__
void
231 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
237 __host__ __device__
void
238 operator()<float, float,
bhalf_t>(
float& y,
const float& x0,
const bhalf_t& x1)
const
248 static constexpr const char*
name =
"Subtract";
250 template <
typename T>
251 __host__ __device__
constexpr void operator()(T& y,
const T& x0,
const T& x1)
const;
254 __host__ __device__
constexpr void
255 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
261 __host__ __device__
constexpr void
262 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
268 __host__ __device__
constexpr void
275 __host__ __device__
constexpr void
280 const float y_tmp = x1_tmp - x2_tmp;
285 __host__ __device__
constexpr void
294 static constexpr const char*
name =
"Bilinear";
298 template <
typename Y,
typename X0,
typename X1>
299 __host__ __device__
constexpr void operator()(Y&,
const X0&,
const X1&)
const;
302 __host__ __device__
constexpr void
303 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
309 __host__ __device__
constexpr void
310 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
316 __host__ __device__
constexpr void
324 __host__ __device__
constexpr void
331 __host__ __device__
constexpr void
338 __host__ __device__
constexpr void
343 const float y_tmp =
alpha_ * x0_tmp +
beta_ * x1_tmp;
348 __host__ __device__
constexpr void
357 __host__ __device__
constexpr void
370 static constexpr const char*
name =
"AddClamp";
375 template <
typename Y,
typename X0,
typename X1>
376 __host__ __device__
constexpr void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
379 __host__ __device__
constexpr void
380 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
382 const float a = x0 + x1;
387 __host__ __device__
constexpr void
388 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
390 const double a = x0 + x1;
395 __host__ __device__
constexpr void
401 y =
a > floor ? (
a < ceil ?
a : ceil) : floor;
405 __host__ __device__
constexpr void
414 __host__ __device__
constexpr void
415 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
422 __host__ __device__
constexpr void
431 __host__ __device__
constexpr void
440 __host__ __device__
constexpr void
441 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
448 __host__ __device__
constexpr void
461 static constexpr const char*
name =
"AddRelu";
463 template <
typename Y,
typename X0,
typename X1>
464 __host__ __device__
constexpr void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
467 __host__ __device__
constexpr void
468 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
470 const float a = x0 + x1;
471 y =
a > 0.0f ?
a : 0.0f;
475 __host__ __device__
constexpr void
476 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
478 const double a = x0 + x1;
479 y =
a > 0.0 ?
a : 0.0;
483 __host__ __device__
constexpr void
491 __host__ __device__
constexpr void
495 const float b =
a > 0.0f ?
a : 0.0f;
500 __host__ __device__
constexpr void
501 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
504 y =
a > 0.0f ?
a : 0.0f;
508 __host__ __device__
constexpr void
512 const float b =
a > 0.0f ?
a : 0.0f;
517 __host__ __device__
constexpr void
521 const float b =
a > 0.0f ?
a : 0.0f;
526 __host__ __device__
constexpr void
527 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
534 __host__ __device__
constexpr void
544 static constexpr const char*
name =
"AddHardswish";
546 template <
typename T>
547 __host__ __device__
constexpr void operator()(T& y,
const T& x0,
const T& x1)
const;
550 __host__ __device__
constexpr void
551 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
554 float b =
a +
float{3};
555 float c = (b > 0) * (b > 6.0f ? 6.0f : b) *
a * 0.166667f;
560 __host__ __device__
constexpr void
561 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
565 double c = (b > 0) * (b > 6.0 ? 6.0 : b) *
a * 0.166667;
570 __host__ __device__
constexpr void
575 float c = (b > 0) * (b > 6.0f ? 6.0f : b) *
a * 0.166667f;
583 static constexpr const char*
name =
"AddFastGelu";
585 template <
typename E,
typename C,
typename D>
586 __host__ __device__
constexpr void operator()(E& e,
const C& c,
const D& d)
const;
589 __host__ __device__
constexpr void
590 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
592 const float x = c + d;
594 FastGelu{}.template operator()<float,
float>(e, x);
598 __host__ __device__
constexpr void
607 __host__ __device__
constexpr void
610 const float x0_f = c + d;
621 __host__ __device__
constexpr void
628 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
634 __host__ __device__
constexpr void
641 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
650 static constexpr const char*
name =
"MultiplyFastGelu";
652 template <
typename E,
typename C,
typename D>
653 __host__ __device__
constexpr void operator()(E& e,
const C& c,
const D& d)
const;
656 __host__ __device__
constexpr void
657 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
659 const float x = c * d;
661 FastGelu{}.template operator()<float,
float>(e, x);
665 __host__ __device__
constexpr void
674 __host__ __device__
constexpr void
677 const float x0_f = c * d;
688 __host__ __device__
constexpr void
695 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
701 __host__ __device__
constexpr void
708 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
717 static constexpr const char*
name =
"AddSilu";
719 template <
typename E,
typename C,
typename D>
720 __host__ __device__
constexpr void operator()(E& e,
const C& c,
const D& d)
const;
723 __host__ __device__
constexpr void
724 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
726 const float x = c + d;
728 Silu{}.template operator()<
float>(e, x);
732 __host__ __device__
constexpr void
741 __host__ __device__
constexpr void
744 const float x0_f = c + d;
748 Silu{}.template operator()<
float>(x1_f, x0_f);
754 __host__ __device__
constexpr void
761 Silu{}.template operator()<
float>(x1_f, x0_f);
769 static constexpr const char*
name =
"ConvScaleAdd";
772 float scale_wei = 1.f,
773 float scale_out = 1.f)
778 template <
typename E,
typename C,
typename D>
779 __host__ __device__
void operator()(E& e,
const C& c,
const D& d)
const;
782 __host__ __device__
void
783 operator()<
f8_t, float,
float>(
f8_t& e,
const float& c,
const float& d)
const
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition binary_element_wise_operation.hpp:11
Definition convolution_backward_data_specialization.hpp:7
ushort bhalf_t
Definition data_type.hpp:30
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
__host__ static __device__ constexpr T Max()
Definition numeric_limits.hpp:311
static constexpr const char * name
Definition binary_element_wise_operation.hpp:370
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition binary_element_wise_operation.hpp:372
const float ceil_
Definition binary_element_wise_operation.hpp:456
const float floor_
Definition binary_element_wise_operation.hpp:455
Definition binary_element_wise_operation.hpp:582
__host__ __device__ constexpr void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition binary_element_wise_operation.hpp:583
Definition binary_element_wise_operation.hpp:543
__host__ __device__ constexpr void operator()(T &y, const T &x0, const T &x1) const
static constexpr const char * name
Definition binary_element_wise_operation.hpp:544
Definition binary_element_wise_operation.hpp:14
static constexpr const char * name
Definition binary_element_wise_operation.hpp:15
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:460
static constexpr const char * name
Definition binary_element_wise_operation.hpp:461
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:716
__host__ __device__ constexpr void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition binary_element_wise_operation.hpp:717
__host__ __device__ constexpr void operator()(Y &, const X0 &, const X1 &) const
Bilinear(float alpha=1.f, float beta=1.f)
Definition binary_element_wise_operation.hpp:296
static constexpr const char * name
Definition binary_element_wise_operation.hpp:294
float beta_
Definition binary_element_wise_operation.hpp:365
float alpha_
Definition binary_element_wise_operation.hpp:364
float scale_in_
Definition binary_element_wise_operation.hpp:790
float scale_wei_
Definition binary_element_wise_operation.hpp:791
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition binary_element_wise_operation.hpp:771
float scale_out_
Definition binary_element_wise_operation.hpp:792
static constexpr const char * name
Definition binary_element_wise_operation.hpp:769
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:924
Definition binary_element_wise_operation.hpp:98
static constexpr const char * name
Definition binary_element_wise_operation.hpp:99
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:102
Definition binary_element_wise_operation.hpp:111
static constexpr const char * name
Definition binary_element_wise_operation.hpp:112
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:115
Definition binary_element_wise_operation.hpp:649
static constexpr const char * name
Definition binary_element_wise_operation.hpp:650
__host__ __device__ constexpr void operator()(E &e, const C &c, const D &d) const
Definition binary_element_wise_operation.hpp:124
static constexpr const char * name
Definition binary_element_wise_operation.hpp:125
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition binary_element_wise_operation.hpp:224
float scale_
Definition binary_element_wise_operation.hpp:243
__host__ __device__ ScaleAdd(float scale=1.f)
Definition binary_element_wise_operation.hpp:221
static constexpr const char * name
Definition binary_element_wise_operation.hpp:219
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1087
Definition binary_element_wise_operation.hpp:247
static constexpr const char * name
Definition binary_element_wise_operation.hpp:248
__host__ __device__ constexpr void operator()(T &y, const T &x0, const T &x1) const