// // Copyright (c) 2021 The Khronos Group Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include "procs.h" #include "subhelpers.h" #include "subgroup_common_templates.h" #include "harness/typeWrappers.h" #include namespace { // Test for ballot functions template struct BALLOT { static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) { // no work here int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int non_uniform_size = gws % lws; log_info(" sub_group_ballot...\n"); } static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, const WorkGroupParams &test_params) { int wi_id, wg_id, sb_id; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; int current_sbs = 0; cl_uint expected_result, device_result; int non_uniform_size = gws % lws; int wg_number = gws / lws; wg_number = non_uniform_size ? wg_number + 1 : wg_number; int last_subgroup_size = 0; for (wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group if (non_uniform_size && wg_id == wg_number - 1) { set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, last_subgroup_size); } for (wi_id = 0; wi_id < lws; ++wi_id) { // inside the work_group // read device outputs for work_group my[wi_id] = y[wi_id]; } for (sb_id = 0; sb_id < sb_number; ++sb_id) { // for each subgroup int wg_offset = sb_id * sbs; if (last_subgroup_size && sb_id == sb_number - 1) { current_sbs = last_subgroup_size; } else { current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; } for (wi_id = 0; wi_id < current_sbs; ++wi_id) { device_result = my[wg_offset + wi_id]; expected_result = 1; if (!compare(device_result, expected_result)) { log_error( "ERROR: sub_group_ballot mismatch for local id " "%d in sub group %d in group %d obtained {%d}, " "expected {%d} \n", wi_id, sb_id, wg_id, device_result, expected_result); return TEST_FAIL; } } } y += lws; m += 4 * lws; } log_info(" sub_group_ballot... passed\n"); return TEST_PASS; } }; // Test for bit extract ballot functions template struct BALLOT_BIT_EXTRACT { static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) { int wi_id, sb_id, wg_id, l; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; int wg_number = gws / lws; int limit_sbs = sbs > 100 ? 100 : sbs; int non_uniform_size = gws % lws; log_info(" sub_group_%s(%s)...\n", operation_names(operation), TypeManager::name()); for (wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group for (sb_id = 0; sb_id < sb_number; ++sb_id) { // for each subgroup int wg_offset = sb_id * sbs; int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; // rand index to bit extract int index_for_odd = (int)(genrand_int32(gMTdata) & 0x7fffffff) % (limit_sbs > current_sbs ? current_sbs : limit_sbs); int index_for_even = (int)(genrand_int32(gMTdata) & 0x7fffffff) % (limit_sbs > current_sbs ? current_sbs : limit_sbs); for (wi_id = 0; wi_id < current_sbs; ++wi_id) { // index of the third element int the vector. int midx = 4 * wg_offset + 4 * wi_id + 2; // storing information about index to bit extract m[midx] = (cl_int)index_for_odd; m[++midx] = (cl_int)index_for_even; } set_randomdata_for_subgroup(t, wg_offset, current_sbs); } // Now map into work group using map from device for (wi_id = 0; wi_id < lws; ++wi_id) { x[wi_id] = t[wi_id]; } x += lws; m += 4 * lws; } } static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, const WorkGroupParams &test_params) { int wi_id, wg_id, l, sb_id; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; int wg_number = gws / lws; cl_uint4 expected_result, device_result; int last_subgroup_size = 0; int current_sbs = 0; int non_uniform_size = gws % lws; for (wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group if (non_uniform_size && wg_id == wg_number - 1) { set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, last_subgroup_size); } // Map to array indexed to array indexed by local ID and sub group for (wi_id = 0; wi_id < lws; ++wi_id) { // inside the work_group // read host inputs for work_group mx[wi_id] = x[wi_id]; // read device outputs for work_group my[wi_id] = y[wi_id]; } for (sb_id = 0; sb_id < sb_number; ++sb_id) { // for each subgroup int wg_offset = sb_id * sbs; if (last_subgroup_size && sb_id == sb_number - 1) { current_sbs = last_subgroup_size; } else { current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; } // take index of array where info which work_item will // be broadcast its value is stored int midx = 4 * wg_offset + 2; // take subgroup local id of this work_item int index_for_odd = (int)m[midx]; int index_for_even = (int)m[++midx]; for (wi_id = 0; wi_id < current_sbs; ++wi_id) { // for each subgroup int bit_value = 0; // from which value of bitfield bit // verification will be done int take_shift = (wi_id & 1) ? index_for_odd % 32 : index_for_even % 32; int bit_mask = 1 << take_shift; if (wi_id < 32) (mx[wg_offset + wi_id].s0 & bit_mask) > 0 ? bit_value = 1 : bit_value = 0; if (wi_id >= 32 && wi_id < 64) (mx[wg_offset + wi_id].s1 & bit_mask) > 0 ? bit_value = 1 : bit_value = 0; if (wi_id >= 64 && wi_id < 96) (mx[wg_offset + wi_id].s2 & bit_mask) > 0 ? bit_value = 1 : bit_value = 0; if (wi_id >= 96 && wi_id < 128) (mx[wg_offset + wi_id].s3 & bit_mask) > 0 ? bit_value = 1 : bit_value = 0; if (wi_id & 1) { bit_value ? expected_result = { 1, 0, 0, 1 } : expected_result = { 0, 0, 0, 1 }; } else { bit_value ? expected_result = { 1, 0, 0, 2 } : expected_result = { 0, 0, 0, 2 }; } device_result = my[wg_offset + wi_id]; if (!compare(device_result, expected_result)) { log_error( "ERROR: sub_group_%s mismatch for local id %d in " "sub group %d in group %d obtained {%d, %d, %d, " "%d}, expected {%d, %d, %d, %d}\n", operation_names(operation), wi_id, sb_id, wg_id, device_result.s0, device_result.s1, device_result.s2, device_result.s3, expected_result.s0, expected_result.s1, expected_result.s2, expected_result.s3); return TEST_FAIL; } } } x += lws; y += lws; m += 4 * lws; } log_info(" sub_group_%s(%s)... passed\n", operation_names(operation), TypeManager::name()); return TEST_PASS; } }; template struct BALLOT_INVERSE { static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) { int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int non_uniform_size = gws % lws; log_info(" sub_group_inverse_ballot...\n"); // no work here } static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, const WorkGroupParams &test_params) { int wi_id, wg_id, sb_id; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; cl_uint4 expected_result, device_result; int non_uniform_size = gws % lws; int wg_number = gws / lws; int last_subgroup_size = 0; int current_sbs = 0; if (non_uniform_size) wg_number++; for (wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group if (non_uniform_size && wg_id == wg_number - 1) { set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, last_subgroup_size); } // Map to array indexed to array indexed by local ID and sub group for (wi_id = 0; wi_id < lws; ++wi_id) { // inside the work_group mx[wi_id] = x[wi_id]; // read host inputs for work_group my[wi_id] = y[wi_id]; // read device outputs for work_group } for (sb_id = 0; sb_id < sb_number; ++sb_id) { // for each subgroup int wg_offset = sb_id * sbs; if (last_subgroup_size && sb_id == sb_number - 1) { current_sbs = last_subgroup_size; } else { current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; } // take index of array where info which work_item will // be broadcast its value is stored int midx = 4 * wg_offset + 2; // take subgroup local id of this work_item // Check result for (wi_id = 0; wi_id < current_sbs; ++wi_id) { // for each subgroup work item wi_id & 1 ? expected_result = { 1, 0, 0, 1 } : expected_result = { 1, 0, 0, 2 }; device_result = my[wg_offset + wi_id]; if (!compare(device_result, expected_result)) { log_error( "ERROR: sub_group_%s mismatch for local id %d in " "sub group %d in group %d obtained {%d, %d, %d, " "%d}, expected {%d, %d, %d, %d}\n", operation_names(operation), wi_id, sb_id, wg_id, device_result.s0, device_result.s1, device_result.s2, device_result.s3, expected_result.s0, expected_result.s1, expected_result.s2, expected_result.s3); return TEST_FAIL; } } } x += lws; y += lws; m += 4 * lws; } log_info(" sub_group_inverse_ballot... passed\n"); return TEST_PASS; } }; // Test for bit count/inclusive and exclusive scan/ find lsb msb ballot function template struct BALLOT_COUNT_SCAN_FIND { static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) { int wi_id, wg_id, sb_id; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; int non_uniform_size = gws % lws; int wg_number = gws / lws; int last_subgroup_size = 0; int current_sbs = 0; log_info(" sub_group_%s(%s)...\n", operation_names(operation), TypeManager::name()); if (non_uniform_size) { wg_number++; } int e; for (wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group if (non_uniform_size && wg_id == wg_number - 1) { set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, last_subgroup_size); } for (sb_id = 0; sb_id < sb_number; ++sb_id) { // for each subgroup int wg_offset = sb_id * sbs; if (last_subgroup_size && sb_id == sb_number - 1) { current_sbs = last_subgroup_size; } else { current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; } if (operation == BallotOp::ballot_bit_count || operation == BallotOp::ballot_inclusive_scan || operation == BallotOp::ballot_exclusive_scan) { set_randomdata_for_subgroup(t, wg_offset, current_sbs); } else if (operation == BallotOp::ballot_find_lsb || operation == BallotOp::ballot_find_msb) { // Regarding to the spec, find lsb and find msb result is // undefined behavior if input value is zero, so generate // only non-zero values. for (wi_id = 0; wi_id < current_sbs; ++wi_id) { char x = (genrand_int32(gMTdata)) & 0xff; // undefined behaviour in case of 0; x = x ? x : 1; memset(&t[wg_offset + wi_id], x, sizeof(Ty)); } } else { log_error("Unknown operation..."); } } // Now map into work group using map from device for (wi_id = 0; wi_id < lws; ++wi_id) { x[wi_id] = t[wi_id]; } x += lws; m += 4 * lws; } } static bs128 getImportantBits(cl_uint sub_group_local_id, cl_uint sub_group_size) { bs128 mask; if (operation == BallotOp::ballot_bit_count || operation == BallotOp::ballot_find_lsb || operation == BallotOp::ballot_find_msb) { for (cl_uint i = 0; i < sub_group_size; ++i) mask.set(i); } else if (operation == BallotOp::ballot_inclusive_scan || operation == BallotOp::ballot_exclusive_scan) { for (cl_uint i = 0; i <= sub_group_local_id; ++i) mask.set(i); if (operation == BallotOp::ballot_exclusive_scan) mask.reset(sub_group_local_id); } return mask; } static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, const WorkGroupParams &test_params) { int wi_id, wg_id, sb_id; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; int non_uniform_size = gws % lws; int wg_number = gws / lws; wg_number = non_uniform_size ? wg_number + 1 : wg_number; cl_uint4 expected_result, device_result; int last_subgroup_size = 0; int current_sbs = 0; for (wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group if (non_uniform_size && wg_id == wg_number - 1) { set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws, last_subgroup_size); } // Map to array indexed to array indexed by local ID and sub group for (wi_id = 0; wi_id < lws; ++wi_id) { // inside the work_group // read host inputs for work_group mx[wi_id] = x[wi_id]; // read device outputs for work_group my[wi_id] = y[wi_id]; } for (sb_id = 0; sb_id < sb_number; ++sb_id) { // for each subgroup int wg_offset = sb_id * sbs; if (last_subgroup_size && sb_id == sb_number - 1) { current_sbs = last_subgroup_size; } else { current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; } // Check result expected_result = { 0, 0, 0, 0 }; for (wi_id = 0; wi_id < current_sbs; ++wi_id) { // for subgroup element bs128 bs; // convert cl_uint4 input into std::bitset<128> bs |= bs128(mx[wg_offset + wi_id].s0) | (bs128(mx[wg_offset + wi_id].s1) << 32) | (bs128(mx[wg_offset + wi_id].s2) << 64) | (bs128(mx[wg_offset + wi_id].s3) << 96); bs &= getImportantBits(wi_id, current_sbs); device_result = my[wg_offset + wi_id]; if (operation == BallotOp::ballot_inclusive_scan || operation == BallotOp::ballot_exclusive_scan || operation == BallotOp::ballot_bit_count) { expected_result.s0 = bs.count(); if (!compare(device_result, expected_result)) { log_error("ERROR: sub_group_%s " "mismatch for local id %d in sub group " "%d in group %d obtained {%d, %d, %d, " "%d}, expected {%d, %d, %d, %d}\n", operation_names(operation), wi_id, sb_id, wg_id, device_result.s0, device_result.s1, device_result.s2, device_result.s3, expected_result.s0, expected_result.s1, expected_result.s2, expected_result.s3); return TEST_FAIL; } } else if (operation == BallotOp::ballot_find_lsb) { for (int id = 0; id < current_sbs; ++id) { if (bs.test(id)) { expected_result.s0 = id; break; } } if (!compare(device_result, expected_result)) { log_error("ERROR: sub_group_ballot_find_lsb " "mismatch for local id %d in sub group " "%d in group %d obtained {%d, %d, %d, " "%d}, expected {%d, %d, %d, %d}\n", wi_id, sb_id, wg_id, device_result.s0, device_result.s1, device_result.s2, device_result.s3, expected_result.s0, expected_result.s1, expected_result.s2, expected_result.s3); return TEST_FAIL; } } else if (operation == BallotOp::ballot_find_msb) { for (int id = current_sbs - 1; id >= 0; --id) { if (bs.test(id)) { expected_result.s0 = id; break; } } if (!compare(device_result, expected_result)) { log_error("ERROR: sub_group_ballot_find_msb " "mismatch for local id %d in sub group " "%d in group %d obtained {%d, %d, %d, " "%d}, expected {%d, %d, %d, %d}\n", wi_id, sb_id, wg_id, device_result.s0, device_result.s1, device_result.s2, device_result.s3, expected_result.s0, expected_result.s1, expected_result.s2, expected_result.s3); return TEST_FAIL; } } } } x += lws; y += lws; m += 4 * lws; } log_info(" sub_group_ballot_%s(%s)... passed\n", operation_names(operation), TypeManager::name()); return TEST_PASS; } }; // test mask functions template struct SMASK { static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) { int wi_id, wg_id, l, sb_id; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; int wg_number = gws / lws; log_info(" get_sub_group_%s_mask...\n", operation_names(operation)); for (wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group for (sb_id = 0; sb_id < sb_number; ++sb_id) { // for each subgroup int wg_offset = sb_id * sbs; int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; // Produce expected masks for each work item in the subgroup for (wi_id = 0; wi_id < current_sbs; ++wi_id) { int midx = 4 * wg_offset + 4 * wi_id; cl_uint max_sub_group_size = m[midx + 2]; cl_uint4 expected_mask = { 0 }; expected_mask = generate_bit_mask( wi_id, operation_names(operation), max_sub_group_size); set_value(t[wg_offset + wi_id], expected_mask); } } // Now map into work group using map from device for (wi_id = 0; wi_id < lws; ++wi_id) { x[wi_id] = t[wi_id]; } x += lws; m += 4 * lws; } } static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, const WorkGroupParams &test_params) { int wi_id, wg_id, sb_id; int gws = test_params.global_workgroup_size; int lws = test_params.local_workgroup_size; int sbs = test_params.subgroup_size; int sb_number = (lws + sbs - 1) / sbs; Ty expected_result, device_result; int wg_number = gws / lws; for (wg_id = 0; wg_id < wg_number; ++wg_id) { // for each work_group for (wi_id = 0; wi_id < lws; ++wi_id) { // inside the work_group mx[wi_id] = x[wi_id]; // read host inputs for work_group my[wi_id] = y[wi_id]; // read device outputs for work_group } for (sb_id = 0; sb_id < sb_number; ++sb_id) { int wg_offset = sb_id * sbs; int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs; // Check result for (wi_id = 0; wi_id < current_sbs; ++wi_id) { // inside the subgroup expected_result = mx[wg_offset + wi_id]; // read host input for subgroup device_result = my[wg_offset + wi_id]; // read device outputs for subgroup if (!compare(device_result, expected_result)) { log_error("ERROR: get_sub_group_%s_mask... mismatch " "for local id %d in sub group %d in group " "%d, obtained %d, expected %d\n", operation_names(operation), wi_id, sb_id, wg_id, device_result, expected_result); return TEST_FAIL; } } } x += lws; y += lws; m += 4 * lws; } log_info(" get_sub_group_%s_mask... passed\n", operation_names(operation)); return TEST_PASS; } }; std::string sub_group_non_uniform_broadcast_source = R"( __kernel void test_sub_group_non_uniform_broadcast(const __global Type *in, __global int4 *xy, __global Type *out) { int gid = get_global_id(0); XY(xy,gid); Type x = in[gid]; if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) { out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].z); } else { out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].w); } } )"; std::string sub_group_broadcast_first_source = R"( __kernel void test_sub_group_broadcast_first(const __global Type *in, __global int4 *xy, __global Type *out) { int gid = get_global_id(0); XY(xy,gid); Type x = in[gid]; if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) { out[gid] = sub_group_broadcast_first(x);; } else { out[gid] = sub_group_broadcast_first(x);; } } )"; std::string sub_group_ballot_bit_scan_find_source = R"( __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out) { int gid = get_global_id(0); XY(xy,gid); Type x = in[gid]; uint4 value = (uint4)(0,0,0,0); value = (uint4)(%s(x),0,0,0); out[gid] = value; } )"; std::string sub_group_ballot_mask_source = R"( __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out) { int gid = get_global_id(0); XY(xy,gid); xy[gid].z = get_max_sub_group_size(); Type x = in[gid]; uint4 mask = %s(); out[gid] = mask; } )"; std::string sub_group_ballot_source = R"( __kernel void test_sub_group_ballot(const __global Type *in, __global int4 *xy, __global Type *out) { uint4 full_ballot = sub_group_ballot(1); uint divergence_mask; uint4 partial_ballot; uint gid = get_global_id(0); XY(xy,gid); if (get_sub_group_local_id() & 1) { divergence_mask = 0xaaaaaaaa; partial_ballot = sub_group_ballot(1); } else { divergence_mask = 0x55555555; partial_ballot = sub_group_ballot(1); } size_t lws = get_local_size(0); uint4 masked_ballot = full_ballot; masked_ballot.x &= divergence_mask; masked_ballot.y &= divergence_mask; masked_ballot.z &= divergence_mask; masked_ballot.w &= divergence_mask; out[gid] = all(masked_ballot == partial_ballot); } )"; std::string sub_group_inverse_ballot_source = R"( __kernel void test_sub_group_inverse_ballot(const __global Type *in, __global int4 *xy, __global Type *out) { int gid = get_global_id(0); XY(xy,gid); Type x = in[gid]; uint4 value = (uint4)(10,0,0,0); if (get_sub_group_local_id() & 1) { uint4 partial_ballot_mask = (uint4)(0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA); if (sub_group_inverse_ballot(partial_ballot_mask)) { value = (uint4)(1,0,0,1); } else { value = (uint4)(0,0,0,1); } } else { uint4 partial_ballot_mask = (uint4)(0x55555555,0x55555555,0x55555555,0x55555555); if (sub_group_inverse_ballot(partial_ballot_mask)) { value = (uint4)(1,0,0,2); } else { value = (uint4)(0,0,0,2); } } out[gid] = value; } )"; std::string sub_group_ballot_bit_extract_source = R"( __kernel void test_sub_group_ballot_bit_extract(const __global Type *in, __global int4 *xy, __global Type *out) { int gid = get_global_id(0); XY(xy,gid); Type x = in[gid]; uint index = xy[gid].z; uint4 value = (uint4)(10,0,0,0); if (get_sub_group_local_id() & 1) { if (sub_group_ballot_bit_extract(x, xy[gid].z)) { value = (uint4)(1,0,0,1); } else { value = (uint4)(0,0,0,1); } } else { if (sub_group_ballot_bit_extract(x, xy[gid].w)) { value = (uint4)(1,0,0,2); } else { value = (uint4)(0,0,0,2); } } out[gid] = value; } )"; template int run_non_uniform_broadcast_for_type(RunTestForType rft) { int error = rft.run_impl>( "sub_group_non_uniform_broadcast"); return error; } } int test_subgroup_functions_ballot(cl_device_id device, cl_context context, cl_command_queue queue, int num_elements) { if (!is_extension_available(device, "cl_khr_subgroup_ballot")) { log_info("cl_khr_subgroup_ballot is not supported on this device, " "skipping test.\n"); return TEST_SKIPPED_ITSELF; } constexpr size_t global_work_size = 170; constexpr size_t local_work_size = 64; WorkGroupParams test_params(global_work_size, local_work_size); test_params.save_kernel_source(sub_group_ballot_mask_source); test_params.save_kernel_source(sub_group_non_uniform_broadcast_source, "sub_group_non_uniform_broadcast"); test_params.save_kernel_source(sub_group_broadcast_first_source, "sub_group_broadcast_first"); RunTestForType rft(device, context, queue, num_elements, test_params); // non uniform broadcast functions int error = run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); error |= run_non_uniform_broadcast_for_type(rft); // broadcast first functions error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl>( "sub_group_broadcast_first"); error |= rft.run_impl< subgroups::cl_half, BC>( "sub_group_broadcast_first"); // mask functions error |= rft.run_impl>( "get_sub_group_eq_mask"); error |= rft.run_impl>( "get_sub_group_ge_mask"); error |= rft.run_impl>( "get_sub_group_gt_mask"); error |= rft.run_impl>( "get_sub_group_le_mask"); error |= rft.run_impl>( "get_sub_group_lt_mask"); // ballot functions WorkGroupParams test_params_ballot(global_work_size, local_work_size); test_params_ballot.save_kernel_source( sub_group_ballot_bit_scan_find_source); test_params_ballot.save_kernel_source(sub_group_ballot_source, "sub_group_ballot"); test_params_ballot.save_kernel_source(sub_group_inverse_ballot_source, "sub_group_inverse_ballot"); test_params_ballot.save_kernel_source(sub_group_ballot_bit_extract_source, "sub_group_ballot_bit_extract"); RunTestForType rft_ballot(device, context, queue, num_elements, test_params_ballot); error |= rft_ballot.run_impl>("sub_group_ballot"); error |= rft_ballot.run_impl>( "sub_group_inverse_ballot"); error |= rft_ballot.run_impl< cl_uint4, BALLOT_BIT_EXTRACT>( "sub_group_ballot_bit_extract"); error |= rft_ballot.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_bit_count"); error |= rft_ballot.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_inclusive_scan"); error |= rft_ballot.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_exclusive_scan"); error |= rft_ballot.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_find_lsb"); error |= rft_ballot.run_impl< cl_uint4, BALLOT_COUNT_SCAN_FIND>( "sub_group_ballot_find_msb"); return error; }