gridwise_gemm_pipeline_v1.hpp Source File

gridwise_gemm_pipeline_v1.hpp Source File#

Composable Kernel: gridwise_gemm_pipeline_v1.hpp Source File
gridwise_gemm_pipeline_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck {
11
12template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
14
15// 1-stage prefetch
16template <>
17struct GridwiseGemmPipeline_v1<1, true, true>
18{
19 static constexpr auto I0 = Number<0>{};
20 static constexpr auto I1 = Number<1>{};
21
22 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
23
24 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
25 {
26 return num_loop > 1;
27 }
28
29 template <bool HasMainLoop,
30 typename AGridDesc,
31 typename ABlockDesc,
32 typename ABlockTransfer,
33 typename AGridBuffer,
34 typename ABlockBuffer,
35 typename ABlockTransferStep,
36 typename BGridDesc,
37 typename BBlockDesc,
38 typename BBlockTransfer,
39 typename BGridBuffer,
40 typename BBlockBuffer,
41 typename BBlockTransferStep,
42 typename BlockwiseGemm,
43 typename CThreadBuffer>
44 __device__ static void Run(const AGridDesc& a_grid_desc,
45 const ABlockDesc& a_block_desc,
46 ABlockTransfer& a_blockwise_copy,
47 const AGridBuffer& a_grid_buf,
48 ABlockBuffer& a_block_buf,
49 const ABlockTransferStep& a_block_copy_step,
50 const BGridDesc& b_grid_desc,
51 const BBlockDesc& b_block_desc,
52 BBlockTransfer& b_blockwise_copy,
53 const BGridBuffer& b_grid_buf,
54 BBlockBuffer& b_block_buf,
55 const BBlockTransferStep& b_block_copy_step,
56 const BlockwiseGemm& blockwise_gemm,
57 CThreadBuffer& c_thread_buf,
58 index_t num_loop)
59 {
60 // preload data into LDS
61 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
62 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
63
64 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
65 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
66
67 // Initialize C
68 c_thread_buf.Clear();
69
70 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
71 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
72
73 // main body
74 if constexpr(HasMainLoop)
75 {
76 index_t i = 0;
77
78 do
79 {
80 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
81
83
84 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
85
86 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
87
89
90 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
91 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
92
93 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
94 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
95
96 ++i;
97 } while(i < (num_loop - 1));
98 }
99
100 // tail
101 {
103
104 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
105 }
106 }
107};
108
109// 2-stage prefetch
110template <>
111struct GridwiseGemmPipeline_v1<2, true, true>
112{
113 static constexpr auto I0 = Number<0>{};
114 static constexpr auto I1 = Number<1>{};
115
116 __host__ __device__ static constexpr bool IsSupported(index_t num_loop)
117 {
118 // TODO: improve applicability
119 return num_loop % 2 == 0;
120 }
121
122 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
123 {
124 return (num_loop / 2) > 1;
125 }
126
127 template <bool HasMainLoop,
128 typename AGridDesc,
129 typename ABlockDesc,
130 typename ABlockTransfer,
131 typename AGridBuffer,
132 typename ABlockBuffer,
133 typename ABlockTransferStep,
134 typename BGridDesc,
135 typename BBlockDesc,
136 typename BBlockTransfer,
137 typename BGridBuffer,
138 typename BBlockBuffer,
139 typename BBlockTransferStep,
140 typename BlockwiseGemm,
141 typename CThreadBuffer>
142 static __device__ void Run(const AGridDesc& a_grid_desc,
143 const ABlockDesc& a_block_desc,
144 ABlockTransfer& a_blockwise_copy,
145 const AGridBuffer& a_grid_buf,
146 ABlockBuffer& a_block_buf,
147 const ABlockTransferStep& a_block_copy_step,
148 const BGridDesc& b_grid_desc,
149 const BBlockDesc& b_block_desc,
150 BBlockTransfer& b_blockwise_copy,
151 const BGridBuffer& b_grid_buf,
152 BBlockBuffer& b_block_buf,
153 const BBlockTransferStep& b_block_copy_step,
154 const BlockwiseGemm& blockwise_gemm,
155 CThreadBuffer& c_thread_buf,
156 index_t num_loop)
157 {
158 // preload data into LDS
159 {
160 // Read 0
161 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
162 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
163
164 // Move
165 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
166 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
167
168 // Read 1
169 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
170 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
171 }
172
173 // Initialize C
174 c_thread_buf.Clear();
175
176 // main body
177 if constexpr(HasMainLoop)
178 {
179 index_t i = 0;
180
181 do
182 {
183 // Move
184 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
185 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
186
187 // Write i
188 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
189 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
190
191 // Read i+2
192 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
193 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
194
195 // Sync
197
198 // Gemm i
199 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
200
201 // Sync
203
204 // Move
205 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
206 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
207
208 // Write i+1
209 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
210 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
211
212 // Read i+3
213 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
214 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
215
216 // Sync
218
219 // Gemm i+1
220 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
221
222 // Sync
224
225 i += 2;
226 } while(i < (num_loop - 2));
227 }
228
229 // tail
230 {
231 // Write num_loop - 2
232 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
233 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
234
235 // Sync
237
238 // Gemm num_loop - 2
239 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
240
241 // Sync
243
244 // Write num_loop - 1
245 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
246 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
247
248 // Sync
250
251 // Gemm num_loop - 1
252 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
253 }
254 }
255};
256
257template <>
258struct GridwiseGemmPipeline_v1<1, false, true>
259{
260 static constexpr auto I0 = Number<0>{};
261 static constexpr auto I1 = Number<1>{};
262
263 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
264
265 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
266 {
267 return num_loop > 1;
268 }
269
270 template <bool HasMainLoop,
271 typename AGridDesc,
272 typename ABlockDesc,
273 typename ABlockTransfer,
274 typename AGridBuffer,
275 typename ABlockBuffer,
276 typename ABlockTransferStep,
277 typename BGridDesc,
278 typename BBlockDesc,
279 typename BBlockTransfer,
280 typename BGridBuffer,
281 typename BBlockBuffer,
282 typename BBlockTransferStep,
283 typename BlockwiseGemm,
284 typename CThreadBuffer>
285 __device__ static void Run(const AGridDesc& a_grid_desc,
286 const ABlockDesc& a_block_desc,
287 ABlockTransfer& a_blockwise_copy,
288 const AGridBuffer& a_grid_buf,
289 ABlockBuffer& a_block_buf,
290 const ABlockTransferStep& a_block_copy_step,
291 const BGridDesc& b_grid_desc,
292 const BBlockDesc& b_block_desc,
293 BBlockTransfer& b_blockwise_copy,
294 const BGridBuffer& b_grid_buf,
295 BBlockBuffer& b_block_buf,
296 const BBlockTransferStep& b_block_copy_step,
297 const BlockwiseGemm& blockwise_gemm,
298 CThreadBuffer& c_thread_buf,
299 index_t num_loop)
300 {
301 constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
302 auto a_block_buf_switch = a_block_buf;
303
304 // preload data into LDS
305 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
306 a_blockwise_copy.Run(
307 a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
308
309 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
310 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
311
312 // Initialize C
313 c_thread_buf.Clear();
314
315 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
316
317 // main body
318 if constexpr(HasMainLoop)
319 {
320 index_t i = 0;
321
322 do
323 {
324 a_blockwise_copy.Run(
325 a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
326
328
329 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
330
331 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
332
334
335 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
336 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
337
338 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
339
340 a_block_buf = a_block_buf_switch;
341 ++i;
342 } while(i < (num_loop - 1));
343 }
344
345 // tail
346 {
348
349 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
350
352 }
353 }
354};
355
356template <>
357struct GridwiseGemmPipeline_v1<1, true, false>
358{
359 static constexpr auto I0 = Number<0>{};
360 static constexpr auto I1 = Number<1>{};
361
362 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
363
364 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
365 {
366 return num_loop > 1;
367 }
368
369 template <bool HasMainLoop,
370 typename AGridDesc,
371 typename ABlockDesc,
372 typename ABlockTransfer,
373 typename AGridBuffer,
374 typename ABlockBuffer,
375 typename ABlockTransferStep,
376 typename BGridDesc,
377 typename BBlockDesc,
378 typename BBlockTransfer,
379 typename BGridBuffer,
380 typename BBlockBuffer,
381 typename BBlockTransferStep,
382 typename BlockwiseGemm,
383 typename CThreadBuffer>
384 __device__ static void Run(const AGridDesc& a_grid_desc,
385 const ABlockDesc& a_block_desc,
386 ABlockTransfer& a_blockwise_copy,
387 const AGridBuffer& a_grid_buf,
388 ABlockBuffer& a_block_buf,
389 const ABlockTransferStep& a_block_copy_step,
390 const BGridDesc& b_grid_desc,
391 const BBlockDesc& b_block_desc,
392 BBlockTransfer& b_blockwise_copy,
393 const BGridBuffer& b_grid_buf,
394 BBlockBuffer& b_block_buf,
395 const BBlockTransferStep& b_block_copy_step,
396 const BlockwiseGemm& blockwise_gemm,
397 CThreadBuffer& c_thread_buf,
398 index_t num_loop)
399 {
400 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
401 auto b_block_buf_switch = b_block_buf;
402
403 // preload data into LDS
404 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
405 b_blockwise_copy.Run(
406 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
407
408 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
409 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
410
411 // Initialize C
412 c_thread_buf.Clear();
413
414 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
415
416 // main body
417 if constexpr(HasMainLoop)
418 {
419 index_t i = 0;
420
421 do
422 {
423 b_blockwise_copy.Run(
424 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
425
427
428 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
429
430 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
431
433
434 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
435 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
436
437 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
438
439 b_block_buf = b_block_buf_switch;
440 ++i;
441 } while(i < (num_loop - 1));
442 }
443
444 // tail
445 {
447
448 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
449
451 }
452 }
453};
454
455template <>
456struct GridwiseGemmPipeline_v1<1, false, false>
457{
458 static constexpr auto I0 = Number<0>{};
459 static constexpr auto I1 = Number<1>{};
460
461 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
462
463 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
464 {
465 return num_loop > 1;
466 }
467
468 template <bool HasMainLoop,
469 typename AGridDesc,
470 typename ABlockDesc,
471 typename ABlockTransfer,
472 typename AGridBuffer,
473 typename ABlockBuffer,
474 typename ABlockTransferStep,
475 typename BGridDesc,
476 typename BBlockDesc,
477 typename BBlockTransfer,
478 typename BGridBuffer,
479 typename BBlockBuffer,
480 typename BBlockTransferStep,
481 typename BlockwiseGemm,
482 typename CThreadBuffer>
483 __device__ static void Run(const AGridDesc& a_grid_desc,
484 const ABlockDesc& a_block_desc,
485 ABlockTransfer& a_blockwise_copy,
486 const AGridBuffer& a_grid_buf,
487 ABlockBuffer& a_block_buf,
488 const ABlockTransferStep& a_block_copy_step,
489 const BGridDesc& b_grid_desc,
490 const BBlockDesc& b_block_desc,
491 BBlockTransfer& b_blockwise_copy,
492 const BGridBuffer& b_grid_buf,
493 BBlockBuffer& b_block_buf,
494 const BBlockTransferStep& b_block_copy_step,
495 const BlockwiseGemm& blockwise_gemm,
496 CThreadBuffer& c_thread_buf,
497 index_t num_loop)
498 {
499 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
500 constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
501 auto b_block_buf_switch = b_block_buf;
502 auto a_block_buf_switch = a_block_buf;
503
504 // preload data into LDS
505 a_blockwise_copy.Run(
506 a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
507 b_blockwise_copy.Run(
508 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
509
510 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
511 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
512
513 // Initialize C
514 c_thread_buf.Clear();
515
516 // main body
517 if constexpr(HasMainLoop)
518 {
519 index_t i = 0;
520
521 do
522 {
523 a_blockwise_copy.Run(
524 a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
525 b_blockwise_copy.Run(
526 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
527
529
530 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
531
533
534 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
535 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
536
537 a_block_buf = a_block_buf_switch;
538 b_block_buf = b_block_buf_switch;
539 ++i;
540 } while(i < (num_loop - 1));
541 }
542
543 // tail
544 {
546
547 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
548
550 }
551 }
552};
553
554template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
556
557template <>
559{
560 static constexpr auto I0 = Number<0>{};
561 static constexpr auto I1 = Number<1>{};
562
563 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
564
565 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
566 {
567 return num_loop > 1;
568 }
569
570 template <bool HasMainLoop,
571 typename AGridDesc,
572 typename ABlockDesc,
573 typename ABlockTransfer,
574 typename AGridBuffer,
575 typename ABlockBuffer,
576 typename ABlockTransferStep,
577 typename BGridDesc,
578 typename BBlockDesc,
579 typename BBlockTransfer,
580 typename BGridBuffer,
581 typename BBlockBuffer,
582 typename BBlockTransferStep,
583 typename ScaleGridDesc,
584 typename ScaleGridBuffer,
585 typename BlockwiseGemm,
586 typename CThreadBuffer>
587 __device__ static void Run(const AGridDesc& a_grid_desc,
588 const ABlockDesc& a_block_desc,
589 ABlockTransfer& a_blockwise_copy,
590 const AGridBuffer& a_grid_buf,
591 ABlockBuffer& a_block_buf,
592 const ABlockTransferStep& a_block_copy_step,
593 const BGridDesc& b_grid_desc,
594 const BBlockDesc& b_block_desc,
595 BBlockTransfer& b_blockwise_copy,
596 const BGridBuffer& b_grid_buf,
597 BBlockBuffer& b_block_buf,
598 const BBlockTransferStep& b_block_copy_step,
599 const ScaleGridDesc& scale_grid_desc,
600 const ScaleGridBuffer& scale_grid_buf,
601 const BlockwiseGemm& blockwise_gemm,
602 CThreadBuffer& c_thread_buf,
603 index_t num_loop)
604 {
605 // Global Prefetch Stage 1
606 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
607 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
608 // Scale read once
609 b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf);
610
611 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
612 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
613
614 // Initialize C
615 c_thread_buf.Clear();
616
617 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
618 // Dequantization fused in blockwise_copy
619 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
620
621 // main body
622 if constexpr(HasMainLoop)
623 {
624 index_t i = 0;
625
626 do
627 {
628 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
629
631
632 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
633
634 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
635
637
638 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
639 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
640
641 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
642 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
643
644 ++i;
645 } while(i < (num_loop - 1));
646 }
647
648 // tail
649 {
651
652 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
653 }
654 }
655};
656
657template <index_t NumPrefetch>
659
660template <>
662{
663 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
664
665 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
666 {
667 return num_loop > 1;
668 }
669
670 template <bool HasMainLoop,
671 typename AGridDesc,
672 typename ABlockDesc,
673 typename ABlockTransfer,
674 typename AGridBuffer,
675 typename ABlockBuffer,
676 typename ABlockTransferStep,
677 typename BGridDesc,
678 typename BBlockDesc,
679 typename BBlockTransfer,
680 typename BGridBuffer,
681 typename BBlockBuffer,
682 typename BBlockTransferStep,
683 typename BlockwiseGemm,
684 typename CThreadBuffer>
685 static __device__ void Run(const AGridDesc& a_grid_desc,
686 const ABlockDesc& a_block_desc,
687 ABlockTransfer& a_blockwise_copy,
688 const AGridBuffer& a_grid_buf,
689 ABlockBuffer& a_block_buf,
690 const ABlockTransferStep& a_block_copy_step,
691 const BGridDesc& b_grid_desc,
692 const BBlockDesc& b_block_desc,
693 BBlockTransfer& b_blockwise_copy,
694 const BGridBuffer& b_grid_buf,
695 BBlockBuffer& b_block_buf,
696 const BBlockTransferStep& b_block_copy_step,
697 const BlockwiseGemm& blockwise_gemm,
698 CThreadBuffer& c_thread_buf,
699 index_t num_loop)
700 {
701 // preload data into LDS
702 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
703 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
704
705 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
706 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
707
708 // Initialize C
709 c_thread_buf.Clear();
710
711 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
712 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
713
714 // main body
715 if constexpr(HasMainLoop)
716 {
717 index_t i = 0;
718
719 do
720 {
721 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
722
724
725 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
726
727 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
728
729 // block_sync_lds(); // moved into blockwise_gemm
730
731 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
732 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
733
734 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
735 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
736
737 ++i;
738 } while(i < (num_loop - 1));
739 }
740
741 // tail
742 {
744
745 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
746 }
747 }
748};
749
750// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
751template <>
753{
754};
755
756// TODO: deprecate as GridwiseGemmPipeline_Selector covers the functionality
757template <index_t NumPrefetch, LoopScheduler LoopSched>
759{
760 if constexpr(LoopSched == LoopScheduler::Default)
761 {
763 }
764 else if constexpr(LoopSched == LoopScheduler::Interwave)
765 {
767 }
768}
769
770} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr auto GridwiseGemmPipeline_v1_Selector()
Definition gridwise_gemm_pipeline_v1.hpp:758
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__device__ void block_sync_lds()
Definition synchronization.hpp:16
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:483
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_pipeline_v1.hpp:461
static constexpr auto I1
Definition gridwise_gemm_pipeline_v1.hpp:459
static constexpr auto I0
Definition gridwise_gemm_pipeline_v1.hpp:458
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:463
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:265
static constexpr auto I0
Definition gridwise_gemm_pipeline_v1.hpp:260
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_pipeline_v1.hpp:263
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:285
static constexpr auto I1
Definition gridwise_gemm_pipeline_v1.hpp:261
static constexpr auto I1
Definition gridwise_gemm_pipeline_v1.hpp:360
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:384
static constexpr auto I0
Definition gridwise_gemm_pipeline_v1.hpp:359
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:364
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_pipeline_v1.hpp:362
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:44
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:24
static constexpr auto I0
Definition gridwise_gemm_pipeline_v1.hpp:19
static constexpr auto I1
Definition gridwise_gemm_pipeline_v1.hpp:20
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_pipeline_v1.hpp:22
static constexpr auto I1
Definition gridwise_gemm_pipeline_v1.hpp:114
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:142
static constexpr auto I0
Definition gridwise_gemm_pipeline_v1.hpp:113
__host__ static __device__ constexpr bool IsSupported(index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:116
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:122
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const ScaleGridDesc &scale_grid_desc, const ScaleGridBuffer &scale_grid_buf, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:587
static constexpr auto I1
Definition gridwise_gemm_pipeline_v1.hpp:561
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_pipeline_v1.hpp:563
static constexpr auto I0
Definition gridwise_gemm_pipeline_v1.hpp:560
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:565
Definition gridwise_gemm_pipeline_v1.hpp:555
Definition gridwise_gemm_pipeline_v1.hpp:13
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_pipeline_v1.hpp:663
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:665
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v1.hpp:685
Definition gridwise_gemm_pipeline_v1.hpp:658