spirv, radv, intel: Add NIR intrinsic for cmat conversion

A cooperative matrix conversion operation was represented in NIR by the
cmat_unary_op intrinsic with an nir_alu_op as extra parameter,
that was already lowered to a specific conversion operation
based on the matrix types.

Instead of that, add a new intrinsic `cmat_convert` that is specific
for that conversion.  In addition to the src/dst matrix descriptions
already available, also include the signedness information in the
intrinsic (reuse nir_cmat_signed for that).  This is needed because
different Convert operations define different interpretations for
integers, regardless their original type.

In this patch, both radv and intel were changed to use the same logic
that was previously used to pick the lowered ALU op.

This change will help represent cmat conversions involving BFloat16,
because it avoids having to create new NIR ALU ops for all the
combinations involving BFloat16.

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34511>
This commit is contained in:
Caio Oliveira 2025-04-14 11:53:40 -07:00 committed by Marge Bot
parent 2f02fa5db4
commit d5ad798140
8 changed files with 104 additions and 17 deletions

View file

@ -430,16 +430,26 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
progress = true;
break;
}
case nir_intrinsic_cmat_unary_op: {
case nir_intrinsic_cmat_convert: {
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_deref->type);
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
nir_def *src = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 16 &&
glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr);
enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type(
dst_desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type(
src_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
nir_op op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_element_type),
nir_get_nir_type_for_glsl_base_type(dst_element_type),
nir_rounding_mode_undef);
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_element_type) == 16 &&
glsl_base_type_bit_size(dst_element_type) == 32 && dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i * 2 < src->num_components; ++i) {
components[i] = nir_channel(&b, src, i * 2);
@ -449,8 +459,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
nir_def *ret = nir_build_alu1(&b, op, src);
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 32 &&
glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_element_type) == 32 &&
glsl_base_type_bit_size(dst_element_type) == 16 && dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < ret->num_components; ++i) {
components[i * 2] = nir_channel(&b, ret, i);
@ -464,6 +474,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
progress = true;
break;
}
case nir_intrinsic_cmat_unary_op: {
nir_def *src = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu1(&b, op, src);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
}
case nir_intrinsic_cmat_scalar_op: {
nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);

View file

@ -3923,3 +3923,24 @@ glsl_type_get_image_count(const glsl_type *type)
{
return glsl_type_count(type, GLSL_TYPE_IMAGE);
}
enum glsl_base_type
glsl_apply_signedness_to_base_type(enum glsl_base_type type, bool signedness)
{
switch (type) {
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
return signedness ? GLSL_TYPE_INT : GLSL_TYPE_UINT;
case GLSL_TYPE_UINT8:
case GLSL_TYPE_INT8:
return signedness ? GLSL_TYPE_INT8 : GLSL_TYPE_UINT8;
case GLSL_TYPE_UINT16:
case GLSL_TYPE_INT16:
return signedness ? GLSL_TYPE_INT16 : GLSL_TYPE_UINT16;
case GLSL_TYPE_UINT64:
case GLSL_TYPE_INT64:
return signedness ? GLSL_TYPE_INT64 : GLSL_TYPE_UINT64;
default:
return type;
}
}

View file

@ -232,6 +232,12 @@ glsl_signed_base_type_of(enum glsl_base_type type)
}
}
/* Change integer types to be signed or unsigned. Other types remain
* unchanged.
*/
enum glsl_base_type
glsl_apply_signedness_to_base_type(enum glsl_base_type type, bool signedness);
int
glsl_get_sampler_dim_coordinate_components(enum glsl_sampler_dim dim);

View file

@ -1336,6 +1336,7 @@ intrinsic("cmat_load", src_comp=[-1, -1, 1], indices=[MATRIX_LAYOUT])
intrinsic("cmat_store", src_comp=[-1, -1, 1], indices=[MATRIX_LAYOUT])
intrinsic("cmat_length", src_comp=[], dest_comp=1, indices=[CMAT_DESC], bit_sizes=[32])
intrinsic("cmat_muladd", src_comp=[-1, -1, -1, -1], indices=[SATURATE, CMAT_SIGNED_MASK])
intrinsic("cmat_convert", src_comp=[-1, -1], indices=[CMAT_SIGNED_MASK])
intrinsic("cmat_unary_op", src_comp=[-1, -1], indices=[ALU_OP])
intrinsic("cmat_binary_op", src_comp=[-1, -1, -1], indices=[ALU_OP])
intrinsic("cmat_scalar_op", src_comp=[-1, -1, -1], indices=[ALU_OP])

View file

@ -235,8 +235,8 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
}
}
static nir_alu_type
convert_op_src_type(SpvOp opcode)
nir_alu_type
vtn_convert_op_src_type(SpvOp opcode)
{
switch (opcode) {
case SpvOpFConvert:
@ -256,8 +256,8 @@ convert_op_src_type(SpvOp opcode)
}
}
static nir_alu_type
convert_op_dst_type(SpvOp opcode)
nir_alu_type
vtn_convert_op_dst_type(SpvOp opcode)
{
switch (opcode) {
case SpvOpFConvert:
@ -378,8 +378,8 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
case SpvOpConvertUToF:
case SpvOpSConvert:
case SpvOpFConvert: {
nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
nir_alu_type src_type = vtn_convert_op_src_type(opcode) | src_bit_size;
nir_alu_type dst_type = vtn_convert_op_dst_type(opcode) | dst_bit_size;
return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
}
@ -914,8 +914,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
case SpvOpSatConvertUToS: {
unsigned src_bit_size = src[0]->bit_size;
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
nir_alu_type src_type = vtn_convert_op_src_type(opcode) | src_bit_size;
nir_alu_type dst_type = vtn_convert_op_dst_type(opcode) | dst_bit_size;
struct conversion_opts opts = {
.rounding_mode = nir_rounding_mode_undef,

View file

@ -198,7 +198,26 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
case SpvOpConvertUToF:
case SpvOpUConvert:
case SpvOpSConvert:
case SpvOpFConvert:
case SpvOpFConvert: {
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
/* The Convert operations define whether integers are interpreted
* as signed or unsigned regardless of their original type. So take
* note of that in the intrinsic. Reuse nir_cmat_signed for that.
*/
const unsigned signed_mask =
(vtn_convert_op_src_type(opcode) == nir_type_int ? NIR_CMAT_A_SIGNED : 0) |
(vtn_convert_op_dst_type(opcode) == nir_type_int ? NIR_CMAT_RESULT_SIGNED : 0);
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_convert");
nir_cmat_convert(&b->nb, &dst->def, &src->def, .cmat_signed_mask = signed_mask);
vtn_push_var_ssa(b, w[2], dst->var);
break;
}
case SpvOpFNegate:
case SpvOpSNegate: {
struct vtn_type *dst_type = vtn_get_type(b, w[1]);

View file

@ -954,6 +954,9 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct vtn_builder *,
void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value *value,
vtn_execution_mode_foreach_cb cb, void *data);
nir_alu_type vtn_convert_op_src_type(SpvOp opcode);
nir_alu_type vtn_convert_op_dst_type(SpvOp opcode);
nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
SpvOp opcode, bool *swap, bool *exact,
unsigned src_bit_size, unsigned dst_bit_size);

View file

@ -101,6 +101,7 @@ lower_cmat_filter(const nir_instr *instr, const void *_state)
case nir_intrinsic_cmat_store:
case nir_intrinsic_cmat_length:
case nir_intrinsic_cmat_muladd:
case nir_intrinsic_cmat_convert:
case nir_intrinsic_cmat_unary_op:
case nir_intrinsic_cmat_binary_op:
case nir_intrinsic_cmat_scalar_op:
@ -418,7 +419,22 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
const unsigned src_packing_factor =
get_packing_factor(src_desc, src_slice->type);
const nir_op op = nir_intrinsic_alu_op(intrin);
nir_op op;
if (intrin->intrinsic == nir_intrinsic_cmat_unary_op) {
op = nir_intrinsic_alu_op(intrin);
} else {
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin);
enum glsl_base_type src_base_type = glsl_apply_signedness_to_base_type(
src_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
enum glsl_base_type dst_base_type = glsl_apply_signedness_to_base_type(
dst_desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_base_type),
nir_get_nir_type_for_glsl_base_type(dst_base_type),
nir_rounding_mode_undef);
}
/* With the combinations of formats exposed on all platforms, matrices with
* the same dimensions will always have the same data size. The only real
@ -585,6 +601,7 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
}
case nir_intrinsic_cmat_convert:
case nir_intrinsic_cmat_unary_op:
lower_cmat_unary_op(b, intrin, state);
return NIR_LOWER_INSTR_PROGRESS_REPLACE;