diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 77b3f7ef517..0cf1e79b7e3 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -68,6 +68,7 @@ static const struct spirv_capabilities implemented_capabilities = { .ComputeDerivativeGroupLinearKHR = true, .ComputeDerivativeGroupQuadsKHR = true, .CooperativeMatrixKHR = true, + .CooperativeMatrixConversionsNV = true, .CullDistance = true, .DemoteToHelperInvocation = true, .DenormFlushToZero = true, @@ -6862,6 +6863,8 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpCooperativeMatrixStoreKHR: case SpvOpCooperativeMatrixLengthKHR: case SpvOpCooperativeMatrixMulAddKHR: + case SpvOpCooperativeMatrixConvertNV: + case SpvOpCooperativeMatrixTransposeNV: vtn_handle_cooperative_instruction(b, opcode, w, count); break; diff --git a/src/compiler/spirv/vtn_cmat.c b/src/compiler/spirv/vtn_cmat.c index e43557d9f96..f0ab5ba95c2 100644 --- a/src/compiler/spirv/vtn_cmat.c +++ b/src/compiler/spirv/vtn_cmat.c @@ -179,6 +179,26 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode, break; } + case SpvOpCooperativeMatrixConvertNV: { + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]); + + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_convert_nv"); + nir_cmat_convert(&b->nb, &dst->def, &src->def); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + + case SpvOpCooperativeMatrixTransposeNV: { + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]); + + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_transpose_nv"); + nir_cmat_transpose(&b->nb, &dst->def, &src->def); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + default: unreachable("Unexpected opcode for cooperative matrix instruction"); }