device_grouped_gemm_multi_abd.hpp Source File

device_grouped_gemm_multi_abd.hpp Source File#

Composable Kernel: device_grouped_gemm_multi_abd.hpp Source File
device_grouped_gemm_multi_abd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
9#include "device_base.hpp"
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
16{
18
19 std::vector<ck::index_t> stride_As_;
20 std::vector<ck::index_t> stride_Bs_;
21 std::vector<ck::index_t> stride_Ds_;
22
24};
25
26/*
27 * \brief Grouped Gemm Multi ABD
28 *
29 * C = a_op(A, A1...) * b_op(B, B1...)
30 * E = cde_op(C, D0, D1, ...)
31 *
32 * \tparam AsLayout A layouts (tuple).
33 * \tparam BsLayout B layouts (tuple).
34 * \tparam DsLayout Ds layouts (tuple).
35 * \tparam ELayout Output layout.
36 * \tparam AsDataType A data types (tuple).
37 * \tparam BsDataType B data types (tuple).
38 * \tparam DsDataType D data types (tuple).
39 * \tparam EDataType Output data type.
40 * \tparam AElementwiseOperation A elementwise operation.
41 * \tparam BElementwiseOperation B elementwise operation.
42 * \tparam CDEElementwiseOperation C elementwise operation.
43 */
44template <typename AsLayout,
45 typename BsLayout,
46 typename DsLayout,
47 typename ELayout,
48 typename AsDataType,
49 typename BsDataType,
50 typename DsDataType,
51 typename EDataType,
52 typename AElementwiseOperation,
53 typename BElementwiseOperation,
54 typename CDEElementwiseOperation>
56{
57 static constexpr index_t NumATensor = AsDataType::Size();
58 static constexpr index_t NumBTensor = BsDataType::Size();
59 static constexpr index_t NumDTensor = DsDataType::Size();
60
61 static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor");
62 static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor");
63 static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
64
65 /*
66 * \brief Make argument pointer for grouped gemm multi abd.
67 *
68 * \param p_as A pointers to the A.
69 * \param p_bs A pointers to the B.
70 * \param p_ds A pointers to the Ds.
71 * \param p_e A pointers to the E.
72 * \param gemm_desc Gemm descriptors for each group.
73 * \param a_element_op A elementwise operation object.
74 * \param b_element_op B elementwise operation object.
75 * \param cde_element_op CDE elementwise operation object.
76 * \return Pointer to the argument.
77 */
78 virtual std::unique_ptr<BaseArgument>
79 MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_as,
80 std::vector<std::array<const void*, NumBTensor>>& p_bs,
81 std::vector<std::array<const void*, NumDTensor>>& p_ds,
82 std::vector<void*>& p_e,
83 std::vector<GemmMultiABDDesc>& gemm_desc,
84 AElementwiseOperation a_element_op = AElementwiseOperation{},
85 BElementwiseOperation b_element_op = BElementwiseOperation{},
86 CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) = 0;
87
88 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
89
90 virtual void SetElementwiseOps(BaseArgument* p_arg,
91 AElementwiseOperation a_element_op,
92 BElementwiseOperation b_element_op,
93 CDEElementwiseOperation cde_element_op) const = 0;
94};
95
96} // namespace device
97} // namespace tensor_operation
98} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_base.hpp:197
Definition device_grouped_gemm_multi_abd.hpp:56
virtual void SetElementwiseOps(BaseArgument *p_arg, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) const =0
static constexpr index_t NumDTensor
Definition device_grouped_gemm_multi_abd.hpp:59
static constexpr index_t NumATensor
Definition device_grouped_gemm_multi_abd.hpp:57
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< std::array< const void *, NumATensor > > &p_as, std::vector< std::array< const void *, NumBTensor > > &p_bs, std::vector< std::array< const void *, NumDTensor > > &p_ds, std::vector< void * > &p_e, std::vector< GemmMultiABDDesc > &gemm_desc, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation c_element_op=CDEElementwiseOperation{})=0
static constexpr index_t NumBTensor
Definition device_grouped_gemm_multi_abd.hpp:58
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition device_grouped_gemm_multi_abd.hpp:16
ck::index_t stride_C_
Definition device_grouped_gemm_multi_abd.hpp:23
std::vector< ck::index_t > stride_Ds_
Definition device_grouped_gemm_multi_abd.hpp:21
std::vector< ck::index_t > stride_As_
Definition device_grouped_gemm_multi_abd.hpp:19
ck::index_t M_
Definition device_grouped_gemm_multi_abd.hpp:17
std::vector< ck::index_t > stride_Bs_
Definition device_grouped_gemm_multi_abd.hpp:20
ck::index_t N_
Definition device_grouped_gemm_multi_abd.hpp:17
ck::index_t K_
Definition device_grouped_gemm_multi_abd.hpp:17