compiler/rust: Add a float16 wrapper

This adds an F16 struct which provides a 16-bit float type using Mesa's
existing half-precision support internally.  Right now, it only contains
the basics but it could be expanded if needed.
This commit is contained in:
Faith Ekstrand 2026-05-05 22:36:40 -04:00
parent a9b28b9838
commit efcc465029
4 changed files with 137 additions and 0 deletions

View file

@ -3,6 +3,8 @@
* SPDX-License-Identifier: MIT
*/
#include "util/double.h"
#include "util/half_float.h"
#include "util/memstream.h"
#include "rust_helpers.h"
#include "nir.h"

View file

@ -0,0 +1,115 @@
// Copyright © 2026 Collabora, Ltd.
// SPDX-License-Identifier: MIT
use crate::bindings::*;
use std::cmp::{Ordering, PartialOrd};
use std::ops::Neg;
#[derive(Clone, Copy, Default, PartialEq)]
pub struct F16 {
v: u16,
}
impl F16 {
pub const RADIX: u32 = 2;
pub const BITS: u32 = 16;
pub const MANTISSA_DIGITS: u32 = 11;
pub const EPSILON: F16 = F16::from_bits(0x1400);
pub const MIN: F16 = F16::from_bits(0xfbff);
pub const MIN_POSITIVE: F16 = F16::from_bits(0x0001);
pub const MAX: F16 = F16::from_bits(0x7bff);
pub const MIN_EXP: i32 = -14;
pub const MAX_EXP: i32 = 15;
pub const NAN: F16 = F16::from_bits(0x7e00);
pub const INFINITY: F16 = F16::from_bits(0x7c00);
pub const NEG_INFINITY: F16 = F16::from_bits(0xfc00);
pub const fn abs(self) -> F16 {
F16::from_bits(self.to_bits() & 0x7fff)
}
pub fn from_f32_rtne(v: f32) -> F16 {
F16::from_bits(unsafe { _mesa_float_to_float16_rtne(v) })
}
pub fn from_f64_rtne(v: f64) -> F16 {
F16::from_bits(unsafe { _mesa_double_to_float16_rtne(v) })
}
pub const fn is_nan(self) -> bool {
(self.to_bits() & 0x7fff) > 0x7c00
}
pub const fn is_infinite(self) -> bool {
self.abs().to_bits() == F16::INFINITY.to_bits()
}
pub const fn is_finite(self) -> bool {
!self.is_infinite()
}
pub const fn is_sign_positive(self) -> bool {
(self.to_bits() & 0x8000) == 0
}
pub const fn is_sign_negative(self) -> bool {
!self.is_sign_positive()
}
pub const fn to_bits(self) -> u16 {
self.v
}
pub const fn from_bits(v: u16) -> F16 {
F16 { v }
}
pub fn max(self, other: F16) -> F16 {
if self.is_nan() {
self
} else if self < other {
other
} else {
// We also get here if other is nan
self
}
}
pub fn min(self, other: F16) -> F16 {
if self.is_nan() {
self
} else if self > other {
other
} else {
// We also get here if other is nan
self
}
}
}
impl From<F16> for f32 {
fn from(f: F16) -> f32 {
unsafe { _mesa_half_to_float(f.v) }
}
}
impl From<F16> for f64 {
fn from(f: F16) -> f64 {
unsafe { _mesa_half_to_float(f.v).into() }
}
}
impl PartialOrd<F16> for F16 {
fn partial_cmp(&self, other: &F16) -> Option<Ordering> {
f32::from(*self).partial_cmp(&f32::from(*other))
}
}
impl Neg for F16 {
type Output = F16;
// Required method
fn neg(self) -> Self::Output {
F16 { v: self.v ^ 0x8000 }
}
}

View file

@ -7,6 +7,7 @@ pub mod bitset;
pub mod cfg;
pub mod dataflow;
pub mod depth_first_search;
pub mod float16;
pub mod memstream;
pub mod nir;
pub mod nir_instr_printer;

View file

@ -7,6 +7,7 @@ _compiler_rs_sources = [
'cfg.rs',
'dataflow.rs',
'depth_first_search.rs',
'float16.rs',
'memstream.rs',
'nir_instr_printer.rs',
'nir.rs',
@ -63,6 +64,8 @@ _compiler_bindgen_args = [
'--allowlist-var', 'rust_.*',
'--allowlist-function', 'glsl_.*',
'--allowlist-function', '_mesa_shader_stage_to_string',
'--allowlist-function', '_mesa_.*half.*',
'--allowlist-function', '_mesa_.*float16.*',
'--allowlist-function', 'nir_.*',
'--allowlist-function', 'compiler_rs.*',
'--allowlist-function', 'u_memstream.*',
@ -93,6 +96,7 @@ _idep_libcompiler_c = declare_dependency(
_compiler_bindings_rs = rust.bindgen(
input : ['bindings.h'],
output : 'bindings.rs',
output_inline_wrapper : 'bindings.c',
c_args : [
pre_args,
],
@ -103,6 +107,20 @@ _compiler_bindings_rs = rust.bindgen(
],
)
_libcompiler_bindings = static_library(
'compiler_bindings',
['bindings.h', _compiler_bindings_rs[1]],
gnu_symbol_visibility : 'hidden',
c_args : [
pre_args,
cc.get_supported_arguments('-Wno-missing-prototypes'),
],
dependencies : [
idep_nir_headers,
idep_mesautil,
],
)
compiler_rs_bindgen_blocklist = []
foreach type : _compiler_binding_types
compiler_rs_bindgen_blocklist += ['--blocklist-type', type]
@ -121,6 +139,7 @@ _libcompiler_rs = static_library(
gnu_symbol_visibility : 'hidden',
rust_abi : 'rust',
dependencies: [_idep_libcompiler_c],
link_with: [_libcompiler_bindings],
rust_args: _compiler_rust_args,
)