/* * Copyright 2025 Valve Corporation * SPDX-License-Identifier: MIT */ #include "nir.h" #include "nir_builder.h" #include "nir_worklist.h" /* Various other IRs do not have 1bit booleans and instead use 0/1, 0/-1, 0/1.0 * This pass detects phis with all sources in one of these representations and * converts the phi to 1bit. The cleanup of related alu is left to other passes * like nir_opt_algebraic. */ /* This enum is used to store what kind of bool the ssa def is in pass_flags. * It's a mask to allow multiple types for constant 0 and undef. */ enum bool_type { /* 0 is false, 1 is true. */ bool_type_single_bit = BITFIELD_BIT(0), /* 0 is false, -1 is true. */ bool_type_all_bits = BITFIELD_BIT(1), /* 0 is false, 1.0 is true. */ bool_type_float = BITFIELD_BIT(2), bool_type_all_types = BITFIELD_MASK(3), }; static inline uint8_t src_pass_flags(nir_src *src) { return src->ssa->parent_instr->pass_flags; } static inline nir_block * block_get_loop_preheader(nir_block *block) { nir_cf_node *parent = block->cf_node.parent; if (parent->type != nir_cf_node_loop) return NULL; if (block != nir_cf_node_cf_tree_first(parent)) return NULL; return nir_cf_node_as_block(nir_cf_node_prev(parent)); } static uint8_t get_bool_types_const(nir_load_const_instr *load) { uint8_t res = bool_type_all_types; unsigned bit_size = load->def.bit_size; for (unsigned i = 0; i < load->def.num_components; i++) { int64_t ival = nir_const_value_as_int(load->value[i], bit_size); if (ival == 0) continue; else if (ival == 1) res &= bool_type_single_bit; else if (ival == -1) res &= bool_type_all_bits; else if (bit_size >= 16 && nir_const_value_as_float(load->value[i], bit_size) == 1.0) res &= bool_type_float; else res = 0; } return res; } static uint8_t get_bool_types_phi(nir_phi_instr *phi) { uint8_t res = bool_type_all_types; nir_foreach_phi_src(phi_src, phi) res &= src_pass_flags(&phi_src->src); return res; } static uint8_t negate_int_bool_types(nir_src *src) { uint8_t src_types = src_pass_flags(src); uint8_t res = 0; if (src_types & bool_type_single_bit) res |= bool_type_all_bits; if (src_types & bool_type_all_bits) res |= bool_type_single_bit; return res; } static uint8_t get_bool_types_alu(nir_alu_instr *alu) { switch (alu->op) { case nir_op_b2i8: case nir_op_b2i16: case nir_op_b2i32: case nir_op_b2i64: return bool_type_single_bit; case nir_op_b2b8: case nir_op_b2b16: case nir_op_b2b32: return bool_type_all_bits; case nir_op_b2f16: case nir_op_b2f32: case nir_op_b2f64: return bool_type_float; case nir_op_ineg: return negate_int_bool_types(&alu->src[0].src); case nir_op_inot: return src_pass_flags(&alu->src[0].src) & bool_type_all_bits; case nir_op_bcsel: return src_pass_flags(&alu->src[1].src) & src_pass_flags(&alu->src[2].src); case nir_op_iand: if (src_pass_flags(&alu->src[0].src) & bool_type_all_bits) return src_pass_flags(&alu->src[1].src); if (src_pass_flags(&alu->src[1].src) & bool_type_all_bits) return src_pass_flags(&alu->src[0].src); FALLTHROUGH; case nir_op_imin: case nir_op_imax: case nir_op_umin: case nir_op_umax: case nir_op_ior: case nir_op_ixor: return src_pass_flags(&alu->src[0].src) & src_pass_flags(&alu->src[1].src); case nir_op_fmax: case nir_op_fmin: case nir_op_fmul: case nir_op_fmulz: return src_pass_flags(&alu->src[0].src) & src_pass_flags(&alu->src[1].src) & bool_type_float; default: return 0; } } static uint8_t get_bool_types(nir_instr *instr) { switch (instr->type) { case nir_instr_type_undef: return bool_type_all_types; case nir_instr_type_load_const: return get_bool_types_const(nir_instr_as_load_const(instr)); case nir_instr_type_phi: return get_bool_types_phi(nir_instr_as_phi(instr)); case nir_instr_type_alu: return get_bool_types_alu(nir_instr_as_alu(instr)); default: return 0; } } static bool phi_to_bool(nir_builder *b, nir_phi_instr *phi, void *unused) { if (!phi->instr.pass_flags || phi->def.bit_size == 1) return false; enum bool_type type = BITFIELD_BIT(ffs(phi->instr.pass_flags) - 1); unsigned bit_size = phi->def.bit_size; phi->def.bit_size = 1; nir_foreach_phi_src(phi_src, phi) { b->cursor = nir_after_block_before_jump(phi_src->pred); nir_def *src = phi_src->src.ssa; if (src == &phi->def) continue; else if (nir_src_is_undef(phi_src->src)) src = nir_undef(b, phi->def.num_components, 1); else if (type == bool_type_float) src = nir_fneu_imm(b, src, 0); else src = nir_i2b(b, src); nir_src_rewrite(&phi_src->src, src); } b->cursor = nir_after_phis(phi->instr.block); nir_def *res = &phi->def; if (type == bool_type_single_bit) res = nir_b2iN(b, res, bit_size); else if (type == bool_type_all_bits) res = nir_bcsel(b, res, nir_imm_intN_t(b, -1, bit_size), nir_imm_intN_t(b, 0, bit_size)); else if (type == bool_type_float) res = nir_b2fN(b, res, bit_size); else unreachable("invalid bool_type"); nir_foreach_use_safe(src, &phi->def) { if (nir_src_parent_instr(src) == &phi->instr || nir_src_parent_instr(src) == res->parent_instr) continue; nir_src_rewrite(src, res); } return true; } bool nir_opt_phi_to_bool(nir_shader *shader) { nir_instr_worklist *worklist = nir_instr_worklist_create(); nir_foreach_function_impl(impl, shader) { nir_foreach_block(block, impl) { nir_block *preheader = block_get_loop_preheader(block); nir_foreach_instr(instr, block) { if (instr->type == nir_instr_type_phi && preheader) { nir_phi_src *phi_src = nir_phi_get_src_from_block(nir_instr_as_phi(instr), preheader); instr->pass_flags = src_pass_flags(&phi_src->src); /* We only know the types of the preheader phi source * so we need to revisit it later if nessecary. */ if (instr->pass_flags) nir_instr_worklist_push_tail(worklist, instr); } else { instr->pass_flags = get_bool_types(instr); } } } } nir_foreach_instr_in_worklist(instr, worklist) { uint8_t bool_types = get_bool_types(instr); if (instr->pass_flags != bool_types) { instr->pass_flags = bool_types; nir_foreach_use(use, nir_instr_def(instr)) nir_instr_worklist_push_tail(worklist, nir_src_parent_instr(use)); } } nir_instr_worklist_destroy(worklist); return nir_shader_phi_pass(shader, phi_to_bool, nir_metadata_control_flow, NULL); }