math_brute_force: consider all types for extension pragmas (#1705)

When generating the kernel code, consider the return type(s) and the
types of all parameters, instead of only the first parameter type.
This fixes a missing extension pragma for certain cases (such as
`nan`).

Signed-off-by: Sven van Haastregt <sven.vanhaastregt@arm.com>
This commit is contained in:
Sven van Haastregt
2023-05-10 10:45:44 +01:00
committed by GitHub
parent 20afedbd4a
commit 8272c83c6f

View File

@@ -67,22 +67,28 @@ void EmitDefineUndef(std::ostringstream &kernel, const char *name,
kernel << "#define " << name << " " << GetUndefValue(type) << '\n';
}
void EmitEnableExtension(std::ostringstream &kernel, ParameterType type)
void EmitEnableExtension(std::ostringstream &kernel,
const std::initializer_list<ParameterType> &types)
{
switch (type)
{
case ParameterType::Double:
kernel << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
break;
bool needsFp64 = false;
case ParameterType::Float:
case ParameterType::Int:
case ParameterType::UInt:
case ParameterType::Long:
case ParameterType::ULong:
// No extension required.
break;
for (const auto &type : types)
{
switch (type)
{
case ParameterType::Double: needsFp64 = true; break;
case ParameterType::Float:
case ParameterType::Int:
case ParameterType::UInt:
case ParameterType::Long:
case ParameterType::ULong:
// No extension required.
break;
}
}
if (needsFp64) kernel << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
}
std::string GetBuildOptions(bool relaxed_mode)
@@ -123,7 +129,7 @@ std::string GetUnaryKernel(const std::string &kernel_name, const char *builtin,
EmitDefineType(kernel, "RETTYPE", retType, vector_size_index);
EmitDefineType(kernel, "TYPE1", type1, vector_size_index);
EmitDefineUndef(kernel, "UNDEF1", type1);
EmitEnableExtension(kernel, type1);
EmitEnableExtension(kernel, { retType, type1 });
// clang-format off
const char *kernel_nonvec3[] = { R"(
@@ -199,7 +205,7 @@ std::string GetUnaryKernel(const std::string &kernel_name, const char *builtin,
EmitDefineType(kernel, "TYPE1", type1, vector_size_index);
EmitDefineUndef(kernel, "UNDEF1", type1);
EmitDefineUndef(kernel, "UNDEFR2", retType2);
EmitEnableExtension(kernel, type1);
EmitEnableExtension(kernel, { retType1, retType2, type1 });
// clang-format off
const char *kernel_nonvec3[] = { R"(
@@ -282,7 +288,7 @@ std::string GetBinaryKernel(const std::string &kernel_name, const char *builtin,
EmitDefineType(kernel, "TYPE2", type2, vector_size_index);
EmitDefineUndef(kernel, "UNDEF1", type1);
EmitDefineUndef(kernel, "UNDEF2", type2);
EmitEnableExtension(kernel, type1);
EmitEnableExtension(kernel, { retType, type1, type2 });
const bool is_vec3 = sizeValues[vector_size_index] == 3;
@@ -384,7 +390,7 @@ std::string GetBinaryKernel(const std::string &kernel_name, const char *builtin,
EmitDefineUndef(kernel, "UNDEF1", type1);
EmitDefineUndef(kernel, "UNDEF2", type2);
EmitDefineUndef(kernel, "UNDEFR2", retType2);
EmitEnableExtension(kernel, type1);
EmitEnableExtension(kernel, { retType1, retType2, type1, type2 });
// clang-format off
const char *kernel_nonvec3[] = { R"(
@@ -476,7 +482,7 @@ std::string GetTernaryKernel(const std::string &kernel_name,
EmitDefineUndef(kernel, "UNDEF1", type1);
EmitDefineUndef(kernel, "UNDEF2", type2);
EmitDefineUndef(kernel, "UNDEF3", type3);
EmitEnableExtension(kernel, type1);
EmitEnableExtension(kernel, { retType, type1, type2, type3 });
// clang-format off
const char *kernel_nonvec3[] = { R"(