32 const CDEElementWise& cde_elementwise)
34 std::cout <<
"Calculating reference using optimized flat indexing with parallel processing..."
38 auto f_gm = [&](
auto g_flat,
auto m_flat) {
47 a_full_dims.
mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
49 b_full_dims.
mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
50 sum +=
static_cast<AccDataType
>(a_val) *
static_cast<AccDataType
>(b_val);
54 EDataType result =
static_cast<EDataType
>(sum);
55 if(ds_full_dims_host.size() == 0)
59 else if(ds_full_dims_host.size() == 1)
61 cde_elementwise(result,
64 ds_full_dims_host[0].mData[g_flat * M_total * N_total +
65 m_flat * N_total + n_flat]));
67 else if(ds_full_dims_host.size() == 2)
74 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
77 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
79 else if(ds_full_dims_host.size() == 3)
86 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
89 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
92 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
94 else if(ds_full_dims_host.size() == 4)
101 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
104 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
107 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
110 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
114 throw std::runtime_error(
"Unsupported NumDTensor for reference calculation");
118 e_full_dims_host_ref.
mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
119 static_cast<EDataType
>(result);
139 const std::vector<index_t>& G_dims,
140 const std::vector<index_t>& M_dims,
141 const std::vector<index_t>& N_dims,
142 const std::vector<index_t>& K_dims,
143 const std::vector<index_t>& A_dims,
144 const std::vector<index_t>& B_dims,
145 const std::vector<index_t>& E_dims,
146 const CDEElementWise& cde_elementwise)
148 std::cout <<
"Calculating reference using multi-dimensional indexing..." << std::endl;
150 std::vector<std::size_t> g_idx(G_dims.size());
151 std::vector<std::size_t> m_idx(M_dims.size());
152 std::vector<std::size_t> n_idx(N_dims.size());
153 std::vector<std::size_t> k_idx(K_dims.size());
154 std::vector<std::size_t> a_idx, b_idx, e_idx;
156 a_idx.reserve(A_dims.size());
157 b_idx.reserve(B_dims.size());
158 e_idx.reserve(E_dims.size());
160 auto calculate_total_elements = [](
const std::vector<ck_tile::index_t>& dims) {
161 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
164 for(
ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
167 for(
int i = G_dims.size() - 1; i >= 0; --i)
169 g_idx[i] = temp % G_dims[i];
173 for(
ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
176 for(
int i = M_dims.size() - 1; i >= 0; --i)
178 m_idx[i] = temp % M_dims[i];
182 for(
ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
185 for(
int i = N_dims.size() - 1; i >= 0; --i)
187 n_idx[i] = temp % N_dims[i];
197 for(
int i = K_dims.size() - 1; i >= 0; --i)
199 k_idx[i] = temp % K_dims[i];
206 a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
207 a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
208 a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
210 b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
211 b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
212 b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
214 auto a_val = a_full_dims(a_idx);
215 auto b_val = b_full_dims(b_idx);
217 sum +=
static_cast<AccDataType
>(a_val) *
static_cast<AccDataType
>(b_val);
221 e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
222 e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
223 e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
225 EDataType result =
static_cast<EDataType
>(sum);
226 if(ds_full_dims_host.size() == 0)
230 else if(ds_full_dims_host.size() == 1)
232 cde_elementwise(result,
236 else if(ds_full_dims_host.size() == 2)
238 cde_elementwise(result,
243 else if(ds_full_dims_host.size() == 3)
245 cde_elementwise(result,
251 else if(ds_full_dims_host.size() == 4)
253 cde_elementwise(result,
262 throw std::runtime_error(
"Unsupported NumDTensor for reference calculation");
265 e_full_dims_host_ref(e_idx) =
static_cast<EDataType
>(result);
void calculate_reference_multi_dimensional(const HostTensor< ADataType > &a_full_dims, const HostTensor< BDataType > &b_full_dims, const std::vector< HostTensor< DDataType > > &ds_full_dims_host, HostTensor< EDataType > &e_full_dims_host_ref, const std::vector< index_t > &G_dims, const std::vector< index_t > &M_dims, const std::vector< index_t > &N_dims, const std::vector< index_t > &K_dims, const std::vector< index_t > &A_dims, const std::vector< index_t > &B_dims, const std::vector< index_t > &E_dims, const CDEElementWise &cde_elementwise)
Definition reference_batched_contraction.hpp:134
void calculate_reference_flat_indexing(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::vector< ck_tile::HostTensor< DDataType > > &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise)
Definition reference_batched_contraction.hpp:23