From a38960e8f37e667442255aba73cf7dd4cff2e215 Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Fri, 14 Mar 2025 13:29:59 -0700 Subject: [PATCH] brw, nir: Use glsl_base_type instead of nir_alu_type for @dpas_intel This will allow including types that don't have a nir_alu_type equivalent, like bfloat16. Reviewed-by: Rohan Garg Reviewed-by: Ian Romanick Part-of: --- src/compiler/nir/nir_intrinsics.py | 7 +++++- src/intel/compiler/brw_from_nir.cpp | 4 ++-- src/intel/compiler/brw_nir.c | 22 +++++++++++++++++++ src/intel/compiler/brw_nir.h | 1 + .../brw_nir_lower_cooperative_matrix.c | 21 ++++++++---------- 5 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index b6494203ce1..f3a1f051f50 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -223,6 +223,11 @@ index("nir_alu_type", "src_type") # The nir_alu_type of the data output from a load or conversion index("nir_alu_type", "dest_type") +# Source and destination data types for dpas_intel. Needed here to +# represent types that won't have a nir_alu_type. +index("enum glsl_base_type", "src_base_type") +index("enum glsl_base_type", "dest_base_type") + # The swizzle mask for quad_swizzle_amd & masked_swizzle_amd index("unsigned", "swizzle_mask") @@ -2421,7 +2426,7 @@ system_value("ray_query_global_intel", 1, bit_sizes=[64]) # its value. Some supported configurations will have the component count of # that matrix different than the others. intrinsic("dpas_intel", dest_comp=0, src_comp=[0, -1, 0], - indices=[DEST_TYPE, SRC_TYPE, SATURATE, SYSTOLIC_DEPTH, REPEAT_COUNT], + indices=[DEST_BASE_TYPE, SRC_BASE_TYPE, SATURATE, SYSTOLIC_DEPTH, REPEAT_COUNT], flags=[CAN_ELIMINATE]) # NVIDIA-specific intrinsics diff --git a/src/intel/compiler/brw_from_nir.cpp b/src/intel/compiler/brw_from_nir.cpp index dbacdd29aff..ed68045a41d 100644 --- a/src/intel/compiler/brw_from_nir.cpp +++ b/src/intel/compiler/brw_from_nir.cpp @@ -4693,9 +4693,9 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, const unsigned rcount = nir_intrinsic_repeat_count(instr); const brw_reg_type dest_type = - brw_type_for_nir_type(devinfo, nir_intrinsic_dest_type(instr)); + brw_type_for_base_type(nir_intrinsic_dest_base_type(instr)); const brw_reg_type src_type = - brw_type_for_nir_type(devinfo, nir_intrinsic_src_type(instr)); + brw_type_for_base_type(nir_intrinsic_src_base_type(instr)); brw_reg src[3] = {}; for (unsigned i = 0; i < ARRAY_SIZE(src); i++) { diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c index be9260b1921..c51269f9f41 100644 --- a/src/intel/compiler/brw_nir.c +++ b/src/intel/compiler/brw_nir.c @@ -2253,6 +2253,28 @@ lsc_op_for_nir_intrinsic(const nir_intrinsic_instr *intrin) } } +enum brw_reg_type +brw_type_for_base_type(enum glsl_base_type base_type) +{ + switch (base_type) { + case GLSL_TYPE_UINT: return BRW_TYPE_UD; + case GLSL_TYPE_INT: return BRW_TYPE_D; + case GLSL_TYPE_FLOAT: return BRW_TYPE_F; + case GLSL_TYPE_FLOAT16: return BRW_TYPE_HF; + case GLSL_TYPE_BFLOAT16: return BRW_TYPE_BF; + case GLSL_TYPE_DOUBLE: return BRW_TYPE_DF; + case GLSL_TYPE_UINT16: return BRW_TYPE_UW; + case GLSL_TYPE_INT16: return BRW_TYPE_W; + case GLSL_TYPE_UINT8: return BRW_TYPE_UB; + case GLSL_TYPE_INT8: return BRW_TYPE_B; + case GLSL_TYPE_UINT64: return BRW_TYPE_UQ; + case GLSL_TYPE_INT64: return BRW_TYPE_Q; + + default: + unreachable("invalid base type"); + } +} + enum brw_reg_type brw_type_for_nir_type(const struct intel_device_info *devinfo, nir_alu_type type) diff --git a/src/intel/compiler/brw_nir.h b/src/intel/compiler/brw_nir.h index 95acf75930f..3e6f57ccd99 100644 --- a/src/intel/compiler/brw_nir.h +++ b/src/intel/compiler/brw_nir.h @@ -242,6 +242,7 @@ unsigned brw_nir_api_subgroup_size(const nir_shader *nir, enum brw_conditional_mod brw_cmod_for_nir_comparison(nir_op op); enum lsc_opcode lsc_op_for_nir_intrinsic(const nir_intrinsic_instr *intrin); +enum brw_reg_type brw_type_for_base_type(enum glsl_base_type base_type); enum brw_reg_type brw_type_for_nir_type(const struct intel_device_info *devinfo, nir_alu_type type); diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index ccd0138465a..bde6b4d0561 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -647,23 +647,20 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) == ((cmat_signed_mask & NIR_CMAT_RESULT_SIGNED) == 0)); - nir_alu_type src_type = - nir_get_nir_type_for_glsl_base_type(src_desc.element_type); - nir_alu_type dest_type = - nir_get_nir_type_for_glsl_base_type(dst_desc.element_type); + enum glsl_base_type src_type = src_desc.element_type; + enum glsl_base_type dst_type = dst_desc.element_type; /* For integer types, the signedness is determined by flags on the * muladd instruction. The types of the sources play no role. Adjust the * types passed to the dpas_intel intrinsic to match. */ - if (nir_alu_type_get_base_type(src_type) == nir_type_uint || - nir_alu_type_get_base_type(src_type) == nir_type_int) { + if (glsl_base_type_is_integer(src_type)) { if ((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) { - src_type = nir_alu_type_get_type_size(src_type) | nir_type_uint; - dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_uint; + src_type = glsl_unsigned_base_type_of(src_type); + dst_type = glsl_unsigned_base_type_of(dst_type); } else { - src_type = nir_alu_type_get_type_size(src_type) | nir_type_int; - dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_int; + src_type = glsl_signed_base_type_of(src_type); + dst_type = glsl_signed_base_type_of(dst_type); } } @@ -673,8 +670,8 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_load_deref(b, accum_slice), nir_load_deref(b, A_slice), nir_load_deref(b, B_slice), - .dest_type = dest_type, - .src_type = src_type, + .dest_base_type = dst_type, + .src_base_type = src_type, .saturate = nir_intrinsic_saturate(intrin), .systolic_depth = 8, .repeat_count = 8);