nir: Add new intrinsics for Cooperative Matrix

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23825>
This commit is contained in:
Caio Oliveira 2023-08-08 11:02:14 -07:00 committed by Marge Bot
parent 2d0f4f2c17
commit 3105d516d0
3 changed files with 98 additions and 0 deletions

View file

@ -269,6 +269,17 @@ typedef enum {
nir_resource_intel_non_uniform = 1u << 3,
} nir_resource_data_intel;
/**
* Which components to interpret as signed in cmat_muladd.
* See 'Cooperative Matrix Operands' in SPV_KHR_cooperative_matrix.
*/
typedef enum {
NIR_CMAT_A_SIGNED = 1u << 0,
NIR_CMAT_B_SIGNED = 1u << 1,
NIR_CMAT_C_SIGNED = 1u << 2,
NIR_CMAT_RESULT_SIGNED = 1u << 3,
} nir_cmat_signed;
typedef union {
bool b;
float f32;

View file

@ -312,6 +312,12 @@ index("bool", "legacy_fneg")
# On a register store, floating-point saturate the stored value.
index("bool", "legacy_fsat")
# For Cooperative Matrix intrinsics.
index("struct glsl_cmat_description", "cmat_desc")
index("enum glsl_matrix_layout", "matrix_layout")
index("nir_cmat_signed", "cmat_signed_mask")
index("nir_op", "alu_op")
intrinsic("nop", flags=[CAN_ELIMINATE])
intrinsic("convert_alu_types", dest_comp=0, src_comp=[0],
@ -1196,6 +1202,29 @@ system_value("flat_mask", 1)
# Whether provoking vertex mode is last
system_value("provoking_last", 1)
# SPV_KHR_cooperative_matrix.
#
# Cooperative matrices are referred through derefs to variables,
# the destination of the operations appears as the first source,
# ordering follows SPIR-V operation.
#
# Load/Store include an extra source for stride, since that
# can be a _dynamically_ uniform value.
#
# Length takes a type not a value, that's encoded as a MATRIX_DESC.
intrinsic("cmat_construct", src_comp=[-1, 1])
intrinsic("cmat_load", src_comp=[-1, -1, 1], indices=[MATRIX_LAYOUT])
intrinsic("cmat_store", src_comp=[-1, -1, 1], indices=[MATRIX_LAYOUT])
intrinsic("cmat_length", src_comp=[], dest_comp=1, indices=[CMAT_DESC], bit_sizes=[32])
intrinsic("cmat_muladd", src_comp=[-1, -1, -1, -1], indices=[SATURATE, CMAT_SIGNED_MASK])
intrinsic("cmat_unary_op", src_comp=[-1, -1], indices=[ALU_OP])
intrinsic("cmat_binary_op", src_comp=[-1, -1, -1], indices=[ALU_OP])
intrinsic("cmat_scalar_op", src_comp=[-1, -1, -1], indices=[ALU_OP])
intrinsic("cmat_bitcast", src_comp=[-1, -1])
intrinsic("cmat_extract", src_comp=[-1, 1], dest_comp=1)
intrinsic("cmat_insert", src_comp=[-1, 1, -1, 1])
intrinsic("cmat_copy", src_comp=[-1, -1])
# IR3-specific version of most SSBO intrinsics. The only different
# compare to the originals is that they add an extra source to hold
# the dword-offset, which is needed by the backend code apart from

View file

@ -1523,6 +1523,64 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
break;
}
case NIR_INTRINSIC_MATRIX_LAYOUT: {
fprintf(fp, "matrix_layout=");
switch (nir_intrinsic_matrix_layout(instr)) {
case GLSL_MATRIX_LAYOUT_ROW_MAJOR:
fprintf(fp, "row_major");
break;
case GLSL_MATRIX_LAYOUT_COLUMN_MAJOR:
fprintf(fp, "col_major");
break;
default:
fprintf(fp, "unknown");
break;
}
break;
}
case NIR_INTRINSIC_CMAT_DESC: {
struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(instr);
const struct glsl_type *t = glsl_cmat_type(&desc);
fprintf(fp, "%s", glsl_get_type_name(t));
break;
}
case NIR_INTRINSIC_CMAT_SIGNED_MASK: {
fprintf(fp, "cmat_signed=");
unsigned int mask = nir_intrinsic_cmat_signed_mask(instr);
if (mask == 0)
fputc('0', fp);
while (mask) {
nir_cmat_signed i = 1u << u_bit_scan(&mask);
switch (i) {
case NIR_CMAT_A_SIGNED:
fputc('A', fp);
break;
case NIR_CMAT_B_SIGNED:
fputc('B', fp);
break;
case NIR_CMAT_C_SIGNED:
fputc('C', fp);
break;
case NIR_CMAT_RESULT_SIGNED:
fprintf(fp, "Result");
break;
default:
fprintf(fp, "unknown");
break;
}
fprintf(fp, "%s", mask ? "|" : "");
}
break;
}
case NIR_INTRINSIC_ALU_OP: {
nir_op alu_op = nir_intrinsic_alu_op(instr);
fprintf(fp, "alu_op=%s", nir_op_infos[alu_op].name);
break;
}
default: {
unsigned off = info->index_map[idx] - 1;
fprintf(fp, "%s=%d", nir_intrinsic_index_names[idx], instr->const_index[off]);