From efcc465029763ed4435c3bef58c825cd4fc18ff2 Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Tue, 5 May 2026 22:36:40 -0400 Subject: [PATCH] compiler/rust: Add a float16 wrapper This adds an F16 struct which provides a 16-bit float type using Mesa's existing half-precision support internally. Right now, it only contains the basics but it could be expanded if needed. --- src/compiler/rust/bindings.h | 2 + src/compiler/rust/float16.rs | 115 ++++++++++++++++++++++++++++++++++ src/compiler/rust/lib.rs | 1 + src/compiler/rust/meson.build | 19 ++++++ 4 files changed, 137 insertions(+) create mode 100644 src/compiler/rust/float16.rs diff --git a/src/compiler/rust/bindings.h b/src/compiler/rust/bindings.h index d1c6b08b6bb..1ad0f2157e5 100644 --- a/src/compiler/rust/bindings.h +++ b/src/compiler/rust/bindings.h @@ -3,6 +3,8 @@ * SPDX-License-Identifier: MIT */ +#include "util/double.h" +#include "util/half_float.h" #include "util/memstream.h" #include "rust_helpers.h" #include "nir.h" diff --git a/src/compiler/rust/float16.rs b/src/compiler/rust/float16.rs new file mode 100644 index 00000000000..f997da5e3d0 --- /dev/null +++ b/src/compiler/rust/float16.rs @@ -0,0 +1,115 @@ +// Copyright © 2026 Collabora, Ltd. +// SPDX-License-Identifier: MIT + +use crate::bindings::*; +use std::cmp::{Ordering, PartialOrd}; +use std::ops::Neg; + +#[derive(Clone, Copy, Default, PartialEq)] +pub struct F16 { + v: u16, +} + +impl F16 { + pub const RADIX: u32 = 2; + pub const BITS: u32 = 16; + pub const MANTISSA_DIGITS: u32 = 11; + pub const EPSILON: F16 = F16::from_bits(0x1400); + pub const MIN: F16 = F16::from_bits(0xfbff); + pub const MIN_POSITIVE: F16 = F16::from_bits(0x0001); + pub const MAX: F16 = F16::from_bits(0x7bff); + pub const MIN_EXP: i32 = -14; + pub const MAX_EXP: i32 = 15; + pub const NAN: F16 = F16::from_bits(0x7e00); + pub const INFINITY: F16 = F16::from_bits(0x7c00); + pub const NEG_INFINITY: F16 = F16::from_bits(0xfc00); + + pub const fn abs(self) -> F16 { + F16::from_bits(self.to_bits() & 0x7fff) + } + + pub fn from_f32_rtne(v: f32) -> F16 { + F16::from_bits(unsafe { _mesa_float_to_float16_rtne(v) }) + } + + pub fn from_f64_rtne(v: f64) -> F16 { + F16::from_bits(unsafe { _mesa_double_to_float16_rtne(v) }) + } + + pub const fn is_nan(self) -> bool { + (self.to_bits() & 0x7fff) > 0x7c00 + } + + pub const fn is_infinite(self) -> bool { + self.abs().to_bits() == F16::INFINITY.to_bits() + } + + pub const fn is_finite(self) -> bool { + !self.is_infinite() + } + + pub const fn is_sign_positive(self) -> bool { + (self.to_bits() & 0x8000) == 0 + } + + pub const fn is_sign_negative(self) -> bool { + !self.is_sign_positive() + } + + pub const fn to_bits(self) -> u16 { + self.v + } + + pub const fn from_bits(v: u16) -> F16 { + F16 { v } + } + + pub fn max(self, other: F16) -> F16 { + if self.is_nan() { + self + } else if self < other { + other + } else { + // We also get here if other is nan + self + } + } + + pub fn min(self, other: F16) -> F16 { + if self.is_nan() { + self + } else if self > other { + other + } else { + // We also get here if other is nan + self + } + } +} + +impl From for f32 { + fn from(f: F16) -> f32 { + unsafe { _mesa_half_to_float(f.v) } + } +} + +impl From for f64 { + fn from(f: F16) -> f64 { + unsafe { _mesa_half_to_float(f.v).into() } + } +} + +impl PartialOrd for F16 { + fn partial_cmp(&self, other: &F16) -> Option { + f32::from(*self).partial_cmp(&f32::from(*other)) + } +} + +impl Neg for F16 { + type Output = F16; + + // Required method + fn neg(self) -> Self::Output { + F16 { v: self.v ^ 0x8000 } + } +} diff --git a/src/compiler/rust/lib.rs b/src/compiler/rust/lib.rs index 68b0087f18b..a7cafb5b4ac 100644 --- a/src/compiler/rust/lib.rs +++ b/src/compiler/rust/lib.rs @@ -7,6 +7,7 @@ pub mod bitset; pub mod cfg; pub mod dataflow; pub mod depth_first_search; +pub mod float16; pub mod memstream; pub mod nir; pub mod nir_instr_printer; diff --git a/src/compiler/rust/meson.build b/src/compiler/rust/meson.build index 2dbf2bfe176..258bff92c51 100644 --- a/src/compiler/rust/meson.build +++ b/src/compiler/rust/meson.build @@ -7,6 +7,7 @@ _compiler_rs_sources = [ 'cfg.rs', 'dataflow.rs', 'depth_first_search.rs', + 'float16.rs', 'memstream.rs', 'nir_instr_printer.rs', 'nir.rs', @@ -63,6 +64,8 @@ _compiler_bindgen_args = [ '--allowlist-var', 'rust_.*', '--allowlist-function', 'glsl_.*', '--allowlist-function', '_mesa_shader_stage_to_string', + '--allowlist-function', '_mesa_.*half.*', + '--allowlist-function', '_mesa_.*float16.*', '--allowlist-function', 'nir_.*', '--allowlist-function', 'compiler_rs.*', '--allowlist-function', 'u_memstream.*', @@ -93,6 +96,7 @@ _idep_libcompiler_c = declare_dependency( _compiler_bindings_rs = rust.bindgen( input : ['bindings.h'], output : 'bindings.rs', + output_inline_wrapper : 'bindings.c', c_args : [ pre_args, ], @@ -103,6 +107,20 @@ _compiler_bindings_rs = rust.bindgen( ], ) +_libcompiler_bindings = static_library( + 'compiler_bindings', + ['bindings.h', _compiler_bindings_rs[1]], + gnu_symbol_visibility : 'hidden', + c_args : [ + pre_args, + cc.get_supported_arguments('-Wno-missing-prototypes'), + ], + dependencies : [ + idep_nir_headers, + idep_mesautil, + ], +) + compiler_rs_bindgen_blocklist = [] foreach type : _compiler_binding_types compiler_rs_bindgen_blocklist += ['--blocklist-type', type] @@ -121,6 +139,7 @@ _libcompiler_rs = static_library( gnu_symbol_visibility : 'hidden', rust_abi : 'rust', dependencies: [_idep_libcompiler_c], + link_with: [_libcompiler_bindings], rust_args: _compiler_rust_args, )