diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c index aab2f9f74b2..4cb456bc747 100644 --- a/src/compiler/nir/nir_divergence_analysis.c +++ b/src/compiler/nir/nir_divergence_analysis.c @@ -615,6 +615,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr *instr) case nir_intrinsic_load_ray_triangle_vertex_positions: case nir_intrinsic_cmat_extract: case nir_intrinsic_cmat_muladd_amd: + case nir_intrinsic_dpas_intel: case nir_intrinsic_isberd_nv: case nir_intrinsic_al2p_nv: case nir_intrinsic_ald_nv: diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 73a6e12b63e..f1303dabf54 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -318,6 +318,10 @@ index("enum glsl_matrix_layout", "matrix_layout") index("nir_cmat_signed", "cmat_signed_mask") index("nir_op", "alu_op") +# For Intel DPAS instrinsic. +index("unsigned", "systolic_depth") +index("unsigned", "repeat_count") + intrinsic("nop", flags=[CAN_ELIMINATE]) intrinsic("convert_alu_types", dest_comp=0, src_comp=[0], @@ -2015,6 +2019,15 @@ system_value("leaf_procedural_intel", 1, bit_sizes=[1]) system_value("btd_shader_type_intel", 1) system_value("ray_query_global_intel", 1, bit_sizes=[64]) +# Source 0: A matrix (type specified by SRC_TYPE) +# Source 1: B matrix (type specified by SRC_TYPE) +# Source 2: Accumulator matrix (type specified by DEST_TYPE) +# +# The matrix parameters are the slices owned by the invocation. +intrinsic("dpas_intel", dest_comp=0, src_comp=[0, 0, 0], + indices=[DEST_TYPE, SRC_TYPE, SATURATE, CMAT_SIGNED_MASK, SYSTOLIC_DEPTH, REPEAT_COUNT], + flags=[CAN_ELIMINATE]) + # NVIDIA-specific intrinsics intrinsic("load_sysval_nv", dest_comp=1, src_comp=[], bit_sizes=[32, 64], indices=[ACCESS, BASE], flags=[CAN_ELIMINATE]) diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp index 088542dc717..1bf964a0c43 100644 --- a/src/intel/compiler/brw_fs_nir.cpp +++ b/src/intel/compiler/brw_fs_nir.cpp @@ -4587,6 +4587,65 @@ fs_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, break; } + case nir_intrinsic_dpas_intel: { + const unsigned sdepth = nir_intrinsic_systolic_depth(instr); + const unsigned rcount = nir_intrinsic_repeat_count(instr); + + const brw_reg_type dest_type = + brw_type_for_nir_type(devinfo, nir_intrinsic_dest_type(instr)); + const brw_reg_type src_type = + brw_type_for_nir_type(devinfo, nir_intrinsic_src_type(instr)); + + dest = retype(dest, dest_type); + fs_reg src2 = retype(get_nir_src(ntb, instr->src[2]), dest_type); + const fs_reg dest_hf = dest; + + fs_builder bld8 = bld.exec_all().group(8, 0); + fs_builder bld16 = bld.exec_all().group(16, 0); + + /* DG2 cannot have the destination or source 0 of DPAS be float16. It is + * still advantageous to support these formats for memory and bandwidth + * savings. + * + * The float16 source must be expanded to float32. + */ + if (devinfo->verx10 == 125 && dest_type == BRW_REGISTER_TYPE_HF && + !s.compiler->lower_dpas) { + dest = bld8.vgrf(BRW_REGISTER_TYPE_F, rcount); + + if (src2.file != ARF) { + const fs_reg src2_hf = src2; + + src2 = bld8.vgrf(BRW_REGISTER_TYPE_F, rcount); + + for (unsigned i = 0; i < 4; i++) { + bld16.MOV(byte_offset(src2, REG_SIZE * i * 2), + byte_offset(src2_hf, REG_SIZE * i)); + } + } else { + src2 = retype(src2, BRW_REGISTER_TYPE_F); + } + } + + bld8.DPAS(dest, + src2, + retype(get_nir_src(ntb, instr->src[1]), src_type), + retype(get_nir_src(ntb, instr->src[0]), src_type), + sdepth, + rcount) + ->saturate = nir_intrinsic_saturate(instr); + + /* Compact the destination to float16 (from float32). */ + if (!dest.equals(dest_hf)) { + for (unsigned i = 0; i < 4; i++) { + bld16.MOV(byte_offset(dest_hf, REG_SIZE * i), + byte_offset(dest, REG_SIZE * i * 2)); + } + } + + break; + } + default: fs_nir_emit_intrinsic(ntb, bld, instr); break; diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index 0efc3218f71..6d31342376b 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -621,9 +621,39 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) glsl_get_vector_elements(slice_type)), 32); } - case nir_intrinsic_cmat_muladd: - /* FINISHME. */ + case nir_intrinsic_cmat_muladd: { + nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); + nir_deref_instr *A_slice = nir_src_as_deref(intrin->src[1]); + nir_deref_instr *B_slice = nir_src_as_deref(intrin->src[2]); + nir_deref_instr *accum_slice = nir_src_as_deref(intrin->src[3]); + + const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); + const struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_mat_type); + + const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, A_slice); + const struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_mat_type); + + const unsigned packing_factor = get_packing_factor(dst_desc, dst_slice->type); + const unsigned num_components = glsl_get_vector_elements(dst_slice->type); + + nir_def *result = + nir_dpas_intel(b, + packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type), + nir_load_deref(b, A_slice), + nir_load_deref(b, B_slice), + nir_load_deref(b, accum_slice), + .dest_type = nir_get_nir_type_for_glsl_base_type(dst_desc.element_type), + .src_type = nir_get_nir_type_for_glsl_base_type(src_desc.element_type), + .saturate = nir_intrinsic_saturate(intrin), + .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin), + .systolic_depth = 8, + .repeat_count = 8); + + nir_store_deref(b, dst_slice, result, + nir_component_mask(num_components)); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + } case nir_intrinsic_cmat_bitcast: { nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);