brw/nir: Prepare try_rebuild_source for scalar values

Reviewed-by: Kenneth Graunke <kenneth@whitecape.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29884>
This commit is contained in:
Ian Romanick 2024-07-03 11:04:03 -07:00
parent 59f66b4150
commit 5ea9ed4798

View file

@ -4728,6 +4728,7 @@ brw_reduce_op_for_nir_reduction_op(nir_op op)
struct rebuild_resource {
unsigned idx;
std::vector<nir_def *> array;
const brw_reg *ssa_values;
};
static bool
@ -4757,6 +4758,9 @@ add_rebuild_src(nir_src *src, void *state)
{
struct rebuild_resource *res = (struct rebuild_resource *) state;
if (res->ssa_values[src->ssa->index].is_scalar)
return true;
for (nir_def *def : res->array) {
if (def == src->ssa)
return true;
@ -4772,11 +4776,22 @@ static brw_reg
try_rebuild_source(nir_to_brw_state &ntb, const brw::fs_builder &bld,
nir_def *resource_def, bool a64 = false)
{
if (ntb.ssa_values[resource_def->index].is_scalar) {
brw_reg r = ntb.ssa_values[resource_def->index];
/* All users of try_rebuild_source expect an integer type. Smash the
* type to an integer type with the same size.
*/
r.type = brw_type_with_size(BRW_TYPE_D, resource_def->bit_size);
return component(r, 0);
}
/* Create a build at the location of the resource_intel intrinsic */
fs_builder ubld = bld.exec_all().group(8 * reg_unit(ntb.devinfo), 0);
const unsigned grf_size = REG_SIZE * reg_unit(ntb.devinfo);
struct rebuild_resource resources = {};
resources.ssa_values = ntb.ssa_values;
resources.idx = 0;
if (!nir_foreach_src(resource_def->parent_instr,
@ -4815,17 +4830,41 @@ try_rebuild_source(nir_to_brw_state &ntb, const brw::fs_builder &bld,
brw_reg srcs[3];
for (unsigned s = 0; s < nir_op_infos[alu->op].num_inputs; s++) {
srcs[s] = offset(
ntb.resource_insts[alu->src[s].src.ssa->index]->dst,
ubld, alu->src[s].swizzle[0]);
brw_reg reg;
if (ntb.resource_insts[alu->src[s].src.ssa->index] == NULL) {
reg = ntb.ssa_values[alu->src[s].src.ssa->index];
assert(reg.is_scalar);
srcs[s] = retype(offset(reg, ubld, alu->src[s].swizzle[0]),
brw_int_type(brw_type_size_bytes(reg.type), false));
} else {
reg = ntb.resource_insts[alu->src[s].src.ssa->index]->dst;
srcs[s] = offset(reg, ubld, alu->src[s].swizzle[0]);
}
assert(srcs[s].file != BAD_FILE);
}
const enum brw_reg_type utype0 =
brw_type_with_size(BRW_TYPE_UD,
brw_type_size_bits(srcs[0].type));
const enum brw_reg_type utype1 =
nir_op_infos[alu->op].num_inputs > 1 ?
brw_type_with_size(BRW_TYPE_UD,
brw_type_size_bits(srcs[1].type)) :
BRW_TYPE_UD;
switch (alu->op) {
case nir_op_iadd:
ubld.ADD(srcs[0].file != IMM ? srcs[0] : srcs[1],
srcs[0].file != IMM ? srcs[1] : srcs[0],
&ntb.resource_insts[def->index]);
if (srcs[0].file != IMM) {
ubld.ADD(retype(srcs[0], utype0),
retype(srcs[1], utype1),
&ntb.resource_insts[def->index]);
} else {
ubld.ADD(retype(srcs[1], utype1),
retype(srcs[0], utype0),
&ntb.resource_insts[def->index]);
}
break;
case nir_op_iadd3: {
brw_reg dst = ubld.vgrf(srcs[0].type);
@ -4836,34 +4875,31 @@ try_rebuild_source(nir_to_brw_state &ntb, const brw::fs_builder &bld,
srcs[2]);
break;
}
case nir_op_ushr: {
enum brw_reg_type utype =
brw_type_with_size(BRW_TYPE_UD,
brw_type_size_bits(srcs[0].type));
ubld.SHR(retype(srcs[0], utype),
retype(srcs[1], utype),
case nir_op_ushr:
ubld.SHR(retype(srcs[0], utype0),
retype(srcs[1], utype1),
&ntb.resource_insts[def->index]);
break;
}
case nir_op_iand:
ubld.AND(srcs[0], srcs[1], &ntb.resource_insts[def->index]);
ubld.AND(retype(srcs[0], utype0),
retype(srcs[1], utype1),
&ntb.resource_insts[def->index]);
break;
case nir_op_ishl:
ubld.SHL(srcs[0], srcs[1], &ntb.resource_insts[def->index]);
ubld.SHL(retype(srcs[0], utype0),
retype(srcs[1], utype1),
&ntb.resource_insts[def->index]);
break;
case nir_op_mov:
break;
case nir_op_ult32: {
if (brw_type_size_bits(srcs[0].type) != 32)
break;
enum brw_reg_type utype =
brw_type_with_size(BRW_TYPE_UD,
brw_type_size_bits(srcs[0].type));
brw_reg dst = ubld.vgrf(utype);
brw_reg dst = ubld.vgrf(utype0);
ntb.resource_insts[def->index] =
ubld.CMP(dst,
retype(srcs[0], utype),
retype(srcs[1], utype),
retype(srcs[0], utype0),
retype(srcs[1], utype1),
brw_cmod_for_nir_comparison(alu->op));
break;
}