diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index fc618de1079..3d4ea35b5ef 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -149,22 +149,31 @@ Temp get_ssa_temp(struct isel_context *ctx, nir_ssa_def *def) return ctx->allocated[def->index]; } -Temp emit_mbcnt(isel_context *ctx, Definition dst, - Operand mask_lo = Operand((uint32_t) -1), Operand mask_hi = Operand((uint32_t) -1)) +Temp emit_mbcnt(isel_context *ctx, Temp dst, Temp mask = Temp()) { Builder bld(ctx->program, ctx->block); - Definition lo_def = ctx->program->wave_size == 32 ? dst : bld.def(v1); - Temp thread_id_lo = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, lo_def, mask_lo, Operand(0u)); + assert(mask.id() == 0 || mask.regClass() == bld.lm); if (ctx->program->wave_size == 32) { - return thread_id_lo; - } else if (ctx->program->chip_class <= GFX7) { - Temp thread_id_hi = bld.vop2(aco_opcode::v_mbcnt_hi_u32_b32, dst, mask_hi, thread_id_lo); - return thread_id_hi; - } else { - Temp thread_id_hi = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32_e64, dst, mask_hi, thread_id_lo); - return thread_id_hi; + Operand mask_lo = mask.id() ? Operand(mask) : Operand(-1u); + return bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, Definition(dst), mask_lo, Operand(0u)); } + + Operand mask_lo(-1u); + Operand mask_hi(-1u); + + if (mask.id()) { + Builder::Result mask_split = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), bld.def(s1), mask); + mask_lo = Operand(mask_split.def(0).getTemp()); + mask_hi = Operand(mask_split.def(1).getTemp()); + } + + Temp mbcnt_lo = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), mask_lo, Operand(0u)); + + if (ctx->program->chip_class <= GFX7) + return bld.vop2(aco_opcode::v_mbcnt_hi_u32_b32, Definition(dst), mask_hi, mbcnt_lo); + else + return bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32_e64, Definition(dst), mask_hi, mbcnt_lo); } Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_needs_wqm = false) @@ -4194,7 +4203,7 @@ void visit_store_ls_or_es_output(isel_context *ctx, nir_intrinsic_instr *instr) unsigned itemsize = ctx->stage == vertex_geometry_gs ? ctx->program->info->vs.es_info.esgs_itemsize : ctx->program->info->tes.es_info.esgs_itemsize; - Temp thread_id = emit_mbcnt(ctx, bld.def(v1)); + Temp thread_id = emit_mbcnt(ctx, bld.tmp(v1)); Temp wave_idx = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), get_arg(ctx, ctx->args->merged_wave_info), Operand(4u << 16 | 24)); Temp vertex_idx = bld.vop2(aco_opcode::v_or_b32, bld.def(v1), thread_id, bld.v_mul24_imm(bld.def(v1), as_vgpr(ctx, wave_idx), ctx->program->wave_size)); @@ -6946,7 +6955,7 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te // return ((val & exec) >> cluster_offset) & cluster_mask != 0 //subgroupClusteredXor(): // return v_bnt_u32_b32(((val & exec) >> cluster_offset) & cluster_mask, 0) & 1 != 0 - Temp lane_id = emit_mbcnt(ctx, bld.def(v1)); + Temp lane_id = emit_mbcnt(ctx, bld.tmp(v1)); Temp cluster_offset = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(~uint32_t(cluster_size - 1)), lane_id); Temp tmp; @@ -6996,10 +7005,7 @@ Temp emit_boolean_exclusive_scan(isel_context *ctx, nir_op op, Temp src) else tmp = bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm)); - Builder::Result lohi = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), bld.def(s1), tmp); - Temp lo = lohi.def(0).getTemp(); - Temp hi = lohi.def(1).getTemp(); - Temp mbcnt = emit_mbcnt(ctx, bld.def(v1), Operand(lo), Operand(hi)); + Temp mbcnt = emit_mbcnt(ctx, bld.tmp(v1), tmp); Definition cmp_def = Definition(); if (op == nir_op_iand) @@ -7470,7 +7476,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) break; } case nir_intrinsic_load_local_invocation_index: { - Temp id = emit_mbcnt(ctx, bld.def(v1)); + Temp id = emit_mbcnt(ctx, bld.tmp(v1)); /* The tg_size bits [6:11] contain the subgroup id, * we need this multiplied by the wave size, and then OR the thread id to it. @@ -7498,7 +7504,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) break; } case nir_intrinsic_load_subgroup_invocation: { - emit_mbcnt(ctx, Definition(get_ssa_temp(ctx, &instr->dest.ssa))); + emit_mbcnt(ctx, get_ssa_temp(ctx, &instr->dest.ssa)); break; } case nir_intrinsic_load_num_subgroups: { @@ -7912,11 +7918,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) } case nir_intrinsic_mbcnt_amd: { Temp src = get_ssa_temp(ctx, instr->src[0].ssa); - RegClass rc = RegClass(src.type(), 1); - Temp mask_lo = bld.tmp(rc), mask_hi = bld.tmp(rc); - bld.pseudo(aco_opcode::p_split_vector, Definition(mask_lo), Definition(mask_hi), src); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); - Temp wqm_tmp = emit_mbcnt(ctx, bld.def(v1), Operand(mask_lo), Operand(mask_hi)); + Temp wqm_tmp = emit_mbcnt(ctx, bld.tmp(v1), src); emit_wqm(ctx, wqm_tmp, dst); break; } @@ -10383,7 +10386,7 @@ static void emit_streamout(isel_context *ctx, unsigned stream) Temp so_vtx_count = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), get_arg(ctx, ctx->args->streamout_config), Operand(0x70010u)); - Temp tid = emit_mbcnt(ctx, bld.def(v1)); + Temp tid = emit_mbcnt(ctx, bld.tmp(v1)); Temp can_emit = bld.vopc(aco_opcode::v_cmp_gt_i32, bld.def(bld.lm), so_vtx_count, tid); @@ -10842,7 +10845,7 @@ void ngg_emit_nogs_output(isel_context *ctx) /* Calculate LDS address where the GS threads stored the primitive ID. */ Temp wave_id_in_tg = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), get_arg(ctx, ctx->args->merged_wave_info), Operand(24u | (4u << 16))); - Temp thread_id_in_wave = emit_mbcnt(ctx, bld.def(v1)); + Temp thread_id_in_wave = emit_mbcnt(ctx, bld.tmp(v1)); Temp wave_id_mul = bld.v_mul24_imm(bld.def(v1), as_vgpr(ctx, wave_id_in_tg), ctx->program->wave_size); Temp thread_id_in_tg = bld.vadd32(bld.def(v1), Operand(wave_id_mul), Operand(thread_id_in_wave)); Temp addr = bld.v_mul24_imm(bld.def(v1), thread_id_in_tg, 4u);