compiler/types: add a bfloat16 type

Signed-off-by: Rohan Garg <rohan.garg@intel.com>
Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Rohan Garg 2024-09-11 12:07:52 +02:00 committed by Marge Bot
parent ecd2d2cf46
commit 9e5d7eb88d
12 changed files with 59 additions and 4 deletions

View file

@ -26,12 +26,21 @@ def sampler_type(name, gl_type, base_type, dim, shadow, array, sampled_type):
})
def vector_type(base_name, vec_name, base_type, gl_type, extra_gl_type=None):
gl_types = [None, None, None, None]
if extra_gl_type is None:
extra_gl_type = ""
simple_type(base_name, gl_type + extra_gl_type, base_type, 1, 1)
simple_type(vec_name + "2", gl_type + "_VEC2" + extra_gl_type, base_type, 2, 1)
simple_type(vec_name + "3", gl_type + "_VEC3" + extra_gl_type, base_type, 3, 1)
simple_type(vec_name + "4", gl_type + "_VEC4" + extra_gl_type, base_type, 4, 1)
if gl_type:
gl_types = [gl_type + extra_gl_type,
gl_type + "_VEC2" + extra_gl_type,
gl_type + "_VEC3" + extra_gl_type,
gl_type + "_VEC4" + extra_gl_type]
simple_type(base_name, gl_types[0], base_type, 1, 1)
simple_type(vec_name + "2", gl_types[1], base_type, 2, 1)
simple_type(vec_name + "3", gl_types[2], base_type, 3, 1)
simple_type(vec_name + "4", gl_types[3], base_type, 4, 1)
simple_type(vec_name + "5", None, base_type, 5, 1)
simple_type(vec_name + "8", None, base_type, 8, 1)
simple_type(vec_name + "16", None, base_type, 16, 1)
@ -52,6 +61,8 @@ vector_type("uint16_t", "u16vec", "GLSL_TYPE_UINT16", "GL_UNSIGNED_INT16", "_N
vector_type("int8_t", "i8vec", "GLSL_TYPE_INT8", "GL_INT8", "_NV")
vector_type("uint8_t", "u8vec", "GLSL_TYPE_UINT8", "GL_UNSIGNED_INT8", "_NV")
vector_type("bfloat16_t", "bf16vec", "GLSL_TYPE_BFLOAT16", None)
simple_type("mat2", "GL_FLOAT_MAT2", "GLSL_TYPE_FLOAT", 2, 2)
simple_type("mat3", "GL_FLOAT_MAT3", "GLSL_TYPE_FLOAT", 3, 3)
simple_type("mat4", "GL_FLOAT_MAT4", "GLSL_TYPE_FLOAT", 4, 4)

View file

@ -1135,6 +1135,7 @@ do_comparison(void *mem_ctx, int operation, ir_rvalue *op0, ir_rvalue *op1)
switch (op0->type->base_type) {
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_BOOL:

View file

@ -163,6 +163,7 @@ copy_constant_to_storage(union gl_constant_value *storage,
case GLSL_TYPE_UINT8:
case GLSL_TYPE_INT8:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
/* All other types should have already been filtered by other
* paths in the caller.
*/

View file

@ -338,6 +338,7 @@ ir_constant::clone(void *mem_ctx, struct hash_table *ht) const
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_BOOL:
case GLSL_TYPE_UINT64:

View file

@ -347,6 +347,8 @@ glsl_get_base_glsl_type(const glsl_type *t)
return &glsl_type_builtin_float16_t;
case GLSL_TYPE_DOUBLE:
return &glsl_type_builtin_double;
case GLSL_TYPE_BFLOAT16:
return &glsl_type_builtin_bfloat16_t;
case GLSL_TYPE_BOOL:
return &glsl_type_builtin_bool;
case GLSL_TYPE_UINT64:
@ -384,6 +386,7 @@ glsl_get_bare_type(const glsl_type *t)
case GLSL_TYPE_UINT16:
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
@ -593,6 +596,7 @@ glsl_ ## vname ## _type (unsigned components) \
VECN(components, float, vec)
VECN(components, float16_t, f16vec)
VECN(components, bfloat16_t, bf16vec)
VECN(components, double, dvec)
VECN(components, int, ivec)
VECN(components, uint, uvec)
@ -641,6 +645,8 @@ glsl_simple_explicit_type(unsigned base_type, unsigned rows, unsigned columns,
return glsl_vec_type(rows);
case GLSL_TYPE_FLOAT16:
return glsl_f16vec_type(rows);
case GLSL_TYPE_BFLOAT16:
return glsl_bf16vec_type(rows);
case GLSL_TYPE_DOUBLE:
return glsl_dvec_type(rows);
case GLSL_TYPE_BOOL:
@ -1742,6 +1748,7 @@ glsl_get_component_slots(const glsl_type *t)
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_BOOL:
return glsl_get_components(t);
@ -1794,6 +1801,7 @@ glsl_get_component_slots_aligned(const glsl_type *t, unsigned offset)
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_BOOL:
return glsl_get_components(t);
@ -2880,6 +2888,7 @@ glsl_count_vec4_slots(const glsl_type *t, bool is_gl_vertex_input, bool is_bindl
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_BOOL:
return t->matrix_columns;
case GLSL_TYPE_DOUBLE:
@ -3084,6 +3093,7 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type)
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_UINT8:
case GLSL_TYPE_INT8:
@ -3732,6 +3742,7 @@ glsl_get_natural_size_align_bytes(const glsl_type *type,
case GLSL_TYPE_UINT16:
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
@ -3791,6 +3802,7 @@ glsl_get_word_size_align_bytes(const glsl_type *type,
case GLSL_TYPE_UINT16:
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
@ -3850,6 +3862,7 @@ glsl_get_vec4_size_align_bytes(const glsl_type *type,
case GLSL_TYPE_UINT16:
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:

View file

@ -63,6 +63,7 @@ enum glsl_base_type {
GLSL_TYPE_INT,
GLSL_TYPE_FLOAT,
GLSL_TYPE_FLOAT16,
GLSL_TYPE_BFLOAT16,
GLSL_TYPE_DOUBLE,
GLSL_TYPE_UINT8,
GLSL_TYPE_INT8,
@ -99,6 +100,7 @@ static unsigned glsl_base_type_bit_size(enum glsl_base_type type)
return 32;
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_UINT16:
case GLSL_TYPE_INT16:
return 16;
@ -167,6 +169,7 @@ glsl_base_type_get_bit_size(const enum glsl_base_type base_type)
return 32;
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_UINT16:
case GLSL_TYPE_INT16:
return 16;
@ -621,6 +624,12 @@ glsl_type_is_float_16_32_64(const glsl_type *t)
return t->base_type == GLSL_TYPE_FLOAT16 || glsl_type_is_float(t) || glsl_type_is_double(t);
}
static inline bool
glsl_type_is_bfloat_16(const glsl_type *t)
{
return t->base_type == GLSL_TYPE_BFLOAT16;
}
static inline bool
glsl_type_is_int_16_32_64(const glsl_type *t)
{
@ -937,6 +946,7 @@ static inline const glsl_type *glsl_int8_t_type(void) { return &glsl_type_builti
static inline const glsl_type *glsl_uint8_t_type(void) { return &glsl_type_builtin_uint8_t; }
static inline const glsl_type *glsl_bool_type(void) { return &glsl_type_builtin_bool; }
static inline const glsl_type *glsl_atomic_uint_type(void) { return &glsl_type_builtin_atomic_uint; }
static inline const glsl_type *glsl_bfloat16_t_type(void) { return &glsl_type_builtin_bfloat16_t; }
static inline const glsl_type *
glsl_floatN_t_type(unsigned bit_size)
@ -950,6 +960,16 @@ glsl_floatN_t_type(unsigned bit_size)
}
}
static inline const glsl_type *
glsl_bfloatN_t_type(unsigned bit_size)
{
switch (bit_size) {
case 16: return &glsl_type_builtin_bfloat16_t;
default:
unreachable("Unsupported bit size");
}
}
static inline const glsl_type *
glsl_intN_t_type(unsigned bit_size)
{
@ -978,6 +998,7 @@ glsl_uintN_t_type(unsigned bit_size)
const glsl_type *glsl_vec_type(unsigned components);
const glsl_type *glsl_f16vec_type(unsigned components);
const glsl_type *glsl_bf16vec_type(unsigned components);
const glsl_type *glsl_dvec_type(unsigned components);
const glsl_type *glsl_ivec_type(unsigned components);
const glsl_type *glsl_uvec_type(unsigned components);

View file

@ -2901,6 +2901,7 @@ nir_get_nir_type_for_glsl_base_type(enum glsl_base_type base_type)
case GLSL_TYPE_INT64: return nir_type_int64;
case GLSL_TYPE_FLOAT: return nir_type_float32;
case GLSL_TYPE_FLOAT16: return nir_type_float16;
case GLSL_TYPE_BFLOAT16: return nir_type_uint16;
case GLSL_TYPE_DOUBLE: return nir_type_float64;
/* clang-format on */

View file

@ -716,6 +716,7 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load,
case GLSL_TYPE_INT64:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_BOOL:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_COOPERATIVE_MATRIX:
@ -809,6 +810,7 @@ _vtn_variable_copy(struct vtn_builder *b, struct vtn_pointer *dest,
case GLSL_TYPE_INT64:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_BOOL:
/* At this point, we have a scalar, vector, or matrix so we know that

View file

@ -41,6 +41,7 @@ type_size_xvec4(const struct glsl_type *type, bool as_vec4, bool bindless)
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_BOOL:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_UINT16:

View file

@ -76,6 +76,7 @@ elk_type_for_base_type(const struct glsl_type *type)
case GLSL_TYPE_VOID:
case GLSL_TYPE_ERROR:
case GLSL_TYPE_COOPERATIVE_MATRIX:
case GLSL_TYPE_BFLOAT16:
unreachable("not reached");
}

View file

@ -574,6 +574,7 @@ elk_type_size_xvec4(const struct glsl_type *type, bool as_vec4, bool bindless)
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_BOOL:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_UINT16:

View file

@ -989,6 +989,7 @@ associate_uniform_storage(struct gl_context *ctx,
FALLTHROUGH;
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
format = uniform_native;
columns = storage->type->matrix_columns;
break;