nir/opt_vectorize: add callback for max vectorization width

The callback allows to request different vectorization factors
per instruction depending on e.g. bitsize or opcode.

This patch also removes using the vectorize_vec2_16bit option
from nir_opt_vectorize().

Reviewed-by: Alyssa Rosenzweig <alyssa@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13080>
This commit is contained in:
Daniel Schürmann 2020-12-18 19:05:47 +01:00 committed by Marge Bot
parent 7ae206d76e
commit bd151a256e
8 changed files with 99 additions and 80 deletions

View file

@ -4042,14 +4042,16 @@ lower_bit_size_callback(const nir_instr *instr, void *_)
return 0; return 0;
} }
static bool static uint8_t
opt_vectorize_callback(const nir_instr *instr, void *_) opt_vectorize_callback(const nir_instr *instr, const void *_)
{ {
assert(instr->type == nir_instr_type_alu); if (instr->type != nir_instr_type_alu)
nir_alu_instr *alu = nir_instr_as_alu(instr); return 0;
unsigned bit_size = alu->dest.dest.ssa.bit_size;
const nir_alu_instr *alu = nir_instr_as_alu(instr);
const unsigned bit_size = alu->dest.dest.ssa.bit_size;
if (bit_size != 16) if (bit_size != 16)
return false; return 1;
switch (alu->op) { switch (alu->op) {
case nir_op_fadd: case nir_op_fadd:
@ -4069,12 +4071,12 @@ opt_vectorize_callback(const nir_instr *instr, void *_)
case nir_op_imax: case nir_op_imax:
case nir_op_umin: case nir_op_umin:
case nir_op_umax: case nir_op_umax:
return true; return 2;
case nir_op_ishl: /* TODO: in NIR, these have 32bit shift operands */ case nir_op_ishl: /* TODO: in NIR, these have 32bit shift operands */
case nir_op_ishr: /* while Radeon needs 16bit operands when vectorized */ case nir_op_ishr: /* while Radeon needs 16bit operands when vectorized */
case nir_op_ushr: case nir_op_ushr:
default: default:
return false; return 1;
} }
} }

View file

@ -3228,6 +3228,15 @@ typedef enum {
*/ */
typedef bool (*nir_instr_filter_cb)(const nir_instr *, const void *); typedef bool (*nir_instr_filter_cb)(const nir_instr *, const void *);
/** A vectorization width callback
*
* Returns the maximum vectorization width per instruction.
* 0, if the instruction must not be modified.
*
* The vectorization width must be a power of 2.
*/
typedef uint8_t (*nir_vectorize_cb)(const nir_instr *, const void *);
typedef struct nir_shader_compiler_options { typedef struct nir_shader_compiler_options {
bool lower_fdiv; bool lower_fdiv;
bool lower_ffma16; bool lower_ffma16;
@ -3455,7 +3464,11 @@ typedef struct nir_shader_compiler_options {
nir_instr_filter_cb lower_to_scalar_filter; nir_instr_filter_cb lower_to_scalar_filter;
/** /**
* Whether nir_opt_vectorize should only create 16-bit 2D vectors. * Disables potentially harmful algebraic transformations for architectures
* with SIMD-within-a-register semantics.
*
* Note, to actually vectorize 16bit instructions, use nir_opt_vectorize()
* with a suitable callback function.
*/ */
bool vectorize_vec2_16bit; bool vectorize_vec2_16bit;
@ -5485,9 +5498,7 @@ bool nir_lower_undef_to_zero(nir_shader *shader);
bool nir_opt_uniform_atomics(nir_shader *shader); bool nir_opt_uniform_atomics(nir_shader *shader);
typedef bool (*nir_opt_vectorize_cb)(const nir_instr *instr, void *data); bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
bool nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
void *data); void *data);
bool nir_opt_conditional_discard(nir_shader *shader); bool nir_opt_conditional_discard(nir_shader *shader);

View file

@ -22,6 +22,16 @@
* *
*/ */
/**
* nir_opt_vectorize() aims to vectorize ALU instructions.
*
* The default vectorization width is 4.
* If desired, a callback function which returns the max vectorization width
* per instruction can be provided.
*
* The max vectorization width must be a power of 2.
*/
#include "nir.h" #include "nir.h"
#include "nir_vla.h" #include "nir_vla.h"
#include "nir_builder.h" #include "nir_builder.h"
@ -125,7 +135,7 @@ instrs_equal(const void *data1, const void *data2)
} }
static bool static bool
instr_can_rewrite(nir_instr *instr, bool vectorize_16bit) instr_can_rewrite(nir_instr *instr)
{ {
switch (instr->type) { switch (instr->type) {
case nir_instr_type_alu: { case nir_instr_type_alu: {
@ -139,12 +149,7 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
return false; return false;
/* no need to hash instructions which are already vectorized */ /* no need to hash instructions which are already vectorized */
if (alu->dest.dest.ssa.num_components >= 4) if (alu->dest.dest.ssa.num_components >= instr->pass_flags)
return false;
if (vectorize_16bit &&
(alu->dest.dest.ssa.num_components >= 2 ||
alu->dest.dest.ssa.bit_size != 16))
return false; return false;
if (nir_op_infos[alu->op].output_size != 0) if (nir_op_infos[alu->op].output_size != 0)
@ -156,8 +161,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
/* don't hash instructions which are already swizzled /* don't hash instructions which are already swizzled
* outside of max_components: these should better be scalarized */ * outside of max_components: these should better be scalarized */
uint32_t mask = vectorize_16bit ? ~1 : ~3; uint32_t mask = ~(instr->pass_flags - 1);
for (unsigned j = 0; j < alu->dest.dest.ssa.num_components; j++) { for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) {
if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask)) if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
return false; return false;
} }
@ -179,10 +184,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
* the same instructions into one vectorized instruction. Note that instr1 * the same instructions into one vectorized instruction. Note that instr1
* should dominate instr2. * should dominate instr2.
*/ */
static nir_instr * static nir_instr *
instr_try_combine(struct nir_shader *nir, struct set *instr_set, instr_try_combine(struct set *instr_set, nir_instr *instr1, nir_instr *instr2)
nir_instr *instr1, nir_instr *instr2)
{ {
assert(instr1->type == nir_instr_type_alu); assert(instr1->type == nir_instr_type_alu);
assert(instr2->type == nir_instr_type_alu); assert(instr2->type == nir_instr_type_alu);
@ -194,14 +197,10 @@ instr_try_combine(struct nir_shader *nir, struct set *instr_set,
unsigned alu2_components = alu2->dest.dest.ssa.num_components; unsigned alu2_components = alu2->dest.dest.ssa.num_components;
unsigned total_components = alu1_components + alu2_components; unsigned total_components = alu1_components + alu2_components;
if (total_components > 4) assert(instr1->pass_flags == instr2->pass_flags);
if (total_components > instr1->pass_flags)
return NULL; return NULL;
if (nir->options->vectorize_vec2_16bit) {
assert(total_components == 2);
assert(alu1->dest.dest.ssa.bit_size == 16);
}
nir_builder b; nir_builder b;
nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node)); nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node));
b.cursor = nir_after_instr(instr1); b.cursor = nir_after_instr(instr1);
@ -352,28 +351,23 @@ vec_instr_set_destroy(struct set *instr_set)
} }
static bool static bool
vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set, vec_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr,
nir_instr *instr, nir_vectorize_cb filter, void *data)
nir_opt_vectorize_cb filter, void *data)
{ {
if (!instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit))
return false;
if (filter && !filter(instr, data))
return false;
/* set max vector to instr pass flags: this is used to hash swizzles */ /* set max vector to instr pass flags: this is used to hash swizzles */
instr->pass_flags = nir->options->vectorize_vec2_16bit ? 2 : 4; instr->pass_flags = filter ? filter(instr, data) : 4;
assert(util_is_power_of_two_or_zero(instr->pass_flags));
if (!instr_can_rewrite(instr))
return false;
struct set_entry *entry = _mesa_set_search(instr_set, instr); struct set_entry *entry = _mesa_set_search(instr_set, instr);
if (entry) { if (entry) {
nir_instr *old_instr = (nir_instr *) entry->key; nir_instr *old_instr = (nir_instr *) entry->key;
_mesa_set_remove(instr_set, entry); _mesa_set_remove(instr_set, entry);
nir_instr *new_instr = instr_try_combine(nir, instr_set, nir_instr *new_instr = instr_try_combine(instr_set, old_instr, instr);
old_instr, instr);
if (new_instr) { if (new_instr) {
if (instr_can_rewrite(new_instr, nir->options->vectorize_vec2_16bit) && if (instr_can_rewrite(new_instr))
(!filter || filter(new_instr, data)))
_mesa_set_add(instr_set, new_instr); _mesa_set_add(instr_set, new_instr);
return true; return true;
} }
@ -384,25 +378,23 @@ vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
} }
static bool static bool
vectorize_block(struct nir_shader *nir, nir_block *block, vectorize_block(nir_block *block, struct set *instr_set,
struct set *instr_set, nir_vectorize_cb filter, void *data)
nir_opt_vectorize_cb filter, void *data)
{ {
bool progress = false; bool progress = false;
nir_foreach_instr_safe(instr, block) { nir_foreach_instr_safe(instr, block) {
if (vec_instr_set_add_or_rewrite(nir, instr_set, instr, filter, data)) if (vec_instr_set_add_or_rewrite(instr_set, instr, filter, data))
progress = true; progress = true;
} }
for (unsigned i = 0; i < block->num_dom_children; i++) { for (unsigned i = 0; i < block->num_dom_children; i++) {
nir_block *child = block->dom_children[i]; nir_block *child = block->dom_children[i];
progress |= vectorize_block(nir, child, instr_set, filter, data); progress |= vectorize_block(child, instr_set, filter, data);
} }
nir_foreach_instr_reverse(instr, block) { nir_foreach_instr_reverse(instr, block) {
if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit) && if (instr_can_rewrite(instr))
(!filter || filter(instr, data)))
_mesa_set_remove_key(instr_set, instr); _mesa_set_remove_key(instr_set, instr);
} }
@ -410,14 +402,14 @@ vectorize_block(struct nir_shader *nir, nir_block *block,
} }
static bool static bool
nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl, nir_opt_vectorize_impl(nir_function_impl *impl,
nir_opt_vectorize_cb filter, void *data) nir_vectorize_cb filter, void *data)
{ {
struct set *instr_set = vec_instr_set_create(); struct set *instr_set = vec_instr_set_create();
nir_metadata_require(impl, nir_metadata_dominance); nir_metadata_require(impl, nir_metadata_dominance);
bool progress = vectorize_block(nir, nir_start_block(impl), instr_set, bool progress = vectorize_block(nir_start_block(impl), instr_set,
filter, data); filter, data);
if (progress) { if (progress) {
@ -432,14 +424,14 @@ nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
} }
bool bool
nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter, nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
void *data) void *data)
{ {
bool progress = false; bool progress = false;
nir_foreach_function(function, shader) { nir_foreach_function(function, shader) {
if (function->impl) if (function->impl)
progress |= nir_opt_vectorize_impl(shader, function->impl, filter, data); progress |= nir_opt_vectorize_impl(function->impl, filter, data);
} }
return progress; return progress;

View file

@ -3067,11 +3067,11 @@ type_size(const struct glsl_type *type, bool bindless)
/* Allow vectorizing of ALU instructions, but avoid vectorizing past what we /* Allow vectorizing of ALU instructions, but avoid vectorizing past what we
* can handle for 64-bit values in TGSI. * can handle for 64-bit values in TGSI.
*/ */
static bool static uint8_t
ntt_should_vectorize_instr(const nir_instr *instr, void *data) ntt_should_vectorize_instr(const nir_instr *instr, const void *data)
{ {
if (instr->type != nir_instr_type_alu) if (instr->type != nir_instr_type_alu)
return false; return 0;
nir_alu_instr *alu = nir_instr_as_alu(instr); nir_alu_instr *alu = nir_instr_as_alu(instr);
@ -3085,7 +3085,7 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data)
* *
* https://gitlab.freedesktop.org/virgl/virglrenderer/-/issues/195 * https://gitlab.freedesktop.org/virgl/virglrenderer/-/issues/195
*/ */
return false; return 1;
default: default:
break; break;
@ -3102,10 +3102,10 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data)
* 64-bit instrs in the first place, I don't see much reason to care about * 64-bit instrs in the first place, I don't see much reason to care about
* this. * this.
*/ */
return false; return 1;
} }
return true; return 4;
} }
static bool static bool

View file

@ -43,6 +43,18 @@ static bool si_alu_to_scalar_filter(const nir_instr *instr, const void *data)
return true; return true;
} }
static uint8_t si_vectorize_callback(const nir_instr *instr, const void *data)
{
if (instr->type != nir_instr_type_alu)
return 0;
nir_alu_instr *alu = nir_instr_as_alu(instr);
if (nir_dest_bit_size(alu->dest.dest) == 16)
return 2;
return 1;
}
void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first) void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first)
{ {
bool progress; bool progress;
@ -114,7 +126,7 @@ void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first)
NIR_PASS_V(nir, nir_opt_move_discards_to_top); NIR_PASS_V(nir, nir_opt_move_discards_to_top);
if (sscreen->options.fp16) if (sscreen->options.fp16)
NIR_PASS(progress, nir, nir_opt_vectorize, NULL, NULL); NIR_PASS(progress, nir, nir_opt_vectorize, si_vectorize_callback, NULL);
} while (progress); } while (progress);
NIR_PASS_V(nir, nir_lower_var_copies); NIR_PASS_V(nir, nir_lower_var_copies);

View file

@ -517,7 +517,7 @@ st_glsl_to_nir_post_opts(struct st_context *st, struct gl_program *prog,
if (nir->options->lower_int64_options) if (nir->options->lower_int64_options)
NIR_PASS(lowered_64bit_ops, nir, nir_lower_int64); NIR_PASS(lowered_64bit_ops, nir, nir_lower_int64);
if (revectorize) if (revectorize && !nir->options->vectorize_vec2_16bit)
NIR_PASS_V(nir, nir_opt_vectorize, nullptr, nullptr); NIR_PASS_V(nir, nir_opt_vectorize, nullptr, nullptr);
if (revectorize || lowered_64bit_ops) if (revectorize || lowered_64bit_ops)

View file

@ -4276,12 +4276,12 @@ bi_lower_bit_size(const nir_instr *instr, UNUSED void *data)
* (8-bit in Bifrost, 32-bit in NIR TODO - workaround!). Some conversions need * (8-bit in Bifrost, 32-bit in NIR TODO - workaround!). Some conversions need
* to be scalarized due to type size. */ * to be scalarized due to type size. */
static bool static uint8_t
bi_vectorize_filter(const nir_instr *instr, void *data) bi_vectorize_filter(const nir_instr *instr, const void *data)
{ {
/* Defaults work for everything else */ /* Defaults work for everything else */
if (instr->type != nir_instr_type_alu) if (instr->type != nir_instr_type_alu)
return true; return 0;
const nir_alu_instr *alu = nir_instr_as_alu(instr); const nir_alu_instr *alu = nir_instr_as_alu(instr);
@ -4293,10 +4293,17 @@ bi_vectorize_filter(const nir_instr *instr, void *data)
case nir_op_ushr: case nir_op_ushr:
case nir_op_f2i16: case nir_op_f2i16:
case nir_op_f2u16: case nir_op_f2u16:
return false; return 1;
default: default:
return true; break;
} }
/* Vectorized instructions cannot write more than 32-bit */
int dst_bit_size = nir_dest_bit_size(alu->dest.dest);
if (dst_bit_size == 16)
return 2;
else
return 1;
} }
static bool static bool

View file

@ -303,25 +303,20 @@ mdg_should_scalarize(const nir_instr *instr, const void *_unused)
} }
/* Only vectorize int64 up to vec2 */ /* Only vectorize int64 up to vec2 */
static bool static uint8_t
midgard_vectorize_filter(const nir_instr *instr, void *data) midgard_vectorize_filter(const nir_instr *instr, const void *data)
{ {
if (instr->type != nir_instr_type_alu) if (instr->type != nir_instr_type_alu)
return true; return 0;
const nir_alu_instr *alu = nir_instr_as_alu(instr); const nir_alu_instr *alu = nir_instr_as_alu(instr);
unsigned num_components = alu->dest.dest.ssa.num_components;
int src_bit_size = nir_src_bit_size(alu->src[0].src); int src_bit_size = nir_src_bit_size(alu->src[0].src);
int dst_bit_size = nir_dest_bit_size(alu->dest.dest); int dst_bit_size = nir_dest_bit_size(alu->dest.dest);
if (src_bit_size == 64 || dst_bit_size == 64) { if (src_bit_size == 64 || dst_bit_size == 64)
if (num_components > 1) return 2;
return false;
}
return true; return 4;
} }
static void static void