22template <
typename Range>
26 int precision = std::cout.precision(),
36 os << std::setw(width) << std::setprecision(precision) << v;
41template <
typename T,
typename Range>
45 int precision = std::cout.precision(),
55 os << std::setw(width) << std::setprecision(precision) << static_cast<T>(v);
60template <
typename F,
typename T, std::size_t... Is>
63 return f(std::get<Is>(args)...);
66template <
typename F,
typename T>
69 constexpr std::size_t N = std::tuple_size<T>{};
74template <
typename F,
typename T, std::size_t... Is>
77 return F(std::get<Is>(args)...);
80template <
typename F,
typename T>
83 constexpr std::size_t N = std::tuple_size<T>{};
108 mStrides.resize(mLens.size(), 0);
113 std::partial_sum(mLens.rbegin(),
115 mStrides.rbegin() + 1,
116 std::multiplies<std::size_t>());
119 template <
typename X,
typename = std::enable_if_t<std::is_convertible_v<X, std::
size_t>>>
125 template <
typename Lengths,
126 typename = std::enable_if_t<
127 std::is_convertible_v<ck_tile::ranges::range_value_t<Lengths>, std::size_t>>>
133 template <
typename X,
135 typename = std::enable_if_t<std::is_convertible_v<X, std::size_t> &&
136 std::is_convertible_v<Y, std::size_t>>>
138 const std::initializer_list<Y>& strides)
139 : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
143 template <
typename Lengths,
145 typename = std::enable_if_t<
146 std::is_convertible_v<ck_tile::ranges::range_value_t<Lengths>, std::size_t> &&
147 std::is_convertible_v<ck_tile::ranges::range_value_t<Strides>, std::size_t>>>
149 : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
167 assert(mLens.size() == mStrides.size());
168 return std::accumulate(
169 mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
185 std::size_t space = 1;
186 for(std::size_t i = 0; i < mLens.size(); ++i)
191 space += (mLens[i] - 1) * mStrides[i];
196 std::size_t
get_length(std::size_t dim)
const {
return mLens[dim]; }
198 const std::vector<std::size_t>&
get_lengths()
const {
return mLens; }
200 std::size_t
get_stride(std::size_t dim)
const {
return mStrides[dim]; }
202 const std::vector<std::size_t>&
get_strides()
const {
return mStrides; }
216 template <
typename... Is>
220 std::initializer_list<std::size_t> iss{
static_cast<std::size_t
>(is)...};
221 return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
235 return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
254 std::vector<std::size_t> mLens;
255 std::vector<std::size_t> mStrides;
258template <
typename New2Old>
262 std::vector<std::size_t> new_lengths(
a.get_num_of_dimension());
263 std::vector<std::size_t> new_strides(
a.get_num_of_dimension());
265 for(std::size_t i = 0; i <
a.get_num_of_dimension(); i++)
267 new_lengths[i] =
a.get_lengths()[new2old[i]];
268 new_strides[i] =
a.get_strides()[new2old[i]];
274template <
typename F,
typename... Xs>
278 static constexpr std::size_t
NDIM =
sizeof...(Xs);
279 std::array<std::size_t, NDIM>
mLens;
286 std::partial_sum(
mLens.rbegin(),
289 std::multiplies<std::size_t>());
295 std::array<std::size_t, NDIM> indices;
297 for(std::size_t idim = 0; idim <
NDIM; ++idim)
300 i -= indices[idim] *
mStrides[idim];
308 std::size_t work_per_thread = (
mN1d + num_thread - 1) / num_thread;
310 std::vector<joinable_thread> threads(num_thread);
312 for(std::size_t it = 0; it < num_thread; ++it)
314 std::size_t iw_begin = it * work_per_thread;
315 std::size_t iw_end = std::min((it + 1) * work_per_thread,
mN1d);
317 auto f = [
this, iw_begin, iw_end] {
318 for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
328template <
typename F,
typename... Xs>
340 template <
typename X>
345 template <
typename X,
typename Y>
346 HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
351 template <
typename Lengths>
356 template <
typename Lengths,
typename Str
ides>
364 template <
typename OutT>
369 return ck_tile::type_convert<OutT>(value);
383 template <
typename FromT>
403 return mDesc.get_element_space_size() / PackedSize;
414 if constexpr(std::is_same_v<T, e8m0_t>)
420 template <
typename F>
429 for(
size_t i = 0; i <
mDesc.get_lengths()[
rank]; i++)
436 template <
typename F>
439 std::vector<size_t> idx(
mDesc.get_num_of_dimension(), 0);
443 template <
typename F>
452 for(
size_t i = 0; i <
mDesc.get_lengths()[
rank]; i++)
459 template <
typename F>
462 std::vector<size_t> idx(
mDesc.get_num_of_dimension(), 0);
466 template <
typename G>
469 switch(
mDesc.get_num_of_dimension())
472 auto f = [&](
auto i) { (*this)(i) = g(i); };
477 auto f = [&](
auto i0,
auto i1) { (*this)(i0, i1) = g(i0, i1); };
483 auto f = [&](
auto i0,
auto i1,
auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
485 mDesc.get_lengths()[0],
486 mDesc.get_lengths()[1],
487 mDesc.get_lengths()[2])(num_thread);
491 auto f = [&](
auto i0,
auto i1,
auto i2,
auto i3) {
492 (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
495 mDesc.get_lengths()[0],
496 mDesc.get_lengths()[1],
497 mDesc.get_lengths()[2],
498 mDesc.get_lengths()[3])(num_thread);
502 auto f = [&](
auto i0,
auto i1,
auto i2,
auto i3,
auto i4) {
503 (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
506 mDesc.get_lengths()[0],
507 mDesc.get_lengths()[1],
508 mDesc.get_lengths()[2],
509 mDesc.get_lengths()[3],
510 mDesc.get_lengths()[4])(num_thread);
514 auto f = [&](
auto i0,
auto i1,
auto i2,
auto i3,
auto i4,
auto i5) {
515 (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
518 mDesc.get_lengths()[0],
519 mDesc.get_lengths()[1],
520 mDesc.get_lengths()[2],
521 mDesc.get_lengths()[3],
522 mDesc.get_lengths()[4],
523 mDesc.get_lengths()[5])(num_thread);
526 default:
throw std::runtime_error(
"unspported dimension");
530 template <
typename... Is>
534 return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize;
537 template <
typename... Is>
543 template <
typename... Is>
554 const T&
operator()(
const std::vector<std::size_t>& idx)
const
564 std::iota(axes.rbegin(), axes.rend(), 0);
568 throw std::runtime_error(
569 "HostTensor::transpose(): size of axes must match tensor dimension");
571 std::vector<size_t> tlengths, tstrides;
572 for(
const auto& axis : axes)
578 ret.mDesc = HostTensorDescriptor(tlengths, tstrides);
589 typename Data::iterator
end() {
return mData.end(); }
593 typename Data::const_iterator
begin()
const {
return mData.begin(); }
595 typename Data::const_iterator
end()
const {
return mData.end(); }
597 typename Data::const_pointer
data()
const {
return mData.data(); }
599 typename Data::size_type
size()
const {
return mData.size(); }
603 auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end)
const
605 assert(s_begin.size() == s_end.size());
608 std::vector<size_t> s_len(s_begin.size());
610 s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
613 sliced_tensor.
ForEach([&](
auto& self,
auto idx) {
614 std::vector<size_t> src_idx(idx.size());
616 idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
620 return sliced_tensor;
623 template <
typename U = T>
626 constexpr std::size_t FromSize =
sizeof(T);
627 constexpr std::size_t ToSize =
sizeof(U);
629 using Element = std::add_const_t<std::remove_reference_t<U>>;
631 size() * FromSize / ToSize};
634 template <
typename U = T>
637 constexpr std::size_t FromSize =
sizeof(T);
638 constexpr std::size_t ToSize =
sizeof(U);
640 using Element = std::remove_reference_t<U>;
642 size() * FromSize / ToSize};
656 for(
typename Data::size_type idx = 0; idx < std::min(n,
mData.size()); ++idx)
662 if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
663 std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
665 os << type_convert<float>(
mData[idx]) <<
" #### ";
667 else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
670 os <<
"pk(" <<
static_cast<int>(unpacked[0]) <<
", "
671 <<
static_cast<int>(unpacked[1]) <<
") #### ";
673 else if constexpr(std::is_same_v<T, int8_t>)
675 os << static_cast<int>(
mData[idx]);
694 for(
typename Data::size_type idx = 0; idx < t.
mData.size(); ++idx)
700 if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
701 std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
703 os << type_convert<float>(t.
mData[idx]) <<
" #### ";
705 else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
708 os <<
"pk(" <<
static_cast<int>(unpacked[0]) <<
", "
709 <<
static_cast<int>(unpacked[1]) <<
") #### ";
727 void loadtxt(std::string file_name, std::string dtype =
"float")
729 std::ifstream file(file_name);
736 while(std::getline(file, line))
740 throw std::runtime_error(std::string(
"data read from file:") + file_name +
748 else if(dtype ==
"int" || dtype ==
"int32")
757 std::cerr <<
"Warning! reading from file:" << file_name
758 <<
", does not match the size of this tensor" << std::endl;
765 throw std::runtime_error(std::string(
"unable to open file:") + file_name);
771 void savetxt(std::string file_name, std::string dtype =
"float")
773 std::ofstream file(file_name);
777 for(
auto& itm :
mData)
780 file << type_convert<float>(itm) << std::endl;
781 else if(dtype ==
"int")
782 file << type_convert<int>(itm) << std::endl;
783 else if(dtype ==
"int8_t")
788 file << type_convert<float>(itm) << std::endl;
796 throw std::runtime_error(std::string(
"unable to open file:") + file_name);
822template <
bool is_row_major>
830 if constexpr(is_row_major)
840template <
bool is_row_major>
848 if constexpr(is_row_major)
Definition tile/core/container/span.hpp:18
#define CK_TILE_HOST
Definition config.hpp:40
__host__ __device__ constexpr auto rank(const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition layout_utils.hpp:310
Definition tile/core/utility/literals.hpp:9
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST auto call_f_unpack_args(F f, T args)
Definition tile/host/host_tensor.hpp:67
CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor &a, const New2Old &new2old)
Definition tile/host/host_tensor.hpp:259
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_HOST std::ostream & LogRangeAsType(std::ostream &os, Range &&range, std::string delim, int precision=std::cout.precision(), int width=0)
Definition tile/host/host_tensor.hpp:42
CK_TILE_HOST auto call_f_unpack_args_impl(F f, T args, std::index_sequence< Is... >)
Definition tile/host/host_tensor.hpp:61
auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, bool_constant< is_row_major >)
Creates a host tensor descriptor with specified dimensions and layout.
Definition tile/host/host_tensor.hpp:823
CK_TILE_HOST auto construct_f_unpack_args(F, T args)
Definition tile/host/host_tensor.hpp:81
e8m0_bexp_t e8m0_t
Definition tile/core/numeric/e8m0.hpp:49
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t &x)
Definition pk_int4.hpp:169
CK_TILE_HOST auto construct_f_unpack_args_impl(T args, std::index_sequence< Is... >)
Definition tile/host/host_tensor.hpp:75
CK_TILE_HOST std::ostream & LogRange(std::ostream &os, Range &&range, std::string delim, int precision=std::cout.precision(), int width=0)
Definition tile/host/host_tensor.hpp:23
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant< is_row_major >)
Definition tile/host/host_tensor.hpp:841
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Descriptor for tensors in host memory.
Definition tile/host/host_tensor.hpp:102
std::size_t get_stride(std::size_t dim) const
Definition tile/host/host_tensor.hpp:200
std::size_t GetOffsetFromMultiIndex(Is... is) const
Calculates the linear offset from multi-dimensional indices.
Definition tile/host/host_tensor.hpp:217
std::size_t get_element_size() const
Calculates the total number of elements in the tensor.
Definition tile/host/host_tensor.hpp:165
void CalculateStrides()
Definition tile/host/host_tensor.hpp:105
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:153
const std::vector< std::size_t > & get_strides() const
Definition tile/host/host_tensor.hpp:202
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
HostTensorDescriptor(const std::initializer_list< X > &lens, const std::initializer_list< Y > &strides)
Definition tile/host/host_tensor.hpp:137
std::size_t get_element_space_size() const
Calculates the total element space required for the tensor in memory.
Definition tile/host/host_tensor.hpp:183
std::size_t get_length(std::size_t dim) const
Definition tile/host/host_tensor.hpp:196
HostTensorDescriptor(const Lengths &lens, const Strides &strides)
Definition tile/host/host_tensor.hpp:148
friend std::ostream & operator<<(std::ostream &os, const HostTensorDescriptor &desc)
Definition tile/host/host_tensor.hpp:238
HostTensorDescriptor()=default
HostTensorDescriptor(const std::initializer_list< X > &lens)
Definition tile/host/host_tensor.hpp:120
std::size_t GetOffsetFromMultiIndex(const std::vector< std::size_t > &iss) const
Calculates the linear memory offset from a multi-dimensional index.
Definition tile/host/host_tensor.hpp:233
HostTensorDescriptor(const Lengths &lens)
Definition tile/host/host_tensor.hpp:128
Definition tile/host/host_tensor.hpp:336
void ForEach(F &&f)
Definition tile/host/host_tensor.hpp:437
HostTensor & operator=(const HostTensor &)=default
std::size_t get_stride(std::size_t dim) const
Definition tile/host/host_tensor.hpp:392
void ForEach(const F &&f) const
Definition tile/host/host_tensor.hpp:460
HostTensor(HostTensor &&)=default
T & operator()(Is... is)
Definition tile/host/host_tensor.hpp:538
Data::size_type size() const
Definition tile/host/host_tensor.hpp:599
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
friend std::ostream & operator<<(std::ostream &os, const HostTensor< T > &t)
Definition tile/host/host_tensor.hpp:690
HostTensor(std::initializer_list< X > lens, std::initializer_list< Y > strides)
Definition tile/host/host_tensor.hpp:346
HostTensor(std::initializer_list< X > lens)
Definition tile/host/host_tensor.hpp:341
const T & operator()(const std::vector< std::size_t > &idx) const
Definition tile/host/host_tensor.hpp:554
std::size_t get_element_space_size_in_bytes() const
Definition tile/host/host_tensor.hpp:406
decltype(auto) get_strides() const
Definition tile/host/host_tensor.hpp:394
HostTensor(const HostTensor &)=default
const T & operator()(Is... is) const
Definition tile/host/host_tensor.hpp:544
Data::iterator end()
Definition tile/host/host_tensor.hpp:589
void GenerateTensorValue(G g, std::size_t num_thread=1)
Definition tile/host/host_tensor.hpp:467
void SetZero()
Definition tile/host/host_tensor.hpp:412
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800
T & operator()(const std::vector< std::size_t > &idx)
Definition tile/host/host_tensor.hpp:549
HostTensor(const Lengths &lens)
Definition tile/host/host_tensor.hpp:352
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition tile/host/host_tensor.hpp:531
Data::pointer data()
Definition tile/host/host_tensor.hpp:591
HostTensorDescriptor Descriptor
Definition tile/host/host_tensor.hpp:337
auto AsSpan() const
Definition tile/host/host_tensor.hpp:624
auto slice(std::vector< size_t > s_begin, std::vector< size_t > s_end) const
Definition tile/host/host_tensor.hpp:603
std::vector< T > Data
Definition tile/host/host_tensor.hpp:338
auto AsSpan()
Definition tile/host/host_tensor.hpp:635
Data::const_iterator begin() const
Definition tile/host/host_tensor.hpp:593
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:396
HostTensor & operator=(HostTensor &&)=default
std::size_t get_element_space_size() const
Definition tile/host/host_tensor.hpp:400
HostTensor< T > transpose(std::vector< size_t > axes={})
Definition tile/host/host_tensor.hpp:582
HostTensor(const Lengths &lens, const Strides &strides)
Definition tile/host/host_tensor.hpp:357
void loadtxt(std::string file_name, std::string dtype="float")
Definition tile/host/host_tensor.hpp:727
Data::const_pointer data() const
Definition tile/host/host_tensor.hpp:597
std::ostream & print_first_n(std::ostream &os, std::size_t n=5) const
Print only the first N elements of the tensor.
Definition tile/host/host_tensor.hpp:652
void ForEach_impl(const F &&f, std::vector< size_t > &idx, size_t rank) const
Definition tile/host/host_tensor.hpp:444
HostTensor(const Descriptor &desc)
Definition tile/host/host_tensor.hpp:362
Data::iterator begin()
Definition tile/host/host_tensor.hpp:587
HostTensor< OutT > CopyAsType() const
Definition tile/host/host_tensor.hpp:365
void savetxt(std::string file_name, std::string dtype="float")
Definition tile/host/host_tensor.hpp:771
HostTensor(const HostTensor< FromT > &other)
Definition tile/host/host_tensor.hpp:384
std::size_t get_length(std::size_t dim) const
Definition tile/host/host_tensor.hpp:388
HostTensor< T > transpose(std::vector< size_t > axes={}) const
Definition tile/host/host_tensor.hpp:559
std::size_t get_element_size() const
Definition tile/host/host_tensor.hpp:398
void ForEach_impl(F &&f, std::vector< size_t > &idx, size_t rank)
Definition tile/host/host_tensor.hpp:421
Data::const_iterator end() const
Definition tile/host/host_tensor.hpp:595
Data mData
Definition tile/host/host_tensor.hpp:801
Definition tile/host/host_tensor.hpp:276
void operator()(std::size_t num_thread=1) const
Definition tile/host/host_tensor.hpp:306
std::array< std::size_t, NDIM > GetNdIndices(std::size_t i) const
Definition tile/host/host_tensor.hpp:293
ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:283
std::size_t mN1d
Definition tile/host/host_tensor.hpp:281
std::array< std::size_t, NDIM > mLens
Definition tile/host/host_tensor.hpp:279
std::array< std::size_t, NDIM > mStrides
Definition tile/host/host_tensor.hpp:280
static constexpr std::size_t NDIM
Definition tile/host/host_tensor.hpp:278
F mF
Definition tile/host/host_tensor.hpp:277
Definition joinable_thread.hpp:12
Definition tile/core/numeric/numeric.hpp:81