diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h index 785473a85a2..51000b2d8fc 100644 --- a/src/compiler/shader_info.h +++ b/src/compiler/shader_info.h @@ -47,6 +47,7 @@ struct spirv_supported_capabilities { bool amd_shader_explicit_vertex_parameter; bool amd_trinary_minmax; bool atomic_storage; + bool cooperative_matrix; bool demote_to_helper_invocation; bool derivative_group; bool descriptor_array_dynamic_indexing; diff --git a/src/compiler/spirv/meson.build b/src/compiler/spirv/meson.build index 06dc9f7979b..dfb53d6738c 100644 --- a/src/compiler/spirv/meson.build +++ b/src/compiler/spirv/meson.build @@ -51,6 +51,7 @@ files_libvtn = files( 'vtn_alu.c', 'vtn_amd.c', 'vtn_cfg.c', + 'vtn_cmat.c', 'vtn_glsl450.c', 'vtn_opencl.c', 'vtn_private.h', diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index c876f85091f..5b20c5cf985 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -266,7 +266,10 @@ vtn_undef_ssa_value(struct vtn_builder *b, const struct glsl_type *type) struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value); val->type = glsl_get_bare_type(type); - if (glsl_type_is_vector_or_scalar(type)) { + if (glsl_type_is_cmat(type)) { + nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_undef"); + vtn_set_ssa_value_var(b, val, mat->var); + } else if (glsl_type_is_vector_or_scalar(type)) { unsigned num_components = glsl_get_vector_elements(val->type); unsigned bit_size = glsl_get_bit_size(val->type); val->def = nir_undef(&b->nb, num_components, bit_size); @@ -296,7 +299,15 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant, struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value); val->type = glsl_get_bare_type(type); - if (glsl_type_is_vector_or_scalar(type)) { + if (glsl_type_is_cmat(type)) { + const struct glsl_type *element_type = glsl_get_cmat_element(type); + + nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_constant"); + nir_cmat_construct(&b->nb, &mat->def, + nir_build_imm(&b->nb, 1, glsl_get_bit_size(element_type), + constant->values)); + vtn_set_ssa_value_var(b, val, mat->var); + } else if (glsl_type_is_vector_or_scalar(type)) { val->def = nir_build_imm(&b->nb, glsl_get_vector_elements(val->type), glsl_get_bit_size(val->type), constant->values); @@ -859,6 +870,7 @@ vtn_types_compatible(struct vtn_builder *b, case vtn_base_type_sampler: case vtn_base_type_sampled_image: case vtn_base_type_event: + case vtn_base_type_cooperative_matrix: return t1->type == t2->type; case vtn_base_type_array: @@ -921,6 +933,7 @@ vtn_type_copy(struct vtn_builder *b, struct vtn_type *src) case vtn_base_type_event: case vtn_base_type_accel_struct: case vtn_base_type_ray_query: + case vtn_base_type_cooperative_matrix: /* Nothing more to do */ break; @@ -1951,6 +1964,10 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode, break; } + case SpvOpTypeCooperativeMatrixKHR: + vtn_handle_cooperative_type(b, val, opcode, w, count); + break; + case SpvOpTypeEvent: val->type->base_type = vtn_base_type_event; /* @@ -2135,9 +2152,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, case SpvOpSpecConstantComposite: case SpvOpConstantComposite: { unsigned elem_count = count - 3; - vtn_fail_if(elem_count != val->type->length, + unsigned expected_length = val->type->base_type == vtn_base_type_cooperative_matrix ? + 1 : val->type->length; + vtn_fail_if(elem_count != expected_length, "%s has %u constituents, expected %u", - spirv_op_to_string(opcode), elem_count, val->type->length); + spirv_op_to_string(opcode), elem_count, expected_length); nir_constant **elems = ralloc_array(b, nir_constant *, elem_count); val->is_undef_constant = true; @@ -2173,6 +2192,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, val->constant->elements = elems; break; + case vtn_base_type_cooperative_matrix: + val->constant->values[0] = elems[0]->values[0]; + break; + default: vtn_fail("Result type of %s must be a composite type", spirv_op_to_string(opcode)); @@ -2685,7 +2708,7 @@ vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type) if (!glsl_type_is_vector_or_scalar(type)) { unsigned elems = glsl_get_length(val->type); val->elems = ralloc_array(b, struct vtn_ssa_value *, elems); - if (glsl_type_is_array_or_matrix(type)) { + if (glsl_type_is_array_or_matrix(type) || glsl_type_is_cmat(type)) { const struct glsl_type *elem_type = glsl_get_array_element(type); for (unsigned i = 0; i < elems; i++) val->elems[i] = vtn_create_ssa_value(b, elem_type); @@ -4216,6 +4239,9 @@ vtn_composite_insert(struct vtn_builder *b, struct vtn_ssa_value *src, struct vtn_ssa_value *insert, const uint32_t *indices, unsigned num_indices) { + if (glsl_type_is_cmat(src->type)) + return vtn_cooperative_matrix_insert(b, src, insert, indices, num_indices); + struct vtn_ssa_value *dest = vtn_composite_copy(b, src); struct vtn_ssa_value *cur = dest; @@ -4254,6 +4280,9 @@ static struct vtn_ssa_value * vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src, const uint32_t *indices, unsigned num_indices) { + if (glsl_type_is_cmat(src->type)) + return vtn_cooperative_matrix_extract(b, src, indices, num_indices); + struct vtn_ssa_value *cur = src; for (unsigned i = 0; i < num_indices; i++) { if (glsl_type_is_vector_or_scalar(cur->type)) { @@ -4310,7 +4339,12 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode, case SpvOpCompositeConstruct: { unsigned elems = count - 3; assume(elems >= 1); - if (glsl_type_is_vector_or_scalar(type->type)) { + if (type->base_type == vtn_base_type_cooperative_matrix) { + vtn_assert(elems == 1); + nir_deref_instr *mat = vtn_create_cmat_temporary(b, type->type, "cmat_construct"); + nir_cmat_construct(&b->nb, &mat->def, vtn_get_nir_ssa(b, w[3])); + vtn_set_ssa_value_var(b, ssa, mat->var); + } else if (glsl_type_is_vector_or_scalar(type->type)) { nir_def *srcs[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < elems; i++) { srcs[i] = vtn_get_nir_ssa(b, w[3 + i]); @@ -5022,6 +5056,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, spv_check_supported(shader_enqueue, cap); break; + case SpvCapabilityCooperativeMatrixKHR: + spv_check_supported(cooperative_matrix, cap); + break; + default: vtn_fail("Unhandled capability: %s (%u)", spirv_capability_to_string(cap), cap); @@ -5656,6 +5694,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpTypePipe: case SpvOpTypeAccelerationStructureKHR: case SpvOpTypeRayQueryKHR: + case SpvOpTypeCooperativeMatrixKHR: vtn_handle_type(b, opcode, w, count); break; @@ -6621,6 +6660,13 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpFinishWritingNodePayloadAMDX: break; + case SpvOpCooperativeMatrixLoadKHR: + case SpvOpCooperativeMatrixStoreKHR: + case SpvOpCooperativeMatrixLengthKHR: + case SpvOpCooperativeMatrixMulAddKHR: + vtn_handle_cooperative_instruction(b, opcode, w, count); + break; + default: vtn_fail_with_opcode("Unhandled opcode", opcode); } diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 4cba6048123..04d71ae6eb2 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -597,6 +597,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val = vtn_untyped_value(b, w[2]); const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; + if (glsl_type_is_cmat(dest_type)) { + vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count); + return; + } + vtn_handle_no_contraction(b, dest_val); bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val); @@ -1297,6 +1302,11 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) */ struct vtn_type *type = vtn_get_type(b, w[1]); + if (type->base_type == vtn_base_type_cooperative_matrix) { + vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count); + return; + } + struct nir_def *src = vtn_get_nir_ssa(b, w[3]); vtn_fail_if(src->num_components * src->bit_size != diff --git a/src/compiler/spirv/vtn_cmat.c b/src/compiler/spirv/vtn_cmat.c new file mode 100644 index 00000000000..2e2898429bc --- /dev/null +++ b/src/compiler/spirv/vtn_cmat.c @@ -0,0 +1,296 @@ +/* + * Copyright 2023 Intel Corporation + * SPDX-License-Identifier: MIT + */ + +#include "glsl_types.h" +#include "nir.h" +#include "nir_types.h" +#include "vtn_private.h" + +static enum glsl_cmat_use +vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use) +{ + switch (use) { + case SpvCooperativeMatrixUseMatrixAKHR: + return GLSL_CMAT_USE_A; + case SpvCooperativeMatrixUseMatrixBKHR: + return GLSL_CMAT_USE_B; + case SpvCooperativeMatrixUseMatrixAccumulatorKHR: + return GLSL_CMAT_USE_ACCUMULATOR; + default: + unreachable("Unexpected cooperative matrix use"); + } +} + +void +vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val, + SpvOp opcode, const uint32_t *w, unsigned count) +{ + vtn_assert(opcode == SpvOpTypeCooperativeMatrixKHR); + + struct vtn_type *component_type = vtn_get_type(b, w[2]); + + const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3])); + const uint32_t rows = vtn_constant_uint(b, w[4]); + const uint32_t cols = vtn_constant_uint(b, w[5]); + + vtn_assert(rows < 256); + vtn_assert(cols < 256); + + enum glsl_cmat_use use = vtn_cooperative_matrix_use_to_glsl(vtn_constant_uint(b, w[6])); + + val->type->base_type = vtn_base_type_cooperative_matrix; + vtn_fail_if(!glsl_type_is_numeric(component_type->type), + "OpTypeCooperativeMatrixKHR " + "Component Type must be a scalar numerical type."); + + val->type->desc.element_type = glsl_get_base_type(component_type->type); + val->type->desc.scope = scope; + val->type->desc.rows = rows; + val->type->desc.cols = cols; + val->type->desc.use = use; + + val->type->type = glsl_cmat_type(&val->type->desc); + val->type->component_type = component_type; +} + +static enum glsl_matrix_layout +vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout) +{ + switch (layout) { + case SpvCooperativeMatrixLayoutRowMajorKHR: + return GLSL_MATRIX_LAYOUT_ROW_MAJOR; + case SpvCooperativeMatrixLayoutColumnMajorKHR: + return GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; + default: + unreachable("Unexpected cooperative matrix layout"); + } +} + +nir_deref_instr * +vtn_create_cmat_temporary(struct vtn_builder *b, const struct glsl_type *t, const char *name) +{ + nir_variable *var = nir_local_variable_create(b->nb.impl, t, name); + return nir_build_deref_var(&b->nb, var); +} + +static nir_deref_instr * +vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id) +{ + nir_deref_instr *deref = vtn_get_deref_for_id(b, value_id); + vtn_assert(glsl_type_is_cmat(deref->type)); + return deref; +} + +void +vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count) +{ + switch (opcode) { + case SpvOpCooperativeMatrixLoadKHR: { + struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer); + struct vtn_pointer *src = vtn_value_to_pointer(b, src_val); + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + + const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]); + nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32); + + SpvMemoryAccessMask access = SpvMemoryAccessMaskNone; + if (count > 6) { + unsigned idx = 6, alignment; + SpvScope scope; + vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope); + vtn_emit_make_visible_barrier(b, access, scope, src->mode); + } + + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast"); + nir_cmat_load(&b->nb, &dst->def, vtn_pointer_to_ssa(b, src), stride, + .matrix_layout = vtn_matrix_layout_to_glsl(layout)); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + + case SpvOpCooperativeMatrixStoreKHR: { + struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer); + struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val); + + const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]); + nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32); + + SpvMemoryAccessMask access = SpvMemoryAccessMaskNone; + if (count > 5) { + unsigned idx = 5, alignment; + SpvScope scope; + vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL); + vtn_emit_make_available_barrier(b, access, scope, dest->mode); + } + + nir_deref_instr *src = vtn_get_cmat_deref(b, w[2]); + nir_cmat_store(&b->nb, vtn_pointer_to_ssa(b, dest), &src->def, stride, + .matrix_layout = vtn_matrix_layout_to_glsl(layout)); + break; + } + + case SpvOpCooperativeMatrixLengthKHR: { + struct vtn_type *type = vtn_get_type(b, w[3]); + nir_def *def = nir_cmat_length(&b->nb, .cmat_desc = type->desc); + vtn_push_nir_ssa(b, w[2], def); + break; + } + + case SpvOpCooperativeMatrixMulAddKHR: { + nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]); + nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]); + nir_deref_instr *mat_c = vtn_get_cmat_deref(b, w[5]); + + const uint32_t operands = count > 6 ? w[6] : 0; + const bool saturate = operands & SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask; + const unsigned signed_mask = operands & (SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask | + SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask | + SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask | + SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask); + + STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask == NIR_CMAT_A_SIGNED); + STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask == NIR_CMAT_B_SIGNED); + STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask == NIR_CMAT_C_SIGNED); + STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask == NIR_CMAT_RESULT_SIGNED); + + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_muladd"); + + nir_cmat_muladd(&b->nb, &dst->def, &mat_a->def, &mat_b->def, &mat_c->def, + .saturate = saturate, + .cmat_signed_mask = signed_mask); + + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + + case SpvOpBitcast: { + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + vtn_assert(dst_type->base_type == vtn_base_type_cooperative_matrix); + nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]); + + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast"); + nir_cmat_bitcast(&b->nb, &dst->def, &src->def); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + + default: + unreachable("Unexpected opcode for cooperative matrix instruction"); + } +} + +void +vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val, + const struct glsl_type *dest_type, SpvOp opcode, + const uint32_t *w, unsigned count) +{ + vtn_assert(glsl_type_is_cmat(dest_type)); + + switch (opcode) { + case SpvOpConvertFToU: + case SpvOpConvertFToS: + case SpvOpConvertSToF: + case SpvOpConvertUToF: + case SpvOpUConvert: + case SpvOpSConvert: + case SpvOpFConvert: + case SpvOpFNegate: + case SpvOpSNegate: { + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]); + + unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type)); + unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type)); + + bool ignored = false; + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, + src_bit_size, dst_bit_size); + + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary"); + nir_cmat_unary_op(&b->nb, &dst->def, &src->def, + .alu_op = op); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + + case SpvOpFAdd: + case SpvOpFSub: + case SpvOpFMul: + case SpvOpFDiv: + case SpvOpIAdd: + case SpvOpISub: + case SpvOpIMul: + case SpvOpSDiv: + case SpvOpUDiv: { + bool ignored = false; + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0); + + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]); + nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]); + + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary"); + nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def, + .alu_op = op); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + + case SpvOpMatrixTimesScalar: { + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + nir_deref_instr *mat = vtn_get_cmat_deref(b, w[3]); + + struct vtn_ssa_value *scalar_val = vtn_ssa_value(b, w[4]); + vtn_assert(glsl_type_is_scalar(scalar_val->type)); + nir_op op = glsl_type_is_integer(scalar_val->type) ? nir_op_imul : nir_op_fmul; + + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_times_scalar"); + nir_cmat_scalar_op(&b->nb, &dst->def, &mat->def, scalar_val->def, + .alu_op = op); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + + default: + unreachable("invalid cooperative matrix alu instruction"); + } +} + +struct vtn_ssa_value * +vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat, + const uint32_t *indices, unsigned num_indices) +{ + vtn_assert(glsl_type_is_cmat(mat->type)); + nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat); + + vtn_assert(num_indices == 1); + nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32); + + const struct glsl_type *element_type = glsl_get_cmat_element(mat->type); + struct vtn_ssa_value *ret = vtn_create_ssa_value(b, element_type); + ret->def = nir_cmat_extract(&b->nb, glsl_get_bit_size(element_type), &mat_deref->def, index); + return ret; +} + +struct vtn_ssa_value * +vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat, + struct vtn_ssa_value *insert, const uint32_t *indices, + unsigned num_indices) +{ + vtn_assert(glsl_type_is_cmat(mat->type)); + nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat); + + vtn_assert(num_indices == 1); + nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32); + + nir_deref_instr *dst = vtn_create_cmat_temporary(b, mat_deref->type, "cmat_insert"); + nir_cmat_insert(&b->nb, &dst->def, insert->def, &mat_deref->def, index); + + struct vtn_ssa_value *ret = vtn_create_ssa_value(b, dst->type); + vtn_set_ssa_value_var(b, ret, dst->var); + return ret; +} diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 66c5cdbbea2..02fe2f26ae8 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -283,6 +283,7 @@ enum vtn_base_type { vtn_base_type_ray_query, vtn_base_type_function, vtn_base_type_event, + vtn_base_type_cooperative_matrix, }; struct vtn_type { @@ -391,6 +392,12 @@ struct vtn_type { /* Return type for functions */ struct vtn_type *return_type; }; + + /* Members for cooperative matrix types. */ + struct { + struct glsl_cmat_description desc; + struct vtn_type *component_type; + }; }; }; @@ -1048,4 +1055,20 @@ void vtn_emit_make_visible_barrier(struct vtn_builder *b, SpvMemoryAccessMask ac void vtn_emit_make_available_barrier(struct vtn_builder *b, SpvMemoryAccessMask access, SpvScope scope, enum vtn_variable_mode mode); + +void vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val, + SpvOp opcode, const uint32_t *w, unsigned count); +void vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count); +void vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val, + const struct glsl_type *dest_type, SpvOp opcode, + const uint32_t *w, unsigned count); +struct vtn_ssa_value *vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat, + const uint32_t *indices, unsigned num_indices); +struct vtn_ssa_value *vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat, + struct vtn_ssa_value *insert, + const uint32_t *indices, unsigned num_indices); +nir_deref_instr *vtn_create_cmat_temporary(struct vtn_builder *b, + const struct glsl_type *t, const char *name); + #endif /* _VTN_PRIVATE_H_ */ diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index 9e14d695e12..d00e3d785fd 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -474,8 +474,15 @@ vtn_pointer_dereference(struct vtn_builder *b, nir_def *arr_index = vtn_access_link_as_ssa(b, deref_chain->link[idx], 1, tail->def.bit_size); + if (type->base_type == vtn_base_type_cooperative_matrix) { + const struct glsl_type *element_type = glsl_get_cmat_element(type->type); + tail = nir_build_deref_cast(&b->nb, &tail->def, tail->modes, + glsl_array_type(element_type, 0, 0), 0); + type = type->component_type; + } else { + type = type->array_element; + } tail = nir_build_deref_array(&b->nb, tail, arr_index); - type = type->array_element; } tail->arr.in_bounds = deref_chain->in_bounds; @@ -510,7 +517,16 @@ _vtn_local_load_store(struct vtn_builder *b, bool load, nir_deref_instr *deref, struct vtn_ssa_value *inout, enum gl_access_qualifier access) { - if (glsl_type_is_vector_or_scalar(deref->type)) { + if (glsl_type_is_cmat(deref->type)) { + if (load) { + nir_deref_instr *temp = vtn_create_cmat_temporary(b, deref->type, "cmat_ssa"); + nir_cmat_copy(&b->nb, &temp->def, &deref->def); + vtn_set_ssa_value_var(b, inout, temp->var); + } else { + nir_deref_instr *src_deref = vtn_get_deref_for_ssa_value(b, inout); + nir_cmat_copy(&b->nb, &deref->def, &src_deref->def); + } + } else if (glsl_type_is_vector_or_scalar(deref->type)) { if (load) { inout->def = nir_load_deref_with_access(&b->nb, deref, access); } else { @@ -555,7 +571,17 @@ get_deref_tail(nir_deref_instr *deref) nir_deref_instr *parent = nir_instr_as_deref(deref->parent.ssa->parent_instr); - if (glsl_type_is_vector(parent->type)) + if (parent->deref_type == nir_deref_type_cast && + parent->parent.ssa->parent_instr->type == nir_instr_type_deref) { + nir_deref_instr *grandparent = + nir_instr_as_deref(parent->parent.ssa->parent_instr); + + if (glsl_type_is_cmat(grandparent->type)) + return grandparent; + } + + if (glsl_type_is_vector(parent->type) || + glsl_type_is_cmat(parent->type)) return parent; else return deref; @@ -571,7 +597,19 @@ vtn_local_load(struct vtn_builder *b, nir_deref_instr *src, if (src_tail != src) { val->type = src->type; - val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa); + + if (glsl_type_is_cmat(src_tail->type)) { + assert(val->is_variable); + nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val); + + /* Reset is_variable because we are repurposing val. */ + val->is_variable = false; + val->def = nir_cmat_extract(&b->nb, + glsl_get_bit_size(src->type), + &mat->def, src->arr.index.ssa); + } else { + val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa); + } } return val; @@ -587,8 +625,16 @@ vtn_local_store(struct vtn_builder *b, struct vtn_ssa_value *src, struct vtn_ssa_value *val = vtn_create_ssa_value(b, dest_tail->type); _vtn_local_load_store(b, true, dest_tail, val, access); - val->def = nir_vector_insert(&b->nb, val->def, src->def, - dest->arr.index.ssa); + if (glsl_type_is_cmat(dest_tail->type)) { + nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val); + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dest_tail->type, "cmat_insert"); + nir_cmat_insert(&b->nb, &dst->def, src->def, &mat->def, dest->arr.index.ssa); + vtn_set_ssa_value_var(b, val, dst->var); + } else { + val->def = nir_vector_insert(&b->nb, val->def, src->def, + dest->arr.index.ssa); + } + _vtn_local_load_store(b, false, dest_tail, val, access); } else { _vtn_local_load_store(b, false, dest_tail, src, access); @@ -654,6 +700,7 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load, case GLSL_TYPE_FLOAT16: case GLSL_TYPE_BOOL: case GLSL_TYPE_DOUBLE: + case GLSL_TYPE_COOPERATIVE_MATRIX: if (glsl_type_is_vector_or_scalar(ptr->type->type)) { /* We hit a vector or scalar; go ahead and emit the load[s] */ nir_deref_instr *deref = vtn_pointer_to_deref(b, ptr);