mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-06-15 17:48:22 +02:00
compiler/rust: Add a float16 wrapper
This adds an F16 struct which provides a 16-bit float type using Mesa's existing half-precision support internally. Right now, it only contains the basics but it could be expanded if needed. Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41375>
This commit is contained in:
parent
7e2f41bd8f
commit
b60694b91e
4 changed files with 147 additions and 0 deletions
|
|
@ -3,6 +3,8 @@
|
||||||
* SPDX-License-Identifier: MIT
|
* SPDX-License-Identifier: MIT
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "util/double.h"
|
||||||
|
#include "util/half_float.h"
|
||||||
#include "util/memstream.h"
|
#include "util/memstream.h"
|
||||||
#include "rust_helpers.h"
|
#include "rust_helpers.h"
|
||||||
#include "nir.h"
|
#include "nir.h"
|
||||||
|
|
|
||||||
125
src/compiler/rust/float16.rs
Normal file
125
src/compiler/rust/float16.rs
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
// Copyright © 2026 Collabora, Ltd.
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
use crate::bindings::*;
|
||||||
|
use std::cmp::{Ordering, PartialOrd};
|
||||||
|
use std::ops::Neg;
|
||||||
|
|
||||||
|
#[repr(transparent)]
|
||||||
|
#[derive(Clone, Copy, Default)]
|
||||||
|
pub struct F16 {
|
||||||
|
v: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl F16 {
|
||||||
|
pub const RADIX: u32 = 2;
|
||||||
|
pub const BITS: u32 = 16;
|
||||||
|
pub const MANTISSA_DIGITS: u32 = 11;
|
||||||
|
pub const EPSILON: F16 = F16::from_bits(0x1400);
|
||||||
|
pub const MIN: F16 = F16::from_bits(0xfbff);
|
||||||
|
pub const MIN_POSITIVE: F16 = F16::from_bits(0x0001);
|
||||||
|
pub const MAX: F16 = F16::from_bits(0x7bff);
|
||||||
|
pub const MIN_EXP: i32 = -14;
|
||||||
|
pub const MAX_EXP: i32 = 15;
|
||||||
|
pub const NAN: F16 = F16::from_bits(0x7e00);
|
||||||
|
pub const INFINITY: F16 = F16::from_bits(0x7c00);
|
||||||
|
pub const NEG_INFINITY: F16 = F16::from_bits(0xfc00);
|
||||||
|
|
||||||
|
const SIGN_BIT: u16 = 0x8000;
|
||||||
|
|
||||||
|
pub const fn abs(self) -> F16 {
|
||||||
|
F16::from_bits(self.to_bits() & !F16::SIGN_BIT)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_f32_rtne(v: f32) -> F16 {
|
||||||
|
F16::from_bits(unsafe { _mesa_float_to_float16_rtne(v) })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_f64_rtne(v: f64) -> F16 {
|
||||||
|
F16::from_bits(unsafe { _mesa_double_to_float16_rtne(v) })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn is_nan(self) -> bool {
|
||||||
|
self.abs().to_bits() > 0x7c00
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn is_infinite(self) -> bool {
|
||||||
|
self.abs().to_bits() == F16::INFINITY.to_bits()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn is_finite(self) -> bool {
|
||||||
|
!self.is_infinite()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn is_sign_positive(self) -> bool {
|
||||||
|
(self.to_bits() & F16::SIGN_BIT) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn is_sign_negative(self) -> bool {
|
||||||
|
!self.is_sign_positive()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn to_bits(self) -> u16 {
|
||||||
|
self.v
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn from_bits(v: u16) -> F16 {
|
||||||
|
F16 { v }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max(self, other: F16) -> F16 {
|
||||||
|
if self.is_nan() {
|
||||||
|
self
|
||||||
|
} else if self < other {
|
||||||
|
other
|
||||||
|
} else {
|
||||||
|
// We also get here if other is nan
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn min(self, other: F16) -> F16 {
|
||||||
|
if self.is_nan() {
|
||||||
|
self
|
||||||
|
} else if self > other {
|
||||||
|
other
|
||||||
|
} else {
|
||||||
|
// We also get here if other is nan
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<F16> for f32 {
|
||||||
|
fn from(f: F16) -> f32 {
|
||||||
|
unsafe { _mesa_half_to_float(f.v) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<F16> for f64 {
|
||||||
|
fn from(f: F16) -> f64 {
|
||||||
|
unsafe { _mesa_half_to_float(f.v).into() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<F16> for F16 {
|
||||||
|
fn eq(&self, other: &F16) -> bool {
|
||||||
|
f32::from(*self).eq(&f32::from(*other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd<F16> for F16 {
|
||||||
|
fn partial_cmp(&self, other: &F16) -> Option<Ordering> {
|
||||||
|
f32::from(*self).partial_cmp(&f32::from(*other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Neg for F16 {
|
||||||
|
type Output = F16;
|
||||||
|
|
||||||
|
fn neg(self) -> Self::Output {
|
||||||
|
F16 {
|
||||||
|
v: self.v ^ F16::SIGN_BIT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -7,6 +7,7 @@ pub mod bitset;
|
||||||
pub mod cfg;
|
pub mod cfg;
|
||||||
pub mod dataflow;
|
pub mod dataflow;
|
||||||
pub mod depth_first_search;
|
pub mod depth_first_search;
|
||||||
|
pub mod float16;
|
||||||
pub mod memstream;
|
pub mod memstream;
|
||||||
pub mod nir;
|
pub mod nir;
|
||||||
pub mod nir_instr_printer;
|
pub mod nir_instr_printer;
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ _compiler_rs_sources = [
|
||||||
'cfg.rs',
|
'cfg.rs',
|
||||||
'dataflow.rs',
|
'dataflow.rs',
|
||||||
'depth_first_search.rs',
|
'depth_first_search.rs',
|
||||||
|
'float16.rs',
|
||||||
'memstream.rs',
|
'memstream.rs',
|
||||||
'nir_instr_printer.rs',
|
'nir_instr_printer.rs',
|
||||||
'nir.rs',
|
'nir.rs',
|
||||||
|
|
@ -64,6 +65,8 @@ _compiler_bindgen_args = [
|
||||||
'--allowlist-var', 'rust_.*',
|
'--allowlist-var', 'rust_.*',
|
||||||
'--allowlist-function', 'glsl_.*',
|
'--allowlist-function', 'glsl_.*',
|
||||||
'--allowlist-function', '_mesa_shader_stage_to_string',
|
'--allowlist-function', '_mesa_shader_stage_to_string',
|
||||||
|
'--allowlist-function', '_mesa_.*half.*',
|
||||||
|
'--allowlist-function', '_mesa_.*float16.*',
|
||||||
'--allowlist-function', 'nir_.*',
|
'--allowlist-function', 'nir_.*',
|
||||||
'--allowlist-function', 'compiler_rs.*',
|
'--allowlist-function', 'compiler_rs.*',
|
||||||
'--allowlist-function', 'u_memstream.*',
|
'--allowlist-function', 'u_memstream.*',
|
||||||
|
|
@ -94,6 +97,7 @@ _idep_libcompiler_c = declare_dependency(
|
||||||
_compiler_bindings_rs = rust.bindgen(
|
_compiler_bindings_rs = rust.bindgen(
|
||||||
input : ['bindings.h'],
|
input : ['bindings.h'],
|
||||||
output : 'bindings.rs',
|
output : 'bindings.rs',
|
||||||
|
output_inline_wrapper : 'bindings.c',
|
||||||
c_args : [
|
c_args : [
|
||||||
pre_args,
|
pre_args,
|
||||||
],
|
],
|
||||||
|
|
@ -104,6 +108,20 @@ _compiler_bindings_rs = rust.bindgen(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_libcompiler_bindings = static_library(
|
||||||
|
'compiler_bindings',
|
||||||
|
['bindings.h', _compiler_bindings_rs[1]],
|
||||||
|
gnu_symbol_visibility : 'hidden',
|
||||||
|
c_args : [
|
||||||
|
pre_args,
|
||||||
|
cc.get_supported_arguments('-Wno-missing-prototypes'),
|
||||||
|
],
|
||||||
|
dependencies : [
|
||||||
|
idep_nir_headers,
|
||||||
|
idep_mesautil,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
compiler_rs_bindgen_blocklist = []
|
compiler_rs_bindgen_blocklist = []
|
||||||
foreach type : _compiler_binding_types
|
foreach type : _compiler_binding_types
|
||||||
compiler_rs_bindgen_blocklist += ['--blocklist-type', type]
|
compiler_rs_bindgen_blocklist += ['--blocklist-type', type]
|
||||||
|
|
@ -122,6 +140,7 @@ _libcompiler_rs = static_library(
|
||||||
gnu_symbol_visibility : 'hidden',
|
gnu_symbol_visibility : 'hidden',
|
||||||
rust_abi : 'rust',
|
rust_abi : 'rust',
|
||||||
dependencies: [_idep_libcompiler_c],
|
dependencies: [_idep_libcompiler_c],
|
||||||
|
link_with: [_libcompiler_bindings],
|
||||||
rust_args: _compiler_rust_args,
|
rust_args: _compiler_rust_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue