From 3b86c96d225c03e5f93072e43ca7f3db3930779c Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Tue, 2 Apr 2024 18:49:22 -0500 Subject: [PATCH] spirv: Handle constant cooperative matrices in OpCompositeExtract Fixes: b98f87612bc1 ("spirv: Implement SPV_KHR_cooperative_matrix") Reviewed-by: Konstantin Seurer Part-of: (cherry picked from commit 8fa46b31a89fde179d87f0b714bc882ebfa43b0d) --- .pick_status.json | 2 +- src/compiler/spirv/spirv_to_nir.c | 49 ++++++++++++++++++------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/.pick_status.json b/.pick_status.json index 8675bcec70b..92d56bcf064 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -224,7 +224,7 @@ "description": "spirv: Handle constant cooperative matrices in OpCompositeExtract", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": "b98f87612bc14fe88184dc099d9d4f8e6b3b23cb", "notes": null diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index f2c9ac300d9..7c76c915c76 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -2278,31 +2278,38 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, int elem = -1; const struct vtn_type *type = comp->type; for (unsigned i = deref_start; i < count; i++) { - vtn_fail_if(w[i] > type->length, - "%uth index of %s is %u but the type has only " - "%u elements", i - deref_start, - spirv_op_to_string(opcode), w[i], type->length); + if (type->base_type == vtn_base_type_cooperative_matrix) { + /* Cooperative matrices are always scalar constants. We don't + * care about the index w[i] because it's always replicated. + */ + type = type->component_type; + } else { + vtn_fail_if(w[i] > type->length, + "%uth index of %s is %u but the type has only " + "%u elements", i - deref_start, + spirv_op_to_string(opcode), w[i], type->length); - switch (type->base_type) { - case vtn_base_type_vector: - elem = w[i]; - type = type->array_element; - break; + switch (type->base_type) { + case vtn_base_type_vector: + elem = w[i]; + type = type->array_element; + break; - case vtn_base_type_matrix: - case vtn_base_type_array: - c = &(*c)->elements[w[i]]; - type = type->array_element; - break; + case vtn_base_type_matrix: + case vtn_base_type_array: + c = &(*c)->elements[w[i]]; + type = type->array_element; + break; - case vtn_base_type_struct: - c = &(*c)->elements[w[i]]; - type = type->members[w[i]]; - break; + case vtn_base_type_struct: + c = &(*c)->elements[w[i]]; + type = type->members[w[i]]; + break; - default: - vtn_fail("%s must only index into composite types", - spirv_op_to_string(opcode)); + default: + vtn_fail("%s must only index into composite types", + spirv_op_to_string(opcode)); + } } }