From 200ca690869f7e3e1b5dd6bb87cb52ff915bfd80 Mon Sep 17 00:00:00 2001 From: Alyssa Rosenzweig Date: Thu, 1 Feb 2024 23:24:38 -0400 Subject: [PATCH] asahi: support stage override in sysval lower for gs rast program. should clean up later but for now this will do Signed-off-by: Alyssa Rosenzweig Part-of: --- .../drivers/asahi/agx_nir_lower_sysvals.c | 14 +++++++++-- src/gallium/drivers/asahi/agx_state.c | 24 +++++++++++-------- src/gallium/drivers/asahi/agx_state.h | 3 ++- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/gallium/drivers/asahi/agx_nir_lower_sysvals.c b/src/gallium/drivers/asahi/agx_nir_lower_sysvals.c index 5b5919a6d43..ae07488c443 100644 --- a/src/gallium/drivers/asahi/agx_nir_lower_sysvals.c +++ b/src/gallium/drivers/asahi/agx_nir_lower_sysvals.c @@ -392,11 +392,21 @@ lay_out_uniforms(struct agx_compiled_shader *shader, struct state *state) } bool -agx_nir_lower_sysvals(nir_shader *shader, bool lower_draw_params) +agx_nir_lower_sysvals(nir_shader *shader, enum pipe_shader_type desc_stage, + bool lower_draw_params) { - return nir_shader_instructions_pass( + /* override stage for the duration on the pass. XXX: should refactor, but + * it's annoying! + */ + enum pipe_shader_type phys_stage = shader->info.stage; + shader->info.stage = desc_stage; + + bool progress = nir_shader_instructions_pass( shader, lower_sysvals, nir_metadata_block_index | nir_metadata_dominance, &lower_draw_params); + + shader->info.stage = phys_stage; + return progress; } bool diff --git a/src/gallium/drivers/asahi/agx_state.c b/src/gallium/drivers/asahi/agx_state.c index d35ba7b03f8..92043b41300 100644 --- a/src/gallium/drivers/asahi/agx_state.c +++ b/src/gallium/drivers/asahi/agx_state.c @@ -1784,7 +1784,7 @@ agx_nir_lower_stats_fs(nir_shader *s) static struct agx_compiled_shader * agx_compile_nir(struct agx_device *dev, nir_shader *nir, const struct agx_shader_key *base_key, - struct util_debug_callback *debug) + struct util_debug_callback *debug, enum pipe_shader_type stage) { struct agx_compiled_shader *compiled = CALLOC_STRUCT(agx_compiled_shader); struct util_dynarray binary; @@ -1797,7 +1797,7 @@ agx_compile_nir(struct agx_device *dev, nir_shader *nir, key.libagx = dev->libagx; key.has_scratch = true; - NIR_PASS(_, nir, agx_nir_lower_sysvals, true); + NIR_PASS(_, nir, agx_nir_lower_sysvals, stage, true); NIR_PASS(_, nir, agx_nir_layout_uniforms, compiled, &key.reserved_preamble); agx_compile_shader_nir(nir, &key, debug, &binary, &compiled->info); @@ -1910,7 +1910,7 @@ agx_compile_variant(struct agx_device *dev, struct pipe_context *pctx, nir_metadata_block_index | nir_metadata_dominance, NULL); } } else if (key->next_stage == ASAHI_VS_GS) { - NIR_PASS(_, nir, agx_nir_lower_sysvals, false); + NIR_PASS(_, nir, agx_nir_lower_sysvals, PIPE_SHADER_VERTEX, false); NIR_PASS(_, nir, agx_nir_lower_vs_before_gs, dev->libagx, key->next.gs.index_size_B, &outputs); } @@ -1926,7 +1926,7 @@ agx_compile_variant(struct agx_device *dev, struct pipe_context *pctx, /* Apply the VS key to the VS before linking it in */ NIR_PASS_V(vs, lower_vbo, key->attribs); NIR_PASS_V(vs, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL); - NIR_PASS_V(vs, agx_nir_lower_sysvals, false); + NIR_PASS_V(vs, agx_nir_lower_sysvals, PIPE_SHADER_VERTEX, false); NIR_PASS_V(nir, agx_nir_lower_tcs, vs, dev->libagx, key->index_size_B); ralloc_free(vs); @@ -2064,7 +2064,7 @@ agx_compile_variant(struct agx_device *dev, struct pipe_context *pctx, } struct agx_compiled_shader *compiled = - agx_compile_nir(dev, nir, &base_key, debug); + agx_compile_nir(dev, nir, &base_key, debug, so->type); compiled->so = so; @@ -2080,13 +2080,16 @@ agx_compile_variant(struct agx_device *dev, struct pipe_context *pctx, /* Compile auxiliary programs */ if (gs_count) { - compiled->gs_count = agx_compile_nir(dev, gs_count, &base_key, debug); + compiled->gs_count = + agx_compile_nir(dev, gs_count, &base_key, debug, so->type); compiled->gs_count->so = so; compiled->gs_count->stage = so->type; } - if (pre_gs) - compiled->pre_gs = agx_compile_nir(dev, pre_gs, &base_key, debug); + if (pre_gs) { + compiled->pre_gs = + agx_compile_nir(dev, pre_gs, &base_key, debug, PIPE_SHADER_COMPUTE); + } if (gs_copy) { struct asahi_gs_shader_key *key = &key_->gs; @@ -2103,7 +2106,8 @@ agx_compile_variant(struct agx_device *dev, struct pipe_context *pctx, base_key.vs.outputs_flat_shaded = key->outputs_flat_shaded; base_key.vs.outputs_linear_shaded = key->outputs_linear_shaded; - compiled->gs_copy = agx_compile_nir(dev, gs_copy, &base_key, debug); + compiled->gs_copy = + agx_compile_nir(dev, gs_copy, &base_key, debug, PIPE_SHADER_GEOMETRY); compiled->gs_copy->so = so; compiled->gs_copy->stage = so->type; } @@ -2799,7 +2803,7 @@ agx_build_meta_shader(struct agx_context *ctx, meta_shader_builder_t builder, struct agx_shader_key base_key = {0}; struct agx_compiled_shader *shader = - agx_compile_nir(dev, b.shader, &base_key, NULL); + agx_compile_nir(dev, b.shader, &base_key, NULL, PIPE_SHADER_COMPUTE); ralloc_free(b.shader); diff --git a/src/gallium/drivers/asahi/agx_state.h b/src/gallium/drivers/asahi/agx_state.h index 8406df2c2e5..5d8debd005b 100644 --- a/src/gallium/drivers/asahi/agx_state.h +++ b/src/gallium/drivers/asahi/agx_state.h @@ -968,7 +968,8 @@ void agx_set_ssbo_uniforms(struct agx_batch *batch, bool agx_nir_lower_point_size(nir_shader *nir, bool fixed_point_size); -bool agx_nir_lower_sysvals(nir_shader *shader, bool lower_draw_params); +bool agx_nir_lower_sysvals(nir_shader *shader, enum pipe_shader_type desc_stage, + bool lower_draw_params); bool agx_nir_layout_uniforms(nir_shader *shader, struct agx_compiled_shader *compiled,