anv: Simplify code that calls brw/jay

Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Reviewed-by: Iván Briano <ivan.briano@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41633>
This commit is contained in:
Caio Oliveira 2026-05-17 14:25:24 -07:00
parent 33475c0cce
commit ffa4bc7d6a

View file

@ -928,167 +928,20 @@ anv_fixup_subgroup_size(struct anv_device *device, nir_shader *shader)
}
static void
anv_shader_compile_vs(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *shader_data,
char **error_str)
populate_compile_params_tes(union brw_any_compile_params *params,
struct anv_shader_data *shader_data,
struct anv_shader_data *prev_shader_data)
{
const struct brw_compiler *compiler = device->physical->compiler;
const struct intel_device_info *devinfo = compiler->devinfo;
nir_shader *nir = shader_data->info->nir;
shader_data->num_stats = 1;
struct brw_compile_vs_params params = {
.base = {
.nir = nir,
.key = &shader_data->key.vs.base,
.prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.vs,
.stats = shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = shader_data->source_hash,
.archiver = shader_data->archiver,
},
};
if (intel_use_jay(devinfo, nir->info.stage)) {
struct jay_shader_bin *bin =
jay_compile(devinfo, mem_ctx, nir,
(union brw_any_prog_data *) params.base.prog_data,
(union brw_any_prog_key *) params.base.key);
shader_data->code = (void *) bin->kernel;
} else {
shader_data->code = (void *) brw_compile(compiler, &params.base);
if (prev_shader_data) {
shader_data->key.tes.inputs_read =
prev_shader_data->info->nir->info.outputs_written;
shader_data->key.tes.patch_inputs_read =
prev_shader_data->info->nir->info.patch_outputs_written;
}
*error_str = params.base.error_str;
}
static void
anv_shader_compile_tcs(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *shader_data,
char **error_str)
{
const struct brw_compiler *compiler = device->physical->compiler;
nir_shader *nir = shader_data->info->nir;
shader_data->key.tcs.outputs_written = nir->info.outputs_written;
shader_data->key.tcs.patch_outputs_written = nir->info.patch_outputs_written;
shader_data->num_stats = 1;
struct brw_compile_tcs_params params = {
.base = {
.nir = nir,
.key = &shader_data->key.tcs.base,
.prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.tcs,
.stats = shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = shader_data->source_hash,
.archiver = shader_data->archiver,
},
};
shader_data->code = (void *)brw_compile(compiler, &params.base);
*error_str = params.base.error_str;
}
static void
anv_shader_compile_tes(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *tes_shader_data,
struct anv_shader_data *tcs_shader_data,
char **error_str)
{
const struct brw_compiler *compiler = device->physical->compiler;
nir_shader *nir = tes_shader_data->info->nir;
if (tcs_shader_data) {
tes_shader_data->key.tes.inputs_read =
tcs_shader_data->info->nir->info.outputs_written;
tes_shader_data->key.tes.patch_inputs_read =
tcs_shader_data->info->nir->info.patch_outputs_written;
}
tes_shader_data->num_stats = 1;
struct brw_compile_tes_params params = {
.base = {
.nir = nir,
.key = &tes_shader_data->key.tes.base,
.prog_data = (struct brw_stage_prog_data *)&tes_shader_data->prog_data.tes,
.stats = tes_shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = tes_shader_data->source_hash,
.archiver = tes_shader_data->archiver,
},
.input_vue_map = tcs_shader_data ?
&tcs_shader_data->prog_data.tcs.base.vue_map : NULL,
};
tes_shader_data->code = (void *)brw_compile(compiler, &params.base);
*error_str = params.base.error_str;
}
static void
anv_shader_compile_gs(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *shader_data,
char **error_str)
{
const struct brw_compiler *compiler = device->physical->compiler;
nir_shader *nir = shader_data->info->nir;
shader_data->num_stats = 1;
struct brw_compile_gs_params params = {
.base = {
.nir = nir,
.key = &shader_data->key.gs.base,
.prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.gs,
.stats = shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = shader_data->source_hash,
.archiver = shader_data->archiver,
},
};
shader_data->code = (void *)brw_compile(compiler, &params.base);
*error_str = params.base.error_str;
}
static void
anv_shader_compile_task(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *shader_data,
char **error_str)
{
const struct brw_compiler *compiler = device->physical->compiler;
nir_shader *nir = shader_data->info->nir;
shader_data->num_stats = 1;
struct brw_compile_task_params params = {
.base = {
.nir = nir,
.key = &shader_data->key.task.base,
.prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.task,
.stats = shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = shader_data->source_hash,
.archiver = shader_data->archiver,
},
};
shader_data->code = (void *)brw_compile(compiler, &params.base);
*error_str = params.base.error_str;
params->tes.input_vue_map = prev_shader_data ?
&prev_shader_data->prog_data.tcs.base.vue_map :
NULL;
}
static nir_def *
@ -1131,49 +984,22 @@ wa_18019110168_load_per_primitive_remap_table(nir_builder *b, void *data)
}
static void
anv_shader_compile_mesh(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *mesh_shader_data,
struct anv_shader_data *task_shader_data,
char **error_str)
populate_compile_params_mesh(union brw_any_compile_params *params,
struct anv_shader_data *shader_data,
struct anv_shader_data *prev_shader_data)
{
const struct brw_compiler *compiler = device->physical->compiler;
nir_shader *nir = mesh_shader_data->info->nir;
mesh_shader_data->num_stats = 1;
struct brw_compile_mesh_params params = {
.base = {
.nir = nir,
.key = &mesh_shader_data->key.mesh.base,
.prog_data = (struct brw_stage_prog_data *)&mesh_shader_data->prog_data.mesh,
.stats = mesh_shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = mesh_shader_data->source_hash,
.archiver = mesh_shader_data->archiver,
},
.tue_map = task_shader_data ?
&task_shader_data->prog_data.task.map :
NULL,
.wa_18019110168_load_provoking_vertex =
wa_18019110168_load_provoking_vertex,
.wa_18019110168_data = (void *)&mesh_shader_data->bind_map,
};
mesh_shader_data->code = (void *)brw_compile(compiler, &params.base);
*error_str = params.base.error_str;
params->mesh.tue_map = prev_shader_data ?
&prev_shader_data->prog_data.task.map : NULL;
params->mesh.wa_18019110168_load_provoking_vertex =
wa_18019110168_load_provoking_vertex;
params->mesh.wa_18019110168_data = (void *)&shader_data->bind_map;
}
static void
anv_shader_compile_fs(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *shader_data,
const struct vk_graphics_pipeline_state *state,
char **error_str)
populate_compile_params_fs(union brw_any_compile_params *params,
const struct intel_device_info *devinfo,
struct anv_shader_data *shader_data)
{
const struct brw_compiler *compiler = device->physical->compiler;
const struct intel_device_info *devinfo = compiler->devinfo;
nir_shader *nir = shader_data->info->nir;
/* When using Primitive Replication for multiview, each view gets its own
@ -1182,111 +1008,22 @@ anv_shader_compile_fs(struct anv_device *device,
uint32_t pos_slots = shader_data->use_primitive_replication ?
MAX2(1, util_bitcount(shader_data->key.base.view_mask)) : 1;
/* TODO: Should we find a way to pass this to brw_compile? */
struct intel_vue_map prev_vue_map;
brw_compute_vue_map(compiler->devinfo,
brw_compute_vue_map(devinfo,
&prev_vue_map,
nir->info.inputs_read,
nir->info.separate_shader,
pos_slots);
struct brw_compile_fs_params params = {
.base = {
.nir = nir,
.key = &shader_data->key.fs.base,
.prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.fs,
.stats = shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = shader_data->source_hash,
.archiver = shader_data->archiver,
},
params->fs.mue_map = shader_data->mue_map;
.mue_map = shader_data->mue_map,
params->fs.allow_spilling = true;
params->fs.max_polygons = UCHAR_MAX;
.allow_spilling = true,
.max_polygons = UCHAR_MAX,
.wa_18019110168_load_per_primitive_remap_table_offset =
wa_18019110168_load_per_primitive_remap_table,
.wa_18019110168_data = (void *)&shader_data->bind_map,
};
if (intel_use_jay(devinfo, nir->info.stage)) {
struct jay_shader_bin *bin =
jay_compile(devinfo, mem_ctx, nir,
(union brw_any_prog_data *) params.base.prog_data,
(union brw_any_prog_key *) params.base.key);
shader_data->code = (void *) bin->kernel;
} else {
shader_data->code = (void *) brw_compile(compiler, &params.base);
}
*error_str = params.base.error_str;
shader_data->num_stats = (uint32_t)!!shader_data->prog_data.fs.dispatch_multi +
(uint32_t)shader_data->prog_data.fs.dispatch_8 +
(uint32_t)shader_data->prog_data.fs.dispatch_16 +
(uint32_t)shader_data->prog_data.fs.dispatch_32;
assert(shader_data->num_stats <= ARRAY_SIZE(shader_data->stats));
/* Update the push constant padding range now that we know the amount of
* per-primitive data delivered in the payload.
*/
for (unsigned i = 0; i < ARRAY_SIZE(shader_data->bind_map.push_ranges); i++) {
if (shader_data->bind_map.push_ranges[i].set == ANV_DESCRIPTOR_SET_PER_PRIM_PADDING) {
shader_data->bind_map.push_ranges[i].length = MAX2(
shader_data->prog_data.fs.num_per_primitive_inputs / 2,
shader_data->bind_map.push_ranges[i].length);
break;
}
}
}
static void
anv_shader_compile_cs(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *shader_data,
char **error_str)
{
const struct brw_compiler *compiler = device->physical->compiler;
const struct intel_device_info *devinfo = compiler->devinfo;
nir_shader *nir = shader_data->info->nir;
shader_data->num_stats = 1;
struct brw_compile_cs_params params = {
.base = {
.nir = nir,
.key = &shader_data->key.cs.base,
.prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.cs,
.stats = shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = shader_data->source_hash,
.archiver = shader_data->archiver,
},
};
if (intel_use_jay(devinfo, nir->info.stage)) {
struct jay_shader_bin *bin = jay_compile(devinfo, mem_ctx, nir,
(union brw_any_prog_data*)params.base.prog_data,
(union brw_any_prog_key*)params.base.key);
struct brw_cs_prog_data *prog_data =
(struct brw_cs_prog_data *)params.base.prog_data;
shader_data->code = (void*)bin->kernel;
shader_data->stats[0] = bin->stats;
prog_data->local_size[0] = nir->info.workgroup_size[0];
prog_data->local_size[1] = nir->info.workgroup_size[1];
prog_data->local_size[2] = nir->info.workgroup_size[2];
} else {
shader_data->code = (void*)brw_compile(compiler, &params.base);
}
*error_str = params.base.error_str;
params->fs.wa_18019110168_load_per_primitive_remap_table_offset =
wa_18019110168_load_per_primitive_remap_table;
params->fs.wa_18019110168_data = (void *)&shader_data->bind_map;
}
static bool
@ -1299,14 +1036,12 @@ should_remat_cb(nir_instr *instr, void *data)
}
static void
anv_shader_compile_bs(struct anv_device *device,
void *mem_ctx,
struct anv_shader_data *shader_data,
char **error_str)
populate_compile_params_bs(union brw_any_compile_params *params,
const struct intel_device_info *devinfo,
void *mem_ctx,
struct anv_shader_data *shader_data)
{
const struct brw_compiler *compiler = device->physical->compiler;
nir_shader *nir = shader_data->info->nir;
const struct intel_device_info *devinfo = compiler->devinfo;
struct brw_nir_lower_shader_calls_state lowering_state = {
.devinfo = devinfo,
@ -1343,25 +1078,8 @@ anv_shader_compile_bs(struct anv_device *device,
&shader_data->key.base, devinfo);
}
shader_data->num_stats = 1;
struct brw_compile_bs_params params = {
.base = {
.nir = nir,
.key = &shader_data->key.bs.base,
.prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.bs,
.stats = shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = shader_data->source_hash,
.archiver = shader_data->archiver,
},
.num_resume_shaders = num_resume_shaders,
.resume_shaders = resume_shaders,
};
shader_data->code = (void *)brw_compile(compiler, &params.base);
*error_str = params.base.error_str;
params->bs.num_resume_shaders = num_resume_shaders;
params->bs.resume_shaders = resume_shaders;
}
static void
@ -2290,35 +2008,47 @@ anv_shader_compile(struct vk_device *vk_device,
struct anv_shader_data *prev_shader_data =
s > 0 ? &shaders_data[s - 1] : NULL;
char *error_str = NULL;
const struct brw_compiler *compiler = device->physical->compiler;
const struct intel_device_info *devinfo = compiler->devinfo;
nir_shader *nir = shader_data->info->nir;
shader_data->num_stats = 1;
union brw_any_compile_params params = { 0 };
params.base = (struct brw_compile_params) {
.nir = nir,
.key = &shader_data->key.base,
.prog_data = &shader_data->prog_data.base,
.stats = shader_data->stats,
.log_data = device,
.mem_ctx = mem_ctx,
.source_hash = shader_data->source_hash,
.archiver = shader_data->archiver,
};
struct brw_compile_params *compile_params = &params.base;
switch (shader_data->info->stage) {
case MESA_SHADER_VERTEX:
anv_shader_compile_vs(device, mem_ctx, shader_data, &error_str);
case MESA_SHADER_GEOMETRY:
case MESA_SHADER_TASK:
case MESA_SHADER_COMPUTE:
/* Nothing to do. */
break;
case MESA_SHADER_TESS_CTRL:
anv_shader_compile_tcs(device, mem_ctx, shader_data, &error_str);
shader_data->key.tcs.outputs_written = nir->info.outputs_written;
shader_data->key.tcs.patch_outputs_written =
nir->info.patch_outputs_written;
break;
case MESA_SHADER_TESS_EVAL:
anv_shader_compile_tes(device, mem_ctx,
&shaders_data[s], prev_shader_data,
&error_str);
break;
case MESA_SHADER_GEOMETRY:
anv_shader_compile_gs(device, mem_ctx, shader_data, &error_str);
break;
case MESA_SHADER_TASK:
anv_shader_compile_task(device, mem_ctx, shader_data, &error_str);
populate_compile_params_tes(&params, shader_data,
prev_shader_data);
break;
case MESA_SHADER_MESH:
anv_shader_compile_mesh(device, mem_ctx,
&shaders_data[s], prev_shader_data,
&error_str);
populate_compile_params_mesh(&params, shader_data,
prev_shader_data);
break;
case MESA_SHADER_FRAGMENT:
anv_shader_compile_fs(device, mem_ctx, shader_data, state, &error_str);
break;
case MESA_SHADER_COMPUTE:
anv_shader_compile_cs(device, mem_ctx, shader_data, &error_str);
populate_compile_params_fs(&params, devinfo, shader_data);
break;
case MESA_SHADER_RAYGEN:
case MESA_SHADER_ANY_HIT:
@ -2326,15 +2056,57 @@ anv_shader_compile(struct vk_device *vk_device,
case MESA_SHADER_MISS:
case MESA_SHADER_INTERSECTION:
case MESA_SHADER_CALLABLE:
anv_shader_compile_bs(device, mem_ctx, shader_data, &error_str);
populate_compile_params_bs(&params, devinfo, mem_ctx,
shader_data);
break;
default:
UNREACHABLE("Invalid graphics shader stage");
}
if (intel_use_jay(devinfo, nir->info.stage)) {
struct jay_shader_bin *bin =
jay_compile(devinfo, mem_ctx, nir,
(union brw_any_prog_data *)compile_params->prog_data,
(union brw_any_prog_key *)compile_params->key);
shader_data->code = bin->kernel;
if (mesa_shader_stage_uses_workgroup(nir->info.stage)) {
struct brw_cs_prog_data *prog_data =
(struct brw_cs_prog_data *)compile_params->prog_data;
shader_data->stats[0] = bin->stats;
prog_data->local_size[0] = nir->info.workgroup_size[0];
prog_data->local_size[1] = nir->info.workgroup_size[1];
prog_data->local_size[2] = nir->info.workgroup_size[2];
}
} else {
shader_data->code = brw_compile(compiler, compile_params);
}
if (shader_data->info->stage == MESA_SHADER_FRAGMENT) {
shader_data->num_stats =
(uint32_t)!!shader_data->prog_data.fs.dispatch_multi +
(uint32_t)shader_data->prog_data.fs.dispatch_8 +
(uint32_t)shader_data->prog_data.fs.dispatch_16 +
(uint32_t)shader_data->prog_data.fs.dispatch_32;
assert(shader_data->num_stats <= ARRAY_SIZE(shader_data->stats));
/* Update the push constant padding range now that we know the amount
* of per-primitive data delivered in the payload.
*/
for (unsigned i = 0; i < ARRAY_SIZE(shader_data->bind_map.push_ranges); i++) {
if (shader_data->bind_map.push_ranges[i].set ==
ANV_DESCRIPTOR_SET_PER_PRIM_PADDING) {
shader_data->bind_map.push_ranges[i].length = MAX2(
shader_data->prog_data.fs.num_per_primitive_inputs / 2,
shader_data->bind_map.push_ranges[i].length);
break;
}
}
}
if (shader_data->code == NULL) {
if (error_str)
result = vk_errorf(device, VK_ERROR_UNKNOWN, "%s", error_str);
if (compile_params->error_str)
result = vk_errorf(device, VK_ERROR_UNKNOWN, "%s", compile_params->error_str);
else
result = vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
goto end;