32#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
36#if !defined(__HIPCC_RTC__)
40#define __hip_abort() \
43#define __hip_assert(COND)
45#define __hip_assert(COND) \
53namespace cooperative_groups {
75 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size =
static_cast<uint64_t
>(0),
76 uint64_t mask =
static_cast<uint64_t
>(0)) {
85 unsigned int meta_group_rank;
86 unsigned int meta_group_size;
89 struct _coalesced_info {
90 lane_mask member_mask;
92 struct _tiled_info tiled_info;
95 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
96 unsigned int tile_size);
97 friend class thread_block;
103 __CG_QUALIFIER__ uint32_t size()
const {
return _size; }
104 __CG_QUALIFIER__
unsigned int cg_type()
const {
return _type; }
106 __CG_QUALIFIER__ uint32_t thread_rank()
const;
108 __CG_QUALIFIER__
bool is_valid()
const;
110 __CG_QUALIFIER__
void sync()
const;
135class multi_grid_group :
public thread_group {
138 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
142 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
143 : thread_group(internal::cg_multi_grid, size) {}
148 __CG_QUALIFIER__ uint32_t num_grids() {
return internal::multi_grid::num_grids(); }
151 __CG_QUALIFIER__ uint32_t grid_rank() {
return internal::multi_grid::grid_rank(); }
152 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::multi_grid::thread_rank(); }
153 __CG_QUALIFIER__
bool is_valid()
const {
return internal::multi_grid::is_valid(); }
154 __CG_QUALIFIER__
void sync()
const { internal::multi_grid::sync(); }
166__CG_QUALIFIER__ multi_grid_group this_multi_grid() {
167 return multi_grid_group(internal::multi_grid::size());
178class grid_group :
public thread_group {
181 friend __CG_QUALIFIER__ grid_group this_grid();
185 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
188 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::grid::thread_rank(); }
189 __CG_QUALIFIER__
bool is_valid()
const {
return internal::grid::is_valid(); }
190 __CG_QUALIFIER__
void sync()
const { internal::grid::sync(); }
202__CG_QUALIFIER__ grid_group this_grid() {
return grid_group(internal::grid::size()); }
213class thread_block :
public thread_group {
216 friend __CG_QUALIFIER__ thread_block this_thread_block();
217 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
218 unsigned int tile_size);
219 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
220 unsigned int tile_size);
223 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
224 : thread_group(internal::cg_workgroup, size) {}
226 __CG_QUALIFIER__ thread_group new_tiled_group(
unsigned int tile_size)
const {
227 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
229 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
230 __hip_assert(
false &&
"invalid tile size")
233 thread_group tiledGroup = thread_group(internal::cg_tiled_group, tile_size);
234 tiledGroup.coalesced_info.tiled_info.size = tile_size;
235 tiledGroup.coalesced_info.tiled_info.is_tiled =
true;
236 tiledGroup.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
237 tiledGroup.coalesced_info.tiled_info.meta_group_size = (size() + tile_size - 1) / tile_size;
243 __CG_STATIC_QUALIFIER__ dim3 group_index() {
return internal::workgroup::group_index(); }
245 __CG_STATIC_QUALIFIER__ dim3 thread_index() {
return internal::workgroup::thread_index(); }
246 __CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
return internal::workgroup::thread_rank(); }
247 __CG_STATIC_QUALIFIER__ uint32_t size() {
return internal::workgroup::size(); }
248 __CG_STATIC_QUALIFIER__
bool is_valid() {
return internal::workgroup::is_valid(); }
249 __CG_STATIC_QUALIFIER__
void sync() { internal::workgroup::sync(); }
250 __CG_QUALIFIER__ dim3 group_dim() {
return internal::workgroup::block_dim(); }
262__CG_QUALIFIER__ thread_block this_thread_block() {
263 return thread_block(internal::workgroup::size());
274class tiled_group :
public thread_group {
276 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
277 unsigned int tile_size);
278 friend __CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
279 unsigned int tile_size);
281 __CG_QUALIFIER__ tiled_group new_tiled_group(
unsigned int tile_size)
const {
282 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
284 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
285 __hip_assert(
false &&
"invalid tile size")
288 if (size() <= tile_size) {
292 tiled_group tiledGroup = tiled_group(tile_size);
293 tiledGroup.coalesced_info.tiled_info.is_tiled =
true;
298 explicit __CG_QUALIFIER__ tiled_group(
unsigned int tileSize)
299 : thread_group(internal::cg_tiled_group, tileSize) {
300 coalesced_info.tiled_info.size = tileSize;
301 coalesced_info.tiled_info.is_tiled =
true;
305 __CG_QUALIFIER__
unsigned int size()
const {
return (coalesced_info.tiled_info.size); }
307 __CG_QUALIFIER__
unsigned int thread_rank()
const {
308 return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.size - 1));
311 __CG_QUALIFIER__
void sync()
const {
312 internal::tiled_group::sync();
323class coalesced_group :
public thread_group {
325 friend __CG_QUALIFIER__ coalesced_group coalesced_threads();
326 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size);
327 friend __CG_QUALIFIER__ coalesced_group tiled_partition(
const coalesced_group& parent,
unsigned int tile_size);
329 __CG_QUALIFIER__ coalesced_group new_tiled_group(
unsigned int tile_size)
const {
330 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
332 if (!tile_size || (tile_size > size()) || !pow2) {
333 return coalesced_group(0);
338 if (coalesced_info.tiled_info.is_tiled) {
339 unsigned int base_offset = (thread_rank() & (~(tile_size - 1)));
340 unsigned int masklength = min(
static_cast<unsigned int>(size()) - base_offset, tile_size);
341 lane_mask member_mask =
static_cast<lane_mask
>(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength);
343 member_mask <<= (__lane_id() & ~(tile_size - 1));
344 coalesced_group coalesced_tile = coalesced_group(member_mask);
345 coalesced_tile.coalesced_info.tiled_info.is_tiled =
true;
346 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
347 coalesced_tile.coalesced_info.tiled_info.meta_group_size = size() / tile_size;
348 return coalesced_tile;
352 lane_mask member_mask = 0;
353 unsigned int tile_rank = 0;
354 int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size;
356 for (
unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
357 lane_mask active = coalesced_info.member_mask & (1 << i);
360 if (lanes_to_skip <= 0 && tile_rank < tile_size) {
362 member_mask |= active;
368 coalesced_group coalesced_tile = coalesced_group(member_mask);
369 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
370 coalesced_tile.coalesced_info.tiled_info.meta_group_size =
371 (size() + tile_size - 1) / tile_size;
372 return coalesced_tile;
374 return coalesced_group(0);
379 explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask)
380 : thread_group(internal::cg_coalesced_group) {
381 coalesced_info.member_mask = member_mask;
382 coalesced_info.size = __popcll(coalesced_info.member_mask);
383 coalesced_info.tiled_info.is_tiled =
false;
384 coalesced_info.tiled_info.meta_group_rank = 0;
385 coalesced_info.tiled_info.meta_group_size = 1;
389 __CG_QUALIFIER__
unsigned int size()
const {
390 return coalesced_info.size;
393 __CG_QUALIFIER__
unsigned int thread_rank()
const {
394 return internal::coalesced_group::masked_bit_count(coalesced_info.member_mask);
397 __CG_QUALIFIER__
void sync()
const {
398 internal::coalesced_group::sync();
401 __CG_QUALIFIER__
unsigned int meta_group_rank()
const {
402 return coalesced_info.tiled_info.meta_group_rank;
405 __CG_QUALIFIER__
unsigned int meta_group_size()
const {
406 return coalesced_info.tiled_info.meta_group_size;
410 __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
411 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
413 srcRank = srcRank %
static_cast<int>(size());
415 int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank
416 : (__AMDGCN_WAVEFRONT_SIZE == 64) ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1))
417 : __fns32(coalesced_info.member_mask, 0, (srcRank + 1));
419 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
423 __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
424 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
430 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
431 return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
435 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
436 lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
439 lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
446 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
450 __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
451 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
457 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
458 return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
462 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
463 lane = __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
465 else if (__AMDGCN_WAVEFRONT_SIZE == 32) {
466 lane = __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
473 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
484__CG_QUALIFIER__ coalesced_group coalesced_threads() {
485 return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec());
493__CG_QUALIFIER__ uint32_t thread_group::thread_rank()
const {
494 switch (this->_type) {
495 case internal::cg_multi_grid: {
496 return (
static_cast<const multi_grid_group*
>(
this)->thread_rank());
498 case internal::cg_grid: {
499 return (
static_cast<const grid_group*
>(
this)->thread_rank());
501 case internal::cg_workgroup: {
502 return (
static_cast<const thread_block*
>(
this)->thread_rank());
504 case internal::cg_tiled_group: {
505 return (
static_cast<const tiled_group*
>(
this)->thread_rank());
507 case internal::cg_coalesced_group: {
508 return (
static_cast<const coalesced_group*
>(
this)->thread_rank());
511 __hip_assert(
false &&
"invalid cooperative group type")
521__CG_QUALIFIER__
bool thread_group::is_valid()
const {
522 switch (this->_type) {
523 case internal::cg_multi_grid: {
524 return (
static_cast<const multi_grid_group*
>(
this)->is_valid());
526 case internal::cg_grid: {
527 return (
static_cast<const grid_group*
>(
this)->is_valid());
529 case internal::cg_workgroup: {
530 return (
static_cast<const thread_block*
>(
this)->is_valid());
532 case internal::cg_tiled_group: {
533 return (
static_cast<const tiled_group*
>(
this)->is_valid());
535 case internal::cg_coalesced_group: {
536 return (
static_cast<const coalesced_group*
>(
this)->is_valid());
539 __hip_assert(
false &&
"invalid cooperative group type")
549__CG_QUALIFIER__
void thread_group::sync()
const {
550 switch (this->_type) {
551 case internal::cg_multi_grid: {
552 static_cast<const multi_grid_group*
>(
this)->sync();
555 case internal::cg_grid: {
556 static_cast<const grid_group*
>(
this)->sync();
559 case internal::cg_workgroup: {
560 static_cast<const thread_block*
>(
this)->sync();
563 case internal::cg_tiled_group: {
564 static_cast<const tiled_group*
>(
this)->sync();
567 case internal::cg_coalesced_group: {
568 static_cast<const coalesced_group*
>(
this)->sync();
572 __hip_assert(
false &&
"invalid cooperative group type")
583template <
class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy
const& g) {
return g.size(); }
590template <
class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy
const& g) {
591 return g.thread_rank();
599template <
class CGTy> __CG_QUALIFIER__
bool is_valid(CGTy
const& g) {
return g.is_valid(); }
606template <
class CGTy> __CG_QUALIFIER__
void sync(CGTy
const& g) { g.sync(); }
612template <
unsigned int tileSize>
class tile_base {
614 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
618 _CG_STATIC_CONST_DECL_
unsigned int thread_rank() {
619 return (internal::workgroup::thread_rank() & (numThreads - 1));
623 __CG_STATIC_QUALIFIER__
unsigned int size() {
return numThreads; }
630template <
unsigned int size>
class thread_block_tile_base :
public tile_base<size> {
631 static_assert(is_valid_tile_size<size>::value,
632 "Tile size is either not a power of 2 or greater than the wavefront size");
633 using tile_base<size>::numThreads;
636 __CG_STATIC_QUALIFIER__
void sync() {
637 internal::tiled_group::sync();
640 template <
class T> __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
641 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
642 return (__shfl(var, srcRank, numThreads));
645 template <
class T> __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
646 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
647 return (__shfl_down(var, lane_delta, numThreads));
650 template <
class T> __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
651 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
652 return (__shfl_up(var, lane_delta, numThreads));
655 template <
class T> __CG_QUALIFIER__ T shfl_xor(T var,
unsigned int laneMask)
const {
656 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
657 return (__shfl_xor(var, laneMask, numThreads));
662template <
unsigned int tileSize,
typename ParentCGTy>
663class parent_group_info {
667 __CG_STATIC_QUALIFIER__
unsigned int meta_group_rank() {
668 return ParentCGTy::thread_rank() / tileSize;
672 __CG_STATIC_QUALIFIER__
unsigned int meta_group_size() {
673 return (ParentCGTy::size() + tileSize - 1) / tileSize;
683template <
unsigned int tileSize,
class ParentCGTy>
684class thread_block_tile_type :
public thread_block_tile_base<tileSize>,
686 public parent_group_info<tileSize, ParentCGTy> {
687 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
689 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
690 coalesced_info.tiled_info.size = numThreads;
691 coalesced_info.tiled_info.is_tiled =
true;
696template <
unsigned int tileSize>
697class thread_block_tile_type<tileSize, void> :
public thread_block_tile_base<tileSize>,
700 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
702 typedef thread_block_tile_base<numThreads> tbtBase;
706 __CG_QUALIFIER__ thread_block_tile_type(
unsigned int meta_group_rank,
unsigned int meta_group_size)
707 : tiled_group(numThreads) {
708 coalesced_info.tiled_info.size = numThreads;
709 coalesced_info.tiled_info.is_tiled =
true;
710 coalesced_info.tiled_info.meta_group_rank = meta_group_rank;
711 coalesced_info.tiled_info.meta_group_size = meta_group_size;
717 using tbtBase::thread_rank;
719 __CG_QUALIFIER__
unsigned int meta_group_rank()
const {
720 return coalesced_info.tiled_info.meta_group_rank;
723 __CG_QUALIFIER__
unsigned int meta_group_size()
const {
724 return coalesced_info.tiled_info.meta_group_size;
739__CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size) {
740 if (parent.cg_type() == internal::cg_tiled_group) {
741 const tiled_group* cg =
static_cast<const tiled_group*
>(&parent);
742 return cg->new_tiled_group(tile_size);
744 else if(parent.cg_type() == internal::cg_coalesced_group) {
745 const coalesced_group* cg =
static_cast<const coalesced_group*
>(&parent);
746 return cg->new_tiled_group(tile_size);
749 const thread_block* tb =
static_cast<const thread_block*
>(&parent);
750 return tb->new_tiled_group(tile_size);
755__CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
unsigned int tile_size) {
756 return (parent.new_tiled_group(tile_size));
759__CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
unsigned int tile_size) {
760 return (parent.new_tiled_group(tile_size));
764__CG_QUALIFIER__ coalesced_group tiled_partition(
const coalesced_group& parent,
unsigned int tile_size) {
765 return (parent.new_tiled_group(tile_size));
768template <
unsigned int size,
class ParentCGTy>
class thread_block_tile;
771template <
unsigned int size,
class ParentCGTy>
class thread_block_tile_internal;
773template <
unsigned int size,
class ParentCGTy>
774class thread_block_tile_internal :
public thread_block_tile_type<size, ParentCGTy> {
776 template <
unsigned int tbtSize,
class tbtParentT>
777 __CG_QUALIFIER__ thread_block_tile_internal(
778 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
779 : thread_block_tile_type<size, ParentCGTy>(g.meta_group_rank(), g.meta_group_size()) {}
781 __CG_QUALIFIER__ thread_block_tile_internal(
const thread_block& g)
782 : thread_block_tile_type<size, ParentCGTy>() {}
786template <
unsigned int size,
class ParentCGTy>
787class thread_block_tile :
public impl::thread_block_tile_internal<size, ParentCGTy> {
789 __CG_QUALIFIER__ thread_block_tile(
const ParentCGTy& g)
790 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
793 __CG_QUALIFIER__
operator thread_block_tile<size, void>()
const {
794 return thread_block_tile<size, void>(*
this);
799template <
unsigned int size>
800class thread_block_tile<size, void> :
public impl::thread_block_tile_internal<size, void> {
801 template <
unsigned int,
class ParentCGTy>
friend class thread_block_tile;
805 template <
class ParentCGTy>
806 __CG_QUALIFIER__ thread_block_tile(
const thread_block_tile<size, ParentCGTy>& g)
807 : impl::thread_block_tile_internal<size, void>(g) {}
810template <
unsigned int size,
class ParentCGTy =
void>
class thread_block_tile;
813template <
unsigned int size,
class ParentCGTy>
struct tiled_partition_internal;
815template <
unsigned int size>
816struct tiled_partition_internal<size, thread_block> :
public thread_block_tile<size, thread_block> {
817 __CG_QUALIFIER__ tiled_partition_internal(
const thread_block& g)
818 : thread_block_tile<size, thread_block>(g) {}
828template <
unsigned int size,
class ParentCGTy>
829__CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(
const ParentCGTy& g) {
830 static_assert(is_valid_tile_size<size>::value,
831 "Tiled partition with size > wavefront size. Currently not supported ");
832 return impl::tiled_partition_internal<size, ParentCGTy>(g);
Device side implementation of cooperative group feature.