diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index fa540f3997d..b958316e3e9 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -2792,6 +2792,119 @@ set_speculate(nir_builder *b, nir_intrinsic_instr *intr, UNUSED void *_) return true; } +static bool +optimize_bounds(nir_builder *b, nir_intrinsic_instr *intr, void *data) +{ + if (intr->intrinsic != nir_intrinsic_load_constant_agx) + return false; + + assert(intr->def.bit_size > 1 && "no if-uses"); + nir_scalar srcs[2] = {{NULL}}; + unsigned use_count = 0; + + nir_alu_instr *first_use = NULL; + + nir_foreach_use(use, &intr->def) { + /* All uses need to be bounds_agx */ + nir_instr *parent = nir_src_parent_instr(use); + if (parent->type != nir_instr_type_alu) + return false; + + nir_alu_instr *alu = nir_instr_as_alu(parent); + if ((alu->op != nir_op_bounds_agx) || (alu->src[0].src.ssa != &intr->def)) + return false; + + assert(alu->def.num_components == 1 && alu->def.bit_size == 32); + + /* All bounds checks need a common offset and bounds */ + for (unsigned s = 0; s < 2; ++s) { + nir_scalar this = nir_scalar_resolved(alu->src[1 + s].src.ssa, + alu->src[1 + s].swizzle[0]); + + if (srcs[s].def == NULL) + srcs[s] = this; + else if (!nir_scalar_equal(srcs[s], this)) + return false; + + /* To avoid dominance problems, we must sink loads. */ + if (this.def->parent_instr->block != intr->instr.block) { + return false; + } + } + + if (!first_use || first_use->def.index > alu->def.index) { + first_use = alu; + } + + ++use_count; + } + + /* We've matched. Freeze the set of uses before chaning things. */ + nir_alu_instr **uses = alloca(sizeof(nir_alu_instr *) * use_count); + + unsigned i = 0; + nir_foreach_use(use, &intr->def) { + nir_instr *parent = nir_src_parent_instr(use); + uses[i++] = nir_instr_as_alu(parent); + } + assert(i == use_count && "should not have changed"); + + /* Sink the load */ + nir_instr_remove(&intr->instr); + b->cursor = nir_before_instr(&first_use->instr); + nir_builder_instr_insert(b, &intr->instr); + + /* Now start rewriting. Grab some common variables */ + b->cursor = nir_before_instr(&intr->instr); + nir_def *offset = nir_channel(b, srcs[0].def, srcs[0].comp); + nir_def *bounds = nir_channel(b, srcs[1].def, srcs[1].comp); + + nir_def *in_bounds = nir_uge(b, bounds, offset); + nir_def *zero = nir_imm_int(b, 0); + + nir_src *base_src = &intr->src[0]; + nir_src *offs_src = &intr->src[1]; + + nir_def *base_lo = nir_unpack_64_2x32_split_x(b, base_src->ssa); + nir_def *base_hi = nir_unpack_64_2x32_split_y(b, base_src->ssa); + + /* Bounds check the base/offset instead. We currently reserve the bottom + * 2^36 of VA (this is driver/compiler ABI). With soft fault enabled, that + * means any read of the lower region will return zero as required. + * + * Therefore, when out-of-bounds, we clamp the index to zero and the high + * half of the address to zero. We don't need to clamp the low half of the + * address. The resulting sum is thus: + * + * 0*(2^32) + lo + (index << shift) + * + * ...which will be in the unmapped zero region provided shift < 4. + */ + base_hi = nir_bcsel(b, in_bounds, base_hi, zero); + + /* Clamp index if the shift is too large or sign-extension used */ + if (nir_intrinsic_base(intr) >= 2 || nir_intrinsic_sign_extend(intr)) { + nir_src_rewrite(offs_src, nir_bcsel(b, in_bounds, offs_src->ssa, zero)); + } + + nir_src_rewrite(base_src, nir_pack_64_2x32_split(b, base_lo, base_hi)); + + /* Now that the load itself is bounds checked, all that's left is removing + * the bounds checks on the output. This requires a little care to avoid an + * infinite loop. + * + * Also note we cannot remove the uses here, because it would invalidate the + * iterator inside intrinsics_pass. I hate C, don't you? + */ + for (unsigned i = 0; i < use_count; ++i) { + b->cursor = nir_after_instr(&uses[i]->instr); + nir_def *chan = nir_channel(b, &intr->def, uses[i]->src[0].swizzle[0]); + nir_def_rewrite_uses(&uses[i]->def, chan); + } + + return true; +} + static void agx_optimize_nir(nir_shader *nir, bool soft_fault, unsigned *preamble_size) { @@ -2899,6 +3012,21 @@ agx_optimize_nir(nir_shader *nir, bool soft_fault, unsigned *preamble_size) NIR_PASS(progress, nir, agx_nir_fuse_algebraic_late); } while (progress); + /* Before optimizing bounds checks, we need to clean up and index defs so + * optimize_bounds does the right thing. + */ + NIR_PASS(_, nir, nir_copy_prop); + NIR_PASS(_, nir, nir_opt_dce); + + nir_foreach_function_impl(impl, nir) { + nir_index_ssa_defs(impl); + } + + if (soft_fault) { + NIR_PASS(_, nir, nir_shader_intrinsics_pass, optimize_bounds, + nir_metadata_control_flow, NULL); + } + /* Do remaining lowering late, since this inserts &s for shifts so we want to * do it after fusing constant shifts. Constant folding will clean up. */