tensor_utils.hpp Source File

tensor_utils.hpp Source File#

Composable Kernel: tensor_utils.hpp Source File
tensor_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/ck.hpp"
7
10#include "ck/utility/tuple.hpp"
15
16// Disable from doxygen docs generation
18namespace ck {
19namespace wrapper {
21
30using MemoryTypeEnum = AddressSpaceEnum;
31
32// Disable from doxygen docs generation
34// forward declarations
35template <typename Shape, typename UnrolledDescriptorType>
36struct Layout;
37template <MemoryTypeEnum BufferAddressSpace,
38 typename ElementType,
39 typename Shape,
40 typename UnrolledDescriptorType>
41struct Tensor;
42
43template <typename FromType, typename ToType>
44struct Slice
45{
46 __host__ __device__ constexpr Slice() : from_(), to_() {}
47 __host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
48
55 template <typename T>
56 __host__ __device__ constexpr auto range(const T& dim) const
57 {
58 if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
59 is_same_v<std::remove_const_t<T>, index_t>)
60 {
61 if(to_ < 0)
62 {
63 return dim - from_ + to_ + 1;
64 }
65 else
66 {
67 // workaround if one end of the interval is index_t and the second one is Number
68 return static_cast<index_t>(to_) - static_cast<index_t>(from_);
69 }
70 }
71 else
72 {
73 static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} &&
74 (ToType{} < 0 || ToType{} > FromType{}),
75 "Invalid range");
76 if constexpr(ToType{} < 0)
77 {
78 return dim - from_ + to_ + Number<1>{};
79 }
80 else
81 {
82 return to_ - from_;
83 }
84 }
85 }
86
87 __host__ __device__ static constexpr bool IsSlice() { return true; }
88
89 const FromType from_;
90 const ToType to_;
91};
92
93template <typename T>
94using is_slice = decltype(std::declval<T&>().IsSlice());
95
96template <typename T>
97using is_tuple = decltype(std::declval<T&>().IsTuple());
99
108template <MemoryTypeEnum MemoryType,
109 typename ElementType,
110 typename Shape,
111 typename UnrolledDescriptorType>
117
125template <MemoryTypeEnum MemoryType,
126 typename ElementType,
127 typename Shape,
128 typename UnrolledDescriptorType>
133
139template <MemoryTypeEnum BufferAddressSpace,
140 typename ElementType,
141 typename Shape,
142 typename UnrolledDescriptorType>
143__host__ __device__ void
150
157template <MemoryTypeEnum BufferAddressSpace,
158 typename ElementType,
159 typename Shape,
160 typename UnrolledDescriptorType>
161__host__ __device__ constexpr const auto&
166
174template <index_t... Idxs,
175 MemoryTypeEnum BufferAddressSpace,
176 typename ElementType,
177 typename Shape,
178 typename UnrolledDescriptorType>
179__host__ __device__ constexpr auto
181{
182 return size<Idxs...>(tensor.GetLayout());
183}
184
192template <index_t... Idxs,
193 MemoryTypeEnum BufferAddressSpace,
194 typename ElementType,
195 typename Shape,
196 typename UnrolledDescriptorType>
197__host__ __device__ constexpr auto
202
210template <index_t... Idxs,
211 MemoryTypeEnum BufferAddressSpace,
212 typename ElementType,
213 typename Shape,
214 typename UnrolledDescriptorType>
215__host__ __device__ constexpr auto
220
227template <MemoryTypeEnum BufferAddressSpace,
228 typename ElementType,
229 typename Shape,
230 typename UnrolledDescriptorType>
231__host__ __device__ constexpr const auto&
236
244template <typename FromType, typename ToType>
245constexpr auto slice(const FromType from, const ToType to)
246{
247 return Slice<FromType, ToType>(from, to);
248}
249
256template <typename ToType>
257constexpr auto slice(const ToType to)
258{
259 if constexpr(is_same_v<ToType, index_t>)
260 {
261 return Slice<index_t, ToType>(0, to);
262 }
263 else
264 {
265 return Slice<Number<0>, ToType>(Number<0>{}, to);
266 }
267}
268
274constexpr auto slice() { return Slice<Number<0>, Number<-1>>(Number<0>{}, Number<-1>{}); }
275
276} // namespace wrapper
277} // namespace ck
__host__ __device__ constexpr const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition layout_utils.hpp:431
__host__ __device__ constexpr auto depth(const Layout< Shape, UnrolledDescriptorType > &layout)
Get depth of the layout shape (return 0 if scalar).
Definition layout_utils.hpp:371
__host__ __device__ constexpr auto rank(const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition layout_utils.hpp:310
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
const GenericPointer< typename T::ValueType > & pointer
Definition pointer.h:1514
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition library/utility/host_tensor.hpp:694
__host__ __device__ constexpr const Layout< Shape, UnrolledDescriptorType > & GetLayout() const
Definition tensor.hpp:246
static constexpr bool IsDynamicBuffer
Definition tensor.hpp:224
__host__ __device__ constexpr auto & GetBuffer()
Definition tensor.hpp:380
constexpr auto slice()
Get whole dim slice (from = 0, to = -1).
Definition tensor_utils.hpp:274
__host__ __device__ constexpr const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition tensor_utils.hpp:162
__host__ __device__ void clear(Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Clear tensor. (Only for Vpgr/Sgpr).
Definition tensor_utils.hpp:144
AddressSpaceEnum MemoryTypeEnum
Memory type, allowed members:
Definition tensor_utils.hpp:30
constexpr auto make_register_tensor(const Layout< Shape, UnrolledDescriptorType > &layout)
Make SGPR or VGPR tensor function.
Definition tensor_utils.hpp:129
constexpr auto make_tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Make tensor function.
Definition tensor_utils.hpp:112