agx: Mask shifts in the backend

This gives our shifts SM5 behaviour at the cost of a little extra ALU. That way,
we match NIR's shifts.

This fixes unsoundness of GLSL expressions like "a << (b & 31)", where the &
would mistakenly get optimized away.

Closes: #8181
Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reported-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21673>
This commit is contained in:
Alyssa Rosenzweig 2023-01-29 13:29:32 -05:00 committed by Marge Bot
parent f4e2b22646
commit 3032e3ad23
4 changed files with 69 additions and 1 deletions

View file

@ -31,6 +31,7 @@
#include "agx_builder.h"
#include "agx_compiler.h"
#include "agx_internal_formats.h"
#include "agx_nir.h"
/* Alignment for shader programs. I'm not sure what the optimal value is. */
#define AGX_CODE_ALIGN 0x100
@ -1918,6 +1919,7 @@ agx_optimize_nir(nir_shader *nir, unsigned *preamble_size)
NIR_PASS_V(nir, nir_opt_peephole_select, 64, false, true);
NIR_PASS_V(nir, nir_opt_algebraic_late);
NIR_PASS_V(nir, agx_nir_lower_algebraic_late);
NIR_PASS_V(nir, nir_opt_constant_folding);
/* Must run after uses are fixed but before a last round of copyprop + DCE */

View file

@ -0,0 +1,14 @@
/*
* Copyright 2023 Alyssa Rosenzweig
* SPDX-License-Identifier: MIT
*/
#ifndef AGX_NIR_H
#define AGX_NIR_H
#include <stdbool.h>
struct nir_shader;
bool agx_nir_lower_algebraic_late(struct nir_shader *shader);
#endif

View file

@ -0,0 +1,41 @@
# Copyright 2022 Alyssa Rosenzweig
# Copyright 2021 Collabora, Ltd.
# Copyright 2016 Intel Corporation
# SPDX-License-Identifier: MIT
import argparse
import sys
import math
a = 'a'
b = 'b'
c = 'c'
lower_sm5_shift = []
# Our shifts differ from SM5 for the upper bits. Mask to match the NIR
# behaviour. Because this happens as a late lowering, NIR won't optimize the
# masking back out (that happens in the main nir_opt_algebraic).
for s in [8, 16, 32, 64]:
for shift in ["ishl", "ishr", "ushr"]:
lower_sm5_shift += [((shift, f'a@{s}', b),
(shift, a, ('iand', b, s - 1)))]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--import-path', required=True)
args = parser.parse_args()
sys.path.insert(0, args.import_path)
run()
def run():
import nir_algebraic # pylint: disable=import-error
print('#include "agx_nir.h"')
print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
lower_sm5_shift).render())
if __name__ == '__main__':
main()

View file

@ -43,6 +43,17 @@ libasahi_agx_files = files(
'agx_validate.c',
)
agx_nir_algebraic_c = custom_target(
'agx_nir_algebraic.c',
input : 'agx_nir_algebraic.py',
output : 'agx_nir_algebraic.c',
command : [
prog_python, '@INPUT@', '-p', dir_compiler_nir,
],
capture : true,
depend_files : nir_algebraic_depends,
)
agx_opcodes_h = custom_target(
'agx_opcodes.h',
input : ['agx_opcodes.h.py'],
@ -82,7 +93,7 @@ idep_agx_builder_h = declare_dependency(
libasahi_compiler = static_library(
'asahi_compiler',
[libasahi_agx_files, agx_opcodes_c],
[libasahi_agx_files, agx_opcodes_c, agx_nir_algebraic_c],
include_directories : [inc_include, inc_src, inc_mesa, inc_gallium, inc_gallium_aux, inc_mapi],
dependencies: [idep_nir, idep_agx_opcodes_h, idep_agx_builder_h, idep_agx_pack],
c_args : [no_override_init_args],