aboutsummaryrefslogtreecommitdiff
blob: 4895be8cec491334518006145ae8e9989c199717 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
// -*- C++ -*-
//===-- unseq_backend_sycl.h ----------------------------------------------===//
//
// Copyright (C) Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// This file incorporates work covered by the following copyright and permission
// notice:
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
//
//===----------------------------------------------------------------------===//

//!!! NOTE: This file should be included under the macro _ONEDPL_BACKEND_SYCL
#ifndef _ONEDPL_unseq_backend_sycl_H
#define _ONEDPL_unseq_backend_sycl_H

#include <type_traits>

#include "../../onedpl_config.h"
#include "../../utils.h"

#include <CL/sycl.hpp>

namespace oneapi
{
namespace dpl
{
namespace unseq_backend
{
// helpers to encapsulate void and other types
template <typename _Tp>
using void_type = typename ::std::enable_if<::std::is_void<_Tp>::value, _Tp>::type;
template <typename _Tp>
using non_void_type = typename ::std::enable_if<!::std::is_void<_Tp>::value, _Tp>::type;

// a way to get value_type from both accessors and USM that is needed for transform_init
template <typename _Unknown>
struct __accessor_traits_impl
{
};

template <typename _T, int _Dim, sycl::access::mode _AccMode, sycl::access::target _AccTarget,
          sycl::access::placeholder _Placeholder>
struct __accessor_traits_impl<sycl::accessor<_T, _Dim, _AccMode, _AccTarget, _Placeholder>>
{
    using value_type = typename sycl::accessor<_T, _Dim, _AccMode, _AccTarget, _Placeholder>::value_type;
};

template <typename _RawArrayValueType>
struct __accessor_traits_impl<_RawArrayValueType*>
{
    using value_type = _RawArrayValueType;
};

template <typename _Unknown>
using __accessor_traits = __accessor_traits_impl<typename ::std::decay<_Unknown>::type>;

template <typename _ExecutionPolicy, typename _F>
struct walk_n
{
    _F __f;

    template <typename _ItemId, typename... _Ranges>
    auto
    operator()(const _ItemId __idx, _Ranges&&... __rngs) const -> decltype(__f(__rngs[__idx]...))
    {
        return __f(__rngs[__idx]...);
    }
};

// If read accessor returns temporary value then __no_op returns lvalue reference to it.
// After temporary value destroying it will be a reference on invalid object.
// So let's don't call functor in case of __no_op
template <typename _ExecutionPolicy>
struct walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>
{
    oneapi::dpl::__internal::__no_op __f;

    template <typename _ItemId, typename _Range>
    auto
    operator()(const _ItemId __idx, _Range&& __rng) const -> decltype(__rng[__idx])
    {
        return __rng[__idx];
    }
};

//------------------------------------------------------------------------
// walk_adjacent_difference
//------------------------------------------------------------------------

template <typename _ExecutionPolicy, typename _F>
struct walk_adjacent_difference
{
    _F __f;

    template <typename _ItemId, typename _Acc1, typename _Acc2>
    void
    operator()(const _ItemId __idx, const _Acc1& _acc_src, _Acc2& _acc_dst) const
    {
        using ::std::get;

        // just copy an element if it is the first one
        if (__idx == 0)
            _acc_dst[__idx] = _acc_src[__idx];
        else
            __f(_acc_src[__idx + (-1)], _acc_src[__idx], _acc_dst[__idx]);
    }
};

//------------------------------------------------------------------------
// transform_reduce
//------------------------------------------------------------------------

template <typename _ExecutionPolicy, typename _Operation1, typename _Operation2>
struct transform_init
{
    _Operation1 __binary_op;
    _Operation2 __unary_op;

    template <typename _NDItemId, typename _Size, typename _AccLocal, typename... _Acc>
    void
    operator()(const _NDItemId __item, _Size __n, ::std::size_t __iters_per_work_item, ::std::size_t __global_id,
               _AccLocal& __local_mem, const _Acc&... __acc) const
    {
        using _Tp = typename __accessor_traits<_AccLocal>::value_type;
        ::std::size_t __adjusted_global_id = __iters_per_work_item * __global_id;
        if (__adjusted_global_id < __n)
        {
            ::std::size_t __local_id = __item.get_local_id(0);
            _Tp __res = __unary_op(__adjusted_global_id, __acc...);
            // Add neighbour to the current __local_mem
            for (::std::size_t __i = 1; __i < __iters_per_work_item; ++__i)
            {
                ::std::size_t __shifted_id = __adjusted_global_id + __i;
                if (__shifted_id < __n)
                    __res = __binary_op(__res, __unary_op(__shifted_id, __acc...));
            }
            __local_mem[__local_id] = __res;
        }
    }
};

// Reduce on local memory
template <typename _ExecutionPolicy, typename _BinaryOperation1, typename _Tp>
struct reduce
{
    _BinaryOperation1 __bin_op1;

    template <typename _NDItemId, typename _GlobalIdx, typename _Size, typename _AccLocal>
    _Tp
    operator()(const _NDItemId __item_id, const _GlobalIdx __global_idx, const _Size __n, _AccLocal& __local_mem) const
    {
        auto __local_idx = __item_id.get_local_id(0);
        auto __group_size = __item_id.get_local_range().size();

        auto __k = 1;
        do
        {
            __item_id.barrier(sycl::access::fence_space::local_space);
            if (__local_idx % (2 * __k) == 0 && __local_idx + __k < __group_size && __global_idx < __n &&
                __global_idx + __k < __n)
            {
                __local_mem[__local_idx] = __bin_op1(__local_mem[__local_idx], __local_mem[__local_idx + __k]);
            }
            __k *= 2;
        } while (__k < __group_size);
        return __local_mem[__local_idx];
    }
};

// Matchers for early_exit_or and early_exit_find

template <typename _ExecutionPolicy, typename _Pred>
struct single_match_pred_by_idx
{
    _Pred __pred;

    template <typename _Idx, typename _Acc>
    bool
    operator()(const _Idx __shifted_idx, _Acc& __acc) const
    {
        return __pred(__shifted_idx, __acc);
    }
};

template <typename _ExecutionPolicy, typename _Pred>
struct single_match_pred : single_match_pred_by_idx<_ExecutionPolicy, walk_n<_ExecutionPolicy, _Pred>>
{
    single_match_pred(_Pred __p) : single_match_pred_by_idx<_ExecutionPolicy, walk_n<_ExecutionPolicy, _Pred>>{__p} {}
};

template <typename _ExecutionPolicy, typename _Pred>
struct multiple_match_pred
{
    _Pred __pred;

    template <typename _Idx, typename _Acc1, typename _Acc2>
    bool
    operator()(const _Idx __shifted_idx, _Acc1& __acc, const _Acc2& __s_acc) const
    {
        // if __shifted_idx > __n - __s_n then subrange bigger than original range.
        // So the second range is not a subrange of the first range
        auto __n = __acc.size();
        auto __s_n = __s_acc.size();
        bool __result = __shifted_idx <= __n - __s_n;
        const auto __total_shift = __shifted_idx;

        using _Size2 = decltype(__s_n);
        for (_Size2 __ii = 0; __ii < __s_n && __result; ++__ii)
            __result = __pred(__acc[__total_shift + __ii], __s_acc[__ii]);

        return __result;
    }
};

template <typename _ExecutionPolicy, typename _Pred, typename _Tp, typename _Size>
struct n_elem_match_pred
{
    _Pred __pred;
    _Tp __value;
    _Size __count;

    template <typename _Idx, typename _Acc>
    bool
    operator()(const _Idx __shifted_idx, const _Acc& __acc) const
    {

        bool __result = ((__shifted_idx + __count) <= __acc.size());
        const auto __total_shift = __shifted_idx;

        for (auto __idx = 0; __idx < __count && __result; ++__idx)
            __result = __pred(__acc[__total_shift + __idx], __value);

        return __result;
    }
};

template <typename _ExecutionPolicy, typename _Pred>
struct first_match_pred
{
    _Pred __pred;

    template <typename _Idx, typename _Acc1, typename _Acc2>
    bool
    operator()(const _Idx __shifted_idx, const _Acc1& __acc, const _Acc2& __s_acc) const
    {

        // assert: __shifted_idx < __n
        const auto __elem = __acc[__shifted_idx];
        auto __s_n = __s_acc.size();

        for (auto __idx = 0; __idx < __s_n; ++__idx)
            if (__pred(__elem, __s_acc[__idx]))
                return true;

        return false;
    }
};

//------------------------------------------------------------------------
// scan
//------------------------------------------------------------------------

// mask assigner for tuples
template <::std::size_t N>
struct __mask_assigner
{
    template <typename _Acc, typename _OutAcc, typename _OutIdx, typename _InAcc, typename _InIdx>
    void
    operator()(_Acc& __acc, _OutAcc&, const _OutIdx __out_idx, const _InAcc& __in_acc, const _InIdx __in_idx) const
    {
        using ::std::get;
        get<N>(__acc[__out_idx]) = __in_acc[__in_idx];
    }
};

// data assigners and accessors for transform_scan
struct __scan_assigner
{
    template <typename _OutAcc, typename _OutIdx, typename _InAcc, typename _InIdx>
    void
    operator()(_OutAcc& __out_acc, const _OutIdx __out_idx, const _InAcc& __in_acc, _InIdx __in_idx) const
    {
        __out_acc[__out_idx] = __in_acc[__in_idx];
    }

    template <typename _Acc, typename _OutAcc, typename _OutIdx, typename _InAcc, typename _InIdx>
    void
    operator()(_Acc&, _OutAcc& __out_acc, const _OutIdx __out_idx, const _InAcc& __in_acc, _InIdx __in_idx) const
    {
        __out_acc[__out_idx] = __in_acc[__in_idx];
    }
};

struct __scan_no_assign
{
    template <typename _OutAcc, typename _OutIdx, typename _InAcc, typename _InIdx>
    void
    operator()(_OutAcc&, const _OutIdx, const _InAcc&, const _InIdx) const
    {
    }
};

// types of initial value for parallel_transform_scan
template <typename _InitType>
struct __scan_init
{
    _InitType __value;
    using __value_type = _InitType;
};

template <typename _InitType>
struct __scan_no_init
{
    using __value_type = _InitType;
};

// structure for the correct processing of the initial scan element
template <typename _InitType>
struct __scan_init_processing
{
    template <typename _Tp>
    void
    operator()(const __scan_init<_InitType>& __init, _Tp&& __value) const
    {
        __value = __init.__value;
    }
    template <typename _Tp>
    void
    operator()(const __scan_no_init<_InitType>&, _Tp&&) const
    {
    }

    template <typename _Tp, typename _BinaryOp>
    void
    operator()(const __scan_init<_InitType>& __init, _Tp&& __value, _BinaryOp __bin_op) const
    {
        __value = __bin_op(__init.__value, __value);
    }
    template <typename _Tp, typename _BinaryOp>
    void
    operator()(const __scan_no_init<_InitType>&, _Tp&&, _BinaryOp) const
    {
    }
};

// functors for scan
template <typename _BinaryOp, typename _Inclusive, ::std::size_t N>
struct __copy_by_mask
{
    _BinaryOp __binary_op;

    template <typename _Item, typename _OutAcc, typename _InAcc, typename _WgSumsAcc, typename _Size,
              typename _SizePerWg>
    void
    operator()(_Item __item, _OutAcc& __out_acc, const _InAcc& __in_acc, const _WgSumsAcc& __wg_sums_acc, _Size __n,
               _SizePerWg __size_per_wg) const
    {
        using ::std::get;
        auto __item_idx = __item.get_linear_id();
        if (__item_idx < __n && get<N>(__in_acc[__item_idx]))
        {
            auto __out_idx = get<N>(__in_acc[__item_idx]) - 1;

            using __tuple_type = typename __internal::__get_tuple_type<
                typename ::std::decay<decltype(get<0>(__in_acc[__item_idx]))>::type,
                typename ::std::decay<decltype(__out_acc[__out_idx])>::type>::__type;

            // calculation of position for copy
            if (__item_idx >= __size_per_wg)
            {
                auto __wg_sums_idx = __item_idx / __size_per_wg - 1;
                __out_idx = __binary_op(__out_idx, __wg_sums_acc[__wg_sums_idx]);
            }
            if (__item_idx % __size_per_wg == 0 || (get<N>(__in_acc[__item_idx]) != get<N>(__in_acc[__item_idx - 1])))
                // If we work with tuples we might have a situation when internal tuple is assigned to ::std::tuple
                // (e.g. returned by user-provided lambda).
                // For internal::tuple<T...> we have a conversion operator to ::std::tuple<T..>. The problem here
                // is that the types of these 2 tuples may be different but still convertible to each other.
                // Technically this should be solved by adding to internal::tuple<T..> an additional conversion
                // operator to ::std::tuple<U...>, but for some reason this doesn't work(conversion from
                // ::std::tuple<T...> to ::std::tuple<U..> fails). What does work is the explicit cast below:
                // for internal::tuple<T..> we define a field that provides a corresponding ::std::tuple<T..>
                // with matching types. We get this type(see __typle_type definition above) and use it
                // for static cast to explicitly convert internal::tuple<T..> -> ::std::tuple<T..>.
                // Now we have the following assignment ::std::tuple<U..> = ::std::tuple<T..> which works as expected.
                // NOTE: we only need this explicit conversion when we have internal::tuple and
                // ::std::tuple as operands, in all the other cases this is not necessary and no conversion
                // is performed(i.e. __typle_type is the same type as its operand).
                __out_acc[__out_idx] = static_cast<__tuple_type>(get<0>(__in_acc[__item_idx]));
        }
    }
};

template <typename _BinaryOp, typename _Inclusive>
struct __partition_by_mask
{
    _BinaryOp __binary_op;

    template <typename _Item, typename _OutAcc, typename _InAcc, typename _WgSumsAcc, typename _Size,
              typename _SizePerWg>
    void
    operator()(_Item __item, _OutAcc& __out_acc, const _InAcc& __in_acc, const _WgSumsAcc& __wg_sums_acc, _Size __n,
               _SizePerWg __size_per_wg) const
    {
        auto __item_idx = __item.get_linear_id();
        if (__item_idx < __n)
        {
            using ::std::get;
            using __in_type = typename ::std::decay<decltype(get<0>(__in_acc[__item_idx]))>::type;
            auto __wg_sums_idx = __item_idx / __size_per_wg;
            bool __not_first_wg = __item_idx >= __size_per_wg;
            if (get<1>(__in_acc[__item_idx]) &&
                (__item_idx % __size_per_wg == 0 || get<1>(__in_acc[__item_idx]) != get<1>(__in_acc[__item_idx - 1])))
            {
                auto __out_idx = get<1>(__in_acc[__item_idx]) - 1;
                using __tuple_type = typename __internal::__get_tuple_type<
                    __in_type, typename ::std::decay<decltype(get<0>(__out_acc[__out_idx]))>::type>::__type;

                if (__not_first_wg)
                    __out_idx = __binary_op(__out_idx, __wg_sums_acc[__wg_sums_idx - 1]);
                get<0>(__out_acc[__out_idx]) = static_cast<__tuple_type>(get<0>(__in_acc[__item_idx]));
            }
            else
            {
                auto __out_idx = __item_idx - get<1>(__in_acc[__item_idx]);
                using __tuple_type = typename __internal::__get_tuple_type<
                    __in_type, typename ::std::decay<decltype(get<1>(__out_acc[__out_idx]))>::type>::__type;

                if (__not_first_wg)
                    __out_idx -= __wg_sums_acc[__wg_sums_idx - 1];
                get<1>(__out_acc[__out_idx]) = static_cast<__tuple_type>(get<0>(__in_acc[__item_idx]));
            }
        }
    }
};

template <typename _Inclusive, typename _BinaryOp>
struct __global_scan_functor
{
    _BinaryOp __binary_op;

    template <typename _Item, typename _OutAcc, typename _InAcc, typename _WgSumsAcc, typename _Size,
              typename _SizePerWg>
    void
    operator()(_Item __item, _OutAcc& __out_acc, const _InAcc&, const _WgSumsAcc& __wg_sums_acc, _Size __n,
               _SizePerWg __size_per_wg) const
    {
        constexpr auto __shift = _Inclusive{} ? 0 : 1;
        auto __item_idx = __item.get_linear_id();
        // skip the first group scanned locally
        if (__item_idx >= __size_per_wg && __item_idx < __n)
        {
            auto __wg_sums_idx = __item_idx / __size_per_wg - 1;
            // an initial value preceeds the first group for the exclusive scan
            __item_idx += __shift;
            auto __bin_op_result = __binary_op(__wg_sums_acc[__wg_sums_idx], __out_acc[__item_idx]);
            using __out_type = typename ::std::decay<decltype(__out_acc[__item_idx])>::type;
            using __in_type = typename ::std::decay<decltype(__bin_op_result)>::type;
            __out_acc[__item_idx] =
                static_cast<typename __internal::__get_tuple_type<__in_type, __out_type>::__type>(__bin_op_result);
        }
    }
};

template <typename _Inclusive, typename _ExecutionPolicy, typename _BinaryOperation, typename _UnaryOp,
          typename _WgAssigner, typename _GlobalAssigner, typename _DataAccessor, typename _InitType>
struct __scan
{
    _BinaryOperation __bin_op;
    _UnaryOp __unary_op;
    _WgAssigner __wg_assigner;
    _GlobalAssigner __gl_assigner;
    _DataAccessor __data_acc;

    template <typename _NDItemId, typename _Size, typename _AccLocal, typename _InAcc, typename _OutAcc,
              typename _WGSumsAcc, typename _SizePerWG, typename _WGSize, typename _ItersPerWG>
    void operator()(_NDItemId __item, _Size __n, _AccLocal& __local_acc, const _InAcc& __acc, _OutAcc& __out_acc,
                    _WGSumsAcc& __wg_sums_acc, _SizePerWG __size_per_wg, _WGSize __wgroup_size,
                    _ItersPerWG __iters_per_wg,
                    _InitType __init = __scan_no_init<typename _InitType::__value_type>{}) const
    {
        using _Tp = typename _InitType::__value_type;
        auto __group_id = __item.get_group(0);
        auto __global_id = __item.get_global_id(0);
        auto __local_id = __item.get_local_id(0);
        auto __use_init = __scan_init_processing<_Tp>{};

        auto __shift = 0;
        __internal::__invoke_if_not(_Inclusive{}, [&]() {
            __shift = 1;
            if (__global_id == 0)
                __use_init(__init, __out_acc[__global_id]);
        });

        auto __adjusted_global_id = __local_id + __size_per_wg * __group_id;
        auto __adder = __local_acc[0];
        for (auto __iter = 0; __iter < __iters_per_wg; ++__iter, __adjusted_global_id += __wgroup_size)
        {
            if (__adjusted_global_id < __n)
            {
                // get input data
                __local_acc[__local_id] = __data_acc(__adjusted_global_id, __acc);
                // apply unary op
                __local_acc[__local_id] = __unary_op(__local_id, __local_acc);
            }
            if (__local_id == 0 && __iter > 0)
                __local_acc[0] = __bin_op(__adder, __local_acc[0]);
            else if (__global_id == 0)
                __use_init(__init, __local_acc[__global_id], __bin_op);

            // 1. reduce
            auto __k = 1;
            do
            {
                __item.barrier(sycl::access::fence_space::local_space);
                if (__local_id % (2 * __k) == 0 && __local_id + __k < __wgroup_size && __adjusted_global_id + __k < __n)
                {
                    __local_acc[__local_id + 2 * __k - 1] =
                        __bin_op(__local_acc[__local_id + __k - 1], __local_acc[__local_id + 2 * __k - 1]);
                }
                __k *= 2;
            } while (__k < __wgroup_size);
            __item.barrier(sycl::access::fence_space::local_space);

            // 2. scan
            auto __partial_sums = __local_acc[__local_id];
            __k = 2;
            do
            {
                auto __shifted_local_id = __local_id - __local_id % __k - 1;
                if (__shifted_local_id >= 0 && __adjusted_global_id < __n && __local_id % (2 * __k) >= __k &&
                    __local_id % (2 * __k) < 2 * __k - 1)
                {
                    __partial_sums = __bin_op(__local_acc[__shifted_local_id], __partial_sums);
                }
                __k *= 2;
            } while (__k < __wgroup_size);
            __item.barrier(sycl::access::fence_space::local_space);
            __local_acc[__local_id] = __partial_sums;
            __item.barrier(sycl::access::fence_space::local_space);
            __adder = __local_acc[__wgroup_size - 1];
            __item.barrier(sycl::access::fence_space::local_space);

            if (__adjusted_global_id + __shift < __n)
                __gl_assigner(__acc, __out_acc, __adjusted_global_id + __shift, __local_acc, __local_id);

            if (__adjusted_global_id == __n - 1)
                __wg_assigner(__wg_sums_acc, __group_id, __local_acc, __local_id);
        }

        if (__local_id == __wgroup_size - 1 && __adjusted_global_id - __wgroup_size < __n)
            __wg_assigner(__wg_sums_acc, __group_id, __local_acc, __local_id);
    }
};

#if _USE_GROUP_ALGOS

template <typename _Tp, typename = typename ::std::enable_if<::std::is_arithmetic<_Tp>::value, void>::type>
using __enable_if_arithmetic = _Tp;

template <typename _InitType,
          typename =
              typename ::std::enable_if<::std::is_arithmetic<typename _InitType::__value_type>::value, void>::type>
using __enable_if_arithmetic_init_type = _InitType;

// Reduce on local memory with subgroups
template <typename _ExecutionPolicy, typename _Tp>
struct reduce<_ExecutionPolicy, ::std::plus<_Tp>, __enable_if_arithmetic<_Tp>>
{
    ::std::plus<_Tp> __reduce;

    template <typename _NDItem, typename _GlobalIdx, typename _GlobalSize, typename _LocalAcc>
    _Tp
    operator()(_NDItem __item, _GlobalIdx __global_id, _GlobalSize __n, _LocalAcc __local_mem) const
    {
        auto __local_id = __item.get_local_id(0);
        if (__global_id >= __n)
        {
            // Fill the rest of local buffer with 0s so each of inclusive_scan method could correctly work
            // for each work-item in sub-group
            __local_mem[__local_id] = 0;
        }
        __item.barrier(sycl::access::fence_space::local_space);
        return sycl::ONEAPI::reduce(__item.get_group(), __local_mem[__local_id], sycl::ONEAPI::plus<_Tp>());
    }
};

template <typename _Inclusive, typename _ExecutionPolicy, typename _UnaryOp, typename _WgAssigner,
          typename _GlobalAssigner, typename _DataAccessor, typename _InitType>
struct __scan<_Inclusive, _ExecutionPolicy, ::std::plus<typename _InitType::__value_type>, _UnaryOp, _WgAssigner,
              _GlobalAssigner, _DataAccessor, __enable_if_arithmetic_init_type<_InitType>>
{
    using _Tp = typename _InitType::__value_type;
    sycl::ONEAPI::plus<_Tp> __bin_op;
    _UnaryOp __unary_op;
    _WgAssigner __wg_assigner;
    _GlobalAssigner __gl_assigner;
    _DataAccessor __data_acc;

    template <typename _NDItemId, typename _Size, typename _AccLocal, typename _InAcc, typename _OutAcc,
              typename _WGSumsAcc, typename _SizePerWG, typename _WGSize, typename _ItersPerWG>
    void operator()(_NDItemId __item, _Size __n, _AccLocal& __local_acc, const _InAcc& __acc, _OutAcc& __out_acc,
                    const _WGSumsAcc& __wg_sums_acc, _SizePerWG __size_per_wg, _WGSize __wgroup_size,
                    _ItersPerWG __iters_per_wg, _InitType __init = __scan_no_init<_Tp>{}) const
    {
        auto __group_id = __item.get_group(0);
        auto __global_id = __item.get_global_id(0);
        auto __local_id = __item.get_local_id(0);
        auto __use_init = __scan_init_processing<_Tp>{};

        auto __shift = 0;
        __internal::__invoke_if_not(_Inclusive{}, [&]() {
            __shift = 1;
            if (__global_id == 0)
                __use_init(__init, __out_acc[__global_id]);
        });

        auto __adjusted_global_id = __local_id + __size_per_wg * __group_id;
        auto __adder = __local_acc[0];
        for (auto __iter = 0; __iter < __iters_per_wg; ++__iter, __adjusted_global_id += __wgroup_size)
        {
            if (__adjusted_global_id < __n)
                __local_acc[__local_id] = __data_acc(__adjusted_global_id, __acc);
            else
                __local_acc[__local_id] = _Tp{0}; // for plus only
            __item.barrier(sycl::access::fence_space::local_space);

            // the result of __unary_op must be convertible to _Tp
            _Tp __old_value = __unary_op(__local_id, __local_acc);
            if (__iter > 0 && __local_id == 0)
                __old_value = __bin_op(__adder, __old_value);
            else if (__adjusted_global_id == 0)
                __use_init(__init, __old_value, __bin_op);
            __item.barrier(sycl::access::fence_space::local_space);

            __local_acc[__local_id] = sycl::ONEAPI::inclusive_scan(__item.get_group(), __old_value, __bin_op);
            __item.barrier(sycl::access::fence_space::local_space);

            __adder = __local_acc[__wgroup_size - 1];
            __item.barrier(sycl::access::fence_space::local_space);

            if (__adjusted_global_id + __shift < __n)
                __gl_assigner(__acc, __out_acc, __adjusted_global_id + __shift, __local_acc, __local_id);

            if (__adjusted_global_id == __n - 1)
                __wg_assigner(__wg_sums_acc, __group_id, __local_acc, __local_id);
        }

        if (__local_id == __wgroup_size - 1 && __adjusted_global_id - __wgroup_size < __n)
            __wg_assigner(__wg_sums_acc, __group_id, __local_acc, __local_id);
    }
};

#endif

//------------------------------------------------------------------------
// __brick_includes
//------------------------------------------------------------------------

template <typename _ExecutionPolicy, typename _Compare, typename _Size1, typename _Size2>
struct __brick_includes
{
    _Compare __comp;
    _Size1 __na;
    _Size2 __nb;

    __brick_includes(_Compare __c, _Size1 __n1, _Size2 __n2) : __comp(__c), __na(__n1), __nb(__n2) {}

    template <typename _ItemId, typename _Acc1, typename _Acc2>
    bool
    operator()(_ItemId __idx, const _Acc1& __b_acc, const _Acc2& __a_acc) const
    {
        using ::std::get;

        auto __a = __a_acc;
        auto __b = __b_acc;

        auto __a_beg = _Size1(0);
        auto __a_end = __na;

        auto __b_beg = _Size2(0);
        auto __b_end = __nb;

        // testing __comp(*__first2, *__first1) or __comp(*(__last1 - 1), *(__last2 - 1))
        if ((__idx == 0 && __comp(__b[__b_beg + 0], __a[__a_beg + 0])) ||
            (__idx == __nb - 1 && __comp(__a[__a_end - 1], __b[__b_end - 1])))
            return true; //__a doesn't include __b

        const auto __idx_b = __b_beg + __idx;
        const auto __val_b = __b[__idx_b];
        auto __res = __internal::__pstl_lower_bound(__a, __a_beg, __a_end, __val_b, __comp);

        // {a} < {b} or __val_b != __a[__res]
        if (__res == __a_end || __comp(__val_b, __a[__res]))
            return true; //__a doesn't include __b

        auto __val_a = __a[__res];

        //searching number of duplication
        const auto __count_a = __internal::__pstl_right_bound(__a, __res, __a_end, __val_a, __comp) - __res + __res -
                               __internal::__pstl_left_bound(__a, __a_beg, __res, __val_a, __comp);

        const auto __count_b = __internal::__pstl_right_bound(__b, _Size2(__idx_b), __b_end, __val_b, __comp) -
                               __idx_b + __idx_b -
                               __internal::__pstl_left_bound(__b, __b_beg, _Size2(__idx_b), __val_b, __comp);

        return __count_b > __count_a; //false means __a includes __b
    }
};

//------------------------------------------------------------------------
// reverse
//------------------------------------------------------------------------
template <typename _Size>
struct __reverse_functor
{
    _Size __size;
    template <typename _Idx, typename _Accessor>
    void
    operator()(const _Idx __idx, _Accessor& __acc) const
    {
        ::std::swap(__acc[__idx], __acc[__size - __idx - 1]);
    }
};

//------------------------------------------------------------------------
// reverse_copy
//------------------------------------------------------------------------
template <typename _Size>
struct __reverse_copy
{
    _Size __size;
    template <typename _Idx, typename _AccessorSrc, typename _AccessorDst>
    void
    operator()(const _Idx __idx, const _AccessorSrc& __acc1, _AccessorDst& __acc2) const
    {
        __acc2[__idx] = __acc1[__size - __idx - 1];
    }
};

//------------------------------------------------------------------------
// rotate_copy
//------------------------------------------------------------------------
template <typename _Size>
struct __rotate_copy
{
    _Size __size;
    _Size __shift;
    template <typename _Idx, typename _AccessorSrc, typename _AccessorDst>
    void
    operator()(const _Idx __idx, const _AccessorSrc& __acc1, _AccessorDst& __acc2) const
    {
        __acc2[__idx] = __acc1[(__shift + __idx) % __size];
    }
};

//------------------------------------------------------------------------
// brick_set_op for difference and intersection operations
//------------------------------------------------------------------------
struct _IntersectionTag : public ::std::false_type
{
};
struct _DifferenceTag : public ::std::true_type
{
};

template <typename _ExecutionPolicy, typename _Compare, typename _Size1, typename _Size2, typename _IsOpDifference>
class __brick_set_op
{
    _Compare __comp;
    _Size1 __na;
    _Size2 __nb;

  public:
    __brick_set_op(_Compare __c, _Size1 __n1, _Size2 __n2) : __comp(__c), __na(__n1), __nb(__n2) {}

    template <typename _ItemId, typename _Acc>
    bool
    operator()(_ItemId __idx, const _Acc& __inout_acc) const
    {
        using ::std::get;
        auto __a = get<0>(__inout_acc.tuple()); // first sequence
        auto __b = get<1>(__inout_acc.tuple()); // second sequence
        auto __c = get<2>(__inout_acc.tuple()); // mask buffer

        auto __a_beg = _Size1(0);
        auto __b_beg = _Size2(0);

        auto __idx_c = __idx;
        const auto __idx_a = __idx;
        auto __val_a = __a[__a_beg + __idx_a];

        auto __res = __internal::__pstl_lower_bound(__b, _Size2(0), __nb, __val_a, __comp);

        bool bres = _IsOpDifference(); //initialization in true in case of difference operation; false - intersection.
        if (__res == __nb || __comp(__val_a, __b[__b_beg + __res]))
        {
            // there is no __val_a in __b, so __b in the defference {__a}/{__b};
        }
        else
        {
            auto __val_b = __b[__b_beg + __res];

            //Difference operation logic: if number of duplication in __a on left side from __idx > total number of
            //duplication in __b than a mask is 1

            //Intersection operation logic: if number of duplication in __a on left side from __idx <= total number of
            //duplication in __b than a mask is 1

            const _Size1 __count_a_left =
                __idx_a - __internal::__pstl_left_bound(__a, _Size1(0), _Size1(__idx_a), __val_a, __comp) + 1;

            const _Size2 __count_b = __internal::__pstl_right_bound(__b, _Size2(__res), __nb, __val_b, __comp) - __res +
                                     __res -
                                     __internal::__pstl_left_bound(__b, _Size2(0), _Size2(__res), __val_b, __comp);

            bres = __internal::__invoke_if_else(_IsOpDifference(),
                                                [&]() { return __count_a_left > __count_b; }, /*difference*/
                                                [&]() { return __count_a_left <= __count_b; } /*intersection*/);
        }
        __c[__idx_c] = bres; //store a mask
        return bres;
    }
};

template <typename _ExecutionPolicy, typename _DiffType>
struct __brick_shift_left
{
    _DiffType __size;
    _DiffType __n;

    template <typename _ItemId, typename _Range>
    void
    operator()(const _ItemId __idx, _Range&& __rng) const
    {
        const _DiffType __i = __idx - __n; //loop invariant
        for (_DiffType __k = __n; __k < __size; __k += __n)
        {
            if (__k + __idx < __size)
                __rng[__k + __i] = ::std::move(__rng[__k + __idx]);
        }
    }
};

} // namespace unseq_backend
} // namespace dpl
} // namespace oneapi

#endif /* _ONEDPL_unseq_backend_sycl_H */