Extended subgroups - use 128bit masks (#1215)

* Extended subgroups - use 128bit masks

* Refactoring to avoid kernels code duplication

* unification kernel names as test_ prefix +subgroups function name
* use string literals that improve readability
* use kernel templates that limit code duplication
* WorkGroupParams allows define default kernel - kernel template for multiple functions
* WorkGroupParams allows define  kernel for specific one subgroup function

Co-authored-by: Stuart Brady <stuart.brady@arm.com>
This commit is contained in:
Grzegorz Wawiorko
2021-10-01 12:28:37 +02:00
committed by GitHub
parent 903f1bf65d
commit 92844bead1
12 changed files with 592 additions and 1054 deletions

View File

@@ -17,13 +17,10 @@
#define SUBGROUPCOMMONTEMPLATES_H
#include "typeWrappers.h"
#include <bitset>
#include "CL/cl_half.h"
#include "subhelpers.h"
#include <set>
typedef std::bitset<128> bs128;
static cl_uint4 generate_bit_mask(cl_uint subgroup_local_id,
const std::string &mask_type,
cl_uint max_sub_group_size)
@@ -577,16 +574,21 @@ template <typename Ty, ArithmeticOp operation> struct SCEX_NU
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
uint32_t work_items_mask = test_params.work_items_mask;
ng = ng / nw;
std::string func_name;
work_items_mask ? func_name = "sub_group_non_uniform_scan_exclusive"
: func_name = "sub_group_scan_exclusive";
test_params.work_items_mask.any()
? func_name = "sub_group_non_uniform_scan_exclusive"
: func_name = "sub_group_scan_exclusive";
log_info(" %s_%s(%s)...\n", func_name.c_str(),
operation_names(operation), TypeManager<Ty>::name());
log_info(" test params: global size = %d local size = %d subgroups "
"size = %d work item mask = 0x%x \n",
test_params.global_workgroup_size, nw, ns, work_items_mask);
"size = %d \n",
test_params.global_workgroup_size, nw, ns);
if (test_params.work_items_mask.any())
{
log_info(" work items mask: %s\n",
test_params.work_items_mask.to_string().c_str());
}
genrand<Ty, operation>(x, t, m, ns, nw, ng);
}
@@ -597,18 +599,22 @@ template <typename Ty, ArithmeticOp operation> struct SCEX_NU
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
uint32_t work_items_mask = test_params.work_items_mask;
bs128 work_items_mask = test_params.work_items_mask;
int nj = (nw + ns - 1) / ns;
Ty tr, rr;
ng = ng / nw;
std::string func_name;
work_items_mask ? func_name = "sub_group_non_uniform_scan_exclusive"
: func_name = "sub_group_scan_exclusive";
test_params.work_items_mask.any()
? func_name = "sub_group_non_uniform_scan_exclusive"
: func_name = "sub_group_scan_exclusive";
uint32_t use_work_items_mask;
// for uniform case take into consideration all workitems
use_work_items_mask = !work_items_mask ? 0xFFFFFFFF : work_items_mask;
if (!work_items_mask.any())
{
work_items_mask.set();
}
for (k = 0; k < ng; ++k)
{ // for each work_group
// Map to array indexed to array indexed by local ID and sub group
@@ -624,8 +630,7 @@ template <typename Ty, ArithmeticOp operation> struct SCEX_NU
std::set<int> active_work_items;
for (i = 0; i < n; ++i)
{
uint32_t check_work_item = 1 << (i % 32);
if (use_work_items_mask & check_work_item)
if (work_items_mask.test(i))
{
active_work_items.insert(i);
}
@@ -688,18 +693,23 @@ template <typename Ty, ArithmeticOp operation> struct SCIN_NU
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
uint32_t work_items_mask = test_params.work_items_mask;
ng = ng / nw;
std::string func_name;
work_items_mask ? func_name = "sub_group_non_uniform_scan_inclusive"
: func_name = "sub_group_scan_inclusive";
test_params.work_items_mask.any()
? func_name = "sub_group_non_uniform_scan_inclusive"
: func_name = "sub_group_scan_inclusive";
genrand<Ty, operation>(x, t, m, ns, nw, ng);
log_info(" %s_%s(%s)...\n", func_name.c_str(),
operation_names(operation), TypeManager<Ty>::name());
log_info(" test params: global size = %d local size = %d subgroups "
"size = %d work item mask = 0x%x \n",
test_params.global_workgroup_size, nw, ns, work_items_mask);
"size = %d \n",
test_params.global_workgroup_size, nw, ns);
if (test_params.work_items_mask.any())
{
log_info(" work items mask: %s\n",
test_params.work_items_mask.to_string().c_str());
}
}
static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
@@ -709,18 +719,22 @@ template <typename Ty, ArithmeticOp operation> struct SCIN_NU
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
uint32_t work_items_mask = test_params.work_items_mask;
bs128 work_items_mask = test_params.work_items_mask;
int nj = (nw + ns - 1) / ns;
Ty tr, rr;
ng = ng / nw;
std::string func_name;
work_items_mask ? func_name = "sub_group_non_uniform_scan_inclusive"
: func_name = "sub_group_scan_inclusive";
work_items_mask.any()
? func_name = "sub_group_non_uniform_scan_inclusive"
: func_name = "sub_group_scan_inclusive";
uint32_t use_work_items_mask;
// for uniform case take into consideration all workitems
use_work_items_mask = !work_items_mask ? 0xFFFFFFFF : work_items_mask;
if (!work_items_mask.any())
{
work_items_mask.set();
}
// std::bitset<32> mask32(use_work_items_mask);
// for (int k) mask32.count();
for (k = 0; k < ng; ++k)
@@ -740,8 +754,7 @@ template <typename Ty, ArithmeticOp operation> struct SCIN_NU
for (i = 0; i < n; ++i)
{
uint32_t check_work_item = 1 << (i % 32);
if (use_work_items_mask & check_work_item)
if (work_items_mask.test(i))
{
if (catch_frist_active == -1)
{
@@ -807,17 +820,22 @@ template <typename Ty, ArithmeticOp operation> struct RED_NU
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
uint32_t work_items_mask = test_params.work_items_mask;
ng = ng / nw;
std::string func_name;
work_items_mask ? func_name = "sub_group_non_uniform_reduce"
: func_name = "sub_group_reduce";
test_params.work_items_mask.any()
? func_name = "sub_group_non_uniform_reduce"
: func_name = "sub_group_reduce";
log_info(" %s_%s(%s)...\n", func_name.c_str(),
operation_names(operation), TypeManager<Ty>::name());
log_info(" test params: global size = %d local size = %d subgroups "
"size = %d work item mask = 0x%x \n",
test_params.global_workgroup_size, nw, ns, work_items_mask);
"size = %d \n",
test_params.global_workgroup_size, nw, ns);
if (test_params.work_items_mask.any())
{
log_info(" work items mask: %s\n",
test_params.work_items_mask.to_string().c_str());
}
genrand<Ty, operation>(x, t, m, ns, nw, ng);
}
@@ -828,14 +846,14 @@ template <typename Ty, ArithmeticOp operation> struct RED_NU
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
uint32_t work_items_mask = test_params.work_items_mask;
bs128 work_items_mask = test_params.work_items_mask;
int nj = (nw + ns - 1) / ns;
ng = ng / nw;
Ty tr, rr;
std::string func_name;
work_items_mask ? func_name = "sub_group_non_uniform_reduce"
: func_name = "sub_group_reduce";
work_items_mask.any() ? func_name = "sub_group_non_uniform_reduce"
: func_name = "sub_group_reduce";
for (k = 0; k < ng; ++k)
{
@@ -847,9 +865,10 @@ template <typename Ty, ArithmeticOp operation> struct RED_NU
my[j] = y[j];
}
uint32_t use_work_items_mask;
use_work_items_mask =
!work_items_mask ? 0xFFFFFFFF : work_items_mask;
if (!work_items_mask.any())
{
work_items_mask.set();
}
for (j = 0; j < nj; ++j)
{
@@ -859,8 +878,7 @@ template <typename Ty, ArithmeticOp operation> struct RED_NU
int catch_frist_active = -1;
for (i = 0; i < n; ++i)
{
uint32_t check_work_item = 1 << (i % 32);
if (use_work_items_mask & check_work_item)
if (work_items_mask.test(i))
{
if (catch_frist_active == -1)
{