mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-26 12:50:10 +01:00
util: Add functions to convert float to/from bfloat16
Reviewed-by: Rohan Garg <rohan.garg@intel.com> Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
parent
3e0418ba02
commit
ecd2d2cf46
1 changed files with 69 additions and 0 deletions
69
src/util/bfloat.h
Normal file
69
src/util/bfloat.h
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright © 2025 Intel Corporation
|
||||
* SPDX-License-Identifier: MIT
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#include "u_math.h"
|
||||
|
||||
/* When converting a Float NaN value to BFloat16 it is possible that the
|
||||
* significand bits that make the value a NaN will be rounded/truncated off
|
||||
* so ensure at least one significand bit is set.
|
||||
*/
|
||||
static inline uint16_t
|
||||
_mesa_float_nan_to_bfloat_bits(union fi x)
|
||||
{
|
||||
assert(isnan(x.f));
|
||||
return x.ui >> 16 | 1 << 6;
|
||||
}
|
||||
|
||||
/* Round-towards-zero. */
|
||||
static inline uint16_t
|
||||
_mesa_float_to_bfloat16_bits_rtz(float f)
|
||||
{
|
||||
union fi x;
|
||||
x.f = f;
|
||||
|
||||
if (isnan(f))
|
||||
_mesa_float_nan_to_bfloat_bits(x);
|
||||
|
||||
return x.ui >> 16;
|
||||
}
|
||||
|
||||
/* Round-to-nearest-even. */
|
||||
static inline uint16_t
|
||||
_mesa_float_to_bfloat16_bits_rte(float f)
|
||||
{
|
||||
union fi x;
|
||||
x.f = f;
|
||||
|
||||
if (isnan(f))
|
||||
_mesa_float_nan_to_bfloat_bits(x);
|
||||
|
||||
/* Use the tail part that is discarded to decide rounding,
|
||||
* break the tie with the nearest even.
|
||||
*
|
||||
* Overflow of the significand value will turn to zero and
|
||||
* increment the exponent. If exponent reaches 0xff, the
|
||||
* value will correctly end up as +/- Inf.
|
||||
*/
|
||||
uint32_t result = x.ui >> 16;
|
||||
const uint32_t tail = x.ui & 0xffff;
|
||||
if (tail > 0x8000 || (tail == 0x8000 && (result & 1) == 1))
|
||||
result++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline float
|
||||
_mesa_bfloat16_bits_to_float(uint16_t bf)
|
||||
{
|
||||
union fi x;
|
||||
x.ui = bf << 16;
|
||||
|
||||
return x.f;
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue