diff --git a/test_conformance/subgroups/subgroup_common_templates.h b/test_conformance/subgroups/subgroup_common_templates.h index 641c1875..0ffa46c8 100644 --- a/test_conformance/subgroups/subgroup_common_templates.h +++ b/test_conformance/subgroups/subgroup_common_templates.h @@ -481,12 +481,12 @@ template struct SHF static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params) { - int i, ii, j, k, l, n, delta; + int i, ii, j, k, n, delta; + cl_uint l; int nw = test_params.local_workgroup_size; int ns = test_params.subgroup_size; int ng = test_params.global_workgroup_size; int nj = (nw + ns - 1) / ns; - int d = ns > 100 ? 100 : ns; ii = 0; ng = ng / nw; for (k = 0; k < ng; ++k) @@ -498,33 +498,10 @@ template struct SHF for (i = 0; i < n; ++i) { int midx = 4 * ii + 4 * i + 2; - l = (int)(genrand_int32(gMTdata) & 0x7fffffff) - % (d > n ? n : d); - switch (operation) - { - case ShuffleOp::shuffle: - case ShuffleOp::shuffle_xor: - // storing information about shuffle index - m[midx] = (cl_int)l; - break; - case ShuffleOp::shuffle_up: - delta = l; // calculate delta for shuffle up - if (i - delta < 0) - { - delta = i; - } - m[midx] = (cl_int)delta; - break; - case ShuffleOp::shuffle_down: - delta = l; // calculate delta for shuffle down - if (i + delta >= n) - { - delta = n - 1 - i; - } - m[midx] = (cl_int)delta; - break; - default: break; - } + l = (((cl_uint)(genrand_int32(gMTdata) & 0x7fffffff) + 1) + % (ns * 2 + 1)) + - 1; + m[midx] = l; cl_ulong number = genrand_int64(gMTdata); set_value(t[ii + i], number); } @@ -542,7 +519,8 @@ template struct SHF static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m, const WorkGroupParams &test_params) { - int ii, i, j, k, l, n; + int ii, i, j, k, n; + cl_uint l; int nw = test_params.local_workgroup_size; int ns = test_params.subgroup_size; int ng = test_params.global_workgroup_size; @@ -567,32 +545,42 @@ template struct SHF { // inside the subgroup // shuffle index storage int midx = 4 * ii + 4 * i + 2; - l = (int)m[midx]; + l = m[midx]; rr = my[ii + i]; + cl_uint tr_idx; + bool skip = false; switch (operation) { // shuffle basic - treat l as index - case ShuffleOp::shuffle: tr = mx[ii + l]; break; - // shuffle up - treat l as delta - case ShuffleOp::shuffle_up: tr = mx[ii + i - l]; break; - // shuffle up - treat l as delta - case ShuffleOp::shuffle_down: - tr = mx[ii + i + l]; - break; + case ShuffleOp::shuffle: tr_idx = l; break; // shuffle xor - treat l as mask - case ShuffleOp::shuffle_xor: - tr = mx[ii + (i ^ l)]; + case ShuffleOp::shuffle_xor: tr_idx = i ^ l; break; + // shuffle up - treat l as delta + case ShuffleOp::shuffle_up: + if (l >= ns) skip = true; + tr_idx = i - l; + break; + // shuffle down - treat l as delta + case ShuffleOp::shuffle_down: + if (l >= ns) skip = true; + tr_idx = i + l; break; default: break; } - if (!compare(rr, tr)) + if (!skip && tr_idx < n) { - log_error("ERROR: sub_group_%s(%s) mismatch for " - "local id %d in sub group %d in group %d\n", - operation_names(operation), - TypeManager::name(), i, j, k); - return TEST_FAIL; + tr = mx[ii + tr_idx]; + + if (!compare(rr, tr)) + { + log_error("ERROR: sub_group_%s(%s) mismatch for " + "local id %d in sub group %d in group " + "%d\n", + operation_names(operation), + TypeManager::name(), i, j, k); + return TEST_FAIL; + } } } }