Extended subgroups - use 128bit masks (#1215)

* Extended subgroups - use 128bit masks

* Refactoring to avoid kernels code duplication

* unification kernel names as test_ prefix +subgroups function name
* use string literals that improve readability
* use kernel templates that limit code duplication
* WorkGroupParams allows define default kernel - kernel template for multiple functions
* WorkGroupParams allows define  kernel for specific one subgroup function

Co-authored-by: Stuart Brady <stuart.brady@arm.com>
This commit is contained in:
Grzegorz Wawiorko
2021-10-01 12:28:37 +02:00
committed by GitHub
parent 903f1bf65d
commit 92844bead1
12 changed files with 592 additions and 1054 deletions

View File

@@ -684,239 +684,127 @@ template <typename Ty, BallotOp operation> struct SMASK
}
};
static const char *bcast_non_uniform_source =
"__kernel void test_bcast_non_uniform(const __global Type *in, __global "
"int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {\n"
" out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].z);\n"
" } else {\n"
" out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].w);\n"
" }\n"
"}\n";
std::string sub_group_non_uniform_broadcast_source = R"(
__kernel void test_sub_group_non_uniform_broadcast(const __global Type *in, __global int4 *xy, __global Type *out) {
int gid = get_global_id(0);
XY(xy,gid);
Type x = in[gid];
if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {
out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].z);
} else {
out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].w);
}
}
)";
std::string sub_group_broadcast_first_source = R"(
__kernel void test_sub_group_broadcast_first(const __global Type *in, __global int4 *xy, __global Type *out) {
int gid = get_global_id(0);
XY(xy,gid);
Type x = in[gid];
if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {
out[gid] = sub_group_broadcast_first(x);;
} else {
out[gid] = sub_group_broadcast_first(x);;
}
}
)";
std::string sub_group_ballot_bit_scan_find_source = R"(
__kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out) {
int gid = get_global_id(0);
XY(xy,gid);
Type x = in[gid];
uint4 value = (uint4)(0,0,0,0);
value = (uint4)(%s(x),0,0,0);
out[gid] = value;
}
)";
std::string sub_group_ballot_mask_source = R"(
__kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out) {
int gid = get_global_id(0);
XY(xy,gid);
xy[gid].z = get_max_sub_group_size();
Type x = in[gid];
uint4 mask = %s();
out[gid] = mask;
}
)";
std::string sub_group_ballot_source = R"(
__kernel void test_sub_group_ballot(const __global Type *in, __global int4 *xy, __global Type *out) {
uint4 full_ballot = sub_group_ballot(1);
uint divergence_mask;
uint4 partial_ballot;
uint gid = get_global_id(0);
XY(xy,gid);
if (get_sub_group_local_id() & 1) {
divergence_mask = 0xaaaaaaaa;
partial_ballot = sub_group_ballot(1);
} else {
divergence_mask = 0x55555555;
partial_ballot = sub_group_ballot(1);
}
size_t lws = get_local_size(0);
uint4 masked_ballot = full_ballot;
masked_ballot.x &= divergence_mask;
masked_ballot.y &= divergence_mask;
masked_ballot.z &= divergence_mask;
masked_ballot.w &= divergence_mask;
out[gid] = all(masked_ballot == partial_ballot);
static const char *bcast_first_source =
"__kernel void test_bcast_first(const __global Type *in, __global int4 "
"*xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {\n"
" out[gid] = sub_group_broadcast_first(x);\n"
" } else {\n"
" out[gid] = sub_group_broadcast_first(x);\n"
" }\n"
"}\n";
static const char *ballot_bit_count_source =
"__kernel void test_sub_group_ballot_bit_count(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" uint4 value = (uint4)(0,0,0,0);\n"
" value = (uint4)(sub_group_ballot_bit_count(x),0,0,0);\n"
" out[gid] = value;\n"
"}\n";
static const char *ballot_inclusive_scan_source =
"__kernel void test_sub_group_ballot_inclusive_scan(const __global Type "
"*in, __global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" uint4 value = (uint4)(0,0,0,0);\n"
" value = (uint4)(sub_group_ballot_inclusive_scan(x),0,0,0);\n"
" out[gid] = value;\n"
"}\n";
static const char *ballot_exclusive_scan_source =
"__kernel void test_sub_group_ballot_exclusive_scan(const __global Type "
"*in, __global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" uint4 value = (uint4)(0,0,0,0);\n"
" value = (uint4)(sub_group_ballot_exclusive_scan(x),0,0,0);\n"
" out[gid] = value;\n"
"}\n";
static const char *ballot_find_lsb_source =
"__kernel void test_sub_group_ballot_find_lsb(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" uint4 value = (uint4)(0,0,0,0);\n"
" value = (uint4)(sub_group_ballot_find_lsb(x),0,0,0);\n"
" out[gid] = value;\n"
"}\n";
static const char *ballot_find_msb_source =
"__kernel void test_sub_group_ballot_find_msb(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" uint4 value = (uint4)(0,0,0,0);"
" value = (uint4)(sub_group_ballot_find_msb(x),0,0,0);"
" out[gid] = value ;"
"}\n";
static const char *get_subgroup_ge_mask_source =
"__kernel void test_get_sub_group_ge_mask(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" xy[gid].z = get_max_sub_group_size();\n"
" Type x = in[gid];\n"
" uint4 mask = get_sub_group_ge_mask();"
" out[gid] = mask;\n"
"}\n";
static const char *get_subgroup_gt_mask_source =
"__kernel void test_get_sub_group_gt_mask(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" xy[gid].z = get_max_sub_group_size();\n"
" Type x = in[gid];\n"
" uint4 mask = get_sub_group_gt_mask();"
" out[gid] = mask;\n"
"}\n";
static const char *get_subgroup_le_mask_source =
"__kernel void test_get_sub_group_le_mask(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" xy[gid].z = get_max_sub_group_size();\n"
" Type x = in[gid];\n"
" uint4 mask = get_sub_group_le_mask();"
" out[gid] = mask;\n"
"}\n";
static const char *get_subgroup_lt_mask_source =
"__kernel void test_get_sub_group_lt_mask(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" xy[gid].z = get_max_sub_group_size();\n"
" Type x = in[gid];\n"
" uint4 mask = get_sub_group_lt_mask();"
" out[gid] = mask;\n"
"}\n";
static const char *get_subgroup_eq_mask_source =
"__kernel void test_get_sub_group_eq_mask(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" xy[gid].z = get_max_sub_group_size();\n"
" Type x = in[gid];\n"
" uint4 mask = get_sub_group_eq_mask();"
" out[gid] = mask;\n"
"}\n";
static const char *ballot_source =
"__kernel void test_sub_group_ballot(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
"uint4 full_ballot = sub_group_ballot(1);\n"
"uint divergence_mask;\n"
"uint4 partial_ballot;\n"
"uint gid = get_global_id(0);"
"XY(xy,gid);\n"
"if (get_sub_group_local_id() & 1) {\n"
" divergence_mask = 0xaaaaaaaa;\n"
" partial_ballot = sub_group_ballot(1);\n"
"} else {\n"
" divergence_mask = 0x55555555;\n"
" partial_ballot = sub_group_ballot(1);\n"
"}\n"
" size_t lws = get_local_size(0);\n"
"uint4 masked_ballot = full_ballot;\n"
"masked_ballot.x &= divergence_mask;\n"
"masked_ballot.y &= divergence_mask;\n"
"masked_ballot.z &= divergence_mask;\n"
"masked_ballot.w &= divergence_mask;\n"
"out[gid] = all(masked_ballot == partial_ballot);\n"
"} \n";
static const char *ballot_source_inverse =
"__kernel void test_sub_group_ballot_inverse(const __global "
"Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" uint4 value = (uint4)(10,0,0,0);\n"
" if (get_sub_group_local_id() & 1) {"
" uint4 partial_ballot_mask = "
"(uint4)(0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA);"
" if (sub_group_inverse_ballot(partial_ballot_mask)) {\n"
" value = (uint4)(1,0,0,1);\n"
" } else {\n"
" value = (uint4)(0,0,0,1);\n"
" }\n"
" } else {\n"
" uint4 partial_ballot_mask = "
"(uint4)(0x55555555,0x55555555,0x55555555,0x55555555);"
" if (sub_group_inverse_ballot(partial_ballot_mask)) {\n"
" value = (uint4)(1,0,0,2);\n"
" } else {\n"
" value = (uint4)(0,0,0,2);\n"
" }\n"
" }\n"
" out[gid] = value;\n"
"}\n";
static const char *ballot_bit_extract_source =
"__kernel void test_sub_group_ballot_bit_extract(const __global Type *in, "
"__global int4 *xy, __global Type *out)\n"
"{\n"
" int gid = get_global_id(0);\n"
" XY(xy,gid);\n"
" Type x = in[gid];\n"
" uint index = xy[gid].z;\n"
" uint4 value = (uint4)(10,0,0,0);\n"
" if (get_sub_group_local_id() & 1) {"
" if (sub_group_ballot_bit_extract(x, xy[gid].z)) {\n"
" value = (uint4)(1,0,0,1);\n"
" } else {\n"
" value = (uint4)(0,0,0,1);\n"
" }\n"
" } else {\n"
" if (sub_group_ballot_bit_extract(x, xy[gid].w)) {\n"
" value = (uint4)(1,0,0,2);\n"
" } else {\n"
" value = (uint4)(0,0,0,2);\n"
" }\n"
" }\n"
" out[gid] = value;\n"
"}\n";
}
)";
std::string sub_group_inverse_ballot_source = R"(
__kernel void test_sub_group_inverse_ballot(const __global Type *in, __global int4 *xy, __global Type *out) {
int gid = get_global_id(0);
XY(xy,gid);
Type x = in[gid];
uint4 value = (uint4)(10,0,0,0);
if (get_sub_group_local_id() & 1) {
uint4 partial_ballot_mask = (uint4)(0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA);
if (sub_group_inverse_ballot(partial_ballot_mask)) {
value = (uint4)(1,0,0,1);
} else {
value = (uint4)(0,0,0,1);
}
} else {
uint4 partial_ballot_mask = (uint4)(0x55555555,0x55555555,0x55555555,0x55555555);
if (sub_group_inverse_ballot(partial_ballot_mask)) {
value = (uint4)(1,0,0,2);
} else {
value = (uint4)(0,0,0,2);
}
}
out[gid] = value;
}
)";
std::string sub_group_ballot_bit_extract_source = R"(
__kernel void test_sub_group_ballot_bit_extract(const __global Type *in, __global int4 *xy, __global Type *out) {
int gid = get_global_id(0);
XY(xy,gid);
Type x = in[gid];
uint index = xy[gid].z;
uint4 value = (uint4)(10,0,0,0);
if (get_sub_group_local_id() & 1) {
if (sub_group_ballot_bit_extract(x, xy[gid].z)) {
value = (uint4)(1,0,0,1);
} else {
value = (uint4)(0,0,0,1);
}
} else {
if (sub_group_ballot_bit_extract(x, xy[gid].w)) {
value = (uint4)(1,0,0,2);
} else {
value = (uint4)(0,0,0,2);
}
}
out[gid] = value;
}
)";
template <typename T> int run_non_uniform_broadcast_for_type(RunTestForType rft)
{
int error =
rft.run_impl<T, BC<T, SubgroupsBroadcastOp::non_uniform_broadcast>>(
"test_bcast_non_uniform", bcast_non_uniform_source);
"sub_group_non_uniform_broadcast");
return error;
}
@@ -932,9 +820,15 @@ int test_subgroup_functions_ballot(cl_device_id device, cl_context context,
"skipping test.\n");
return TEST_SKIPPED_ITSELF;
}
constexpr size_t global_work_size = 170;
constexpr size_t local_work_size = 64;
WorkGroupParams test_params(global_work_size, local_work_size);
test_params.save_kernel_source(sub_group_ballot_mask_source);
test_params.save_kernel_source(sub_group_non_uniform_broadcast_source,
"sub_group_non_uniform_broadcast");
test_params.save_kernel_source(sub_group_broadcast_first_source,
"sub_group_broadcast_first");
RunTestForType rft(device, context, queue, num_elements, test_params);
// non uniform broadcast functions
@@ -1018,76 +912,87 @@ int test_subgroup_functions_ballot(cl_device_id device, cl_context context,
// broadcast first functions
error |=
rft.run_impl<cl_int, BC<cl_int, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_uint,
BC<cl_uint, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_long,
BC<cl_long, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_ulong,
BC<cl_ulong, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_short,
BC<cl_short, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_ushort,
BC<cl_ushort, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_char,
BC<cl_char, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_uchar,
BC<cl_uchar, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_float,
BC<cl_float, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<cl_double,
BC<cl_double, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
error |= rft.run_impl<
subgroups::cl_half,
BC<subgroups::cl_half, SubgroupsBroadcastOp::broadcast_first>>(
"test_bcast_first", bcast_first_source);
"sub_group_broadcast_first");
// mask functions
error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::eq_mask>>(
"test_get_sub_group_eq_mask", get_subgroup_eq_mask_source);
"get_sub_group_eq_mask");
error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::ge_mask>>(
"test_get_sub_group_ge_mask", get_subgroup_ge_mask_source);
"get_sub_group_ge_mask");
error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::gt_mask>>(
"test_get_sub_group_gt_mask", get_subgroup_gt_mask_source);
"get_sub_group_gt_mask");
error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::le_mask>>(
"test_get_sub_group_le_mask", get_subgroup_le_mask_source);
"get_sub_group_le_mask");
error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::lt_mask>>(
"test_get_sub_group_lt_mask", get_subgroup_lt_mask_source);
"get_sub_group_lt_mask");
// ballot functions
error |= rft.run_impl<cl_uint, BALLOT<cl_uint>>("test_sub_group_ballot",
ballot_source);
error |= rft.run_impl<cl_uint4,
BALLOT_INVERSE<cl_uint4, BallotOp::inverse_ballot>>(
"test_sub_group_ballot_inverse", ballot_source_inverse);
error |= rft.run_impl<
WorkGroupParams test_params_ballot(global_work_size, local_work_size);
test_params_ballot.save_kernel_source(
sub_group_ballot_bit_scan_find_source);
test_params_ballot.save_kernel_source(sub_group_ballot_source,
"sub_group_ballot");
test_params_ballot.save_kernel_source(sub_group_inverse_ballot_source,
"sub_group_inverse_ballot");
test_params_ballot.save_kernel_source(sub_group_ballot_bit_extract_source,
"sub_group_ballot_bit_extract");
RunTestForType rft_ballot(device, context, queue, num_elements,
test_params_ballot);
error |= rft_ballot.run_impl<cl_uint, BALLOT<cl_uint>>("sub_group_ballot");
error |=
rft_ballot.run_impl<cl_uint4,
BALLOT_INVERSE<cl_uint4, BallotOp::inverse_ballot>>(
"sub_group_inverse_ballot");
error |= rft_ballot.run_impl<
cl_uint4, BALLOT_BIT_EXTRACT<cl_uint4, BallotOp::ballot_bit_extract>>(
"test_sub_group_ballot_bit_extract", ballot_bit_extract_source);
error |= rft.run_impl<
"sub_group_ballot_bit_extract");
error |= rft_ballot.run_impl<
cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_bit_count>>(
"test_sub_group_ballot_bit_count", ballot_bit_count_source);
error |= rft.run_impl<
"sub_group_ballot_bit_count");
error |= rft_ballot.run_impl<
cl_uint4,
BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_inclusive_scan>>(
"test_sub_group_ballot_inclusive_scan", ballot_inclusive_scan_source);
error |= rft.run_impl<
"sub_group_ballot_inclusive_scan");
error |= rft_ballot.run_impl<
cl_uint4,
BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_exclusive_scan>>(
"test_sub_group_ballot_exclusive_scan", ballot_exclusive_scan_source);
error |= rft.run_impl<
"sub_group_ballot_exclusive_scan");
error |= rft_ballot.run_impl<
cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_lsb>>(
"test_sub_group_ballot_find_lsb", ballot_find_lsb_source);
error |= rft.run_impl<
"sub_group_ballot_find_lsb");
error |= rft_ballot.run_impl<
cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_msb>>(
"test_sub_group_ballot_find_msb", ballot_find_msb_source);
"sub_group_ballot_find_msb");
return error;
}