diff options
author | Dmitriy Sobolev <Dmitriy.Sobolev@intel.com> | 2021-02-20 19:30:57 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-20 19:30:57 +0300 |
commit | 7b97a68f5d0484208c3e07aa76febc67fe4343d1 (patch) | |
tree | 04ed1d3774767b646f95dfa080f7319fb7f3a594 | |
parent | Remove explicit default copy constructor in copy_constructible_value_holder (... (diff) | |
download | llvm-project-7b97a68f5d0484208c3e07aa76febc67fe4343d1.tar.gz llvm-project-7b97a68f5d0484208c3e07aa76febc67fe4343d1.tar.bz2 llvm-project-7b97a68f5d0484208c3e07aa76febc67fe4343d1.zip |
Avoid divergence of work items in the same SIMD before calling collectives (#129)
-rw-r--r-- | include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h | 82 |
1 files changed, 39 insertions, 43 deletions
diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h index 13d21dd380b4..dbb88446ee5a 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h @@ -182,22 +182,6 @@ __convert_to_ordered(_T __value) // radix sort: run-time device info functions //------------------------------------------------------------------------ -// get item id in sub-group -inline ::std::uint32_t -__get_sg_item_idx(const sycl::nd_item<1>& __idx) -{ - // technically sycl::id<1>::operator[int] returns a value that always fits in uint8_t (no overflow) - // and since 64-bit arithmetic is more expensive, the return type is set to ::std::uint32_t - return static_cast<::std::uint32_t>(__idx.get_sub_group().get_local_id()[0]); -} - -// get number of items in sub-group -inline ::std::uint32_t -__get_sg_item_num(const sycl::nd_item<1>& __idx) -{ - return __idx.get_sub_group().get_local_range()[0]; -} - // get rounded up result of (__number / __divisor) template <typename _T1, typename _T2> inline auto @@ -275,6 +259,20 @@ __get_bucket_value(_T __value, ::std::uint32_t __radix_iter) return (__value >> __bucket_offset) & __bucket_mask; } +template <typename _T, bool __is_comp_asc> +inline __enable_if_t<__is_comp_asc, _T> +__get_last_value() +{ + return ::std::numeric_limits<_T>::max(); +}; + +template <typename _T, bool __is_comp_asc> +inline __enable_if_t<!__is_comp_asc, _T> +__get_last_value() +{ + return ::std::numeric_limits<_T>::min(); +}; + //----------------------------------------------------------------------- // radix sort: count kernel (per iteration) //----------------------------------------------------------------------- @@ -517,36 +515,34 @@ __radix_sort_reorder_submit(_ExecutionPolicy&& __exec, ::std::size_t __segments, for (::std::size_t __block_idx = 0; __block_idx < __blocks_per_segment * __it_size; ++__block_idx) { const ::std::size_t __val_idx = __start_idx + __sg_size * __block_idx; - // TODO: profile how it affects performance - if (__val_idx < __inout_buf_size) + + // get value, convert it to ordered (in terms of bitness) + // if the index is outside of the range, use fake value which will not affect other values + __ordered_t<_InputT> __batch_val = __val_idx < __inout_buf_size + ? __convert_to_ordered(__input_rng[__val_idx]) + : __get_last_value<__ordered_t<_InputT>, __is_comp_asc>(); + + // get bit values in a certain bucket of a value + ::std::uint32_t __bucket_val = + __get_bucket_value<__radix_bits, __is_comp_asc>(__batch_val, __radix_iter); + + _OffsetT __new_offset_idx = 0; + // TODO: most computation-heavy code segment - find a better optimized solution + for (::std::uint32_t __radix_state_idx = 0; __radix_state_idx < __radix_states; ++__radix_state_idx) { - // get value, convert it to ordered (in terms of bitness) - __ordered_t<_InputT> __batch_val = __convert_to_ordered(__input_rng[__val_idx]); - // get bit values in a certain bucket of a value - ::std::uint32_t __bucket_val = - __get_bucket_value<__radix_bits, __is_comp_asc>(__batch_val, __radix_iter); - - _OffsetT __new_offset_idx = 0; - // TODO: most computation-heavy code segment - find a better optimized solution - for (::std::uint32_t __radix_state_idx = 0; __radix_state_idx < __radix_states; - ++__radix_state_idx) - { - ::std::uint32_t __is_current_bucket = __bucket_val == __radix_state_idx; - ::std::uint32_t __sg_item_offset = - sycl::ONEAPI::exclusive_scan(__self_item.get_sub_group(), __is_current_bucket, - sycl::ONEAPI::plus<::std::uint32_t>()); - - __new_offset_idx |= - __is_current_bucket * (__offset_arr[__radix_state_idx] + __sg_item_offset); - ::std::uint32_t __sg_total_offset = - sycl::ONEAPI::reduce(__self_item.get_sub_group(), __is_current_bucket, - sycl::ONEAPI::plus<::std::uint32_t>()); - - __offset_arr[__radix_state_idx] = __offset_arr[__radix_state_idx] + __sg_total_offset; - } + ::std::uint32_t __is_current_bucket = __bucket_val == __radix_state_idx; + ::std::uint32_t __sg_item_offset = sycl::ONEAPI::exclusive_scan( + __self_item.get_sub_group(), __is_current_bucket, sycl::ONEAPI::plus<::std::uint32_t>()); - __output_rng[__new_offset_idx] = __input_rng[__val_idx]; + __new_offset_idx |= __is_current_bucket * (__offset_arr[__radix_state_idx] + __sg_item_offset); + ::std::uint32_t __sg_total_offset = sycl::ONEAPI::reduce( + __self_item.get_sub_group(), __is_current_bucket, sycl::ONEAPI::plus<::std::uint32_t>()); + + __offset_arr[__radix_state_idx] = __offset_arr[__radix_state_idx] + __sg_total_offset; } + + if (__val_idx < __inout_buf_size) + __output_rng[__new_offset_idx] = __input_rng[__val_idx]; } }); }); |