diff --git a/test_conformance/math_brute_force/common.cpp b/test_conformance/math_brute_force/common.cpp index 71b4defe..83a33516 100644 --- a/test_conformance/math_brute_force/common.cpp +++ b/test_conformance/math_brute_force/common.cpp @@ -67,22 +67,28 @@ void EmitDefineUndef(std::ostringstream &kernel, const char *name, kernel << "#define " << name << " " << GetUndefValue(type) << '\n'; } -void EmitEnableExtension(std::ostringstream &kernel, ParameterType type) +void EmitEnableExtension(std::ostringstream &kernel, + const std::initializer_list &types) { - switch (type) - { - case ParameterType::Double: - kernel << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"; - break; + bool needsFp64 = false; - case ParameterType::Float: - case ParameterType::Int: - case ParameterType::UInt: - case ParameterType::Long: - case ParameterType::ULong: - // No extension required. - break; + for (const auto &type : types) + { + switch (type) + { + case ParameterType::Double: needsFp64 = true; break; + + case ParameterType::Float: + case ParameterType::Int: + case ParameterType::UInt: + case ParameterType::Long: + case ParameterType::ULong: + // No extension required. + break; + } } + + if (needsFp64) kernel << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"; } std::string GetBuildOptions(bool relaxed_mode) @@ -123,7 +129,7 @@ std::string GetUnaryKernel(const std::string &kernel_name, const char *builtin, EmitDefineType(kernel, "RETTYPE", retType, vector_size_index); EmitDefineType(kernel, "TYPE1", type1, vector_size_index); EmitDefineUndef(kernel, "UNDEF1", type1); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType, type1 }); // clang-format off const char *kernel_nonvec3[] = { R"( @@ -199,7 +205,7 @@ std::string GetUnaryKernel(const std::string &kernel_name, const char *builtin, EmitDefineType(kernel, "TYPE1", type1, vector_size_index); EmitDefineUndef(kernel, "UNDEF1", type1); EmitDefineUndef(kernel, "UNDEFR2", retType2); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType1, retType2, type1 }); // clang-format off const char *kernel_nonvec3[] = { R"( @@ -282,7 +288,7 @@ std::string GetBinaryKernel(const std::string &kernel_name, const char *builtin, EmitDefineType(kernel, "TYPE2", type2, vector_size_index); EmitDefineUndef(kernel, "UNDEF1", type1); EmitDefineUndef(kernel, "UNDEF2", type2); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType, type1, type2 }); const bool is_vec3 = sizeValues[vector_size_index] == 3; @@ -384,7 +390,7 @@ std::string GetBinaryKernel(const std::string &kernel_name, const char *builtin, EmitDefineUndef(kernel, "UNDEF1", type1); EmitDefineUndef(kernel, "UNDEF2", type2); EmitDefineUndef(kernel, "UNDEFR2", retType2); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType1, retType2, type1, type2 }); // clang-format off const char *kernel_nonvec3[] = { R"( @@ -476,7 +482,7 @@ std::string GetTernaryKernel(const std::string &kernel_name, EmitDefineUndef(kernel, "UNDEF1", type1); EmitDefineUndef(kernel, "UNDEF2", type2); EmitDefineUndef(kernel, "UNDEF3", type3); - EmitEnableExtension(kernel, type1); + EmitEnableExtension(kernel, { retType, type1, type2, type3 }); // clang-format off const char *kernel_nonvec3[] = { R"(