diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index 20c935f1c18..3179b6d2d2e 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -1576,24 +1576,6 @@ agx_emit_alu(agx_builder *b, nir_alu_instr *instr) case nir_op_bcsel: return agx_icmpsel_to(b, dst, s0, i0, s2, s1, AGX_ICOND_UEQ); - case nir_op_b2i32: - case nir_op_b2i16: - case nir_op_b2i8: - return agx_icmpsel_to(b, dst, s0, i0, i0, i1, AGX_ICOND_UEQ); - - case nir_op_b2b32: - return agx_icmpsel_to(b, dst, s0, i0, i0, agx_mov_imm(b, 32, 0xFFFFFFFF), - AGX_ICOND_UEQ); - - case nir_op_b2f16: - case nir_op_b2f32: { - /* At this point, boolean is just zero/nonzero, so compare with zero */ - agx_index f1 = (sz == 16) ? agx_mov_imm(b, 16, _mesa_float_to_half(1.0)) - : agx_mov_imm(b, 32, fui(1.0)); - - return agx_fcmpsel_to(b, dst, s0, i0, i0, f1, AGX_FCOND_EQ); - } - case nir_op_i2i32: { if (src_sz == 8) { /* Sign extend in software, NIR likes 8-bit conversions */ diff --git a/src/asahi/compiler/agx_nir_algebraic.py b/src/asahi/compiler/agx_nir_algebraic.py index 18119230463..bcc630ffdb3 100644 --- a/src/asahi/compiler/agx_nir_algebraic.py +++ b/src/asahi/compiler/agx_nir_algebraic.py @@ -83,6 +83,17 @@ lower_pack = [ ('isub', 32, 'bits'))), ] +lower_selects = [] + +for T, sizes, one in [('f', [16, 32], 1.0), + ('i', [8, 16, 32], 1), + ('b', [32], -1)]: + for size in sizes: + lower_selects.extend([ + ((f'b2{T}{size}', ('inot', 'a@1')), ('bcsel', a, 0, one)), + ((f'b2{T}{size}', 'a@1'), ('bcsel', a, one, 0)), + ]) + fuse_extr = [] for start in range(32): fuse_extr.extend([ @@ -170,7 +181,8 @@ def run(): print('#include "agx_nir.h"') print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late", - lower_sm5_shift + lower_pack).render()) + lower_sm5_shift + lower_pack + + lower_selects).render()) print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late", fuse_extr + fuse_ubfe + fuse_imad).render()) print(nir_algebraic.AlgebraicPass("agx_nir_opt_ixor_bcsel",