brw, nir: Use glsl_base_type instead of nir_alu_type for @dpas_intel

This will allow including types that don't have a nir_alu_type
equivalent, like bfloat16.

Reviewed-by: Rohan Garg <rohan.garg@intel.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Caio Oliveira 2025-03-14 13:29:59 -07:00 committed by Marge Bot
parent cf4021f93c
commit a38960e8f3
5 changed files with 40 additions and 15 deletions

View file

@ -223,6 +223,11 @@ index("nir_alu_type", "src_type")
# The nir_alu_type of the data output from a load or conversion
index("nir_alu_type", "dest_type")
# Source and destination data types for dpas_intel. Needed here to
# represent types that won't have a nir_alu_type.
index("enum glsl_base_type", "src_base_type")
index("enum glsl_base_type", "dest_base_type")
# The swizzle mask for quad_swizzle_amd & masked_swizzle_amd
index("unsigned", "swizzle_mask")
@ -2421,7 +2426,7 @@ system_value("ray_query_global_intel", 1, bit_sizes=[64])
# its value. Some supported configurations will have the component count of
# that matrix different than the others.
intrinsic("dpas_intel", dest_comp=0, src_comp=[0, -1, 0],
indices=[DEST_TYPE, SRC_TYPE, SATURATE, SYSTOLIC_DEPTH, REPEAT_COUNT],
indices=[DEST_BASE_TYPE, SRC_BASE_TYPE, SATURATE, SYSTOLIC_DEPTH, REPEAT_COUNT],
flags=[CAN_ELIMINATE])
# NVIDIA-specific intrinsics

View file

@ -4693,9 +4693,9 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
const unsigned rcount = nir_intrinsic_repeat_count(instr);
const brw_reg_type dest_type =
brw_type_for_nir_type(devinfo, nir_intrinsic_dest_type(instr));
brw_type_for_base_type(nir_intrinsic_dest_base_type(instr));
const brw_reg_type src_type =
brw_type_for_nir_type(devinfo, nir_intrinsic_src_type(instr));
brw_type_for_base_type(nir_intrinsic_src_base_type(instr));
brw_reg src[3] = {};
for (unsigned i = 0; i < ARRAY_SIZE(src); i++) {

View file

@ -2253,6 +2253,28 @@ lsc_op_for_nir_intrinsic(const nir_intrinsic_instr *intrin)
}
}
enum brw_reg_type
brw_type_for_base_type(enum glsl_base_type base_type)
{
switch (base_type) {
case GLSL_TYPE_UINT: return BRW_TYPE_UD;
case GLSL_TYPE_INT: return BRW_TYPE_D;
case GLSL_TYPE_FLOAT: return BRW_TYPE_F;
case GLSL_TYPE_FLOAT16: return BRW_TYPE_HF;
case GLSL_TYPE_BFLOAT16: return BRW_TYPE_BF;
case GLSL_TYPE_DOUBLE: return BRW_TYPE_DF;
case GLSL_TYPE_UINT16: return BRW_TYPE_UW;
case GLSL_TYPE_INT16: return BRW_TYPE_W;
case GLSL_TYPE_UINT8: return BRW_TYPE_UB;
case GLSL_TYPE_INT8: return BRW_TYPE_B;
case GLSL_TYPE_UINT64: return BRW_TYPE_UQ;
case GLSL_TYPE_INT64: return BRW_TYPE_Q;
default:
unreachable("invalid base type");
}
}
enum brw_reg_type
brw_type_for_nir_type(const struct intel_device_info *devinfo,
nir_alu_type type)

View file

@ -242,6 +242,7 @@ unsigned brw_nir_api_subgroup_size(const nir_shader *nir,
enum brw_conditional_mod brw_cmod_for_nir_comparison(nir_op op);
enum lsc_opcode lsc_op_for_nir_intrinsic(const nir_intrinsic_instr *intrin);
enum brw_reg_type brw_type_for_base_type(enum glsl_base_type base_type);
enum brw_reg_type brw_type_for_nir_type(const struct intel_device_info *devinfo,
nir_alu_type type);

View file

@ -647,23 +647,20 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) ==
((cmat_signed_mask & NIR_CMAT_RESULT_SIGNED) == 0));
nir_alu_type src_type =
nir_get_nir_type_for_glsl_base_type(src_desc.element_type);
nir_alu_type dest_type =
nir_get_nir_type_for_glsl_base_type(dst_desc.element_type);
enum glsl_base_type src_type = src_desc.element_type;
enum glsl_base_type dst_type = dst_desc.element_type;
/* For integer types, the signedness is determined by flags on the
* muladd instruction. The types of the sources play no role. Adjust the
* types passed to the dpas_intel intrinsic to match.
*/
if (nir_alu_type_get_base_type(src_type) == nir_type_uint ||
nir_alu_type_get_base_type(src_type) == nir_type_int) {
if (glsl_base_type_is_integer(src_type)) {
if ((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) {
src_type = nir_alu_type_get_type_size(src_type) | nir_type_uint;
dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_uint;
src_type = glsl_unsigned_base_type_of(src_type);
dst_type = glsl_unsigned_base_type_of(dst_type);
} else {
src_type = nir_alu_type_get_type_size(src_type) | nir_type_int;
dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_int;
src_type = glsl_signed_base_type_of(src_type);
dst_type = glsl_signed_base_type_of(dst_type);
}
}
@ -673,8 +670,8 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
nir_load_deref(b, accum_slice),
nir_load_deref(b, A_slice),
nir_load_deref(b, B_slice),
.dest_type = dest_type,
.src_type = src_type,
.dest_base_type = dst_type,
.src_base_type = src_type,
.saturate = nir_intrinsic_saturate(intrin),
.systolic_depth = 8,
.repeat_count = 8);