Extended subgroups - use 128bit masks (#1215)

* Extended subgroups - use 128bit masks

* Refactoring to avoid kernels code duplication

* unification kernel names as test_ prefix +subgroups function name
* use string literals that improve readability
* use kernel templates that limit code duplication
* WorkGroupParams allows define default kernel - kernel template for multiple functions
* WorkGroupParams allows define  kernel for specific one subgroup function

Co-authored-by: Stuart Brady <stuart.brady@arm.com>
This commit is contained in:
Grzegorz Wawiorko
2021-10-01 12:28:37 +02:00
committed by GitHub
parent 903f1bf65d
commit 92844bead1
12 changed files with 592 additions and 1054 deletions

View File

@@ -150,25 +150,25 @@ template <typename T>
int run_broadcast_scan_reduction_for_type(RunTestForType rft)
{
int error = rft.run_impl<T, BC<T, SubgroupsBroadcastOp::broadcast>>(
"test_bcast", bcast_source);
error |= rft.run_impl<T, RED_NU<T, ArithmeticOp::add_>>("test_redadd",
redadd_source);
error |= rft.run_impl<T, RED_NU<T, ArithmeticOp::max_>>("test_redmax",
redmax_source);
error |= rft.run_impl<T, RED_NU<T, ArithmeticOp::min_>>("test_redmin",
redmin_source);
error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::add_>>("test_scinadd",
scinadd_source);
error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::max_>>("test_scinmax",
scinmax_source);
error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::min_>>("test_scinmin",
scinmin_source);
error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::add_>>("test_scexadd",
scexadd_source);
error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::max_>>("test_scexmax",
scexmax_source);
error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::min_>>("test_scexmin",
scexmin_source);
"sub_group_broadcast");
error |=
rft.run_impl<T, RED_NU<T, ArithmeticOp::add_>>("sub_group_reduce_add");
error |=
rft.run_impl<T, RED_NU<T, ArithmeticOp::max_>>("sub_group_reduce_max");
error |=
rft.run_impl<T, RED_NU<T, ArithmeticOp::min_>>("sub_group_reduce_min");
error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::add_>>(
"sub_group_scan_inclusive_add");
error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::max_>>(
"sub_group_scan_inclusive_max");
error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::min_>>(
"sub_group_scan_inclusive_min");
error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::add_>>(
"sub_group_scan_exclusive_add");
error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::max_>>(
"sub_group_scan_exclusive_max");
error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::min_>>(
"sub_group_scan_exclusive_min");
return error;
}
@@ -181,11 +181,14 @@ int test_subgroup_functions(cl_device_id device, cl_context context,
constexpr size_t global_work_size = 2000;
constexpr size_t local_work_size = 200;
WorkGroupParams test_params(global_work_size, local_work_size);
test_params.save_kernel_source(sub_group_reduction_scan_source);
test_params.save_kernel_source(sub_group_generic_source,
"sub_group_broadcast");
RunTestForType rft(device, context, queue, num_elements, test_params);
int error =
rft.run_impl<cl_int, AA<NonUniformVoteOp::any>>("test_any", any_source);
error |=
rft.run_impl<cl_int, AA<NonUniformVoteOp::all>>("test_all", all_source);
rft.run_impl<cl_int, AA<NonUniformVoteOp::any>>("sub_group_any");
error |= rft.run_impl<cl_int, AA<NonUniformVoteOp::all>>("sub_group_all");
error |= run_broadcast_scan_reduction_for_type<cl_int>(rft);
error |= run_broadcast_scan_reduction_for_type<cl_uint>(rft);
error |= run_broadcast_scan_reduction_for_type<cl_long>(rft);