device_reduce_multi_d.hpp Source File

device_reduce_multi_d.hpp Source File#

Composable Kernel: device_reduce_multi_d.hpp Source File
device_reduce_multi_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7#include <memory>
8
9#include "ck/ck.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
16template <typename InDataType,
17 typename DsDataType,
18 typename AccDataType,
19 typename OutDataType,
20 index_t Rank,
21 index_t NumReduceDim,
22 typename ReduceOperation,
23 typename InElementwiseOperation,
24 typename OutElementwiseOperation>
26{
27 static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
28
29 static constexpr index_t NumDTensor = DsDataType::Size();
30
31 virtual std::unique_ptr<BaseArgument>
32 MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
33 const std::array<index_t, Rank> inStrides,
34 const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths,
35 const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides,
36 const std::array<index_t, NumOutDim> outLengths,
37 const std::array<index_t, NumOutDim> outStrides,
38 const std::array<int, NumReduceDim> reduceDims,
39 const void* in_dev,
40 const std::array<const void*, NumDTensor> ds_dev,
41 void* out_dev,
42 const InElementwiseOperation in_elementwise_op,
43 const OutElementwiseOperation out_elementwise_op) = 0;
44
45 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
46};
47
48template <typename InDataType,
49 typename DsDataType,
50 typename AccDataType,
51 typename OutDataType,
52 index_t Rank,
53 index_t NumReduceDim,
54 typename ReduceOperation,
55 typename InElementwiseOperation,
56 typename OutElementwiseOperation>
57using DeviceReduceMultiDPtr = std::unique_ptr<DeviceReduceMultiD<InDataType,
58 DsDataType,
59 AccDataType,
60 OutDataType,
61 Rank,
62 NumReduceDim,
63 ReduceOperation,
64 InElementwiseOperation,
65 OutElementwiseOperation>>;
66
67} // namespace device
68} // namespace tensor_operation
69} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceReduceMultiD< InDataType, DsDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, OutElementwiseOperation > > DeviceReduceMultiDPtr
Definition device_reduce_multi_d.hpp:57
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_reduce_multi_d.hpp:26
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumOutDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumOutDim >, NumDTensor > DsStrides, const std::array< index_t, NumOutDim > outLengths, const std::array< index_t, NumOutDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const void *in_dev, const std::array< const void *, NumDTensor > ds_dev, void *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op)=0
static constexpr index_t NumOutDim
Definition device_reduce_multi_d.hpp:27
static constexpr index_t NumDTensor
Definition device_reduce_multi_d.hpp:29
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0