spirv: Implement SPV_KHR_cooperative_matrix

Includes a modified version of using extract/insert for OpLoad/OpStore
from Ian.

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> (earlier version)
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl> (earlier version)
Acked-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23825>
This commit is contained in:
Caio Oliveira 2023-06-16 17:02:39 -07:00 committed by Marge Bot
parent b17a2c35bc
commit b98f87612b
7 changed files with 436 additions and 12 deletions

View file

@ -47,6 +47,7 @@ struct spirv_supported_capabilities {
bool amd_shader_explicit_vertex_parameter;
bool amd_trinary_minmax;
bool atomic_storage;
bool cooperative_matrix;
bool demote_to_helper_invocation;
bool derivative_group;
bool descriptor_array_dynamic_indexing;

View file

@ -51,6 +51,7 @@ files_libvtn = files(
'vtn_alu.c',
'vtn_amd.c',
'vtn_cfg.c',
'vtn_cmat.c',
'vtn_glsl450.c',
'vtn_opencl.c',
'vtn_private.h',

View file

@ -266,7 +266,10 @@ vtn_undef_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
val->type = glsl_get_bare_type(type);
if (glsl_type_is_vector_or_scalar(type)) {
if (glsl_type_is_cmat(type)) {
nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_undef");
vtn_set_ssa_value_var(b, val, mat->var);
} else if (glsl_type_is_vector_or_scalar(type)) {
unsigned num_components = glsl_get_vector_elements(val->type);
unsigned bit_size = glsl_get_bit_size(val->type);
val->def = nir_undef(&b->nb, num_components, bit_size);
@ -296,7 +299,15 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
val->type = glsl_get_bare_type(type);
if (glsl_type_is_vector_or_scalar(type)) {
if (glsl_type_is_cmat(type)) {
const struct glsl_type *element_type = glsl_get_cmat_element(type);
nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_constant");
nir_cmat_construct(&b->nb, &mat->def,
nir_build_imm(&b->nb, 1, glsl_get_bit_size(element_type),
constant->values));
vtn_set_ssa_value_var(b, val, mat->var);
} else if (glsl_type_is_vector_or_scalar(type)) {
val->def = nir_build_imm(&b->nb, glsl_get_vector_elements(val->type),
glsl_get_bit_size(val->type),
constant->values);
@ -859,6 +870,7 @@ vtn_types_compatible(struct vtn_builder *b,
case vtn_base_type_sampler:
case vtn_base_type_sampled_image:
case vtn_base_type_event:
case vtn_base_type_cooperative_matrix:
return t1->type == t2->type;
case vtn_base_type_array:
@ -921,6 +933,7 @@ vtn_type_copy(struct vtn_builder *b, struct vtn_type *src)
case vtn_base_type_event:
case vtn_base_type_accel_struct:
case vtn_base_type_ray_query:
case vtn_base_type_cooperative_matrix:
/* Nothing more to do */
break;
@ -1951,6 +1964,10 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
break;
}
case SpvOpTypeCooperativeMatrixKHR:
vtn_handle_cooperative_type(b, val, opcode, w, count);
break;
case SpvOpTypeEvent:
val->type->base_type = vtn_base_type_event;
/*
@ -2135,9 +2152,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
case SpvOpSpecConstantComposite:
case SpvOpConstantComposite: {
unsigned elem_count = count - 3;
vtn_fail_if(elem_count != val->type->length,
unsigned expected_length = val->type->base_type == vtn_base_type_cooperative_matrix ?
1 : val->type->length;
vtn_fail_if(elem_count != expected_length,
"%s has %u constituents, expected %u",
spirv_op_to_string(opcode), elem_count, val->type->length);
spirv_op_to_string(opcode), elem_count, expected_length);
nir_constant **elems = ralloc_array(b, nir_constant *, elem_count);
val->is_undef_constant = true;
@ -2173,6 +2192,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
val->constant->elements = elems;
break;
case vtn_base_type_cooperative_matrix:
val->constant->values[0] = elems[0]->values[0];
break;
default:
vtn_fail("Result type of %s must be a composite type",
spirv_op_to_string(opcode));
@ -2685,7 +2708,7 @@ vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
if (!glsl_type_is_vector_or_scalar(type)) {
unsigned elems = glsl_get_length(val->type);
val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
if (glsl_type_is_array_or_matrix(type)) {
if (glsl_type_is_array_or_matrix(type) || glsl_type_is_cmat(type)) {
const struct glsl_type *elem_type = glsl_get_array_element(type);
for (unsigned i = 0; i < elems; i++)
val->elems[i] = vtn_create_ssa_value(b, elem_type);
@ -4216,6 +4239,9 @@ vtn_composite_insert(struct vtn_builder *b, struct vtn_ssa_value *src,
struct vtn_ssa_value *insert, const uint32_t *indices,
unsigned num_indices)
{
if (glsl_type_is_cmat(src->type))
return vtn_cooperative_matrix_insert(b, src, insert, indices, num_indices);
struct vtn_ssa_value *dest = vtn_composite_copy(b, src);
struct vtn_ssa_value *cur = dest;
@ -4254,6 +4280,9 @@ static struct vtn_ssa_value *
vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
const uint32_t *indices, unsigned num_indices)
{
if (glsl_type_is_cmat(src->type))
return vtn_cooperative_matrix_extract(b, src, indices, num_indices);
struct vtn_ssa_value *cur = src;
for (unsigned i = 0; i < num_indices; i++) {
if (glsl_type_is_vector_or_scalar(cur->type)) {
@ -4310,7 +4339,12 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
case SpvOpCompositeConstruct: {
unsigned elems = count - 3;
assume(elems >= 1);
if (glsl_type_is_vector_or_scalar(type->type)) {
if (type->base_type == vtn_base_type_cooperative_matrix) {
vtn_assert(elems == 1);
nir_deref_instr *mat = vtn_create_cmat_temporary(b, type->type, "cmat_construct");
nir_cmat_construct(&b->nb, &mat->def, vtn_get_nir_ssa(b, w[3]));
vtn_set_ssa_value_var(b, ssa, mat->var);
} else if (glsl_type_is_vector_or_scalar(type->type)) {
nir_def *srcs[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < elems; i++) {
srcs[i] = vtn_get_nir_ssa(b, w[3 + i]);
@ -5022,6 +5056,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
spv_check_supported(shader_enqueue, cap);
break;
case SpvCapabilityCooperativeMatrixKHR:
spv_check_supported(cooperative_matrix, cap);
break;
default:
vtn_fail("Unhandled capability: %s (%u)",
spirv_capability_to_string(cap), cap);
@ -5656,6 +5694,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpTypePipe:
case SpvOpTypeAccelerationStructureKHR:
case SpvOpTypeRayQueryKHR:
case SpvOpTypeCooperativeMatrixKHR:
vtn_handle_type(b, opcode, w, count);
break;
@ -6621,6 +6660,13 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpFinishWritingNodePayloadAMDX:
break;
case SpvOpCooperativeMatrixLoadKHR:
case SpvOpCooperativeMatrixStoreKHR:
case SpvOpCooperativeMatrixLengthKHR:
case SpvOpCooperativeMatrixMulAddKHR:
vtn_handle_cooperative_instruction(b, opcode, w, count);
break;
default:
vtn_fail_with_opcode("Unhandled opcode", opcode);
}

View file

@ -597,6 +597,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
if (glsl_type_is_cmat(dest_type)) {
vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
return;
}
vtn_handle_no_contraction(b, dest_val);
bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
@ -1297,6 +1302,11 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
*/
struct vtn_type *type = vtn_get_type(b, w[1]);
if (type->base_type == vtn_base_type_cooperative_matrix) {
vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
return;
}
struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
vtn_fail_if(src->num_components * src->bit_size !=

View file

@ -0,0 +1,296 @@
/*
* Copyright 2023 Intel Corporation
* SPDX-License-Identifier: MIT
*/
#include "glsl_types.h"
#include "nir.h"
#include "nir_types.h"
#include "vtn_private.h"
static enum glsl_cmat_use
vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)
{
switch (use) {
case SpvCooperativeMatrixUseMatrixAKHR:
return GLSL_CMAT_USE_A;
case SpvCooperativeMatrixUseMatrixBKHR:
return GLSL_CMAT_USE_B;
case SpvCooperativeMatrixUseMatrixAccumulatorKHR:
return GLSL_CMAT_USE_ACCUMULATOR;
default:
unreachable("Unexpected cooperative matrix use");
}
}
void
vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
SpvOp opcode, const uint32_t *w, unsigned count)
{
vtn_assert(opcode == SpvOpTypeCooperativeMatrixKHR);
struct vtn_type *component_type = vtn_get_type(b, w[2]);
const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
const uint32_t rows = vtn_constant_uint(b, w[4]);
const uint32_t cols = vtn_constant_uint(b, w[5]);
vtn_assert(rows < 256);
vtn_assert(cols < 256);
enum glsl_cmat_use use = vtn_cooperative_matrix_use_to_glsl(vtn_constant_uint(b, w[6]));
val->type->base_type = vtn_base_type_cooperative_matrix;
vtn_fail_if(!glsl_type_is_numeric(component_type->type),
"OpTypeCooperativeMatrixKHR "
"Component Type must be a scalar numerical type.");
val->type->desc.element_type = glsl_get_base_type(component_type->type);
val->type->desc.scope = scope;
val->type->desc.rows = rows;
val->type->desc.cols = cols;
val->type->desc.use = use;
val->type->type = glsl_cmat_type(&val->type->desc);
val->type->component_type = component_type;
}
static enum glsl_matrix_layout
vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)
{
switch (layout) {
case SpvCooperativeMatrixLayoutRowMajorKHR:
return GLSL_MATRIX_LAYOUT_ROW_MAJOR;
case SpvCooperativeMatrixLayoutColumnMajorKHR:
return GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
default:
unreachable("Unexpected cooperative matrix layout");
}
}
nir_deref_instr *
vtn_create_cmat_temporary(struct vtn_builder *b, const struct glsl_type *t, const char *name)
{
nir_variable *var = nir_local_variable_create(b->nb.impl, t, name);
return nir_build_deref_var(&b->nb, var);
}
static nir_deref_instr *
vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id)
{
nir_deref_instr *deref = vtn_get_deref_for_id(b, value_id);
vtn_assert(glsl_type_is_cmat(deref->type));
return deref;
}
void
vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
switch (opcode) {
case SpvOpCooperativeMatrixLoadKHR: {
struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]);
nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32);
SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
if (count > 6) {
unsigned idx = 6, alignment;
SpvScope scope;
vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
vtn_emit_make_visible_barrier(b, access, scope, src->mode);
}
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
nir_cmat_load(&b->nb, &dst->def, vtn_pointer_to_ssa(b, src), stride,
.matrix_layout = vtn_matrix_layout_to_glsl(layout));
vtn_push_var_ssa(b, w[2], dst->var);
break;
}
case SpvOpCooperativeMatrixStoreKHR: {
struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer);
struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]);
nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32);
SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
if (count > 5) {
unsigned idx = 5, alignment;
SpvScope scope;
vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
vtn_emit_make_available_barrier(b, access, scope, dest->mode);
}
nir_deref_instr *src = vtn_get_cmat_deref(b, w[2]);
nir_cmat_store(&b->nb, vtn_pointer_to_ssa(b, dest), &src->def, stride,
.matrix_layout = vtn_matrix_layout_to_glsl(layout));
break;
}
case SpvOpCooperativeMatrixLengthKHR: {
struct vtn_type *type = vtn_get_type(b, w[3]);
nir_def *def = nir_cmat_length(&b->nb, .cmat_desc = type->desc);
vtn_push_nir_ssa(b, w[2], def);
break;
}
case SpvOpCooperativeMatrixMulAddKHR: {
nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
nir_deref_instr *mat_c = vtn_get_cmat_deref(b, w[5]);
const uint32_t operands = count > 6 ? w[6] : 0;
const bool saturate = operands & SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask;
const unsigned signed_mask = operands & (SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask);
STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask == NIR_CMAT_A_SIGNED);
STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask == NIR_CMAT_B_SIGNED);
STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask == NIR_CMAT_C_SIGNED);
STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask == NIR_CMAT_RESULT_SIGNED);
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_muladd");
nir_cmat_muladd(&b->nb, &dst->def, &mat_a->def, &mat_b->def, &mat_c->def,
.saturate = saturate,
.cmat_signed_mask = signed_mask);
vtn_push_var_ssa(b, w[2], dst->var);
break;
}
case SpvOpBitcast: {
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
vtn_assert(dst_type->base_type == vtn_base_type_cooperative_matrix);
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
nir_cmat_bitcast(&b->nb, &dst->def, &src->def);
vtn_push_var_ssa(b, w[2], dst->var);
break;
}
default:
unreachable("Unexpected opcode for cooperative matrix instruction");
}
}
void
vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
const struct glsl_type *dest_type, SpvOp opcode,
const uint32_t *w, unsigned count)
{
vtn_assert(glsl_type_is_cmat(dest_type));
switch (opcode) {
case SpvOpConvertFToU:
case SpvOpConvertFToS:
case SpvOpConvertSToF:
case SpvOpConvertUToF:
case SpvOpUConvert:
case SpvOpSConvert:
case SpvOpFConvert:
case SpvOpFNegate:
case SpvOpSNegate: {
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type));
unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type));
bool ignored = false;
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
src_bit_size, dst_bit_size);
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
.alu_op = op);
vtn_push_var_ssa(b, w[2], dst->var);
break;
}
case SpvOpFAdd:
case SpvOpFSub:
case SpvOpFMul:
case SpvOpFDiv:
case SpvOpIAdd:
case SpvOpISub:
case SpvOpIMul:
case SpvOpSDiv:
case SpvOpUDiv: {
bool ignored = false;
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0);
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,
.alu_op = op);
vtn_push_var_ssa(b, w[2], dst->var);
break;
}
case SpvOpMatrixTimesScalar: {
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
nir_deref_instr *mat = vtn_get_cmat_deref(b, w[3]);
struct vtn_ssa_value *scalar_val = vtn_ssa_value(b, w[4]);
vtn_assert(glsl_type_is_scalar(scalar_val->type));
nir_op op = glsl_type_is_integer(scalar_val->type) ? nir_op_imul : nir_op_fmul;
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_times_scalar");
nir_cmat_scalar_op(&b->nb, &dst->def, &mat->def, scalar_val->def,
.alu_op = op);
vtn_push_var_ssa(b, w[2], dst->var);
break;
}
default:
unreachable("invalid cooperative matrix alu instruction");
}
}
struct vtn_ssa_value *
vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
const uint32_t *indices, unsigned num_indices)
{
vtn_assert(glsl_type_is_cmat(mat->type));
nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
vtn_assert(num_indices == 1);
nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
const struct glsl_type *element_type = glsl_get_cmat_element(mat->type);
struct vtn_ssa_value *ret = vtn_create_ssa_value(b, element_type);
ret->def = nir_cmat_extract(&b->nb, glsl_get_bit_size(element_type), &mat_deref->def, index);
return ret;
}
struct vtn_ssa_value *
vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
struct vtn_ssa_value *insert, const uint32_t *indices,
unsigned num_indices)
{
vtn_assert(glsl_type_is_cmat(mat->type));
nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
vtn_assert(num_indices == 1);
nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
nir_deref_instr *dst = vtn_create_cmat_temporary(b, mat_deref->type, "cmat_insert");
nir_cmat_insert(&b->nb, &dst->def, insert->def, &mat_deref->def, index);
struct vtn_ssa_value *ret = vtn_create_ssa_value(b, dst->type);
vtn_set_ssa_value_var(b, ret, dst->var);
return ret;
}

View file

@ -283,6 +283,7 @@ enum vtn_base_type {
vtn_base_type_ray_query,
vtn_base_type_function,
vtn_base_type_event,
vtn_base_type_cooperative_matrix,
};
struct vtn_type {
@ -391,6 +392,12 @@ struct vtn_type {
/* Return type for functions */
struct vtn_type *return_type;
};
/* Members for cooperative matrix types. */
struct {
struct glsl_cmat_description desc;
struct vtn_type *component_type;
};
};
};
@ -1048,4 +1055,20 @@ void vtn_emit_make_visible_barrier(struct vtn_builder *b, SpvMemoryAccessMask ac
void vtn_emit_make_available_barrier(struct vtn_builder *b, SpvMemoryAccessMask access,
SpvScope scope, enum vtn_variable_mode mode);
void vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
SpvOp opcode, const uint32_t *w, unsigned count);
void vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count);
void vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
const struct glsl_type *dest_type, SpvOp opcode,
const uint32_t *w, unsigned count);
struct vtn_ssa_value *vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
const uint32_t *indices, unsigned num_indices);
struct vtn_ssa_value *vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
struct vtn_ssa_value *insert,
const uint32_t *indices, unsigned num_indices);
nir_deref_instr *vtn_create_cmat_temporary(struct vtn_builder *b,
const struct glsl_type *t, const char *name);
#endif /* _VTN_PRIVATE_H_ */

View file

@ -474,8 +474,15 @@ vtn_pointer_dereference(struct vtn_builder *b,
nir_def *arr_index =
vtn_access_link_as_ssa(b, deref_chain->link[idx], 1,
tail->def.bit_size);
if (type->base_type == vtn_base_type_cooperative_matrix) {
const struct glsl_type *element_type = glsl_get_cmat_element(type->type);
tail = nir_build_deref_cast(&b->nb, &tail->def, tail->modes,
glsl_array_type(element_type, 0, 0), 0);
type = type->component_type;
} else {
type = type->array_element;
}
tail = nir_build_deref_array(&b->nb, tail, arr_index);
type = type->array_element;
}
tail->arr.in_bounds = deref_chain->in_bounds;
@ -510,7 +517,16 @@ _vtn_local_load_store(struct vtn_builder *b, bool load, nir_deref_instr *deref,
struct vtn_ssa_value *inout,
enum gl_access_qualifier access)
{
if (glsl_type_is_vector_or_scalar(deref->type)) {
if (glsl_type_is_cmat(deref->type)) {
if (load) {
nir_deref_instr *temp = vtn_create_cmat_temporary(b, deref->type, "cmat_ssa");
nir_cmat_copy(&b->nb, &temp->def, &deref->def);
vtn_set_ssa_value_var(b, inout, temp->var);
} else {
nir_deref_instr *src_deref = vtn_get_deref_for_ssa_value(b, inout);
nir_cmat_copy(&b->nb, &deref->def, &src_deref->def);
}
} else if (glsl_type_is_vector_or_scalar(deref->type)) {
if (load) {
inout->def = nir_load_deref_with_access(&b->nb, deref, access);
} else {
@ -555,7 +571,17 @@ get_deref_tail(nir_deref_instr *deref)
nir_deref_instr *parent =
nir_instr_as_deref(deref->parent.ssa->parent_instr);
if (glsl_type_is_vector(parent->type))
if (parent->deref_type == nir_deref_type_cast &&
parent->parent.ssa->parent_instr->type == nir_instr_type_deref) {
nir_deref_instr *grandparent =
nir_instr_as_deref(parent->parent.ssa->parent_instr);
if (glsl_type_is_cmat(grandparent->type))
return grandparent;
}
if (glsl_type_is_vector(parent->type) ||
glsl_type_is_cmat(parent->type))
return parent;
else
return deref;
@ -571,7 +597,19 @@ vtn_local_load(struct vtn_builder *b, nir_deref_instr *src,
if (src_tail != src) {
val->type = src->type;
val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa);
if (glsl_type_is_cmat(src_tail->type)) {
assert(val->is_variable);
nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val);
/* Reset is_variable because we are repurposing val. */
val->is_variable = false;
val->def = nir_cmat_extract(&b->nb,
glsl_get_bit_size(src->type),
&mat->def, src->arr.index.ssa);
} else {
val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa);
}
}
return val;
@ -587,8 +625,16 @@ vtn_local_store(struct vtn_builder *b, struct vtn_ssa_value *src,
struct vtn_ssa_value *val = vtn_create_ssa_value(b, dest_tail->type);
_vtn_local_load_store(b, true, dest_tail, val, access);
val->def = nir_vector_insert(&b->nb, val->def, src->def,
dest->arr.index.ssa);
if (glsl_type_is_cmat(dest_tail->type)) {
nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val);
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dest_tail->type, "cmat_insert");
nir_cmat_insert(&b->nb, &dst->def, src->def, &mat->def, dest->arr.index.ssa);
vtn_set_ssa_value_var(b, val, dst->var);
} else {
val->def = nir_vector_insert(&b->nb, val->def, src->def,
dest->arr.index.ssa);
}
_vtn_local_load_store(b, false, dest_tail, val, access);
} else {
_vtn_local_load_store(b, false, dest_tail, src, access);
@ -654,6 +700,7 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load,
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BOOL:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_COOPERATIVE_MATRIX:
if (glsl_type_is_vector_or_scalar(ptr->type->type)) {
/* We hit a vector or scalar; go ahead and emit the load[s] */
nir_deref_instr *deref = vtn_pointer_to_deref(b, ptr);