mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-24 06:40:11 +01:00
nir: add a pass to optimize phis to 1bit
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33498>
This commit is contained in:
parent
a5f5d26080
commit
dd1a7f0e8c
3 changed files with 239 additions and 0 deletions
|
|
@ -255,6 +255,7 @@ files_libnir = files(
|
|||
'nir_opt_offsets.c',
|
||||
'nir_opt_peephole_select.c',
|
||||
'nir_opt_phi_precision.c',
|
||||
'nir_opt_phi_to_bool.c',
|
||||
'nir_opt_preamble.c',
|
||||
'nir_opt_ray_queries.c',
|
||||
'nir_opt_reassociate_bfi.c',
|
||||
|
|
|
|||
|
|
@ -6081,6 +6081,8 @@ bool nir_remove_single_src_phis_block(nir_block *block);
|
|||
|
||||
bool nir_opt_phi_precision(nir_shader *shader);
|
||||
|
||||
bool nir_opt_phi_to_bool(nir_shader *shader);
|
||||
|
||||
bool nir_opt_shrink_stores(nir_shader *shader, bool shrink_image_store);
|
||||
|
||||
bool nir_opt_shrink_vectors(nir_shader *shader, bool shrink_start);
|
||||
|
|
|
|||
236
src/compiler/nir/nir_opt_phi_to_bool.c
Normal file
236
src/compiler/nir/nir_opt_phi_to_bool.c
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
/*
|
||||
* Copyright 2025 Valve Corporation
|
||||
* SPDX-License-Identifier: MIT
|
||||
*/
|
||||
|
||||
#include "nir.h"
|
||||
#include "nir_builder.h"
|
||||
#include "nir_worklist.h"
|
||||
|
||||
/* Various other IRs do not have 1bit booleans and instead use 0/1, 0/-1, 0/1.0
|
||||
* This pass detects phis with all sources in one of these representations and
|
||||
* converts the phi to 1bit. The cleanup of related alu is left to other passes
|
||||
* like nir_opt_algebraic.
|
||||
*/
|
||||
|
||||
/* This enum is used to store what kind of bool the ssa def is in pass_flags.
|
||||
* It's a mask to allow multiple types for constant 0 and undef.
|
||||
*/
|
||||
enum bool_type {
|
||||
/* 0 is false, 1 is true. */
|
||||
bool_type_single_bit = BITFIELD_BIT(0),
|
||||
/* 0 is false, -1 is true. */
|
||||
bool_type_all_bits = BITFIELD_BIT(1),
|
||||
/* 0 is false, 1.0 is true. */
|
||||
bool_type_float = BITFIELD_BIT(2),
|
||||
|
||||
bool_type_all_types = BITFIELD_MASK(3),
|
||||
};
|
||||
|
||||
static inline uint8_t
|
||||
src_pass_flags(nir_src *src)
|
||||
{
|
||||
return src->ssa->parent_instr->pass_flags;
|
||||
}
|
||||
|
||||
static inline nir_block *
|
||||
block_get_loop_preheader(nir_block *block)
|
||||
{
|
||||
nir_cf_node *parent = block->cf_node.parent;
|
||||
if (parent->type != nir_cf_node_loop)
|
||||
return NULL;
|
||||
if (block != nir_cf_node_cf_tree_first(parent))
|
||||
return NULL;
|
||||
return nir_cf_node_as_block(nir_cf_node_prev(parent));
|
||||
}
|
||||
|
||||
static uint8_t
|
||||
get_bool_types_const(nir_load_const_instr *load)
|
||||
{
|
||||
uint8_t res = bool_type_all_types;
|
||||
unsigned bit_size = load->def.bit_size;
|
||||
for (unsigned i = 0; i < load->def.num_components; i++) {
|
||||
int64_t ival = nir_const_value_as_int(load->value[i], bit_size);
|
||||
if (ival == 0)
|
||||
continue;
|
||||
else if (ival == 1)
|
||||
res &= bool_type_single_bit;
|
||||
else if (ival == -1)
|
||||
res &= bool_type_all_bits;
|
||||
else if (bit_size >= 16 && nir_const_value_as_float(load->value[i], bit_size) == 1.0)
|
||||
res &= bool_type_float;
|
||||
else
|
||||
res = 0;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static uint8_t
|
||||
get_bool_types_phi(nir_phi_instr *phi)
|
||||
{
|
||||
uint8_t res = bool_type_all_types;
|
||||
nir_foreach_phi_src(phi_src, phi)
|
||||
res &= src_pass_flags(&phi_src->src);
|
||||
return res;
|
||||
}
|
||||
|
||||
static uint8_t
|
||||
negate_int_bool_types(nir_src *src)
|
||||
{
|
||||
uint8_t src_types = src_pass_flags(src);
|
||||
uint8_t res = 0;
|
||||
if (src_types & bool_type_single_bit)
|
||||
res |= bool_type_all_bits;
|
||||
if (src_types & bool_type_all_bits)
|
||||
res |= bool_type_single_bit;
|
||||
return res;
|
||||
}
|
||||
|
||||
static uint8_t
|
||||
get_bool_types_alu(nir_alu_instr *alu)
|
||||
{
|
||||
switch (alu->op) {
|
||||
case nir_op_b2i8:
|
||||
case nir_op_b2i16:
|
||||
case nir_op_b2i32:
|
||||
case nir_op_b2i64:
|
||||
return bool_type_single_bit;
|
||||
case nir_op_b2b8:
|
||||
case nir_op_b2b16:
|
||||
case nir_op_b2b32:
|
||||
return bool_type_all_bits;
|
||||
case nir_op_b2f16:
|
||||
case nir_op_b2f32:
|
||||
case nir_op_b2f64:
|
||||
return bool_type_float;
|
||||
case nir_op_ineg:
|
||||
return negate_int_bool_types(&alu->src[0].src);
|
||||
case nir_op_inot:
|
||||
return src_pass_flags(&alu->src[0].src) & bool_type_all_bits;
|
||||
case nir_op_bcsel:
|
||||
return src_pass_flags(&alu->src[1].src) & src_pass_flags(&alu->src[2].src);
|
||||
case nir_op_iand:
|
||||
if (src_pass_flags(&alu->src[0].src) & bool_type_all_bits)
|
||||
return src_pass_flags(&alu->src[1].src);
|
||||
if (src_pass_flags(&alu->src[1].src) & bool_type_all_bits)
|
||||
return src_pass_flags(&alu->src[0].src);
|
||||
FALLTHROUGH;
|
||||
case nir_op_imin:
|
||||
case nir_op_imax:
|
||||
case nir_op_umin:
|
||||
case nir_op_umax:
|
||||
case nir_op_ior:
|
||||
case nir_op_ixor:
|
||||
return src_pass_flags(&alu->src[0].src) & src_pass_flags(&alu->src[1].src);
|
||||
case nir_op_fmax:
|
||||
case nir_op_fmin:
|
||||
case nir_op_fmul:
|
||||
case nir_op_fmulz:
|
||||
return src_pass_flags(&alu->src[0].src) & src_pass_flags(&alu->src[1].src) & bool_type_float;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static uint8_t
|
||||
get_bool_types(nir_instr *instr)
|
||||
{
|
||||
switch (instr->type) {
|
||||
case nir_instr_type_undef:
|
||||
return bool_type_all_types;
|
||||
case nir_instr_type_load_const:
|
||||
return get_bool_types_const(nir_instr_as_load_const(instr));
|
||||
case nir_instr_type_phi:
|
||||
return get_bool_types_phi(nir_instr_as_phi(instr));
|
||||
case nir_instr_type_alu:
|
||||
return get_bool_types_alu(nir_instr_as_alu(instr));
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
phi_to_bool(nir_builder *b, nir_phi_instr *phi, void *unused)
|
||||
{
|
||||
if (!phi->instr.pass_flags || phi->def.bit_size == 1)
|
||||
return false;
|
||||
|
||||
enum bool_type type = BITFIELD_BIT(ffs(phi->instr.pass_flags) - 1);
|
||||
|
||||
unsigned bit_size = phi->def.bit_size;
|
||||
phi->def.bit_size = 1;
|
||||
|
||||
nir_foreach_phi_src(phi_src, phi) {
|
||||
b->cursor = nir_after_block_before_jump(phi_src->pred);
|
||||
nir_def *src = phi_src->src.ssa;
|
||||
if (src == &phi->def)
|
||||
continue;
|
||||
else if (nir_src_is_undef(phi_src->src))
|
||||
src = nir_undef(b, phi->def.num_components, 1);
|
||||
else if (type == bool_type_float)
|
||||
src = nir_fneu_imm(b, src, 0);
|
||||
else
|
||||
src = nir_i2b(b, src);
|
||||
|
||||
nir_src_rewrite(&phi_src->src, src);
|
||||
}
|
||||
|
||||
b->cursor = nir_after_phis(phi->instr.block);
|
||||
|
||||
nir_def *res = &phi->def;
|
||||
if (type == bool_type_single_bit)
|
||||
res = nir_b2iN(b, res, bit_size);
|
||||
else if (type == bool_type_all_bits)
|
||||
res = nir_bcsel(b, res, nir_imm_intN_t(b, -1, bit_size), nir_imm_intN_t(b, 0, bit_size));
|
||||
else if (type == bool_type_float)
|
||||
res = nir_b2fN(b, res, bit_size);
|
||||
else
|
||||
unreachable("invalid bool_type");
|
||||
|
||||
nir_foreach_use_safe(src, &phi->def) {
|
||||
if (nir_src_parent_instr(src) == &phi->instr ||
|
||||
nir_src_parent_instr(src) == res->parent_instr)
|
||||
continue;
|
||||
nir_src_rewrite(src, res);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
nir_opt_phi_to_bool(nir_shader *shader)
|
||||
{
|
||||
nir_instr_worklist *worklist = nir_instr_worklist_create();
|
||||
|
||||
nir_foreach_function_impl(impl, shader) {
|
||||
nir_foreach_block(block, impl) {
|
||||
nir_block *preheader = block_get_loop_preheader(block);
|
||||
nir_foreach_instr(instr, block) {
|
||||
if (instr->type == nir_instr_type_phi && preheader) {
|
||||
nir_phi_src *phi_src = nir_phi_get_src_from_block(nir_instr_as_phi(instr), preheader);
|
||||
instr->pass_flags = src_pass_flags(&phi_src->src);
|
||||
/* We only know the types of the preheader phi source
|
||||
* so we need to revisit it later if nessecary.
|
||||
*/
|
||||
if (instr->pass_flags)
|
||||
nir_instr_worklist_push_tail(worklist, instr);
|
||||
} else {
|
||||
instr->pass_flags = get_bool_types(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nir_foreach_instr_in_worklist(instr, worklist) {
|
||||
uint8_t bool_types = get_bool_types(instr);
|
||||
if (instr->pass_flags != bool_types) {
|
||||
instr->pass_flags = bool_types;
|
||||
nir_foreach_use(use, nir_instr_def(instr))
|
||||
nir_instr_worklist_push_tail(worklist, nir_src_parent_instr(use));
|
||||
}
|
||||
}
|
||||
|
||||
nir_instr_worklist_destroy(worklist);
|
||||
|
||||
return nir_shader_phi_pass(shader, phi_to_bool, nir_metadata_control_flow, NULL);
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue