ac/nir: mask shift operands

NIR shifts are defined to truncate the shift amount to the number of bits
needed to represent the bit-size of the value shifted. LLVM treats large
shifts as poison. This fix achieves NIR semantics for shifts.

As an example, a|(b << 32), where "a" is 32bits, should produce a|b
according to NIR (because 32&31 == 0).

This caused LLVM to incorrectly optimize "(a >> c) | (b << (32 - c))" to a
u2u32(pack_64_2x32(a, b) >> c) (v_alignbit_b32), when the original NIR
should have returned "a | b" if c==0.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Mihai Preda <mhpreda@gmail.com>
Reviewed-by: Qiang Yu <yuq825@gmail.com>
Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Cc: mesa-stable
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19966>
(cherry picked from commit 064336d359)
This commit is contained in:
Rhys Perry 2022-11-23 20:41:29 +00:00 committed by Dylan Baker
parent 19e4daa0d1
commit 8fe5aa95cb
2 changed files with 16 additions and 4 deletions

View file

@ -796,7 +796,7 @@
"description": "ac/nir: mask shift operands",
"nominated": true,
"nomination_type": 0,
"resolution": 0,
"resolution": 1,
"main_sha": null,
"because_sha": null
},

View file

@ -732,33 +732,45 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
case nir_op_ixor:
result = LLVMBuildXor(ctx->ac.builder, src[0], src[1], "");
break;
case nir_op_ishl:
case nir_op_ishl: {
if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) <
ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])))
src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), "");
else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) >
ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])))
src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), "");
LLVMTypeRef type = LLVMTypeOf(src[0]);
src[1] = LLVMBuildAnd(ctx->ac.builder, src[1],
LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), "");
result = LLVMBuildShl(ctx->ac.builder, src[0], src[1], "");
break;
case nir_op_ishr:
}
case nir_op_ishr: {
if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) <
ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])))
src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), "");
else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) >
ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])))
src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), "");
LLVMTypeRef type = LLVMTypeOf(src[0]);
src[1] = LLVMBuildAnd(ctx->ac.builder, src[1],
LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), "");
result = LLVMBuildAShr(ctx->ac.builder, src[0], src[1], "");
break;
case nir_op_ushr:
}
case nir_op_ushr: {
if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) <
ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])))
src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), "");
else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) >
ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])))
src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), "");
LLVMTypeRef type = LLVMTypeOf(src[0]);
src[1] = LLVMBuildAnd(ctx->ac.builder, src[1],
LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), "");
result = LLVMBuildLShr(ctx->ac.builder, src[0], src[1], "");
break;
}
case nir_op_ilt:
result = emit_int_cmp(&ctx->ac, LLVMIntSLT, src[0], src[1]);
break;