From b71c2047943a44a2e99c367e406e680caa160bfe Mon Sep 17 00:00:00 2001 From: Grzegorz Wawiorko Date: Wed, 5 Jan 2022 17:08:52 +0100 Subject: [PATCH] test_subgroups - Set safe input values for half type and mul, add operations (#1346) * Set safe input values for half type and mul, add operations * Set safe values for all data types * Typo fix * Set constant seed for shuffle * Change function name to more specific * set_value takes an integer value, not a bit pattern --- .../subgroups/subgroup_common_templates.h | 48 +++++++++++++++---- .../test_subgroup_clustered_reduce.cpp | 2 +- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/test_conformance/subgroups/subgroup_common_templates.h b/test_conformance/subgroups/subgroup_common_templates.h index fc0b03b5..641c1875 100644 --- a/test_conformance/subgroups/subgroup_common_templates.h +++ b/test_conformance/subgroups/subgroup_common_templates.h @@ -20,6 +20,8 @@ #include "CL/cl_half.h" #include "subhelpers.h" #include +#include +#include static cl_uint4 generate_bit_mask(cl_uint subgroup_local_id, const std::string &mask_type, @@ -391,11 +393,44 @@ template bool is_floating_point() || std::is_same::value; } +// limit possible input values to avoid arithmetic rounding/overflow issues. +// for each subgroup values defined different values +// for rest of workitems set 1 +// shuffle values +static void fill_and_shuffle_safe_values(std::vector &safe_values, + int sb_size) +{ + // max product is 720, cl_half has enough precision for it + const std::vector non_one_values{ 2, 3, 4, 5, 6 }; + + if (sb_size <= non_one_values.size()) + { + safe_values.assign(non_one_values.begin(), + non_one_values.begin() + sb_size); + } + else + { + safe_values.assign(sb_size, 1); + std::copy(non_one_values.begin(), non_one_values.end(), + safe_values.begin()); + } + + std::mt19937 mersenne_twister_engine(10000); + std::shuffle(safe_values.begin(), safe_values.end(), + mersenne_twister_engine); +}; + template -void genrand(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng) +void generate_inputs(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng) { int nj = (nw + ns - 1) / ns; + std::vector safe_values; + if (operation == ArithmeticOp::mul_ || operation == ArithmeticOp::add_) + { + fill_and_shuffle_safe_values(safe_values, ns); + } + for (int k = 0; k < ng; ++k) { for (int j = 0; j < nj; ++j) @@ -406,13 +441,10 @@ void genrand(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng) for (int i = 0; i < n; ++i) { cl_ulong out_value; - double y; if (operation == ArithmeticOp::mul_ || operation == ArithmeticOp::add_) { - // work around to avoid overflow, do not use 0 for - // multiplication - out_value = (genrand_int32(gMTdata) % 4) + 1; + out_value = safe_values[i]; } else { @@ -591,7 +623,7 @@ template struct SCEX_NU int ns = test_params.subgroup_size; int ng = test_params.global_workgroup_size; ng = ng / nw; - genrand(x, t, m, ns, nw, ng); + generate_inputs(x, t, m, ns, nw, ng); } static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, @@ -689,7 +721,7 @@ template struct SCIN_NU int ns = test_params.subgroup_size; int ng = test_params.global_workgroup_size; ng = ng / nw; - genrand(x, t, m, ns, nw, ng); + generate_inputs(x, t, m, ns, nw, ng); } static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, @@ -805,7 +837,7 @@ template struct RED_NU int ns = test_params.subgroup_size; int ng = test_params.global_workgroup_size; ng = ng / nw; - genrand(x, t, m, ns, nw, ng); + generate_inputs(x, t, m, ns, nw, ng); } static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, diff --git a/test_conformance/subgroups/test_subgroup_clustered_reduce.cpp b/test_conformance/subgroups/test_subgroup_clustered_reduce.cpp index f5872006..527be5ad 100644 --- a/test_conformance/subgroups/test_subgroup_clustered_reduce.cpp +++ b/test_conformance/subgroups/test_subgroup_clustered_reduce.cpp @@ -52,7 +52,7 @@ template struct RED_CLU int ns = test_params.subgroup_size; int ng = test_params.global_workgroup_size; ng = ng / nw; - genrand(x, t, m, ns, nw, ng); + generate_inputs(x, t, m, ns, nw, ng); } static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,