From 8fe5aa95cb5572a28c4794e45164abd0fa29706c Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Wed, 23 Nov 2022 20:41:29 +0000 Subject: [PATCH] ac/nir: mask shift operands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NIR shifts are defined to truncate the shift amount to the number of bits needed to represent the bit-size of the value shifted. LLVM treats large shifts as poison. This fix achieves NIR semantics for shifts. As an example, a|(b << 32), where "a" is 32bits, should produce a|b according to NIR (because 32&31 == 0). This caused LLVM to incorrectly optimize "(a >> c) | (b << (32 - c))" to a u2u32(pack_64_2x32(a, b) >> c) (v_alignbit_b32), when the original NIR should have returned "a | b" if c==0. Signed-off-by: Rhys Perry Reviewed-by: Mihai Preda Reviewed-by: Qiang Yu Reviewed-by: Marek Olšák Cc: mesa-stable Part-of: (cherry picked from commit 064336d35977abd0d5b6ed37784c6cc42cf4f66f) --- .pick_status.json | 2 +- src/amd/llvm/ac_nir_to_llvm.c | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/.pick_status.json b/.pick_status.json index f52197291ad..2b751c0c97a 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -796,7 +796,7 @@ "description": "ac/nir: mask shift operands", "nominated": true, "nomination_type": 0, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null }, diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c index 9d3d60603a6..b34883d66cc 100644 --- a/src/amd/llvm/ac_nir_to_llvm.c +++ b/src/amd/llvm/ac_nir_to_llvm.c @@ -732,33 +732,45 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr) case nir_op_ixor: result = LLVMBuildXor(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ishl: + case nir_op_ishl: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildShl(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ishr: + } + case nir_op_ishr: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildAShr(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ushr: + } + case nir_op_ushr: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildLShr(ctx->ac.builder, src[0], src[1], ""); break; + } case nir_op_ilt: result = emit_int_cmp(&ctx->ac, LLVMIntSLT, src[0], src[1]); break;