Improve testing of sub_group_ballot (#1382)

Signed-off-by: Stuart Brady <stuart.brady@arm.com>
This commit is contained in:
Stuart Brady
2022-01-28 09:15:44 +00:00
committed by GitHub
parent 656886030b
commit 60471a5208
4 changed files with 159 additions and 67 deletions

View File

@@ -34,6 +34,12 @@ extern MTdata gMTdata;
typedef std::bitset<128> bs128;
extern cl_half_rounding_mode g_rounding_mode;
static bs128 cl_uint4_to_bs128(cl_uint4 v)
{
return bs128(v.s0) | (bs128(v.s1) << 32) | (bs128(v.s2) << 64)
| (bs128(v.s3) << 96);
}
static cl_uint4 bs128_to_cl_uint4(bs128 v)
{
bs128 bs128_ffffffff = 0xffffffffU;

View File

@@ -31,45 +31,26 @@ template <typename Ty> struct BALLOT
static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
{
// no work here
int gws = test_params.global_workgroup_size;
int lws = test_params.local_workgroup_size;
int sbs = test_params.subgroup_size;
int non_uniform_size = gws % lws;
}
static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
const WorkGroupParams &test_params)
{
int wi_id, wg_id, sb_id;
int gws = test_params.global_workgroup_size;
int lws = test_params.local_workgroup_size;
int sbs = test_params.subgroup_size;
int sb_number = (lws + sbs - 1) / sbs;
int current_sbs = 0;
cl_uint expected_result, device_result;
int non_uniform_size = gws % lws;
int wg_number = gws / lws;
wg_number = non_uniform_size ? wg_number + 1 : wg_number;
int last_subgroup_size = 0;
for (wg_id = 0; wg_id < wg_number; ++wg_id)
for (int wg_id = 0; wg_id < wg_number; ++wg_id)
{ // for each work_group
if (non_uniform_size && wg_id == wg_number - 1)
{
set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
last_subgroup_size);
}
for (wi_id = 0; wi_id < lws; ++wi_id)
{ // inside the work_group
// read device outputs for work_group
my[wi_id] = y[wi_id];
}
for (sb_id = 0; sb_id < sb_number; ++sb_id)
for (int sb_id = 0; sb_id < sb_number; ++sb_id)
{ // for each subgroup
int wg_offset = sb_id * sbs;
int current_sbs;
if (last_subgroup_size && sb_id == sb_number - 1)
{
current_sbs = last_subgroup_size;
@@ -78,25 +59,121 @@ template <typename Ty> struct BALLOT
{
current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
}
for (wi_id = 0; wi_id < current_sbs; ++wi_id)
for (int wi_id = 0; wi_id < current_sbs; wi_id++)
{
device_result = my[wg_offset + wi_id];
expected_result = 1;
if (!compare(device_result, expected_result))
cl_uint v;
if (genrand_bool(gMTdata))
{
v = genrand_bool(gMTdata);
}
else if (genrand_bool(gMTdata))
{
v = 1U << ((genrand_int32(gMTdata) % 31) + 1);
}
else
{
v = genrand_int32(gMTdata);
}
cl_uint4 v4 = { v, 0, 0, 0 };
t[wi_id + wg_offset] = v4;
}
}
// Now map into work group using map from device
for (int wi_id = 0; wi_id < lws; ++wi_id)
{
x[wi_id] = t[wi_id];
}
x += lws;
m += 4 * lws;
}
}
static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
const WorkGroupParams &test_params)
{
int gws = test_params.global_workgroup_size;
int lws = test_params.local_workgroup_size;
int sbs = test_params.subgroup_size;
int sb_number = (lws + sbs - 1) / sbs;
int non_uniform_size = gws % lws;
int wg_number = gws / lws;
wg_number = non_uniform_size ? wg_number + 1 : wg_number;
int last_subgroup_size = 0;
for (int wg_id = 0; wg_id < wg_number; ++wg_id)
{ // for each work_group
if (non_uniform_size && wg_id == wg_number - 1)
{
set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
last_subgroup_size);
}
for (int wi_id = 0; wi_id < lws; ++wi_id)
{ // inside the work_group
mx[wi_id] = x[wi_id]; // read host inputs for work_group
my[wi_id] = y[wi_id]; // read device outputs for work_group
}
for (int sb_id = 0; sb_id < sb_number; ++sb_id)
{ // for each subgroup
int wg_offset = sb_id * sbs;
int current_sbs;
if (last_subgroup_size && sb_id == sb_number - 1)
{
current_sbs = last_subgroup_size;
}
else
{
current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
}
bs128 expected_result_bs = 0;
std::set<int> active_work_items;
for (int wi_id = 0; wi_id < current_sbs; ++wi_id)
{
if (test_params.work_items_mask.test(wi_id))
{
bool predicate = (mx[wg_offset + wi_id].s0 != 0);
expected_result_bs |= (bs128(predicate) << wi_id);
active_work_items.insert(wi_id);
}
}
if (active_work_items.empty())
{
continue;
}
cl_uint4 expected_result =
bs128_to_cl_uint4(expected_result_bs);
for (const int &active_work_item : active_work_items)
{
int wi_id = active_work_item;
cl_uint4 device_result = my[wg_offset + wi_id];
bs128 device_result_bs = cl_uint4_to_bs128(device_result);
if (device_result_bs != expected_result_bs)
{
log_error(
"ERROR: sub_group_ballot mismatch for local id "
"%d in sub group %d in group %d obtained %d, "
"expected %d\n",
wi_id, sb_id, wg_id, device_result,
expected_result);
"%d in sub group %d in group %d obtained {%d, %d, "
"%d, %d}, expected {%d, %d, %d, %d}\n",
wi_id, sb_id, wg_id, device_result.s0,
device_result.s1, device_result.s2,
device_result.s3, expected_result.s0,
expected_result.s1, expected_result.s2,
expected_result.s3);
return TEST_FAIL;
}
}
}
x += lws;
y += lws;
m += 4 * lws;
}
return TEST_PASS;
}
};
@@ -724,27 +801,26 @@ __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type
}
)";
std::string sub_group_ballot_source = R"(
__kernel void test_sub_group_ballot(const __global Type *in, __global int4 *xy, __global Type *out) {
uint4 full_ballot = sub_group_ballot(1);
uint divergence_mask;
uint4 partial_ballot;
__kernel void test_sub_group_ballot(const __global Type *in, __global int4 *xy, __global Type *out, uint4 work_item_mask_vector) {
uint gid = get_global_id(0);
XY(xy,gid);
if (get_sub_group_local_id() & 1) {
divergence_mask = 0xaaaaaaaa;
partial_ballot = sub_group_ballot(1);
} else {
divergence_mask = 0x55555555;
partial_ballot = sub_group_ballot(1);
uint subgroup_local_id = get_sub_group_local_id();
uint elect_work_item = 1 << (subgroup_local_id % 32);
uint work_item_mask;
if (subgroup_local_id < 32) {
work_item_mask = work_item_mask_vector.x;
} else if(subgroup_local_id < 64) {
work_item_mask = work_item_mask_vector.y;
} else if(subgroup_local_id < 96) {
work_item_mask = work_item_mask_vector.z;
} else if(subgroup_local_id < 128) {
work_item_mask = work_item_mask_vector.w;
}
size_t lws = get_local_size(0);
uint4 masked_ballot = full_ballot;
masked_ballot.x &= divergence_mask;
masked_ballot.y &= divergence_mask;
masked_ballot.z &= divergence_mask;
masked_ballot.w &= divergence_mask;
out[gid] = all(masked_ballot == partial_ballot);
uint4 value = (uint4)(0, 0, 0, 0);
if (elect_work_item & work_item_mask) {
value = sub_group_ballot(in[gid].s0);
}
out[gid] = value;
}
)";
std::string sub_group_inverse_ballot_source = R"(
@@ -952,42 +1028,47 @@ int test_subgroup_functions_ballot(cl_device_id device, cl_context context,
error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::lt_mask>>(
"get_sub_group_lt_mask");
// ballot functions
WorkGroupParams test_params_ballot(global_work_size, local_work_size);
test_params_ballot.save_kernel_source(
sub_group_ballot_bit_scan_find_source);
test_params_ballot.save_kernel_source(sub_group_ballot_source,
"sub_group_ballot");
test_params_ballot.save_kernel_source(sub_group_inverse_ballot_source,
"sub_group_inverse_ballot");
test_params_ballot.save_kernel_source(sub_group_ballot_bit_extract_source,
"sub_group_ballot_bit_extract");
// sub_group_ballot function
WorkGroupParams test_params_ballot(global_work_size, local_work_size, 3);
test_params_ballot.save_kernel_source(sub_group_ballot_source);
RunTestForType rft_ballot(device, context, queue, num_elements,
test_params_ballot);
error |= rft_ballot.run_impl<cl_uint, BALLOT<cl_uint>>("sub_group_ballot");
error |=
rft_ballot.run_impl<cl_uint4,
BALLOT_INVERSE<cl_uint4, BallotOp::inverse_ballot>>(
rft_ballot.run_impl<cl_uint4, BALLOT<cl_uint4>>("sub_group_ballot");
// ballot arithmetic functions
WorkGroupParams test_params_arith(global_work_size, local_work_size);
test_params_arith.save_kernel_source(sub_group_ballot_bit_scan_find_source);
test_params_arith.save_kernel_source(sub_group_inverse_ballot_source,
"sub_group_inverse_ballot");
test_params_arith.save_kernel_source(sub_group_ballot_bit_extract_source,
"sub_group_ballot_bit_extract");
RunTestForType rft_arith(device, context, queue, num_elements,
test_params_arith);
error |=
rft_arith.run_impl<cl_uint4,
BALLOT_INVERSE<cl_uint4, BallotOp::inverse_ballot>>(
"sub_group_inverse_ballot");
error |= rft_ballot.run_impl<
error |= rft_arith.run_impl<
cl_uint4, BALLOT_BIT_EXTRACT<cl_uint4, BallotOp::ballot_bit_extract>>(
"sub_group_ballot_bit_extract");
error |= rft_ballot.run_impl<
error |= rft_arith.run_impl<
cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_bit_count>>(
"sub_group_ballot_bit_count");
error |= rft_ballot.run_impl<
error |= rft_arith.run_impl<
cl_uint4,
BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_inclusive_scan>>(
"sub_group_ballot_inclusive_scan");
error |= rft_ballot.run_impl<
error |= rft_arith.run_impl<
cl_uint4,
BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_exclusive_scan>>(
"sub_group_ballot_exclusive_scan");
error |= rft_ballot.run_impl<
error |= rft_arith.run_impl<
cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_lsb>>(
"sub_group_ballot_find_lsb");
error |= rft_ballot.run_impl<
error |= rft_arith.run_impl<
cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_msb>>(
"sub_group_ballot_find_msb");
return error;
}