mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-21 18:00:13 +01:00
spirv: Implement SPV_KHR_cooperative_matrix
Includes a modified version of using extract/insert for OpLoad/OpStore from Ian. Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> (earlier version) Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl> (earlier version) Acked-by: Faith Ekstrand <faith.ekstrand@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23825>
This commit is contained in:
parent
b17a2c35bc
commit
b98f87612b
7 changed files with 436 additions and 12 deletions
|
|
@ -47,6 +47,7 @@ struct spirv_supported_capabilities {
|
|||
bool amd_shader_explicit_vertex_parameter;
|
||||
bool amd_trinary_minmax;
|
||||
bool atomic_storage;
|
||||
bool cooperative_matrix;
|
||||
bool demote_to_helper_invocation;
|
||||
bool derivative_group;
|
||||
bool descriptor_array_dynamic_indexing;
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ files_libvtn = files(
|
|||
'vtn_alu.c',
|
||||
'vtn_amd.c',
|
||||
'vtn_cfg.c',
|
||||
'vtn_cmat.c',
|
||||
'vtn_glsl450.c',
|
||||
'vtn_opencl.c',
|
||||
'vtn_private.h',
|
||||
|
|
|
|||
|
|
@ -266,7 +266,10 @@ vtn_undef_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
|
|||
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
|
||||
val->type = glsl_get_bare_type(type);
|
||||
|
||||
if (glsl_type_is_vector_or_scalar(type)) {
|
||||
if (glsl_type_is_cmat(type)) {
|
||||
nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_undef");
|
||||
vtn_set_ssa_value_var(b, val, mat->var);
|
||||
} else if (glsl_type_is_vector_or_scalar(type)) {
|
||||
unsigned num_components = glsl_get_vector_elements(val->type);
|
||||
unsigned bit_size = glsl_get_bit_size(val->type);
|
||||
val->def = nir_undef(&b->nb, num_components, bit_size);
|
||||
|
|
@ -296,7 +299,15 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
|
|||
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
|
||||
val->type = glsl_get_bare_type(type);
|
||||
|
||||
if (glsl_type_is_vector_or_scalar(type)) {
|
||||
if (glsl_type_is_cmat(type)) {
|
||||
const struct glsl_type *element_type = glsl_get_cmat_element(type);
|
||||
|
||||
nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_constant");
|
||||
nir_cmat_construct(&b->nb, &mat->def,
|
||||
nir_build_imm(&b->nb, 1, glsl_get_bit_size(element_type),
|
||||
constant->values));
|
||||
vtn_set_ssa_value_var(b, val, mat->var);
|
||||
} else if (glsl_type_is_vector_or_scalar(type)) {
|
||||
val->def = nir_build_imm(&b->nb, glsl_get_vector_elements(val->type),
|
||||
glsl_get_bit_size(val->type),
|
||||
constant->values);
|
||||
|
|
@ -859,6 +870,7 @@ vtn_types_compatible(struct vtn_builder *b,
|
|||
case vtn_base_type_sampler:
|
||||
case vtn_base_type_sampled_image:
|
||||
case vtn_base_type_event:
|
||||
case vtn_base_type_cooperative_matrix:
|
||||
return t1->type == t2->type;
|
||||
|
||||
case vtn_base_type_array:
|
||||
|
|
@ -921,6 +933,7 @@ vtn_type_copy(struct vtn_builder *b, struct vtn_type *src)
|
|||
case vtn_base_type_event:
|
||||
case vtn_base_type_accel_struct:
|
||||
case vtn_base_type_ray_query:
|
||||
case vtn_base_type_cooperative_matrix:
|
||||
/* Nothing more to do */
|
||||
break;
|
||||
|
||||
|
|
@ -1951,6 +1964,10 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
|
|||
break;
|
||||
}
|
||||
|
||||
case SpvOpTypeCooperativeMatrixKHR:
|
||||
vtn_handle_cooperative_type(b, val, opcode, w, count);
|
||||
break;
|
||||
|
||||
case SpvOpTypeEvent:
|
||||
val->type->base_type = vtn_base_type_event;
|
||||
/*
|
||||
|
|
@ -2135,9 +2152,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpSpecConstantComposite:
|
||||
case SpvOpConstantComposite: {
|
||||
unsigned elem_count = count - 3;
|
||||
vtn_fail_if(elem_count != val->type->length,
|
||||
unsigned expected_length = val->type->base_type == vtn_base_type_cooperative_matrix ?
|
||||
1 : val->type->length;
|
||||
vtn_fail_if(elem_count != expected_length,
|
||||
"%s has %u constituents, expected %u",
|
||||
spirv_op_to_string(opcode), elem_count, val->type->length);
|
||||
spirv_op_to_string(opcode), elem_count, expected_length);
|
||||
|
||||
nir_constant **elems = ralloc_array(b, nir_constant *, elem_count);
|
||||
val->is_undef_constant = true;
|
||||
|
|
@ -2173,6 +2192,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
val->constant->elements = elems;
|
||||
break;
|
||||
|
||||
case vtn_base_type_cooperative_matrix:
|
||||
val->constant->values[0] = elems[0]->values[0];
|
||||
break;
|
||||
|
||||
default:
|
||||
vtn_fail("Result type of %s must be a composite type",
|
||||
spirv_op_to_string(opcode));
|
||||
|
|
@ -2685,7 +2708,7 @@ vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
|
|||
if (!glsl_type_is_vector_or_scalar(type)) {
|
||||
unsigned elems = glsl_get_length(val->type);
|
||||
val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
|
||||
if (glsl_type_is_array_or_matrix(type)) {
|
||||
if (glsl_type_is_array_or_matrix(type) || glsl_type_is_cmat(type)) {
|
||||
const struct glsl_type *elem_type = glsl_get_array_element(type);
|
||||
for (unsigned i = 0; i < elems; i++)
|
||||
val->elems[i] = vtn_create_ssa_value(b, elem_type);
|
||||
|
|
@ -4216,6 +4239,9 @@ vtn_composite_insert(struct vtn_builder *b, struct vtn_ssa_value *src,
|
|||
struct vtn_ssa_value *insert, const uint32_t *indices,
|
||||
unsigned num_indices)
|
||||
{
|
||||
if (glsl_type_is_cmat(src->type))
|
||||
return vtn_cooperative_matrix_insert(b, src, insert, indices, num_indices);
|
||||
|
||||
struct vtn_ssa_value *dest = vtn_composite_copy(b, src);
|
||||
|
||||
struct vtn_ssa_value *cur = dest;
|
||||
|
|
@ -4254,6 +4280,9 @@ static struct vtn_ssa_value *
|
|||
vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
|
||||
const uint32_t *indices, unsigned num_indices)
|
||||
{
|
||||
if (glsl_type_is_cmat(src->type))
|
||||
return vtn_cooperative_matrix_extract(b, src, indices, num_indices);
|
||||
|
||||
struct vtn_ssa_value *cur = src;
|
||||
for (unsigned i = 0; i < num_indices; i++) {
|
||||
if (glsl_type_is_vector_or_scalar(cur->type)) {
|
||||
|
|
@ -4310,7 +4339,12 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpCompositeConstruct: {
|
||||
unsigned elems = count - 3;
|
||||
assume(elems >= 1);
|
||||
if (glsl_type_is_vector_or_scalar(type->type)) {
|
||||
if (type->base_type == vtn_base_type_cooperative_matrix) {
|
||||
vtn_assert(elems == 1);
|
||||
nir_deref_instr *mat = vtn_create_cmat_temporary(b, type->type, "cmat_construct");
|
||||
nir_cmat_construct(&b->nb, &mat->def, vtn_get_nir_ssa(b, w[3]));
|
||||
vtn_set_ssa_value_var(b, ssa, mat->var);
|
||||
} else if (glsl_type_is_vector_or_scalar(type->type)) {
|
||||
nir_def *srcs[NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned i = 0; i < elems; i++) {
|
||||
srcs[i] = vtn_get_nir_ssa(b, w[3 + i]);
|
||||
|
|
@ -5022,6 +5056,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||
spv_check_supported(shader_enqueue, cap);
|
||||
break;
|
||||
|
||||
case SpvCapabilityCooperativeMatrixKHR:
|
||||
spv_check_supported(cooperative_matrix, cap);
|
||||
break;
|
||||
|
||||
default:
|
||||
vtn_fail("Unhandled capability: %s (%u)",
|
||||
spirv_capability_to_string(cap), cap);
|
||||
|
|
@ -5656,6 +5694,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpTypePipe:
|
||||
case SpvOpTypeAccelerationStructureKHR:
|
||||
case SpvOpTypeRayQueryKHR:
|
||||
case SpvOpTypeCooperativeMatrixKHR:
|
||||
vtn_handle_type(b, opcode, w, count);
|
||||
break;
|
||||
|
||||
|
|
@ -6621,6 +6660,13 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpFinishWritingNodePayloadAMDX:
|
||||
break;
|
||||
|
||||
case SpvOpCooperativeMatrixLoadKHR:
|
||||
case SpvOpCooperativeMatrixStoreKHR:
|
||||
case SpvOpCooperativeMatrixLengthKHR:
|
||||
case SpvOpCooperativeMatrixMulAddKHR:
|
||||
vtn_handle_cooperative_instruction(b, opcode, w, count);
|
||||
break;
|
||||
|
||||
default:
|
||||
vtn_fail_with_opcode("Unhandled opcode", opcode);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -597,6 +597,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
|||
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
|
||||
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
|
||||
|
||||
if (glsl_type_is_cmat(dest_type)) {
|
||||
vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
|
||||
return;
|
||||
}
|
||||
|
||||
vtn_handle_no_contraction(b, dest_val);
|
||||
bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
|
||||
|
||||
|
|
@ -1297,6 +1302,11 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
|
|||
*/
|
||||
|
||||
struct vtn_type *type = vtn_get_type(b, w[1]);
|
||||
if (type->base_type == vtn_base_type_cooperative_matrix) {
|
||||
vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
|
||||
return;
|
||||
}
|
||||
|
||||
struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
|
||||
|
||||
vtn_fail_if(src->num_components * src->bit_size !=
|
||||
|
|
|
|||
296
src/compiler/spirv/vtn_cmat.c
Normal file
296
src/compiler/spirv/vtn_cmat.c
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
/*
|
||||
* Copyright 2023 Intel Corporation
|
||||
* SPDX-License-Identifier: MIT
|
||||
*/
|
||||
|
||||
#include "glsl_types.h"
|
||||
#include "nir.h"
|
||||
#include "nir_types.h"
|
||||
#include "vtn_private.h"
|
||||
|
||||
static enum glsl_cmat_use
|
||||
vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)
|
||||
{
|
||||
switch (use) {
|
||||
case SpvCooperativeMatrixUseMatrixAKHR:
|
||||
return GLSL_CMAT_USE_A;
|
||||
case SpvCooperativeMatrixUseMatrixBKHR:
|
||||
return GLSL_CMAT_USE_B;
|
||||
case SpvCooperativeMatrixUseMatrixAccumulatorKHR:
|
||||
return GLSL_CMAT_USE_ACCUMULATOR;
|
||||
default:
|
||||
unreachable("Unexpected cooperative matrix use");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
|
||||
SpvOp opcode, const uint32_t *w, unsigned count)
|
||||
{
|
||||
vtn_assert(opcode == SpvOpTypeCooperativeMatrixKHR);
|
||||
|
||||
struct vtn_type *component_type = vtn_get_type(b, w[2]);
|
||||
|
||||
const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
|
||||
const uint32_t rows = vtn_constant_uint(b, w[4]);
|
||||
const uint32_t cols = vtn_constant_uint(b, w[5]);
|
||||
|
||||
vtn_assert(rows < 256);
|
||||
vtn_assert(cols < 256);
|
||||
|
||||
enum glsl_cmat_use use = vtn_cooperative_matrix_use_to_glsl(vtn_constant_uint(b, w[6]));
|
||||
|
||||
val->type->base_type = vtn_base_type_cooperative_matrix;
|
||||
vtn_fail_if(!glsl_type_is_numeric(component_type->type),
|
||||
"OpTypeCooperativeMatrixKHR "
|
||||
"Component Type must be a scalar numerical type.");
|
||||
|
||||
val->type->desc.element_type = glsl_get_base_type(component_type->type);
|
||||
val->type->desc.scope = scope;
|
||||
val->type->desc.rows = rows;
|
||||
val->type->desc.cols = cols;
|
||||
val->type->desc.use = use;
|
||||
|
||||
val->type->type = glsl_cmat_type(&val->type->desc);
|
||||
val->type->component_type = component_type;
|
||||
}
|
||||
|
||||
static enum glsl_matrix_layout
|
||||
vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)
|
||||
{
|
||||
switch (layout) {
|
||||
case SpvCooperativeMatrixLayoutRowMajorKHR:
|
||||
return GLSL_MATRIX_LAYOUT_ROW_MAJOR;
|
||||
case SpvCooperativeMatrixLayoutColumnMajorKHR:
|
||||
return GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||
default:
|
||||
unreachable("Unexpected cooperative matrix layout");
|
||||
}
|
||||
}
|
||||
|
||||
nir_deref_instr *
|
||||
vtn_create_cmat_temporary(struct vtn_builder *b, const struct glsl_type *t, const char *name)
|
||||
{
|
||||
nir_variable *var = nir_local_variable_create(b->nb.impl, t, name);
|
||||
return nir_build_deref_var(&b->nb, var);
|
||||
}
|
||||
|
||||
static nir_deref_instr *
|
||||
vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id)
|
||||
{
|
||||
nir_deref_instr *deref = vtn_get_deref_for_id(b, value_id);
|
||||
vtn_assert(glsl_type_is_cmat(deref->type));
|
||||
return deref;
|
||||
}
|
||||
|
||||
void
|
||||
vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count)
|
||||
{
|
||||
switch (opcode) {
|
||||
case SpvOpCooperativeMatrixLoadKHR: {
|
||||
struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
|
||||
struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
|
||||
const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]);
|
||||
nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32);
|
||||
|
||||
SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
|
||||
if (count > 6) {
|
||||
unsigned idx = 6, alignment;
|
||||
SpvScope scope;
|
||||
vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
|
||||
vtn_emit_make_visible_barrier(b, access, scope, src->mode);
|
||||
}
|
||||
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
|
||||
nir_cmat_load(&b->nb, &dst->def, vtn_pointer_to_ssa(b, src), stride,
|
||||
.matrix_layout = vtn_matrix_layout_to_glsl(layout));
|
||||
vtn_push_var_ssa(b, w[2], dst->var);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpCooperativeMatrixStoreKHR: {
|
||||
struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer);
|
||||
struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
|
||||
|
||||
const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]);
|
||||
nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32);
|
||||
|
||||
SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
|
||||
if (count > 5) {
|
||||
unsigned idx = 5, alignment;
|
||||
SpvScope scope;
|
||||
vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
|
||||
vtn_emit_make_available_barrier(b, access, scope, dest->mode);
|
||||
}
|
||||
|
||||
nir_deref_instr *src = vtn_get_cmat_deref(b, w[2]);
|
||||
nir_cmat_store(&b->nb, vtn_pointer_to_ssa(b, dest), &src->def, stride,
|
||||
.matrix_layout = vtn_matrix_layout_to_glsl(layout));
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpCooperativeMatrixLengthKHR: {
|
||||
struct vtn_type *type = vtn_get_type(b, w[3]);
|
||||
nir_def *def = nir_cmat_length(&b->nb, .cmat_desc = type->desc);
|
||||
vtn_push_nir_ssa(b, w[2], def);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpCooperativeMatrixMulAddKHR: {
|
||||
nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
|
||||
nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
|
||||
nir_deref_instr *mat_c = vtn_get_cmat_deref(b, w[5]);
|
||||
|
||||
const uint32_t operands = count > 6 ? w[6] : 0;
|
||||
const bool saturate = operands & SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask;
|
||||
const unsigned signed_mask = operands & (SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
|
||||
SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
|
||||
SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
|
||||
SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask);
|
||||
|
||||
STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask == NIR_CMAT_A_SIGNED);
|
||||
STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask == NIR_CMAT_B_SIGNED);
|
||||
STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask == NIR_CMAT_C_SIGNED);
|
||||
STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask == NIR_CMAT_RESULT_SIGNED);
|
||||
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_muladd");
|
||||
|
||||
nir_cmat_muladd(&b->nb, &dst->def, &mat_a->def, &mat_b->def, &mat_c->def,
|
||||
.saturate = saturate,
|
||||
.cmat_signed_mask = signed_mask);
|
||||
|
||||
vtn_push_var_ssa(b, w[2], dst->var);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpBitcast: {
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
vtn_assert(dst_type->base_type == vtn_base_type_cooperative_matrix);
|
||||
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
|
||||
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
|
||||
nir_cmat_bitcast(&b->nb, &dst->def, &src->def);
|
||||
vtn_push_var_ssa(b, w[2], dst->var);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
unreachable("Unexpected opcode for cooperative matrix instruction");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
|
||||
const struct glsl_type *dest_type, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count)
|
||||
{
|
||||
vtn_assert(glsl_type_is_cmat(dest_type));
|
||||
|
||||
switch (opcode) {
|
||||
case SpvOpConvertFToU:
|
||||
case SpvOpConvertFToS:
|
||||
case SpvOpConvertSToF:
|
||||
case SpvOpConvertUToF:
|
||||
case SpvOpUConvert:
|
||||
case SpvOpSConvert:
|
||||
case SpvOpFConvert:
|
||||
case SpvOpFNegate:
|
||||
case SpvOpSNegate: {
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
|
||||
|
||||
unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type));
|
||||
unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type));
|
||||
|
||||
bool ignored = false;
|
||||
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
|
||||
src_bit_size, dst_bit_size);
|
||||
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
|
||||
nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
|
||||
.alu_op = op);
|
||||
vtn_push_var_ssa(b, w[2], dst->var);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpFAdd:
|
||||
case SpvOpFSub:
|
||||
case SpvOpFMul:
|
||||
case SpvOpFDiv:
|
||||
case SpvOpIAdd:
|
||||
case SpvOpISub:
|
||||
case SpvOpIMul:
|
||||
case SpvOpSDiv:
|
||||
case SpvOpUDiv: {
|
||||
bool ignored = false;
|
||||
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0);
|
||||
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
|
||||
nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
|
||||
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
|
||||
nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,
|
||||
.alu_op = op);
|
||||
vtn_push_var_ssa(b, w[2], dst->var);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpMatrixTimesScalar: {
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
nir_deref_instr *mat = vtn_get_cmat_deref(b, w[3]);
|
||||
|
||||
struct vtn_ssa_value *scalar_val = vtn_ssa_value(b, w[4]);
|
||||
vtn_assert(glsl_type_is_scalar(scalar_val->type));
|
||||
nir_op op = glsl_type_is_integer(scalar_val->type) ? nir_op_imul : nir_op_fmul;
|
||||
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_times_scalar");
|
||||
nir_cmat_scalar_op(&b->nb, &dst->def, &mat->def, scalar_val->def,
|
||||
.alu_op = op);
|
||||
vtn_push_var_ssa(b, w[2], dst->var);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
unreachable("invalid cooperative matrix alu instruction");
|
||||
}
|
||||
}
|
||||
|
||||
struct vtn_ssa_value *
|
||||
vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
|
||||
const uint32_t *indices, unsigned num_indices)
|
||||
{
|
||||
vtn_assert(glsl_type_is_cmat(mat->type));
|
||||
nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
|
||||
|
||||
vtn_assert(num_indices == 1);
|
||||
nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
|
||||
|
||||
const struct glsl_type *element_type = glsl_get_cmat_element(mat->type);
|
||||
struct vtn_ssa_value *ret = vtn_create_ssa_value(b, element_type);
|
||||
ret->def = nir_cmat_extract(&b->nb, glsl_get_bit_size(element_type), &mat_deref->def, index);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct vtn_ssa_value *
|
||||
vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
|
||||
struct vtn_ssa_value *insert, const uint32_t *indices,
|
||||
unsigned num_indices)
|
||||
{
|
||||
vtn_assert(glsl_type_is_cmat(mat->type));
|
||||
nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
|
||||
|
||||
vtn_assert(num_indices == 1);
|
||||
nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
|
||||
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, mat_deref->type, "cmat_insert");
|
||||
nir_cmat_insert(&b->nb, &dst->def, insert->def, &mat_deref->def, index);
|
||||
|
||||
struct vtn_ssa_value *ret = vtn_create_ssa_value(b, dst->type);
|
||||
vtn_set_ssa_value_var(b, ret, dst->var);
|
||||
return ret;
|
||||
}
|
||||
|
|
@ -283,6 +283,7 @@ enum vtn_base_type {
|
|||
vtn_base_type_ray_query,
|
||||
vtn_base_type_function,
|
||||
vtn_base_type_event,
|
||||
vtn_base_type_cooperative_matrix,
|
||||
};
|
||||
|
||||
struct vtn_type {
|
||||
|
|
@ -391,6 +392,12 @@ struct vtn_type {
|
|||
/* Return type for functions */
|
||||
struct vtn_type *return_type;
|
||||
};
|
||||
|
||||
/* Members for cooperative matrix types. */
|
||||
struct {
|
||||
struct glsl_cmat_description desc;
|
||||
struct vtn_type *component_type;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
|
|
@ -1048,4 +1055,20 @@ void vtn_emit_make_visible_barrier(struct vtn_builder *b, SpvMemoryAccessMask ac
|
|||
void vtn_emit_make_available_barrier(struct vtn_builder *b, SpvMemoryAccessMask access,
|
||||
SpvScope scope, enum vtn_variable_mode mode);
|
||||
|
||||
|
||||
void vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
|
||||
SpvOp opcode, const uint32_t *w, unsigned count);
|
||||
void vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count);
|
||||
void vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
|
||||
const struct glsl_type *dest_type, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count);
|
||||
struct vtn_ssa_value *vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
|
||||
const uint32_t *indices, unsigned num_indices);
|
||||
struct vtn_ssa_value *vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
|
||||
struct vtn_ssa_value *insert,
|
||||
const uint32_t *indices, unsigned num_indices);
|
||||
nir_deref_instr *vtn_create_cmat_temporary(struct vtn_builder *b,
|
||||
const struct glsl_type *t, const char *name);
|
||||
|
||||
#endif /* _VTN_PRIVATE_H_ */
|
||||
|
|
|
|||
|
|
@ -474,8 +474,15 @@ vtn_pointer_dereference(struct vtn_builder *b,
|
|||
nir_def *arr_index =
|
||||
vtn_access_link_as_ssa(b, deref_chain->link[idx], 1,
|
||||
tail->def.bit_size);
|
||||
if (type->base_type == vtn_base_type_cooperative_matrix) {
|
||||
const struct glsl_type *element_type = glsl_get_cmat_element(type->type);
|
||||
tail = nir_build_deref_cast(&b->nb, &tail->def, tail->modes,
|
||||
glsl_array_type(element_type, 0, 0), 0);
|
||||
type = type->component_type;
|
||||
} else {
|
||||
type = type->array_element;
|
||||
}
|
||||
tail = nir_build_deref_array(&b->nb, tail, arr_index);
|
||||
type = type->array_element;
|
||||
}
|
||||
tail->arr.in_bounds = deref_chain->in_bounds;
|
||||
|
||||
|
|
@ -510,7 +517,16 @@ _vtn_local_load_store(struct vtn_builder *b, bool load, nir_deref_instr *deref,
|
|||
struct vtn_ssa_value *inout,
|
||||
enum gl_access_qualifier access)
|
||||
{
|
||||
if (glsl_type_is_vector_or_scalar(deref->type)) {
|
||||
if (glsl_type_is_cmat(deref->type)) {
|
||||
if (load) {
|
||||
nir_deref_instr *temp = vtn_create_cmat_temporary(b, deref->type, "cmat_ssa");
|
||||
nir_cmat_copy(&b->nb, &temp->def, &deref->def);
|
||||
vtn_set_ssa_value_var(b, inout, temp->var);
|
||||
} else {
|
||||
nir_deref_instr *src_deref = vtn_get_deref_for_ssa_value(b, inout);
|
||||
nir_cmat_copy(&b->nb, &deref->def, &src_deref->def);
|
||||
}
|
||||
} else if (glsl_type_is_vector_or_scalar(deref->type)) {
|
||||
if (load) {
|
||||
inout->def = nir_load_deref_with_access(&b->nb, deref, access);
|
||||
} else {
|
||||
|
|
@ -555,7 +571,17 @@ get_deref_tail(nir_deref_instr *deref)
|
|||
nir_deref_instr *parent =
|
||||
nir_instr_as_deref(deref->parent.ssa->parent_instr);
|
||||
|
||||
if (glsl_type_is_vector(parent->type))
|
||||
if (parent->deref_type == nir_deref_type_cast &&
|
||||
parent->parent.ssa->parent_instr->type == nir_instr_type_deref) {
|
||||
nir_deref_instr *grandparent =
|
||||
nir_instr_as_deref(parent->parent.ssa->parent_instr);
|
||||
|
||||
if (glsl_type_is_cmat(grandparent->type))
|
||||
return grandparent;
|
||||
}
|
||||
|
||||
if (glsl_type_is_vector(parent->type) ||
|
||||
glsl_type_is_cmat(parent->type))
|
||||
return parent;
|
||||
else
|
||||
return deref;
|
||||
|
|
@ -571,7 +597,19 @@ vtn_local_load(struct vtn_builder *b, nir_deref_instr *src,
|
|||
|
||||
if (src_tail != src) {
|
||||
val->type = src->type;
|
||||
val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa);
|
||||
|
||||
if (glsl_type_is_cmat(src_tail->type)) {
|
||||
assert(val->is_variable);
|
||||
nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val);
|
||||
|
||||
/* Reset is_variable because we are repurposing val. */
|
||||
val->is_variable = false;
|
||||
val->def = nir_cmat_extract(&b->nb,
|
||||
glsl_get_bit_size(src->type),
|
||||
&mat->def, src->arr.index.ssa);
|
||||
} else {
|
||||
val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa);
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
|
|
@ -587,8 +625,16 @@ vtn_local_store(struct vtn_builder *b, struct vtn_ssa_value *src,
|
|||
struct vtn_ssa_value *val = vtn_create_ssa_value(b, dest_tail->type);
|
||||
_vtn_local_load_store(b, true, dest_tail, val, access);
|
||||
|
||||
val->def = nir_vector_insert(&b->nb, val->def, src->def,
|
||||
dest->arr.index.ssa);
|
||||
if (glsl_type_is_cmat(dest_tail->type)) {
|
||||
nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val);
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dest_tail->type, "cmat_insert");
|
||||
nir_cmat_insert(&b->nb, &dst->def, src->def, &mat->def, dest->arr.index.ssa);
|
||||
vtn_set_ssa_value_var(b, val, dst->var);
|
||||
} else {
|
||||
val->def = nir_vector_insert(&b->nb, val->def, src->def,
|
||||
dest->arr.index.ssa);
|
||||
}
|
||||
|
||||
_vtn_local_load_store(b, false, dest_tail, val, access);
|
||||
} else {
|
||||
_vtn_local_load_store(b, false, dest_tail, src, access);
|
||||
|
|
@ -654,6 +700,7 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load,
|
|||
case GLSL_TYPE_FLOAT16:
|
||||
case GLSL_TYPE_BOOL:
|
||||
case GLSL_TYPE_DOUBLE:
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
if (glsl_type_is_vector_or_scalar(ptr->type->type)) {
|
||||
/* We hit a vector or scalar; go ahead and emit the load[s] */
|
||||
nir_deref_instr *deref = vtn_pointer_to_deref(b, ptr);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue