spirv_new: fix test_decorate to use the device's default rounding (#1987)

The verification code assumes the hardware uses CL_HALF_RTE, which
causes a mismatch computation results when the hardware uses RTZ. Fix to
use the hardware's default rounding mode.
This commit is contained in:
Chuang-Yu Cheng
2024-07-03 01:29:00 +09:00
committed by GitHub
parent 340b7c956a
commit 1cd0266ca1

View File

@@ -216,7 +216,8 @@ static inline Ti generate_saturated_rhs_input(RandomSeed &seed)
}
template <typename Ti, typename Tl, typename To>
static inline To compute_saturated_output(Ti lhs, Ti rhs)
static inline To compute_saturated_output(Ti lhs, Ti rhs,
cl_half_rounding_mode half_rounding)
{
constexpr auto loVal = std::numeric_limits<To>::min();
constexpr auto hiVal = std::numeric_limits<To>::max();
@@ -226,7 +227,7 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs)
cl_float f = cl_half_to_float(lhs) * cl_half_to_float(rhs);
// Quantize to fp16:
f = cl_half_to_float(cl_half_from_float(f, CL_HALF_RTE));
f = cl_half_to_float(cl_half_from_float(f, half_rounding));
To val = (To)std::min<float>(std::max<float>(f, loVal), hiVal);
if (isnan(cl_half_to_float(rhs)))
@@ -246,6 +247,26 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs)
return val;
}
static cl_half_rounding_mode get_half_rounding_mode(cl_device_id deviceID)
{
const cl_device_fp_config fpConfigHalf =
get_default_rounding_mode(deviceID, CL_DEVICE_HALF_FP_CONFIG);
if (fpConfigHalf == CL_FP_ROUND_TO_NEAREST)
{
return CL_HALF_RTE;
}
else if (fpConfigHalf == CL_FP_ROUND_TO_ZERO)
{
return CL_HALF_RTZ;
}
else
{
log_error("Error while acquiring half rounding mode");
}
return CL_HALF_RTE;
}
template <typename Ti, typename Tl, typename To>
int verify_saturated_results(cl_device_id deviceID, cl_context context,
cl_command_queue queue, const char *kname,
@@ -303,9 +324,16 @@ int verify_saturated_results(cl_device_id deviceID, cl_context context,
err = clEnqueueReadBuffer(queue, res, CL_TRUE, 0, out_bytes, &h_res[0], 0, NULL, NULL);
SPIRV_CHECK_ERROR(err, "Failed to read to output");
cl_half_rounding_mode half_rounding = CL_HALF_RTE;
if (std::is_same<Ti, cl_half>::value)
{
half_rounding = get_half_rounding_mode(deviceID);
}
for (int i = 0; i < num; i++)
{
To val = compute_saturated_output<Ti, Tl, To>(h_lhs[i], h_rhs[i]);
To val = compute_saturated_output<Ti, Tl, To>(h_lhs[i], h_rhs[i],
half_rounding);
if (val != h_res[i])
{