Refactor kernel execution in subgroup tests (#1391)

Signed-off-by: Stuart Brady <stuart.brady@arm.com>
This commit is contained in:
Stuart Brady
2022-03-02 13:25:53 +00:00
committed by GitHub
parent 2d93b122c3
commit 279803abab

View File

@@ -1322,73 +1322,129 @@ inline bool compare_ordered(const subgroups::cl_half &lhs, const int &rhs)
return cl_half_to_float(lhs.data) == rhs;
}
// Run a test kernel to compute the result of a built-in on an input
static int run_kernel(cl_context context, cl_command_queue queue,
cl_kernel kernel, size_t global, size_t local,
void *idata, size_t isize, void *mdata, size_t msize,
void *odata, size_t osize, size_t tsize = 0)
{
clMemWrapper in;
clMemWrapper xy;
clMemWrapper out;
clMemWrapper tmp;
int error;
in = clCreateBuffer(context, CL_MEM_READ_ONLY, isize, NULL, &error);
test_error(error, "clCreateBuffer failed");
xy = clCreateBuffer(context, CL_MEM_WRITE_ONLY, msize, NULL, &error);
test_error(error, "clCreateBuffer failed");
out = clCreateBuffer(context, CL_MEM_WRITE_ONLY, osize, NULL, &error);
test_error(error, "clCreateBuffer failed");
if (tsize)
template <typename Ty, typename Fns> class KernelExecutor {
public:
KernelExecutor(cl_context c, cl_command_queue q, cl_kernel k, size_t g,
size_t l, Ty *id, size_t is, Ty *mid, Ty *mod, cl_int *md,
size_t ms, Ty *od, size_t os, size_t ts = 0)
: context(c), queue(q), kernel(k), global(g), local(l), idata(id),
isize(is), mapin_data(mid), mapout_data(mod), mdata(md), msize(ms),
odata(od), osize(os), tsize(ts)
{
tmp = clCreateBuffer(context, CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
tsize, NULL, &error);
has_status = false;
run_failed = false;
}
cl_context context;
cl_command_queue queue;
cl_kernel kernel;
size_t global;
size_t local;
Ty *idata;
size_t isize;
Ty *mapin_data;
Ty *mapout_data;
cl_int *mdata;
size_t msize;
Ty *odata;
size_t osize;
size_t tsize;
bool run_failed;
private:
bool has_status;
test_status status;
public:
// Run a test kernel to compute the result of a built-in on an input
int run()
{
clMemWrapper in;
clMemWrapper xy;
clMemWrapper out;
clMemWrapper tmp;
int error;
in = clCreateBuffer(context, CL_MEM_READ_ONLY, isize, NULL, &error);
test_error(error, "clCreateBuffer failed");
}
error = clSetKernelArg(kernel, 0, sizeof(in), (void *)&in);
test_error(error, "clSetKernelArg failed");
xy = clCreateBuffer(context, CL_MEM_WRITE_ONLY, msize, NULL, &error);
test_error(error, "clCreateBuffer failed");
error = clSetKernelArg(kernel, 1, sizeof(xy), (void *)&xy);
test_error(error, "clSetKernelArg failed");
out = clCreateBuffer(context, CL_MEM_WRITE_ONLY, osize, NULL, &error);
test_error(error, "clCreateBuffer failed");
error = clSetKernelArg(kernel, 2, sizeof(out), (void *)&out);
test_error(error, "clSetKernelArg failed");
if (tsize)
{
tmp = clCreateBuffer(context,
CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
tsize, NULL, &error);
test_error(error, "clCreateBuffer failed");
}
if (tsize)
{
error = clSetKernelArg(kernel, 3, sizeof(tmp), (void *)&tmp);
error = clSetKernelArg(kernel, 0, sizeof(in), (void *)&in);
test_error(error, "clSetKernelArg failed");
error = clSetKernelArg(kernel, 1, sizeof(xy), (void *)&xy);
test_error(error, "clSetKernelArg failed");
error = clSetKernelArg(kernel, 2, sizeof(out), (void *)&out);
test_error(error, "clSetKernelArg failed");
if (tsize)
{
error = clSetKernelArg(kernel, 3, sizeof(tmp), (void *)&tmp);
test_error(error, "clSetKernelArg failed");
}
error = clEnqueueWriteBuffer(queue, in, CL_FALSE, 0, isize, idata, 0,
NULL, NULL);
test_error(error, "clEnqueueWriteBuffer failed");
error = clEnqueueWriteBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0,
NULL, NULL);
test_error(error, "clEnqueueWriteBuffer failed");
error = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local,
0, NULL, NULL);
test_error(error, "clEnqueueNDRangeKernel failed");
error = clEnqueueReadBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0,
NULL, NULL);
test_error(error, "clEnqueueReadBuffer failed");
error = clEnqueueReadBuffer(queue, out, CL_FALSE, 0, osize, odata, 0,
NULL, NULL);
test_error(error, "clEnqueueReadBuffer failed");
error = clFinish(queue);
test_error(error, "clFinish failed");
return error;
}
error = clEnqueueWriteBuffer(queue, in, CL_FALSE, 0, isize, idata, 0, NULL,
NULL);
test_error(error, "clEnqueueWriteBuffer failed");
test_status run_and_check(const WorkGroupParams &test_params)
{
cl_int error = run();
if (error != CL_SUCCESS)
{
print_error(error, "Failed to run subgroup test kernel");
status = TEST_FAIL;
run_failed = true;
return status;
}
error = clEnqueueWriteBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0, NULL,
NULL);
test_error(error, "clEnqueueWriteBuffer failed");
error = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 0,
NULL, NULL);
test_error(error, "clEnqueueNDRangeKernel failed");
test_status tmp_status =
Fns::chk(idata, odata, mapin_data, mapout_data, mdata, test_params);
error = clEnqueueReadBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0, NULL,
NULL);
test_error(error, "clEnqueueReadBuffer failed");
if (!has_status || tmp_status == TEST_FAIL
|| (tmp_status == TEST_PASS && status != TEST_FAIL))
{
status = tmp_status;
has_status = true;
}
error = clEnqueueReadBuffer(queue, out, CL_FALSE, 0, osize, odata, 0, NULL,
NULL);
test_error(error, "clEnqueueReadBuffer failed");
error = clFinish(queue);
test_error(error, "clFinish failed");
return error;
}
return status;
}
};
// Driver for testing a single built in function
template <typename Ty, typename Fns, size_t TSIZE = 0> struct test
@@ -1536,74 +1592,52 @@ template <typename Ty, typename Fns, size_t TSIZE = 0> struct test
test_error_fail(error, "Unable to set divergence mask argument");
}
KernelExecutor<Ty, Fns> executor(
context, queue, kernel, global, local, idata.data(),
input_array_size * sizeof(Ty), mapin.data(), mapout.data(),
sgmap.data(), global * sizeof(cl_int4), odata.data(),
output_array_size * sizeof(Ty), TSIZE * sizeof(Ty));
// Run the kernel once on zeroes to get the map
memset(idata.data(), 0, input_array_size * sizeof(Ty));
error = run_kernel(context, queue, kernel, global, local, idata.data(),
input_array_size * sizeof(Ty), sgmap.data(),
global * sizeof(cl_int4), odata.data(),
output_array_size * sizeof(Ty), TSIZE * sizeof(Ty));
error = executor.run();
test_error_fail(error, "Running kernel first time failed");
// Generate the desired input for the kernel
test_params.subgroup_size = subgroup_size;
Fns::gen(idata.data(), mapin.data(), sgmap.data(), test_params);
test_status combined_status;
test_status status;
if (test_params.divergence_mask_arg != -1)
{
combined_status = TEST_SKIPPED_ITSELF;
for (auto &mask : test_params.all_work_item_masks)
{
test_params.work_items_mask = mask;
cl_uint4 mask_vector = bs128_to_cl_uint4(mask);
clSetKernelArg(kernel, test_params.divergence_mask_arg,
sizeof(cl_uint4), &mask_vector);
error = run_kernel(context, queue, kernel, global, local,
idata.data(), input_array_size * sizeof(Ty),
sgmap.data(), global * sizeof(cl_int4),
odata.data(), output_array_size * sizeof(Ty),
TSIZE * sizeof(Ty));
test_error_fail(error, "Running kernel second time failed");
// Check the result
test_status status =
Fns::chk(idata.data(), odata.data(), mapin.data(),
mapout.data(), sgmap.data(), test_params);
if (status == TEST_FAIL
|| (status == TEST_PASS && combined_status != TEST_FAIL))
combined_status = status;
status = executor.run_and_check(test_params);
if (status == TEST_FAIL) break;
}
}
else
{
error =
run_kernel(context, queue, kernel, global, local, idata.data(),
input_array_size * sizeof(Ty), sgmap.data(),
global * sizeof(cl_int4), odata.data(),
output_array_size * sizeof(Ty), TSIZE * sizeof(Ty));
test_error_fail(error, "Running kernel second time failed");
// Check the result
combined_status =
Fns::chk(idata.data(), odata.data(), mapin.data(),
mapout.data(), sgmap.data(), test_params);
status = executor.run_and_check(test_params);
}
// Detailed failure and skip messages should be logged by Fns::gen
// and Fns::chk.
if (combined_status == TEST_PASS)
// Detailed failure and skip messages should be logged by
// run_and_check.
if (status == TEST_PASS)
{
Fns::log_test(test_params, " passed");
}
else if (combined_status == TEST_FAIL)
else if (!executor.run_failed && status == TEST_FAIL)
{
test_fail("Data verification failed\n");
}
return combined_status;
return status;
}
};