mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-03 16:28:08 +02:00
spirv, radv, intel: Add NIR intrinsic for cmat conversion
A cooperative matrix conversion operation was represented in NIR by the cmat_unary_op intrinsic with an nir_alu_op as extra parameter, that was already lowered to a specific conversion operation based on the matrix types. Instead of that, add a new intrinsic `cmat_convert` that is specific for that conversion. In addition to the src/dst matrix descriptions already available, also include the signedness information in the intrinsic (reuse nir_cmat_signed for that). This is needed because different Convert operations define different interpretations for integers, regardless their original type. In this patch, both radv and intel were changed to use the same logic that was previously used to pick the lowered ALU op. This change will help represent cmat conversions involving BFloat16, because it avoids having to create new NIR ALU ops for all the combinations involving BFloat16. Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34511>
This commit is contained in:
parent
2f02fa5db4
commit
d5ad798140
8 changed files with 104 additions and 17 deletions
|
|
@ -430,16 +430,26 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
progress = true;
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_unary_op: {
|
||||
case nir_intrinsic_cmat_convert: {
|
||||
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
|
||||
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
|
||||
nir_def *src = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
|
||||
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 16 &&
|
||||
glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr);
|
||||
|
||||
enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type(
|
||||
dst_desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
|
||||
enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type(
|
||||
src_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
|
||||
|
||||
nir_op op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_element_type),
|
||||
nir_get_nir_type_for_glsl_base_type(dst_element_type),
|
||||
nir_rounding_mode_undef);
|
||||
|
||||
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_element_type) == 16 &&
|
||||
glsl_base_type_bit_size(dst_element_type) == 32 && dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned i = 0; i * 2 < src->num_components; ++i) {
|
||||
components[i] = nir_channel(&b, src, i * 2);
|
||||
|
|
@ -449,8 +459,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
|
||||
nir_def *ret = nir_build_alu1(&b, op, src);
|
||||
|
||||
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 32 &&
|
||||
glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_element_type) == 32 &&
|
||||
glsl_base_type_bit_size(dst_element_type) == 16 && dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned i = 0; i < ret->num_components; ++i) {
|
||||
components[i * 2] = nir_channel(&b, ret, i);
|
||||
|
|
@ -464,6 +474,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
progress = true;
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_unary_op: {
|
||||
nir_def *src = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
nir_def *ret = nir_build_alu1(&b, op, src);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
|
||||
nir_component_mask(ret->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_scalar_op: {
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
|
|
|
|||
|
|
@ -3923,3 +3923,24 @@ glsl_type_get_image_count(const glsl_type *type)
|
|||
{
|
||||
return glsl_type_count(type, GLSL_TYPE_IMAGE);
|
||||
}
|
||||
|
||||
enum glsl_base_type
|
||||
glsl_apply_signedness_to_base_type(enum glsl_base_type type, bool signedness)
|
||||
{
|
||||
switch (type) {
|
||||
case GLSL_TYPE_UINT:
|
||||
case GLSL_TYPE_INT:
|
||||
return signedness ? GLSL_TYPE_INT : GLSL_TYPE_UINT;
|
||||
case GLSL_TYPE_UINT8:
|
||||
case GLSL_TYPE_INT8:
|
||||
return signedness ? GLSL_TYPE_INT8 : GLSL_TYPE_UINT8;
|
||||
case GLSL_TYPE_UINT16:
|
||||
case GLSL_TYPE_INT16:
|
||||
return signedness ? GLSL_TYPE_INT16 : GLSL_TYPE_UINT16;
|
||||
case GLSL_TYPE_UINT64:
|
||||
case GLSL_TYPE_INT64:
|
||||
return signedness ? GLSL_TYPE_INT64 : GLSL_TYPE_UINT64;
|
||||
default:
|
||||
return type;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -232,6 +232,12 @@ glsl_signed_base_type_of(enum glsl_base_type type)
|
|||
}
|
||||
}
|
||||
|
||||
/* Change integer types to be signed or unsigned. Other types remain
|
||||
* unchanged.
|
||||
*/
|
||||
enum glsl_base_type
|
||||
glsl_apply_signedness_to_base_type(enum glsl_base_type type, bool signedness);
|
||||
|
||||
int
|
||||
glsl_get_sampler_dim_coordinate_components(enum glsl_sampler_dim dim);
|
||||
|
||||
|
|
|
|||
|
|
@ -1336,6 +1336,7 @@ intrinsic("cmat_load", src_comp=[-1, -1, 1], indices=[MATRIX_LAYOUT])
|
|||
intrinsic("cmat_store", src_comp=[-1, -1, 1], indices=[MATRIX_LAYOUT])
|
||||
intrinsic("cmat_length", src_comp=[], dest_comp=1, indices=[CMAT_DESC], bit_sizes=[32])
|
||||
intrinsic("cmat_muladd", src_comp=[-1, -1, -1, -1], indices=[SATURATE, CMAT_SIGNED_MASK])
|
||||
intrinsic("cmat_convert", src_comp=[-1, -1], indices=[CMAT_SIGNED_MASK])
|
||||
intrinsic("cmat_unary_op", src_comp=[-1, -1], indices=[ALU_OP])
|
||||
intrinsic("cmat_binary_op", src_comp=[-1, -1, -1], indices=[ALU_OP])
|
||||
intrinsic("cmat_scalar_op", src_comp=[-1, -1, -1], indices=[ALU_OP])
|
||||
|
|
|
|||
|
|
@ -235,8 +235,8 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
|
|||
}
|
||||
}
|
||||
|
||||
static nir_alu_type
|
||||
convert_op_src_type(SpvOp opcode)
|
||||
nir_alu_type
|
||||
vtn_convert_op_src_type(SpvOp opcode)
|
||||
{
|
||||
switch (opcode) {
|
||||
case SpvOpFConvert:
|
||||
|
|
@ -256,8 +256,8 @@ convert_op_src_type(SpvOp opcode)
|
|||
}
|
||||
}
|
||||
|
||||
static nir_alu_type
|
||||
convert_op_dst_type(SpvOp opcode)
|
||||
nir_alu_type
|
||||
vtn_convert_op_dst_type(SpvOp opcode)
|
||||
{
|
||||
switch (opcode) {
|
||||
case SpvOpFConvert:
|
||||
|
|
@ -378,8 +378,8 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
|
|||
case SpvOpConvertUToF:
|
||||
case SpvOpSConvert:
|
||||
case SpvOpFConvert: {
|
||||
nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
|
||||
nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
|
||||
nir_alu_type src_type = vtn_convert_op_src_type(opcode) | src_bit_size;
|
||||
nir_alu_type dst_type = vtn_convert_op_dst_type(opcode) | dst_bit_size;
|
||||
return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
|
||||
}
|
||||
|
||||
|
|
@ -914,8 +914,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpSatConvertUToS: {
|
||||
unsigned src_bit_size = src[0]->bit_size;
|
||||
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
|
||||
nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
|
||||
nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
|
||||
nir_alu_type src_type = vtn_convert_op_src_type(opcode) | src_bit_size;
|
||||
nir_alu_type dst_type = vtn_convert_op_dst_type(opcode) | dst_bit_size;
|
||||
|
||||
struct conversion_opts opts = {
|
||||
.rounding_mode = nir_rounding_mode_undef,
|
||||
|
|
|
|||
|
|
@ -198,7 +198,26 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
|
|||
case SpvOpConvertUToF:
|
||||
case SpvOpUConvert:
|
||||
case SpvOpSConvert:
|
||||
case SpvOpFConvert:
|
||||
case SpvOpFConvert: {
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
|
||||
|
||||
/* The Convert operations define whether integers are interpreted
|
||||
* as signed or unsigned regardless of their original type. So take
|
||||
* note of that in the intrinsic. Reuse nir_cmat_signed for that.
|
||||
*/
|
||||
const unsigned signed_mask =
|
||||
(vtn_convert_op_src_type(opcode) == nir_type_int ? NIR_CMAT_A_SIGNED : 0) |
|
||||
(vtn_convert_op_dst_type(opcode) == nir_type_int ? NIR_CMAT_RESULT_SIGNED : 0);
|
||||
|
||||
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_convert");
|
||||
nir_cmat_convert(&b->nb, &dst->def, &src->def, .cmat_signed_mask = signed_mask);
|
||||
vtn_push_var_ssa(b, w[2], dst->var);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpFNegate:
|
||||
case SpvOpSNegate: {
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
|
|
|
|||
|
|
@ -954,6 +954,9 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct vtn_builder *,
|
|||
void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value *value,
|
||||
vtn_execution_mode_foreach_cb cb, void *data);
|
||||
|
||||
nir_alu_type vtn_convert_op_src_type(SpvOp opcode);
|
||||
nir_alu_type vtn_convert_op_dst_type(SpvOp opcode);
|
||||
|
||||
nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
|
||||
SpvOp opcode, bool *swap, bool *exact,
|
||||
unsigned src_bit_size, unsigned dst_bit_size);
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ lower_cmat_filter(const nir_instr *instr, const void *_state)
|
|||
case nir_intrinsic_cmat_store:
|
||||
case nir_intrinsic_cmat_length:
|
||||
case nir_intrinsic_cmat_muladd:
|
||||
case nir_intrinsic_cmat_convert:
|
||||
case nir_intrinsic_cmat_unary_op:
|
||||
case nir_intrinsic_cmat_binary_op:
|
||||
case nir_intrinsic_cmat_scalar_op:
|
||||
|
|
@ -418,7 +419,22 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
|
|||
const unsigned src_packing_factor =
|
||||
get_packing_factor(src_desc, src_slice->type);
|
||||
|
||||
const nir_op op = nir_intrinsic_alu_op(intrin);
|
||||
nir_op op;
|
||||
|
||||
if (intrin->intrinsic == nir_intrinsic_cmat_unary_op) {
|
||||
op = nir_intrinsic_alu_op(intrin);
|
||||
} else {
|
||||
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin);
|
||||
|
||||
enum glsl_base_type src_base_type = glsl_apply_signedness_to_base_type(
|
||||
src_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
|
||||
enum glsl_base_type dst_base_type = glsl_apply_signedness_to_base_type(
|
||||
dst_desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
|
||||
|
||||
op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_base_type),
|
||||
nir_get_nir_type_for_glsl_base_type(dst_base_type),
|
||||
nir_rounding_mode_undef);
|
||||
}
|
||||
|
||||
/* With the combinations of formats exposed on all platforms, matrices with
|
||||
* the same dimensions will always have the same data size. The only real
|
||||
|
|
@ -585,6 +601,7 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
|
|||
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
|
||||
}
|
||||
|
||||
case nir_intrinsic_cmat_convert:
|
||||
case nir_intrinsic_cmat_unary_op:
|
||||
lower_cmat_unary_op(b, intrin, state);
|
||||
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue