nir: Add component mask re-interpret helpers

These are based on the ones which already existed in the load/store
vectorization pass but I made some improvements while moving them.  In
particular,

 1. They're both faster if the bit sizes are equal
 2. The check is faster if old_bit_size > new_bit_size
 3. The check now fails if it would use more than NIR_MAX_VEC_COMPONENTS

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6871>
This commit is contained in:
Jason Ekstrand 2020-09-25 16:01:03 -05:00 committed by Marge Bot
parent 57e7c5f05e
commit 769ede2de4
3 changed files with 78 additions and 39 deletions

View file

@ -36,6 +36,68 @@
#include "main/menums.h" /* BITFIELD64_MASK */
/** Return true if the component mask "mask" with bit size "old_bit_size" can
* be re-interpreted to be used with "new_bit_size".
*/
bool
nir_component_mask_can_reinterpret(nir_component_mask_t mask,
unsigned old_bit_size,
unsigned new_bit_size)
{
assert(util_is_power_of_two_nonzero(old_bit_size));
assert(util_is_power_of_two_nonzero(new_bit_size));
if (old_bit_size == new_bit_size)
return true;
if (old_bit_size == 1 || new_bit_size == 1)
return false;
if (old_bit_size > new_bit_size) {
unsigned ratio = old_bit_size / new_bit_size;
return util_last_bit(mask) * ratio <= NIR_MAX_VEC_COMPONENTS;
}
unsigned iter = mask;
while (iter) {
int start, count;
u_bit_scan_consecutive_range(&iter, &start, &count);
start *= old_bit_size;
count *= old_bit_size;
if (start % new_bit_size != 0)
return false;
if (count % new_bit_size != 0)
return false;
}
return true;
}
/** Re-interprets a component mask "mask" with bit size "old_bit_size" so that
* it can be used can be used with "new_bit_size".
*/
nir_component_mask_t
nir_component_mask_reinterpret(nir_component_mask_t mask,
unsigned old_bit_size,
unsigned new_bit_size)
{
assert(nir_component_mask_can_reinterpret(mask, old_bit_size, new_bit_size));
if (old_bit_size == new_bit_size)
return mask;
nir_component_mask_t new_mask = 0;
unsigned iter = mask;
while (iter) {
int start, count;
u_bit_scan_consecutive_range(&iter, &start, &count);
start = start * old_bit_size / new_bit_size;
count = count * old_bit_size / new_bit_size;
new_mask |= BITFIELD_RANGE(start, count);
}
return new_mask;
}
nir_shader *
nir_shader_create(void *mem_ctx,
gl_shader_stage stage,

View file

@ -76,6 +76,14 @@ nir_num_components_valid(unsigned num_components)
num_components == 16;
}
bool nir_component_mask_can_reinterpret(nir_component_mask_t mask,
unsigned old_bit_size,
unsigned new_bit_size);
nir_component_mask_t
nir_component_mask_reinterpret(nir_component_mask_t mask,
unsigned old_bit_size,
unsigned new_bit_size);
/** Defines a cast function
*
* This macro defines a cast function from in_type to out_type where

View file

@ -625,25 +625,6 @@ cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref
return nir_build_deref_cast(b, &deref->dest.ssa, deref->mode, type, 0);
}
/* Return true if the write mask "write_mask" of a store with "old_bit_size"
* bits per element can be represented for a store with "new_bit_size" bits per
* element. */
static bool
writemask_representable(unsigned write_mask, unsigned old_bit_size, unsigned new_bit_size)
{
while (write_mask) {
int start, count;
u_bit_scan_consecutive_range(&write_mask, &start, &count);
start *= old_bit_size;
count *= old_bit_size;
if (start % new_bit_size != 0)
return false;
if (count % new_bit_size != 0)
return false;
}
return true;
}
/* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
* of "low" and "high". */
static bool
@ -683,33 +664,17 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
return false;
unsigned write_mask = nir_intrinsic_write_mask(low->intrin);
if (!writemask_representable(write_mask, get_bit_size(low), new_bit_size))
if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(low), new_bit_size))
return false;
write_mask = nir_intrinsic_write_mask(high->intrin);
if (!writemask_representable(write_mask, get_bit_size(high), new_bit_size))
if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(high), new_bit_size))
return false;
}
return true;
}
/* Updates a write mask, "write_mask", so that it can be used with a
* "new_bit_size"-bit store instead of a "old_bit_size"-bit store. */
static uint32_t
update_writemask(unsigned write_mask, unsigned old_bit_size, unsigned new_bit_size)
{
uint32_t res = 0;
while (write_mask) {
int start, count;
u_bit_scan_consecutive_range(&write_mask, &start, &count);
start = start * old_bit_size / new_bit_size;
count = count * old_bit_size / new_bit_size;
res |= ((1 << count) - 1) << start;
}
return res;
}
static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
{
/* avoid adding another deref to the path */
@ -847,8 +812,12 @@ vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
/* get new writemasks */
uint32_t low_write_mask = nir_intrinsic_write_mask(low->intrin);
uint32_t high_write_mask = nir_intrinsic_write_mask(high->intrin);
low_write_mask = update_writemask(low_write_mask, get_bit_size(low), new_bit_size);
high_write_mask = update_writemask(high_write_mask, get_bit_size(high), new_bit_size);
low_write_mask = nir_component_mask_reinterpret(low_write_mask,
get_bit_size(low),
new_bit_size);
high_write_mask = nir_component_mask_reinterpret(high_write_mask,
get_bit_size(high),
new_bit_size);
high_write_mask <<= high_start / new_bit_size;
uint32_t write_mask = low_write_mask | high_write_mask;