diff --git a/test_common/harness/mt19937.cpp b/test_common/harness/mt19937.cpp index c32d9bac..f5665deb 100644 --- a/test_common/harness/mt19937.cpp +++ b/test_common/harness/mt19937.cpp @@ -277,3 +277,5 @@ double genrand_res53(MTdata d) unsigned long a = genrand_int32(d) >> 5, b = genrand_int32(d) >> 6; return (a * 67108864.0 + b) * (1.0 / 9007199254740992.0); } + +bool genrand_bool(MTdata d) { return ((cl_uint)genrand_int32(d) & 1); } diff --git a/test_common/harness/mt19937.h b/test_common/harness/mt19937.h index 35c84933..98eec843 100644 --- a/test_common/harness/mt19937.h +++ b/test_common/harness/mt19937.h @@ -90,6 +90,9 @@ double genrand_res53(MTdata /*data*/); #ifdef __cplusplus +/* generates a random boolean */ +bool genrand_bool(MTdata /*data*/); + #include struct MTdataHolder diff --git a/test_conformance/subgroups/subhelpers.h b/test_conformance/subgroups/subhelpers.h index aa4abc96..153045d0 100644 --- a/test_conformance/subgroups/subhelpers.h +++ b/test_conformance/subgroups/subhelpers.h @@ -34,6 +34,12 @@ extern MTdata gMTdata; typedef std::bitset<128> bs128; extern cl_half_rounding_mode g_rounding_mode; +static bs128 cl_uint4_to_bs128(cl_uint4 v) +{ + return bs128(v.s0) | (bs128(v.s1) << 32) | (bs128(v.s2) << 64) + | (bs128(v.s3) << 96); +} + static cl_uint4 bs128_to_cl_uint4(bs128 v) { bs128 bs128_ffffffff = 0xffffffffU; diff --git a/test_conformance/subgroups/test_subgroup_ballot.cpp b/test_conformance/subgroups/test_subgroup_ballot.cpp index 837988ea..4148707e 100644 --- a/test_conformance/subgroups/test_subgroup_ballot.cpp +++ b/test_conformance/subgroups/test_subgroup_ballot.cpp @@ -31,45 +31,26 @@ template struct BALLOT static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) { - // no work here - int gws = test_params.global_workgroup_size; - int lws = test_params.local_workgroup_size; - int sbs = test_params.subgroup_size; - int non_uniform_size = gws % lws; - } - - static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, - const WorkGroupParams &test_params) - { - int wi_id, wg_id, sb_id; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; - int current_sbs = 0; - cl_uint expected_result, device_result; int non_uniform_size = gws % lws; int wg_number = gws / lws; wg_number = non_uniform_size ? wg_number + 1 : wg_number; int last_subgroup_size = 0; - for (wg_id = 0; wg_id < wg_number; ++wg_id) + for (int wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group if (non_uniform_size && wg_id == wg_number - 1) { set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, last_subgroup_size); } - - for (wi_id = 0; wi_id < lws; ++wi_id) - { // inside the work_group - // read device outputs for work_group - my[wi_id] = y[wi_id]; - } - - for (sb_id = 0; sb_id < sb_number; ++sb_id) + for (int sb_id = 0; sb_id < sb_number; ++sb_id) { // for each subgroup int wg_offset = sb_id * sbs; + int current_sbs; if (last_subgroup_size && sb_id == sb_number - 1) { current_sbs = last_subgroup_size; @@ -78,25 +59,121 @@ template struct BALLOT { current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; } - for (wi_id = 0; wi_id < current_sbs; ++wi_id) + + for (int wi_id = 0; wi_id < current_sbs; wi_id++) { - device_result = my[wg_offset + wi_id]; - expected_result = 1; - if (!compare(device_result, expected_result)) + cl_uint v; + if (genrand_bool(gMTdata)) + { + v = genrand_bool(gMTdata); + } + else if (genrand_bool(gMTdata)) + { + v = 1U << ((genrand_int32(gMTdata) % 31) + 1); + } + else + { + v = genrand_int32(gMTdata); + } + cl_uint4 v4 = { v, 0, 0, 0 }; + t[wi_id + wg_offset] = v4; + } + } + // Now map into work group using map from device + for (int wi_id = 0; wi_id < lws; ++wi_id) + { + x[wi_id] = t[wi_id]; + } + x += lws; + m += 4 * lws; + } + } + + static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, + const WorkGroupParams &test_params) + { + int gws = test_params.global_workgroup_size; + int lws = test_params.local_workgroup_size; + int sbs = test_params.subgroup_size; + int sb_number = (lws + sbs - 1) / sbs; + int non_uniform_size = gws % lws; + int wg_number = gws / lws; + wg_number = non_uniform_size ? wg_number + 1 : wg_number; + int last_subgroup_size = 0; + + for (int wg_id = 0; wg_id < wg_number; ++wg_id) + { // for each work_group + if (non_uniform_size && wg_id == wg_number - 1) + { + set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, + last_subgroup_size); + } + for (int wi_id = 0; wi_id < lws; ++wi_id) + { // inside the work_group + mx[wi_id] = x[wi_id]; // read host inputs for work_group + my[wi_id] = y[wi_id]; // read device outputs for work_group + } + + for (int sb_id = 0; sb_id < sb_number; ++sb_id) + { // for each subgroup + int wg_offset = sb_id * sbs; + int current_sbs; + if (last_subgroup_size && sb_id == sb_number - 1) + { + current_sbs = last_subgroup_size; + } + else + { + current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; + } + + bs128 expected_result_bs = 0; + + std::set active_work_items; + for (int wi_id = 0; wi_id < current_sbs; ++wi_id) + { + if (test_params.work_items_mask.test(wi_id)) + { + bool predicate = (mx[wg_offset + wi_id].s0 != 0); + expected_result_bs |= (bs128(predicate) << wi_id); + active_work_items.insert(wi_id); + } + } + if (active_work_items.empty()) + { + continue; + } + + cl_uint4 expected_result = + bs128_to_cl_uint4(expected_result_bs); + for (const int &active_work_item : active_work_items) + { + int wi_id = active_work_item; + + cl_uint4 device_result = my[wg_offset + wi_id]; + bs128 device_result_bs = cl_uint4_to_bs128(device_result); + + if (device_result_bs != expected_result_bs) { log_error( "ERROR: sub_group_ballot mismatch for local id " - "%d in sub group %d in group %d obtained %d, " - "expected %d\n", - wi_id, sb_id, wg_id, device_result, - expected_result); + "%d in sub group %d in group %d obtained {%d, %d, " + "%d, %d}, expected {%d, %d, %d, %d}\n", + wi_id, sb_id, wg_id, device_result.s0, + device_result.s1, device_result.s2, + device_result.s3, expected_result.s0, + expected_result.s1, expected_result.s2, + expected_result.s3); return TEST_FAIL; } } } + + x += lws; y += lws; m += 4 * lws; } + return TEST_PASS; } }; @@ -724,27 +801,26 @@ __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type } )"; std::string sub_group_ballot_source = R"( -__kernel void test_sub_group_ballot(const __global Type *in, __global int4 *xy, __global Type *out) { - uint4 full_ballot = sub_group_ballot(1); - uint divergence_mask; - uint4 partial_ballot; +__kernel void test_sub_group_ballot(const __global Type *in, __global int4 *xy, __global Type *out, uint4 work_item_mask_vector) { uint gid = get_global_id(0); XY(xy,gid); - if (get_sub_group_local_id() & 1) { - divergence_mask = 0xaaaaaaaa; - partial_ballot = sub_group_ballot(1); - } else { - divergence_mask = 0x55555555; - partial_ballot = sub_group_ballot(1); + uint subgroup_local_id = get_sub_group_local_id(); + uint elect_work_item = 1 << (subgroup_local_id % 32); + uint work_item_mask; + if (subgroup_local_id < 32) { + work_item_mask = work_item_mask_vector.x; + } else if(subgroup_local_id < 64) { + work_item_mask = work_item_mask_vector.y; + } else if(subgroup_local_id < 96) { + work_item_mask = work_item_mask_vector.z; + } else if(subgroup_local_id < 128) { + work_item_mask = work_item_mask_vector.w; } - size_t lws = get_local_size(0); - uint4 masked_ballot = full_ballot; - masked_ballot.x &= divergence_mask; - masked_ballot.y &= divergence_mask; - masked_ballot.z &= divergence_mask; - masked_ballot.w &= divergence_mask; - out[gid] = all(masked_ballot == partial_ballot); - + uint4 value = (uint4)(0, 0, 0, 0); + if (elect_work_item & work_item_mask) { + value = sub_group_ballot(in[gid].s0); + } + out[gid] = value; } )"; std::string sub_group_inverse_ballot_source = R"( @@ -952,42 +1028,47 @@ int test_subgroup_functions_ballot(cl_device_id device, cl_context context, error |= rft.run_impl>( "get_sub_group_lt_mask"); - // ballot functions - WorkGroupParams test_params_ballot(global_work_size, local_work_size); - test_params_ballot.save_kernel_source( - sub_group_ballot_bit_scan_find_source); - test_params_ballot.save_kernel_source(sub_group_ballot_source, - "sub_group_ballot"); - test_params_ballot.save_kernel_source(sub_group_inverse_ballot_source, - "sub_group_inverse_ballot"); - test_params_ballot.save_kernel_source(sub_group_ballot_bit_extract_source, - "sub_group_ballot_bit_extract"); + // sub_group_ballot function + WorkGroupParams test_params_ballot(global_work_size, local_work_size, 3); + test_params_ballot.save_kernel_source(sub_group_ballot_source); RunTestForType rft_ballot(device, context, queue, num_elements, test_params_ballot); - error |= rft_ballot.run_impl>("sub_group_ballot"); error |= - rft_ballot.run_impl>( + rft_ballot.run_impl>("sub_group_ballot"); + + // ballot arithmetic functions + WorkGroupParams test_params_arith(global_work_size, local_work_size); + test_params_arith.save_kernel_source(sub_group_ballot_bit_scan_find_source); + test_params_arith.save_kernel_source(sub_group_inverse_ballot_source, + "sub_group_inverse_ballot"); + test_params_arith.save_kernel_source(sub_group_ballot_bit_extract_source, + "sub_group_ballot_bit_extract"); + RunTestForType rft_arith(device, context, queue, num_elements, + test_params_arith); + error |= + rft_arith.run_impl>( "sub_group_inverse_ballot"); - error |= rft_ballot.run_impl< + error |= rft_arith.run_impl< cl_uint4, BALLOT_BIT_EXTRACT>( "sub_group_ballot_bit_extract"); - error |= rft_ballot.run_impl< + error |= rft_arith.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_bit_count"); - error |= rft_ballot.run_impl< + error |= rft_arith.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_inclusive_scan"); - error |= rft_ballot.run_impl< + error |= rft_arith.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_exclusive_scan"); - error |= rft_ballot.run_impl< + error |= rft_arith.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_find_lsb"); - error |= rft_ballot.run_impl< + error |= rft_arith.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_find_msb"); + return error; }