From 2d0f4f2c17b79830e9780a68bc473718d4abd4ad Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Tue, 30 May 2023 23:26:14 -0700 Subject: [PATCH] compiler/types: Add support for Cooperative Matrix types Reviewed-by: Jesse Natalie Reviewed-by: Ian Romanick Reviewed-by: Bas Nieuwenhuizen Part-of: --- src/compiler/glsl/ast_to_hir.cpp | 3 + .../glsl/gl_nir_link_uniform_initializers.c | 2 + src/compiler/glsl/ir_clone.cpp | 3 + src/compiler/glsl_types.cpp | 92 +++++++++++++++++++ src/compiler/glsl_types.h | 32 +++++++ src/compiler/nir/nir.c | 1 + src/compiler/nir_types.cpp | 28 ++++++ src/compiler/nir_types.h | 6 ++ src/intel/compiler/brw_shader.cpp | 1 + src/intel/compiler/brw_vec4_visitor.cpp | 1 + src/mesa/main/uniform_query.cpp | 1 + 11 files changed, 170 insertions(+) diff --git a/src/compiler/glsl/ast_to_hir.cpp b/src/compiler/glsl/ast_to_hir.cpp index f91d1ae3b23..b9721600281 100644 --- a/src/compiler/glsl/ast_to_hir.cpp +++ b/src/compiler/glsl/ast_to_hir.cpp @@ -1191,6 +1191,9 @@ do_comparison(void *mem_ctx, int operation, ir_rvalue *op0, ir_rvalue *op1) * ignores the sampler present in the type. */ break; + + case GLSL_TYPE_COOPERATIVE_MATRIX: + unreachable("unsupported base type cooperative matrix"); } if (cmp == NULL) diff --git a/src/compiler/glsl/gl_nir_link_uniform_initializers.c b/src/compiler/glsl/gl_nir_link_uniform_initializers.c index 74e52d898c6..80cd6a15e2b 100644 --- a/src/compiler/glsl/gl_nir_link_uniform_initializers.c +++ b/src/compiler/glsl/gl_nir_link_uniform_initializers.c @@ -169,6 +169,8 @@ copy_constant_to_storage(union gl_constant_value *storage, */ assert(!"Should not get here."); break; + case GLSL_TYPE_COOPERATIVE_MATRIX: + unreachable("unsupported base type cooperative matrix"); } i += dmul; } diff --git a/src/compiler/glsl/ir_clone.cpp b/src/compiler/glsl/ir_clone.cpp index 5a384478078..059ae579b8a 100644 --- a/src/compiler/glsl/ir_clone.cpp +++ b/src/compiler/glsl/ir_clone.cpp @@ -370,6 +370,9 @@ ir_constant::clone(void *mem_ctx, struct hash_table *ht) const case GLSL_TYPE_INTERFACE: assert(!"Should not get here."); break; + + case GLSL_TYPE_COOPERATIVE_MATRIX: + unreachable("unsupported base type cooperative matrix"); } return NULL; diff --git a/src/compiler/glsl_types.cpp b/src/compiler/glsl_types.cpp index e9871d2b3a4..d563b204a5d 100644 --- a/src/compiler/glsl_types.cpp +++ b/src/compiler/glsl_types.cpp @@ -51,6 +51,7 @@ static struct { hash_table *explicit_matrix_types; hash_table *array_types; + hash_table *cmat_types; hash_table *struct_types; hash_table *interface_types; hash_table *subroutine_types; @@ -391,6 +392,7 @@ const glsl_type *glsl_type::get_bare_type() const return get_array_instance(this->fields.array->get_bare_type(), this->length); + case GLSL_TYPE_COOPERATIVE_MATRIX: case GLSL_TYPE_SAMPLER: case GLSL_TYPE_TEXTURE: case GLSL_TYPE_IMAGE: @@ -527,6 +529,19 @@ make_array_type(linear_ctx *lin_ctx, const glsl_type *element_type, unsigned len return t; } +static const char * +glsl_cmat_use_to_string(enum glsl_cmat_use use) +{ + switch (use) { + case GLSL_CMAT_USE_NONE: return "NONE"; + case GLSL_CMAT_USE_A: return "A"; + case GLSL_CMAT_USE_B: return "B"; + case GLSL_CMAT_USE_ACCUMULATOR: return "ACCUMULATOR"; + default: + unreachable("invalid cooperative matrix use"); + } +}; + const glsl_type * glsl_type::vec(unsigned components, const glsl_type *const ts[]) { @@ -1250,6 +1265,68 @@ glsl_type::get_array_instance(const glsl_type *element, return t; } +static const struct glsl_type * +make_cmat_type(linear_ctx *lin_ctx, const struct glsl_cmat_description desc) +{ + assert(lin_ctx != NULL); + + struct glsl_type *t = linear_zalloc(lin_ctx, struct glsl_type); + t->base_type = GLSL_TYPE_COOPERATIVE_MATRIX; + t->sampled_type = GLSL_TYPE_VOID; + t->vector_elements = 1; + t->cmat_desc = desc; + + const struct glsl_type *element_type = glsl_type::get_instance(desc.element_type, 1, 1); + t->name_id = (uintptr_t ) linear_asprintf(lin_ctx, "coopmat<%s, %s, %u, %u, %s>", + glsl_get_type_name(element_type), + mesa_scope_name((mesa_scope)desc.scope), + desc.rows, desc.cols, + glsl_cmat_use_to_string((enum glsl_cmat_use)desc.use)); + + return t; +} + +const glsl_type * +glsl_type::get_cmat_instance(const struct glsl_cmat_description desc) +{ + STATIC_ASSERT(sizeof(struct glsl_cmat_description) == 4); + + const uint32_t key = desc.element_type | desc.scope << 5 | + desc.rows << 8 | desc.cols << 16 | + desc.use << 24; + const uint32_t key_hash = _mesa_hash_uint(&key); + + simple_mtx_lock(&glsl_type_cache_mutex); + assert(glsl_type_cache.users > 0); + void *mem_ctx = glsl_type_cache.mem_ctx; + + if (glsl_type_cache.cmat_types == NULL) { + glsl_type_cache.cmat_types = + _mesa_hash_table_create_u32_keys(mem_ctx); + } + hash_table *cmat_types = glsl_type_cache.cmat_types; + + const struct hash_entry *entry = _mesa_hash_table_search_pre_hashed( + cmat_types, key_hash, (void *) (uintptr_t) key); + if (entry == NULL) { + const struct glsl_type *t = make_cmat_type(glsl_type_cache.lin_ctx, desc); + entry = _mesa_hash_table_insert_pre_hashed(cmat_types, key_hash, + (void *) (uintptr_t) key, (void *) t); + } + + const struct glsl_type *t = (const struct glsl_type *)entry->data; + simple_mtx_unlock(&glsl_type_cache_mutex); + + assert(t->base_type == GLSL_TYPE_COOPERATIVE_MATRIX); + assert(t->cmat_desc.element_type == desc.element_type); + assert(t->cmat_desc.scope == desc.scope); + assert(t->cmat_desc.rows == desc.rows); + assert(t->cmat_desc.cols == desc.cols); + assert(t->cmat_desc.use == desc.use); + + return t; +} + bool glsl_type::compare_no_precision(const glsl_type *b) const { @@ -1679,6 +1756,7 @@ glsl_type::component_slots() const case GLSL_TYPE_SUBROUTINE: return 1; + case GLSL_TYPE_COOPERATIVE_MATRIX: case GLSL_TYPE_ATOMIC_UINT: case GLSL_TYPE_VOID: case GLSL_TYPE_ERROR: @@ -1745,6 +1823,7 @@ glsl_type::component_slots_aligned(unsigned offset) const case GLSL_TYPE_SUBROUTINE: return 1; + case GLSL_TYPE_COOPERATIVE_MATRIX: case GLSL_TYPE_ATOMIC_UINT: case GLSL_TYPE_VOID: case GLSL_TYPE_ERROR: @@ -2599,6 +2678,10 @@ glsl_type::get_explicit_type_for_size_align(glsl_type_size_align_func type_info, type_info(this, size, alignment); assert(*alignment > 0); return this; + } else if (this->is_cmat()) { + *size = 0; + *alignment = 0; + return this; } else if (this->is_scalar()) { type_info(this, size, alignment); assert(*size == explicit_type_scalar_byte_size(this)); @@ -2822,6 +2905,7 @@ glsl_type::count_vec4_slots(bool is_gl_vertex_input, bool is_bindless) const case GLSL_TYPE_SUBROUTINE: return 1; + case GLSL_TYPE_COOPERATIVE_MATRIX: case GLSL_TYPE_ATOMIC_UINT: case GLSL_TYPE_VOID: case GLSL_TYPE_ERROR: @@ -2925,6 +3009,7 @@ union packed_type { unsigned length:13; unsigned explicit_stride:14; } array; + glsl_cmat_description cmat_desc; struct { unsigned base_type:5; unsigned interface_packing_or_packed:2; @@ -3039,6 +3124,10 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type) blob_write_uint32(blob, type->explicit_stride); encode_type_to_blob(blob, type->fields.array); return; + case GLSL_TYPE_COOPERATIVE_MATRIX: + encoded.cmat_desc = type->cmat_desc; + blob_write_uint32(blob, encoded.u32); + return; case GLSL_TYPE_STRUCT: case GLSL_TYPE_INTERFACE: encoded.strct.length = MIN2(type->length, 0xfffff); @@ -3145,6 +3234,9 @@ decode_type_from_blob(struct blob_reader *blob) return glsl_type::get_array_instance(decode_type_from_blob(blob), length, explicit_stride); } + case GLSL_TYPE_COOPERATIVE_MATRIX: { + return glsl_type::get_cmat_instance(encoded.cmat_desc); + } case GLSL_TYPE_STRUCT: case GLSL_TYPE_INTERFACE: { char *name = blob_read_string(blob); diff --git a/src/compiler/glsl_types.h b/src/compiler/glsl_types.h index 06e109695fb..9d2e7044bb5 100644 --- a/src/compiler/glsl_types.h +++ b/src/compiler/glsl_types.h @@ -76,6 +76,7 @@ enum glsl_base_type { GLSL_TYPE_UINT64, GLSL_TYPE_INT64, GLSL_TYPE_BOOL, + GLSL_TYPE_COOPERATIVE_MATRIX, GLSL_TYPE_SAMPLER, GLSL_TYPE_TEXTURE, GLSL_TYPE_IMAGE, @@ -167,6 +168,7 @@ glsl_base_type_get_bit_size(const enum glsl_base_type base_type) case GLSL_TYPE_UINT: case GLSL_TYPE_FLOAT: /* TODO handle mediump */ case GLSL_TYPE_SUBROUTINE: + case GLSL_TYPE_COOPERATIVE_MATRIX: return 32; case GLSL_TYPE_FLOAT16: @@ -279,6 +281,24 @@ enum { GLSL_PRECISION_LOW }; +enum glsl_cmat_use { + GLSL_CMAT_USE_NONE = 0, + GLSL_CMAT_USE_A, + GLSL_CMAT_USE_B, + GLSL_CMAT_USE_ACCUMULATOR, +}; + +struct glsl_cmat_description { + /* MSVC can't merge bitfields of different types and also sign extend enums, + * so use uint8_t for those cases. + */ + uint8_t element_type:5; /* enum glsl_base_type */ + uint8_t scope:3; /* mesa_scope */ + uint8_t rows; + uint8_t cols; + uint8_t use; /* enum glsl_cmat_use */ +}; + const char *glsl_get_type_name(const struct glsl_type *type); struct glsl_type { @@ -297,6 +317,8 @@ struct glsl_type { unsigned interface_packing:2; unsigned interface_row_major:1; + struct glsl_cmat_description cmat_desc; + /** * For \c GLSL_TYPE_STRUCT this specifies if the struct is packed or not. * @@ -456,6 +478,11 @@ struct glsl_type { unsigned array_size, unsigned explicit_stride = 0); + /** + * Get the instance of a cooperative matrix type + */ + static const glsl_type *get_cmat_instance(const struct glsl_cmat_description desc); + /** * Get the instance of a record type */ @@ -931,6 +958,11 @@ struct glsl_type { return is_array() && fields.array->is_array(); } + bool is_cmat() const + { + return base_type == GLSL_TYPE_COOPERATIVE_MATRIX; + } + /** * Query whether or not a type is a record */ diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c index 26dedff4e37..6282dea3a53 100644 --- a/src/compiler/nir/nir.c +++ b/src/compiler/nir/nir.c @@ -2755,6 +2755,7 @@ nir_get_nir_type_for_glsl_base_type(enum glsl_base_type base_type) case GLSL_TYPE_DOUBLE: return nir_type_float64; /* clang-format on */ + case GLSL_TYPE_COOPERATIVE_MATRIX: case GLSL_TYPE_SAMPLER: case GLSL_TYPE_TEXTURE: case GLSL_TYPE_IMAGE: diff --git a/src/compiler/nir_types.cpp b/src/compiler/nir_types.cpp index e167105e728..b84c6086a65 100644 --- a/src/compiler/nir_types.cpp +++ b/src/compiler/nir_types.cpp @@ -335,6 +335,12 @@ glsl_type_is_array_or_matrix(const struct glsl_type *type) return type->is_array() || type->is_matrix(); } +bool +glsl_type_is_cmat(const struct glsl_type *type) +{ + return type->is_cmat(); +} + bool glsl_type_is_struct(const struct glsl_type *type) { @@ -642,6 +648,12 @@ glsl_array_type(const glsl_type *element, unsigned array_size, return glsl_type::get_array_instance(element, array_size, explicit_stride); } +const glsl_type * +glsl_cmat_type(const glsl_cmat_description *desc) +{ + return glsl_type::get_cmat_instance(*desc); +} + const glsl_type * glsl_replace_vector_type(const glsl_type *t, unsigned components) { @@ -857,6 +869,7 @@ glsl_get_natural_size_align_bytes(const struct glsl_type *type, case GLSL_TYPE_ATOMIC_UINT: case GLSL_TYPE_SUBROUTINE: + case GLSL_TYPE_COOPERATIVE_MATRIX: case GLSL_TYPE_VOID: case GLSL_TYPE_ERROR: unreachable("type does not have a natural size"); @@ -910,6 +923,7 @@ glsl_get_vec4_size_align_bytes(const struct glsl_type *type, case GLSL_TYPE_IMAGE: case GLSL_TYPE_ATOMIC_UINT: case GLSL_TYPE_SUBROUTINE: + case GLSL_TYPE_COOPERATIVE_MATRIX: case GLSL_TYPE_VOID: case GLSL_TYPE_ERROR: unreachable("type does not make sense for glsl_get_vec4_size_align_bytes()"); @@ -1102,3 +1116,17 @@ glsl_type_replace_vec3_with_vec4(const struct glsl_type *type) { return type->replace_vec3_with_vec4(); } + +const struct glsl_type * +glsl_get_cmat_element(const struct glsl_type *type) +{ + assert(type->base_type == GLSL_TYPE_COOPERATIVE_MATRIX); + return glsl_type::get_instance(type->cmat_desc.element_type, 1, 1); +} + +const struct glsl_cmat_description * +glsl_get_cmat_description(const struct glsl_type *type) +{ + assert(type->base_type == GLSL_TYPE_COOPERATIVE_MATRIX); + return &type->cmat_desc; +} diff --git a/src/compiler/nir_types.h b/src/compiler/nir_types.h index 22a9ec2bd8a..ff6172a1cb4 100644 --- a/src/compiler/nir_types.h +++ b/src/compiler/nir_types.h @@ -140,6 +140,7 @@ bool glsl_type_is_array(const struct glsl_type *type); bool glsl_type_is_unsized_array(const struct glsl_type *type); bool glsl_type_is_array_of_arrays(const struct glsl_type *type); bool glsl_type_is_array_or_matrix(const struct glsl_type *type); +bool glsl_type_is_cmat(const struct glsl_type *type); bool glsl_type_is_struct(const struct glsl_type *type); bool glsl_type_is_interface(const struct glsl_type *type); bool glsl_type_is_struct_or_ifc(const struct glsl_type *type); @@ -201,6 +202,8 @@ const struct glsl_type *glsl_array_type(const struct glsl_type *element, unsigned array_size, unsigned explicit_stride); +const struct glsl_type *glsl_cmat_type(const struct glsl_cmat_description *desc); + const struct glsl_type *glsl_struct_type(const struct glsl_struct_field *fields, unsigned num_fields, const char *name, bool packed); @@ -254,6 +257,9 @@ int glsl_get_field_index(const struct glsl_type *type, const char *name); bool glsl_type_is_leaf(const struct glsl_type *type); +const struct glsl_type *glsl_get_cmat_element(const struct glsl_type *type); +const struct glsl_cmat_description *glsl_get_cmat_description(const struct glsl_type *type); + #ifdef __cplusplus } #endif diff --git a/src/intel/compiler/brw_shader.cpp b/src/intel/compiler/brw_shader.cpp index 88beeb5e29d..423c1976b7c 100644 --- a/src/intel/compiler/brw_shader.cpp +++ b/src/intel/compiler/brw_shader.cpp @@ -74,6 +74,7 @@ brw_type_for_base_type(const struct glsl_type *type) return BRW_REGISTER_TYPE_Q; case GLSL_TYPE_VOID: case GLSL_TYPE_ERROR: + case GLSL_TYPE_COOPERATIVE_MATRIX: unreachable("not reached"); } diff --git a/src/intel/compiler/brw_vec4_visitor.cpp b/src/intel/compiler/brw_vec4_visitor.cpp index 9a56f1e4d8c..54866dcb868 100644 --- a/src/intel/compiler/brw_vec4_visitor.cpp +++ b/src/intel/compiler/brw_vec4_visitor.cpp @@ -622,6 +622,7 @@ type_size_xvec4(const struct glsl_type *type, bool as_vec4, bool bindless) return bindless ? 1 : DIV_ROUND_UP(BRW_IMAGE_PARAM_SIZE, 4); case GLSL_TYPE_VOID: case GLSL_TYPE_ERROR: + case GLSL_TYPE_COOPERATIVE_MATRIX: unreachable("not reached"); } diff --git a/src/mesa/main/uniform_query.cpp b/src/mesa/main/uniform_query.cpp index 245c9fa2dd3..d1863697e72 100644 --- a/src/mesa/main/uniform_query.cpp +++ b/src/mesa/main/uniform_query.cpp @@ -1011,6 +1011,7 @@ associate_uniform_storage(struct gl_context *ctx, case GLSL_TYPE_STRUCT: case GLSL_TYPE_ERROR: case GLSL_TYPE_INTERFACE: + case GLSL_TYPE_COOPERATIVE_MATRIX: assert(!"Should not get here."); break; }