From ffa4bc7d6a73bec4b754a5d9b5a86606514ea79f Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Sun, 17 May 2026 14:25:24 -0700 Subject: [PATCH] anv: Simplify code that calls brw/jay MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Acked-by: Alyssa Rosenzweig Reviewed-by: Iván Briano Part-of: --- src/intel/vulkan/anv_shader_compile.c | 452 +++++++------------------- 1 file changed, 112 insertions(+), 340 deletions(-) diff --git a/src/intel/vulkan/anv_shader_compile.c b/src/intel/vulkan/anv_shader_compile.c index c3264fae901..fd2781cc042 100644 --- a/src/intel/vulkan/anv_shader_compile.c +++ b/src/intel/vulkan/anv_shader_compile.c @@ -928,167 +928,20 @@ anv_fixup_subgroup_size(struct anv_device *device, nir_shader *shader) } static void -anv_shader_compile_vs(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *shader_data, - char **error_str) +populate_compile_params_tes(union brw_any_compile_params *params, + struct anv_shader_data *shader_data, + struct anv_shader_data *prev_shader_data) { - const struct brw_compiler *compiler = device->physical->compiler; - const struct intel_device_info *devinfo = compiler->devinfo; - nir_shader *nir = shader_data->info->nir; - - shader_data->num_stats = 1; - - struct brw_compile_vs_params params = { - .base = { - .nir = nir, - .key = &shader_data->key.vs.base, - .prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.vs, - .stats = shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = shader_data->source_hash, - .archiver = shader_data->archiver, - }, - }; - - if (intel_use_jay(devinfo, nir->info.stage)) { - struct jay_shader_bin *bin = - jay_compile(devinfo, mem_ctx, nir, - (union brw_any_prog_data *) params.base.prog_data, - (union brw_any_prog_key *) params.base.key); - - shader_data->code = (void *) bin->kernel; - } else { - shader_data->code = (void *) brw_compile(compiler, ¶ms.base); + if (prev_shader_data) { + shader_data->key.tes.inputs_read = + prev_shader_data->info->nir->info.outputs_written; + shader_data->key.tes.patch_inputs_read = + prev_shader_data->info->nir->info.patch_outputs_written; } - *error_str = params.base.error_str; -} - -static void -anv_shader_compile_tcs(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *shader_data, - char **error_str) -{ - const struct brw_compiler *compiler = device->physical->compiler; - nir_shader *nir = shader_data->info->nir; - - shader_data->key.tcs.outputs_written = nir->info.outputs_written; - shader_data->key.tcs.patch_outputs_written = nir->info.patch_outputs_written; - - shader_data->num_stats = 1; - - struct brw_compile_tcs_params params = { - .base = { - .nir = nir, - .key = &shader_data->key.tcs.base, - .prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.tcs, - .stats = shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = shader_data->source_hash, - .archiver = shader_data->archiver, - }, - }; - - shader_data->code = (void *)brw_compile(compiler, ¶ms.base); - *error_str = params.base.error_str; -} - -static void -anv_shader_compile_tes(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *tes_shader_data, - struct anv_shader_data *tcs_shader_data, - char **error_str) -{ - const struct brw_compiler *compiler = device->physical->compiler; - nir_shader *nir = tes_shader_data->info->nir; - - if (tcs_shader_data) { - tes_shader_data->key.tes.inputs_read = - tcs_shader_data->info->nir->info.outputs_written; - tes_shader_data->key.tes.patch_inputs_read = - tcs_shader_data->info->nir->info.patch_outputs_written; - } - - tes_shader_data->num_stats = 1; - - struct brw_compile_tes_params params = { - .base = { - .nir = nir, - .key = &tes_shader_data->key.tes.base, - .prog_data = (struct brw_stage_prog_data *)&tes_shader_data->prog_data.tes, - .stats = tes_shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = tes_shader_data->source_hash, - .archiver = tes_shader_data->archiver, - }, - .input_vue_map = tcs_shader_data ? - &tcs_shader_data->prog_data.tcs.base.vue_map : NULL, - }; - - tes_shader_data->code = (void *)brw_compile(compiler, ¶ms.base); - *error_str = params.base.error_str; -} - -static void -anv_shader_compile_gs(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *shader_data, - char **error_str) -{ - const struct brw_compiler *compiler = device->physical->compiler; - nir_shader *nir = shader_data->info->nir; - - shader_data->num_stats = 1; - - struct brw_compile_gs_params params = { - .base = { - .nir = nir, - .key = &shader_data->key.gs.base, - .prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.gs, - .stats = shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = shader_data->source_hash, - .archiver = shader_data->archiver, - }, - }; - - shader_data->code = (void *)brw_compile(compiler, ¶ms.base); - *error_str = params.base.error_str; -} - -static void -anv_shader_compile_task(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *shader_data, - char **error_str) -{ - const struct brw_compiler *compiler = device->physical->compiler; - nir_shader *nir = shader_data->info->nir; - - shader_data->num_stats = 1; - - struct brw_compile_task_params params = { - .base = { - .nir = nir, - .key = &shader_data->key.task.base, - .prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.task, - .stats = shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = shader_data->source_hash, - .archiver = shader_data->archiver, - }, - }; - - shader_data->code = (void *)brw_compile(compiler, ¶ms.base); - *error_str = params.base.error_str; + params->tes.input_vue_map = prev_shader_data ? + &prev_shader_data->prog_data.tcs.base.vue_map : + NULL; } static nir_def * @@ -1131,49 +984,22 @@ wa_18019110168_load_per_primitive_remap_table(nir_builder *b, void *data) } static void -anv_shader_compile_mesh(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *mesh_shader_data, - struct anv_shader_data *task_shader_data, - char **error_str) +populate_compile_params_mesh(union brw_any_compile_params *params, + struct anv_shader_data *shader_data, + struct anv_shader_data *prev_shader_data) { - const struct brw_compiler *compiler = device->physical->compiler; - nir_shader *nir = mesh_shader_data->info->nir; - - mesh_shader_data->num_stats = 1; - - struct brw_compile_mesh_params params = { - .base = { - .nir = nir, - .key = &mesh_shader_data->key.mesh.base, - .prog_data = (struct brw_stage_prog_data *)&mesh_shader_data->prog_data.mesh, - .stats = mesh_shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = mesh_shader_data->source_hash, - .archiver = mesh_shader_data->archiver, - }, - .tue_map = task_shader_data ? - &task_shader_data->prog_data.task.map : - NULL, - .wa_18019110168_load_provoking_vertex = - wa_18019110168_load_provoking_vertex, - .wa_18019110168_data = (void *)&mesh_shader_data->bind_map, - }; - - mesh_shader_data->code = (void *)brw_compile(compiler, ¶ms.base); - *error_str = params.base.error_str; + params->mesh.tue_map = prev_shader_data ? + &prev_shader_data->prog_data.task.map : NULL; + params->mesh.wa_18019110168_load_provoking_vertex = + wa_18019110168_load_provoking_vertex; + params->mesh.wa_18019110168_data = (void *)&shader_data->bind_map; } static void -anv_shader_compile_fs(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *shader_data, - const struct vk_graphics_pipeline_state *state, - char **error_str) +populate_compile_params_fs(union brw_any_compile_params *params, + const struct intel_device_info *devinfo, + struct anv_shader_data *shader_data) { - const struct brw_compiler *compiler = device->physical->compiler; - const struct intel_device_info *devinfo = compiler->devinfo; nir_shader *nir = shader_data->info->nir; /* When using Primitive Replication for multiview, each view gets its own @@ -1182,111 +1008,22 @@ anv_shader_compile_fs(struct anv_device *device, uint32_t pos_slots = shader_data->use_primitive_replication ? MAX2(1, util_bitcount(shader_data->key.base.view_mask)) : 1; + /* TODO: Should we find a way to pass this to brw_compile? */ struct intel_vue_map prev_vue_map; - brw_compute_vue_map(compiler->devinfo, + brw_compute_vue_map(devinfo, &prev_vue_map, nir->info.inputs_read, nir->info.separate_shader, pos_slots); - struct brw_compile_fs_params params = { - .base = { - .nir = nir, - .key = &shader_data->key.fs.base, - .prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.fs, - .stats = shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = shader_data->source_hash, - .archiver = shader_data->archiver, - }, + params->fs.mue_map = shader_data->mue_map; - .mue_map = shader_data->mue_map, + params->fs.allow_spilling = true; + params->fs.max_polygons = UCHAR_MAX; - .allow_spilling = true, - .max_polygons = UCHAR_MAX, - - .wa_18019110168_load_per_primitive_remap_table_offset = - wa_18019110168_load_per_primitive_remap_table, - .wa_18019110168_data = (void *)&shader_data->bind_map, - }; - - if (intel_use_jay(devinfo, nir->info.stage)) { - struct jay_shader_bin *bin = - jay_compile(devinfo, mem_ctx, nir, - (union brw_any_prog_data *) params.base.prog_data, - (union brw_any_prog_key *) params.base.key); - - shader_data->code = (void *) bin->kernel; - } else { - shader_data->code = (void *) brw_compile(compiler, ¶ms.base); - } - - *error_str = params.base.error_str; - - shader_data->num_stats = (uint32_t)!!shader_data->prog_data.fs.dispatch_multi + - (uint32_t)shader_data->prog_data.fs.dispatch_8 + - (uint32_t)shader_data->prog_data.fs.dispatch_16 + - (uint32_t)shader_data->prog_data.fs.dispatch_32; - assert(shader_data->num_stats <= ARRAY_SIZE(shader_data->stats)); - - /* Update the push constant padding range now that we know the amount of - * per-primitive data delivered in the payload. - */ - for (unsigned i = 0; i < ARRAY_SIZE(shader_data->bind_map.push_ranges); i++) { - if (shader_data->bind_map.push_ranges[i].set == ANV_DESCRIPTOR_SET_PER_PRIM_PADDING) { - shader_data->bind_map.push_ranges[i].length = MAX2( - shader_data->prog_data.fs.num_per_primitive_inputs / 2, - shader_data->bind_map.push_ranges[i].length); - break; - } - } -} - -static void -anv_shader_compile_cs(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *shader_data, - char **error_str) -{ - const struct brw_compiler *compiler = device->physical->compiler; - const struct intel_device_info *devinfo = compiler->devinfo; - nir_shader *nir = shader_data->info->nir; - - shader_data->num_stats = 1; - - struct brw_compile_cs_params params = { - .base = { - .nir = nir, - .key = &shader_data->key.cs.base, - .prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.cs, - .stats = shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = shader_data->source_hash, - .archiver = shader_data->archiver, - }, - }; - - if (intel_use_jay(devinfo, nir->info.stage)) { - struct jay_shader_bin *bin = jay_compile(devinfo, mem_ctx, nir, - (union brw_any_prog_data*)params.base.prog_data, - (union brw_any_prog_key*)params.base.key); - - struct brw_cs_prog_data *prog_data = - (struct brw_cs_prog_data *)params.base.prog_data; - - shader_data->code = (void*)bin->kernel; - shader_data->stats[0] = bin->stats; - - prog_data->local_size[0] = nir->info.workgroup_size[0]; - prog_data->local_size[1] = nir->info.workgroup_size[1]; - prog_data->local_size[2] = nir->info.workgroup_size[2]; - } else { - shader_data->code = (void*)brw_compile(compiler, ¶ms.base); - } - - *error_str = params.base.error_str; + params->fs.wa_18019110168_load_per_primitive_remap_table_offset = + wa_18019110168_load_per_primitive_remap_table; + params->fs.wa_18019110168_data = (void *)&shader_data->bind_map; } static bool @@ -1299,14 +1036,12 @@ should_remat_cb(nir_instr *instr, void *data) } static void -anv_shader_compile_bs(struct anv_device *device, - void *mem_ctx, - struct anv_shader_data *shader_data, - char **error_str) +populate_compile_params_bs(union brw_any_compile_params *params, + const struct intel_device_info *devinfo, + void *mem_ctx, + struct anv_shader_data *shader_data) { - const struct brw_compiler *compiler = device->physical->compiler; nir_shader *nir = shader_data->info->nir; - const struct intel_device_info *devinfo = compiler->devinfo; struct brw_nir_lower_shader_calls_state lowering_state = { .devinfo = devinfo, @@ -1343,25 +1078,8 @@ anv_shader_compile_bs(struct anv_device *device, &shader_data->key.base, devinfo); } - shader_data->num_stats = 1; - - struct brw_compile_bs_params params = { - .base = { - .nir = nir, - .key = &shader_data->key.bs.base, - .prog_data = (struct brw_stage_prog_data *)&shader_data->prog_data.bs, - .stats = shader_data->stats, - .log_data = device, - .mem_ctx = mem_ctx, - .source_hash = shader_data->source_hash, - .archiver = shader_data->archiver, - }, - .num_resume_shaders = num_resume_shaders, - .resume_shaders = resume_shaders, - }; - - shader_data->code = (void *)brw_compile(compiler, ¶ms.base); - *error_str = params.base.error_str; + params->bs.num_resume_shaders = num_resume_shaders; + params->bs.resume_shaders = resume_shaders; } static void @@ -2290,35 +2008,47 @@ anv_shader_compile(struct vk_device *vk_device, struct anv_shader_data *prev_shader_data = s > 0 ? &shaders_data[s - 1] : NULL; - char *error_str = NULL; + const struct brw_compiler *compiler = device->physical->compiler; + const struct intel_device_info *devinfo = compiler->devinfo; + nir_shader *nir = shader_data->info->nir; + + shader_data->num_stats = 1; + + union brw_any_compile_params params = { 0 }; + params.base = (struct brw_compile_params) { + .nir = nir, + .key = &shader_data->key.base, + .prog_data = &shader_data->prog_data.base, + .stats = shader_data->stats, + .log_data = device, + .mem_ctx = mem_ctx, + .source_hash = shader_data->source_hash, + .archiver = shader_data->archiver, + }; + struct brw_compile_params *compile_params = ¶ms.base; + switch (shader_data->info->stage) { case MESA_SHADER_VERTEX: - anv_shader_compile_vs(device, mem_ctx, shader_data, &error_str); + case MESA_SHADER_GEOMETRY: + case MESA_SHADER_TASK: + case MESA_SHADER_COMPUTE: + /* Nothing to do. */ break; case MESA_SHADER_TESS_CTRL: - anv_shader_compile_tcs(device, mem_ctx, shader_data, &error_str); + shader_data->key.tcs.outputs_written = nir->info.outputs_written; + shader_data->key.tcs.patch_outputs_written = + nir->info.patch_outputs_written; break; case MESA_SHADER_TESS_EVAL: - anv_shader_compile_tes(device, mem_ctx, - &shaders_data[s], prev_shader_data, - &error_str); - break; - case MESA_SHADER_GEOMETRY: - anv_shader_compile_gs(device, mem_ctx, shader_data, &error_str); - break; - case MESA_SHADER_TASK: - anv_shader_compile_task(device, mem_ctx, shader_data, &error_str); + populate_compile_params_tes(¶ms, shader_data, + prev_shader_data); break; case MESA_SHADER_MESH: - anv_shader_compile_mesh(device, mem_ctx, - &shaders_data[s], prev_shader_data, - &error_str); + populate_compile_params_mesh(¶ms, shader_data, + prev_shader_data); break; case MESA_SHADER_FRAGMENT: - anv_shader_compile_fs(device, mem_ctx, shader_data, state, &error_str); - break; - case MESA_SHADER_COMPUTE: - anv_shader_compile_cs(device, mem_ctx, shader_data, &error_str); + populate_compile_params_fs(¶ms, devinfo, shader_data); break; case MESA_SHADER_RAYGEN: case MESA_SHADER_ANY_HIT: @@ -2326,15 +2056,57 @@ anv_shader_compile(struct vk_device *vk_device, case MESA_SHADER_MISS: case MESA_SHADER_INTERSECTION: case MESA_SHADER_CALLABLE: - anv_shader_compile_bs(device, mem_ctx, shader_data, &error_str); + populate_compile_params_bs(¶ms, devinfo, mem_ctx, + shader_data); break; default: UNREACHABLE("Invalid graphics shader stage"); } + if (intel_use_jay(devinfo, nir->info.stage)) { + struct jay_shader_bin *bin = + jay_compile(devinfo, mem_ctx, nir, + (union brw_any_prog_data *)compile_params->prog_data, + (union brw_any_prog_key *)compile_params->key); + shader_data->code = bin->kernel; + + if (mesa_shader_stage_uses_workgroup(nir->info.stage)) { + struct brw_cs_prog_data *prog_data = + (struct brw_cs_prog_data *)compile_params->prog_data; + shader_data->stats[0] = bin->stats; + prog_data->local_size[0] = nir->info.workgroup_size[0]; + prog_data->local_size[1] = nir->info.workgroup_size[1]; + prog_data->local_size[2] = nir->info.workgroup_size[2]; + } + } else { + shader_data->code = brw_compile(compiler, compile_params); + } + + if (shader_data->info->stage == MESA_SHADER_FRAGMENT) { + shader_data->num_stats = + (uint32_t)!!shader_data->prog_data.fs.dispatch_multi + + (uint32_t)shader_data->prog_data.fs.dispatch_8 + + (uint32_t)shader_data->prog_data.fs.dispatch_16 + + (uint32_t)shader_data->prog_data.fs.dispatch_32; + assert(shader_data->num_stats <= ARRAY_SIZE(shader_data->stats)); + + /* Update the push constant padding range now that we know the amount + * of per-primitive data delivered in the payload. + */ + for (unsigned i = 0; i < ARRAY_SIZE(shader_data->bind_map.push_ranges); i++) { + if (shader_data->bind_map.push_ranges[i].set == + ANV_DESCRIPTOR_SET_PER_PRIM_PADDING) { + shader_data->bind_map.push_ranges[i].length = MAX2( + shader_data->prog_data.fs.num_per_primitive_inputs / 2, + shader_data->bind_map.push_ranges[i].length); + break; + } + } + } + if (shader_data->code == NULL) { - if (error_str) - result = vk_errorf(device, VK_ERROR_UNKNOWN, "%s", error_str); + if (compile_params->error_str) + result = vk_errorf(device, VK_ERROR_UNKNOWN, "%s", compile_params->error_str); else result = vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY); goto end;