Use CTS type wrappers for test_enqueued_local_size (#1544)

Signed-off-by: John Kesapides <john.kesapides@arm.com>

Signed-off-by: John Kesapides <john.kesapides@arm.com>
This commit is contained in:
John Kesapides
2022-10-14 09:55:10 +01:00
committed by GitHub
parent 5e116e7b0d
commit 90a5183ec4

View File

@@ -26,29 +26,30 @@
#include "procs.h" #include "procs.h"
static const char *enqueued_local_size_2d_code = static const char *enqueued_local_size_2d_code = R"(
"__kernel void test_enqueued_local_size_2d(global int *dst)\n" __kernel void test_enqueued_local_size_2d(global int *dst)
"{\n" {
" if ((get_global_id(0) == 0) && (get_global_id(1) == 0))\n" if ((get_global_id(0) == 0) && (get_global_id(1) == 0))
" {\n" {
" dst[0] = (int)get_enqueued_local_size(0)\n;" dst[0] = (int)get_enqueued_local_size(0);
" dst[1] = (int)get_enqueued_local_size(1)\n;" dst[1] = (int)get_enqueued_local_size(1);
" }\n" }
"}\n"; }
)";
static const char *enqueued_local_size_1d_code = static const char *enqueued_local_size_1d_code = R"(
"__kernel void test_enqueued_local_size_1d(global int *dst)\n" __kernel void test_enqueued_local_size_1d(global int *dst)
"{\n" {
" int tid_x = get_global_id(0);\n" int tid_x = get_global_id(0);
" if (get_global_id(0) == 0)\n" if (get_global_id(0) == 0)
" {\n" {
" dst[tid_x] = (int)get_enqueued_local_size(0)\n;" dst[tid_x] = (int)get_enqueued_local_size(0);
" }\n" }
"}\n"; }
)";
static int static int verify_enqueued_local_size(int *result, size_t *expected, int n)
verify_enqueued_local_size(int *result, size_t *expected, int n)
{ {
int i; int i;
for (i = 0; i < n; i++) for (i = 0; i < n; i++)
@@ -64,14 +65,14 @@ verify_enqueued_local_size(int *result, size_t *expected, int n)
} }
int int test_enqueued_local_size(cl_device_id device, cl_context context,
test_enqueued_local_size(cl_device_id device, cl_context context, cl_command_queue queue, int num_elements) cl_command_queue queue, int num_elements)
{ {
cl_mem streams; clMemWrapper stream;
cl_program program[2]; clProgramWrapper program[2];
cl_kernel kernel[2]; clKernelWrapper kernel[2];
int *output_ptr; cl_int output_ptr[2];
size_t globalsize[2]; size_t globalsize[2];
size_t localsize[2]; size_t localsize[2];
int err; int err;
@@ -97,10 +98,8 @@ test_enqueued_local_size(cl_device_id device, cl_context context, cl_command_que
} }
} }
output_ptr = (int*)malloc(2 * sizeof(int)); stream = clCreateBuffer(context, CL_MEM_READ_WRITE, 2 * sizeof(cl_int),
nullptr, &err);
streams =
clCreateBuffer(context, CL_MEM_READ_WRITE, 2 * sizeof(int), NULL, &err);
test_error(err, "clCreateBuffer failed."); test_error(err, "clCreateBuffer failed.");
std::string cl_std = "-cl-std=CL"; std::string cl_std = "-cl-std=CL";
@@ -114,16 +113,17 @@ test_enqueued_local_size(cl_device_id device, cl_context context, cl_command_que
"test_enqueued_local_size_2d", cl_std.c_str()); "test_enqueued_local_size_2d", cl_std.c_str());
test_error(err, "create_single_kernel_helper failed"); test_error(err, "create_single_kernel_helper failed");
err = clSetKernelArg(kernel[0], 0, sizeof streams, &streams); err = clSetKernelArg(kernel[0], 0, sizeof stream, &stream);
test_error(err, "clSetKernelArgs failed."); test_error(err, "clSetKernelArgs failed.");
err = clSetKernelArg(kernel[1], 0, sizeof streams, &streams); err = clSetKernelArg(kernel[1], 0, sizeof stream, &stream);
test_error(err, "clSetKernelArgs failed."); test_error(err, "clSetKernelArgs failed.");
globalsize[0] = (size_t)num_elements; globalsize[0] = static_cast<size_t>(num_elements);
globalsize[1] = (size_t)num_elements; globalsize[1] = static_cast<size_t>(num_elements);
size_t max_wgs; size_t max_wgs;
err = clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(max_wgs), &max_wgs, NULL); err = clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE,
sizeof(max_wgs), &max_wgs, nullptr);
test_error(err, "clGetDeviceInfo failed."); test_error(err, "clGetDeviceInfo failed.");
localsize[0] = std::min<size_t>(16, max_wgs); localsize[0] = std::min<size_t>(16, max_wgs);
@@ -143,35 +143,31 @@ test_enqueued_local_size(cl_device_id device, cl_context context, cl_command_que
} }
} }
err = clEnqueueNDRangeKernel(queue, kernel[1], 2, NULL, globalsize, localsize, 0, NULL, NULL); err = clEnqueueNDRangeKernel(queue, kernel[1], 2, nullptr, globalsize,
localsize, 0, nullptr, nullptr);
test_error(err, "clEnqueueNDRangeKernel failed."); test_error(err, "clEnqueueNDRangeKernel failed.");
err = clEnqueueReadBuffer(queue, streams, CL_TRUE, 0, 2*sizeof(int), output_ptr, 0, NULL, NULL); err = clEnqueueReadBuffer(queue, stream, CL_BLOCKING, 0, 2 * sizeof(int),
output_ptr, 0, nullptr, nullptr);
test_error(err, "clEnqueueReadBuffer failed."); test_error(err, "clEnqueueReadBuffer failed.");
err = verify_enqueued_local_size(output_ptr, localsize, 2); err = verify_enqueued_local_size(output_ptr, localsize, 2);
globalsize[0] = (size_t)num_elements; globalsize[0] = static_cast<size_t>(num_elements);
localsize[0] = 9; localsize[0] = 9;
if (use_uniform_work_groups && (globalsize[0] % localsize[0])) if (use_uniform_work_groups && (globalsize[0] % localsize[0]))
{ {
globalsize[0] += (localsize[0] - (globalsize[0] % localsize[0])); globalsize[0] += (localsize[0] - (globalsize[0] % localsize[0]));
} }
err = clEnqueueNDRangeKernel(queue, kernel[1], 1, NULL, globalsize, localsize, 0, NULL, NULL); err = clEnqueueNDRangeKernel(queue, kernel[1], 1, nullptr, globalsize,
localsize, 0, nullptr, nullptr);
test_error(err, "clEnqueueNDRangeKernel failed."); test_error(err, "clEnqueueNDRangeKernel failed.");
err = clEnqueueReadBuffer(queue, streams, CL_TRUE, 0, 2*sizeof(int), output_ptr, 0, NULL, NULL); err = clEnqueueReadBuffer(queue, stream, CL_BLOCKING, 0, 2 * sizeof(int),
output_ptr, 0, nullptr, nullptr);
test_error(err, "clEnqueueReadBuffer failed."); test_error(err, "clEnqueueReadBuffer failed.");
err = verify_enqueued_local_size(output_ptr, localsize, 1); err = verify_enqueued_local_size(output_ptr, localsize, 1);
// cleanup
clReleaseMemObject(streams);
clReleaseKernel(kernel[0]);
clReleaseKernel(kernel[1]);
clReleaseProgram(program[0]);
clReleaseProgram(program[1]);
free(output_ptr);
return err; return err;
} }