c11_atomics: unify host half representation and conversion with wrapper class (#2503)

Introduce `HostHalf` wrapper class to eliminate explicit
`cl_half_from_float`
and `cl_half_to_float` conversions throughout the test code. The wrapper
provides semantic value constructors/operators and automatic
conversions,
simplifying half-precision arithmetic operations.

Key improvements:
- `HostHalf` class with operator overloading for arithmetic and
comparisons
- Type traits `is_host_atomic_fp_v` and `is_host_fp_v` for generic FP
handling
- Unified floating-point atomic operations (add/sub/min/max/exchange)
- Removed 300+ lines of half-specific conditional branches
- Consistent calculation for all FP types
This commit is contained in:
Yilong Guo
2025-12-17 00:37:33 +08:00
committed by GitHub
parent 67fbbe4ee2
commit 119af24d54
4 changed files with 248 additions and 494 deletions

View File

@@ -24,6 +24,8 @@
#include "Windows.h"
#endif
extern cl_half_rounding_mode gHalfRoundingMode;
//flag for test verification (good test should discover non-atomic functions and fail)
//#define NON_ATOMIC_FUNCTIONS
@@ -37,6 +39,93 @@ enum TExplicitMemoryOrderType
MEMORY_ORDER_SEQ_CST
};
// Wrapper class for half-precision
class HostHalf {
public:
// Convert from semantic values
HostHalf(cl_uint value = 0)
: value(
cl_half_from_float(static_cast<float>(value), gHalfRoundingMode))
{}
HostHalf(int value): HostHalf(static_cast<cl_uint>(value)) {}
HostHalf(float value): value(cl_half_from_float(value, gHalfRoundingMode))
{}
HostHalf(double value): HostHalf(static_cast<float>(value)) {}
// Convert to semantic values
operator cl_uint() const
{
return static_cast<cl_uint>(cl_half_to_float(value));
}
operator float() const { return cl_half_to_float(value); }
operator double() const
{
return static_cast<double>(cl_half_to_float(value));
}
// Construct from bit representation
HostHalf(cl_half value): value(value) {}
// Get the underlying bit representation
operator cl_half() const { return value; }
HostHalf operator-() const
{
return HostHalf(
cl_half_from_float(-cl_half_to_float(value), gHalfRoundingMode));
}
#define GENERIC_OP(RetType, op) \
RetType operator op(const HostHalf &other) const \
{ \
return RetType(cl_half_to_float(value) \
op cl_half_to_float(other.value)); \
}
GENERIC_OP(bool, ==)
GENERIC_OP(bool, !=)
GENERIC_OP(bool, <)
GENERIC_OP(bool, <=)
GENERIC_OP(bool, >)
GENERIC_OP(bool, >=)
GENERIC_OP(HostHalf, +)
GENERIC_OP(HostHalf, -)
GENERIC_OP(HostHalf, *)
GENERIC_OP(HostHalf, /)
#undef GENERIC_OP
#define INPLACE_OP(op) \
HostHalf &operator op##=(const HostHalf &other) \
{ \
value = cl_half_from_float(cl_half_to_float(value) \
op cl_half_to_float(other.value), \
gHalfRoundingMode); \
return *this; \
}
INPLACE_OP(+)
INPLACE_OP(-)
INPLACE_OP(*)
INPLACE_OP(/)
#undef INPLACE_OP
friend std::ostream &operator<<(std::ostream &os, const HostHalf &hh)
{
float f = cl_half_to_float(hh.value);
os << f;
return os;
}
private:
cl_half value;
};
namespace std {
inline HostHalf abs(const HostHalf &value)
{
return value < HostHalf(0) ? -value : value;
}
} // namespace std
// host atomic types (applicable for atomic functions supported on host OS)
#ifdef WIN32
#define HOST_ATOMIC_INT unsigned long
@@ -73,7 +162,7 @@ enum TExplicitMemoryOrderType
#define HOST_UINT cl_uint
#define HOST_LONG cl_long
#define HOST_ULONG cl_ulong
#define HOST_HALF cl_half
#define HOST_HALF HostHalf
#define HOST_FLOAT cl_float
#define HOST_DOUBLE cl_double
@@ -91,6 +180,18 @@ enum TExplicitMemoryOrderType
extern cl_half_rounding_mode gHalfRoundingMode;
template <typename HostAtomicType>
constexpr bool is_host_atomic_fp_v =
std::disjunction_v<std::is_same<HostAtomicType, HOST_ATOMIC_HALF>,
std::is_same<HostAtomicType, HOST_ATOMIC_FLOAT>,
std::is_same<HostAtomicType, HOST_ATOMIC_DOUBLE>>;
template <typename HostDataType>
constexpr bool is_host_fp_v =
std::disjunction_v<std::is_same<HostDataType, HOST_HALF>,
std::is_same<HostDataType, HOST_FLOAT>,
std::is_same<HostDataType, HOST_DOUBLE>>;
// host atomic functions
void host_atomic_thread_fence(TExplicitMemoryOrderType order);
@@ -98,24 +199,13 @@ template <typename AtomicType, typename CorrespondingType>
CorrespondingType host_atomic_fetch_add(volatile AtomicType *a, CorrespondingType c,
TExplicitMemoryOrderType order)
{
if constexpr (std::is_same_v<AtomicType, HOST_ATOMIC_HALF>)
if constexpr (is_host_atomic_fp_v<AtomicType>)
{
static std::mutex mx;
std::lock_guard<std::mutex> lock(mx);
CorrespondingType old_value = *a;
*a = cl_half_from_float((cl_half_to_float(*a) + cl_half_to_float(c)),
gHalfRoundingMode);
return old_value;
}
else if constexpr (
std::is_same_v<
AtomicType,
HOST_ATOMIC_FLOAT> || std::is_same_v<AtomicType, HOST_ATOMIC_DOUBLE>)
{
static std::mutex mx;
std::lock_guard<std::mutex> lock(mx);
CorrespondingType old_value = *a;
*a += c;
CorrespondingType new_value = old_value + c;
*a = static_cast<AtomicType>(new_value);
return old_value;
}
else
@@ -135,24 +225,13 @@ template <typename AtomicType, typename CorrespondingType>
CorrespondingType host_atomic_fetch_sub(volatile AtomicType *a, CorrespondingType c,
TExplicitMemoryOrderType order)
{
if constexpr (
std::is_same_v<
AtomicType,
HOST_ATOMIC_DOUBLE> || std::is_same_v<AtomicType, HOST_ATOMIC_FLOAT>)
if constexpr (is_host_atomic_fp_v<AtomicType>)
{
static std::mutex mx;
std::lock_guard<std::mutex> lock(mx);
CorrespondingType old_value = *a;
*a -= c;
return old_value;
}
else if constexpr (std::is_same_v<AtomicType, HOST_ATOMIC_HALF>)
{
static std::mutex mx;
std::lock_guard<std::mutex> lock(mx);
CorrespondingType old_value = *a;
*a = cl_half_from_float((cl_half_to_float(*a) - cl_half_to_float(c)),
gHalfRoundingMode);
CorrespondingType new_value = old_value - c;
*a = static_cast<AtomicType>(new_value);
return old_value;
}
else
@@ -173,12 +252,14 @@ CorrespondingType host_atomic_exchange(volatile AtomicType *a, CorrespondingType
TExplicitMemoryOrderType order)
{
#if defined( _MSC_VER ) || (defined( __INTEL_COMPILER ) && defined(WIN32))
if (sizeof(CorrespondingType) == 2)
return InterlockedExchange16(reinterpret_cast<volatile SHORT *>(a), c);
if constexpr (sizeof(CorrespondingType) == 2)
return InterlockedExchange16(reinterpret_cast<volatile SHORT *>(a),
*reinterpret_cast<SHORT *>(&c));
else
return InterlockedExchange(reinterpret_cast<volatile LONG *>(a), c);
return InterlockedExchange(reinterpret_cast<volatile LONG *>(a),
*reinterpret_cast<LONG *>(&c));
#elif defined(__GNUC__)
return __sync_lock_test_and_set(a, c);
return __sync_lock_test_and_set(a, *reinterpret_cast<AtomicType *>(&c));
#else
log_info("Host function not implemented: atomic_exchange\n");
return 0;
@@ -195,30 +276,14 @@ bool host_atomic_compare_exchange(volatile AtomicType *a, CorrespondingType *exp
TExplicitMemoryOrderType order_failure)
{
CorrespondingType tmp;
if constexpr (std::is_same_v<AtomicType, HOST_ATOMIC_HALF>)
if constexpr (is_host_atomic_fp_v<AtomicType>)
{
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
tmp = *reinterpret_cast<volatile cl_half *>(a);
if (cl_half_to_float(tmp) == cl_half_to_float(*expected))
{
*reinterpret_cast<volatile cl_half *>(a) = desired;
return true;
}
*expected = tmp;
}
else if constexpr (
std::is_same_v<
AtomicType,
HOST_ATOMIC_DOUBLE> || std::is_same_v<AtomicType, HOST_ATOMIC_FLOAT>)
{
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
tmp = *reinterpret_cast<volatile float *>(a);
tmp = static_cast<CorrespondingType>(*a);
if (tmp == *expected)
{
*a = desired;
*a = static_cast<AtomicType>(desired);
return true;
}
*expected = tmp;
@@ -244,8 +309,8 @@ CorrespondingType host_atomic_load(volatile AtomicType *a,
TExplicitMemoryOrderType order)
{
#if defined( _MSC_VER ) || (defined( __INTEL_COMPILER ) && defined(WIN32))
if (sizeof(CorrespondingType) == 2)
auto prev = InterlockedOr16(reinterpret_cast<volatile SHORT *>(a), 0);
if constexpr (sizeof(CorrespondingType) == 2)
return InterlockedOr16(reinterpret_cast<volatile SHORT *>(a), 0);
else
return InterlockedExchangeAdd(reinterpret_cast<volatile LONG *>(a), 0);
#elif defined(__GNUC__)