nir: Add a shader bisect tool.

When you're trying to figure out what shader some NIR pass broke, use
nir_shader_bisect_select() to decide between NIR pass behaviors, and then
nir_shader_bisect.py will help you automatically bisect down to which
source_blake3 is at fault.  Once it's identified, it prints you a C call
you can use for selecting that shader specifically, which you can use for
continuing on in your debugging.

On a test I was looking at, this took 10 steps to bisect 134 shaders down
to the source_blake3 of the NIR shader in question.

This idea is heavily lifted from Job Noorman's ir3_shader_bisect.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37468>
This commit is contained in:
Emma Anholt 2025-09-18 09:23:36 -07:00 committed by Marge Bot
parent 6d2b2963a2
commit d01aae2fb1
4 changed files with 177 additions and 0 deletions

View file

@ -322,6 +322,7 @@ else
'nir_search_helpers.h',
'nir_serialize.c',
'nir_serialize.h',
'nir_shader_bisect.c',
'nir_shader_compiler_options.h',
'nir_split_64bit_vec3_and_vec4.c',
'nir_split_conversions.c',

View file

@ -3977,6 +3977,8 @@ nir_shader *nir_shader_create(void *mem_ctx,
mesa_shader_stage stage,
const nir_shader_compiler_options *options);
bool nir_shader_bisect_select(nir_shader *s);
/** Adds a variable to the appropriate list in nir_shader */
void nir_shader_add_variable(nir_shader *shader, nir_variable *var);

View file

@ -0,0 +1,73 @@
#include "nir.h"
#include "util/hex.h"
#include "util/log.h"
#include "util/mesa-blake3.h"
#include "util/u_call_once.h"
/** @file
* Shader bisect support.
*
* Simply use nir_shader_bisect_select() to control some bad behavior you've
* identified (calling a shader pass or executing some bad part of a
* shader_pass), then run your application under nir_shader_bisect.py to be
* interactively guided through bisecting down to which NIR shader in your
* program is being badly affected by the code in question.
*
* Note that doing this requires (unless someone rigs up cache key handling)
* MESA_SHADER_DISABLE_CACHE=1, which is also set by nir_shader_bisect.py.
*/
/* These are expected to be hex dumps of blake3 bytes, no space and no leading
* '0x'
*/
static const char *nir_shader_bisect_lo = NULL;
static const char *nir_shader_bisect_hi = NULL;
DEBUG_GET_ONCE_OPTION(nir_shader_bisect_lo, "NIR_SHADER_BISECT_LO", NULL);
DEBUG_GET_ONCE_OPTION(nir_shader_bisect_hi, "NIR_SHADER_BISECT_HI", NULL);
static void
nir_shader_bisect_init(void)
{
nir_shader_bisect_lo = debug_get_option_nir_shader_bisect_lo();
nir_shader_bisect_hi = debug_get_option_nir_shader_bisect_hi();
if (nir_shader_bisect_lo)
assert(strlen(nir_shader_bisect_lo) + 1 == BLAKE3_HEX_LEN);
if (nir_shader_bisect_hi)
assert(strlen(nir_shader_bisect_hi) + 1 == BLAKE3_HEX_LEN);
}
bool
nir_shader_bisect_select(nir_shader *s)
{
static once_flag once = ONCE_FLAG_INIT;
call_once(&once, nir_shader_bisect_init);
if (!nir_shader_bisect_lo && !nir_shader_bisect_hi)
return false;
char id[BLAKE3_HEX_LEN];
_mesa_blake3_format(id, s->info.source_blake3);
if (nir_shader_bisect_lo && strcmp(id, nir_shader_bisect_lo) < 0)
return false;
if (nir_shader_bisect_hi && strcmp(id, nir_shader_bisect_hi) > 0)
return false;
uint32_t u32[BLAKE3_OUT_LEN32] = { 0 };
for (unsigned i = 0; i < BLAKE3_OUT_LEN; i++)
u32[i / 4] |= (uint32_t)s->info.source_blake3[i] << ((i % 4) * 8);
/* Provide feedback of both the source_blake3 and the blake3_format id to the
* script of what shaders got affected, so it can bisect on the set of
* shaders remaining for the env vars, and print out a final blake3 when we
* get down to 1 shader.
*/
mesa_logi("NIR bisect selected source_blake3: {0x%08x, 0x%08x, 0x%08x, 0x%08x, 0x%08x, 0x%08x, 0x%08x, 0x%08x} (%s)\n",
u32[0], u32[1], u32[2], u32[3],
u32[4], u32[5], u32[6], u32[7], id);
return true;
}

View file

@ -0,0 +1,101 @@
#!/usr/bin/env python3
import argparse
import subprocess
import os
import re
def run(args, lo, hi):
print(f'NIR_SHADER_BISECT_LO: {lo}')
print(f'NIR_SHADER_BISECT_HI: {hi}')
env = os.environ.copy()
env['MESA_SHADER_CACHE_DISABLE'] = '1'
env['NIR_SHADER_BISECT_LO'] = lo
env['NIR_SHADER_BISECT_HI'] = hi
env.pop('MESA_LOG_FILE', None)
cmd = [args.cmd] + args.cmd_args
if args.debug:
print(f"running: {cmd}")
result = subprocess.run(cmd, env=env,
text=True, capture_output=True)
if args.debug:
print(f"Result: {result.returncode}")
print("stdout:")
print(result.stdout)
print("stderr:")
print(result.stderr)
shaders = set(re.findall(
"NIR bisect selected source_blake3: (.*) \((.*)\)", result.stderr))
num = len(shaders)
print(f'Shaders matched: {num}')
if num <= 5:
for blake3, id in sorted(shaders):
print(f' {blake3}')
return shaders
def was_good():
while True:
response = input('Was the previous run [g]ood or [b]ad? ')
if response in ('g', 'b'):
return response == 'g'
def bisect(args):
lo = 0
hi = (1 << (8 * 32)) - 1
lo = f'{lo:064x}'
hi = f'{hi:064x}'
# Do an initial run to sanity check that the user has actually made
# nir_shader_bisect_select() select some broken behavior.
bad = run(args, lo, hi)
if was_good():
print("Entire hash range produced a good result -- did you make nir_shader_bisect_select() select the behavior you want?")
exit(1)
if len(bad) == 0:
print("No bad shaders detected -- did you rebuild after adding nir_shader_bisect_select()?")
exit(1)
while True:
if len(bad) == 1:
for shader in bad:
print(f"Bisected to source_blake3 {shader}")
print(
f"You can now replace nir_shader_bisect_select() with _mesa_printed_blake3_equal(s->info.source_blake3, (uint32_t[]){shader})")
exit(0)
else:
num = len(bad)
print(f"Shaders remaining to bisect: {num}")
if num <= 5:
for shader in sorted(bad):
print(f' {shader}')
# Find the middle shader remaining in the set of shaders from the last
# bad run, and check if the bottom half of set of possibly-bad shaders
# (up to and including it) is bad.
ids = sorted([id for blake3, id in bad])
lo = ids[0]
split = ids[(len(ids) - 1) // 2]
cur = run(args, lo, split)
if was_good():
bad = bad.difference(cur)
else:
bad = cur
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--debug', action='store_true')
parser.add_argument('cmd')
parser.add_argument('cmd_args', nargs=argparse.REMAINDER)
args = parser.parse_args()
bisect(args)