From 8272c83c6fb87dc4ac5ad08169fa59f9a07e086b Mon Sep 17 00:00:00 2001 From: Sven van Haastregt Date: Wed, 10 May 2023 10:45:44 +0100 Subject: [PATCH] math_brute_force: consider all types for extension pragmas (#1705) When generating the kernel code, consider the return type(s) and the types of all parameters, instead of only the first parameter type. This fixes a missing extension pragma for certain cases (such as `nan`). Signed-off-by: Sven van Haastregt --- test_conformance/math_brute_force/common.cpp | 42 +++++++++++--------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/test_conformance/math_brute_force/common.cpp b/test_conformance/math_brute_force/common.cpp index 71b4defe..83a33516 100644 --- a/test_conformance/math_brute_force/common.cpp +++ b/test_conformance/math_brute_force/common.cpp @@ -67,22 +67,28 @@ void EmitDefineUndef(std::ostringstream &kernel, const char *name, kernel << "#define " << name << " " << GetUndefValue(type) << '\n'; } -void EmitEnableExtension(std::ostringstream &kernel, ParameterType type) +void EmitEnableExtension(std::ostringstream &kernel, + const std::initializer_list &types) { - switch (type) - { - case ParameterType::Double: - kernel << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"; - break; + bool needsFp64 = false; - case ParameterType::Float: - case ParameterType::Int: - case ParameterType::UInt: - case ParameterType::Long: - case ParameterType::ULong: - // No extension required. - break; + for (const auto &type : types) + { + switch (type) + { + case ParameterType::Double: needsFp64 = true; break; + + case ParameterType::Float: + case ParameterType::Int: + case ParameterType::UInt: + case ParameterType::Long: + case ParameterType::ULong: + // No extension required. + break; + } } + + if (needsFp64) kernel << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"; } std::string GetBuildOptions(bool relaxed_mode) @@ -123,7 +129,7 @@ std::string GetUnaryKernel(const std::string &kernel_name, const char *builtin, EmitDefineType(kernel, "RETTYPE", retType, vector_size_index); EmitDefineType(kernel, "TYPE1", type1, vector_size_index); EmitDefineUndef(kernel, "UNDEF1", type1); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType, type1 }); // clang-format off const char *kernel_nonvec3[] = { R"( @@ -199,7 +205,7 @@ std::string GetUnaryKernel(const std::string &kernel_name, const char *builtin, EmitDefineType(kernel, "TYPE1", type1, vector_size_index); EmitDefineUndef(kernel, "UNDEF1", type1); EmitDefineUndef(kernel, "UNDEFR2", retType2); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType1, retType2, type1 }); // clang-format off const char *kernel_nonvec3[] = { R"( @@ -282,7 +288,7 @@ std::string GetBinaryKernel(const std::string &kernel_name, const char *builtin, EmitDefineType(kernel, "TYPE2", type2, vector_size_index); EmitDefineUndef(kernel, "UNDEF1", type1); EmitDefineUndef(kernel, "UNDEF2", type2); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType, type1, type2 }); const bool is_vec3 = sizeValues[vector_size_index] == 3; @@ -384,7 +390,7 @@ std::string GetBinaryKernel(const std::string &kernel_name, const char *builtin, EmitDefineUndef(kernel, "UNDEF1", type1); EmitDefineUndef(kernel, "UNDEF2", type2); EmitDefineUndef(kernel, "UNDEFR2", retType2); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType1, retType2, type1, type2 }); // clang-format off const char *kernel_nonvec3[] = { R"( @@ -476,7 +482,7 @@ std::string GetTernaryKernel(const std::string &kernel_name, EmitDefineUndef(kernel, "UNDEF1", type1); EmitDefineUndef(kernel, "UNDEF2", type2); EmitDefineUndef(kernel, "UNDEF3", type3); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType, type1, type2, type3 }); // clang-format off const char *kernel_nonvec3[] = { R"(