diff --git a/src/asahi/lib/agx_nir_lower_gs.c b/src/asahi/lib/agx_nir_lower_gs.c index 79c967ce391..2e3e6b44949 100644 --- a/src/asahi/lib/agx_nir_lower_gs.c +++ b/src/asahi/lib/agx_nir_lower_gs.c @@ -35,6 +35,9 @@ struct lower_gs_state { int static_count[GS_NUM_COUNTERS][MAX_VERTEX_STREAMS]; nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE]; + /* Per-input primitive stride of the output index buffer */ + unsigned max_indices; + /* The count buffer contains `count_stride_el` 32-bit words in a row for each * input primitive, for `input_primitives * count_stride_el * 4` total bytes. */ @@ -772,8 +775,7 @@ lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr, libagx_end_primitive( b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, - previous_vertices(b, state, 0, calc_unrolled_id(b)), - previous_primitives(b, state, 0, calc_unrolled_id(b)), + nir_imul_imm(b, calc_unrolled_id(b), state->max_indices), calc_unrolled_index_id(b), nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS)); } @@ -915,15 +917,24 @@ lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state) b->cursor = nir_before_instr(&intr->instr); switch (intr->intrinsic) { - case nir_intrinsic_set_vertex_and_primitive_count: - /* This instruction is mostly for the count shader, so just remove. But - * for points, we write the index buffer here so the rast shader can map. + case nir_intrinsic_set_vertex_and_primitive_count: { + /* Points write their index buffer here, other primitives write on end. We + * also pad the index buffer here for the rasterization stream. */ + struct lower_gs_state *state_ = state; if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) { lower_end_primitive(b, intr, state); } + if (nir_intrinsic_stream_id(intr) == 0 && !state_->rasterizer_discard) { + libagx_pad_index_gs(b, load_geometry_param(b, output_index_buffer), + intr->src[0].ssa, intr->src[1].ssa, + calc_unrolled_id(b), + nir_imm_int(b, state_->max_indices)); + } + break; + } case nir_intrinsic_end_primitive_with_counter: { unsigned min = verts_in_output_prim(b->shader); @@ -978,7 +989,8 @@ collect_components(nir_builder *b, nir_intrinsic_instr *intr, void *data) static nir_shader * agx_nir_create_pre_gs(struct lower_gs_state *state, bool indexed, bool restart, struct nir_xfb_info *xfb, unsigned vertices_per_prim, - uint8_t streams, unsigned invocations) + uint8_t streams, unsigned invocations, + unsigned index_buffer_allocation) { nir_builder b_ = nir_builder_init_simple_shader( MESA_SHADER_COMPUTE, &agx_nir_options, "Pre-GS patch up"); @@ -991,9 +1003,7 @@ agx_nir_create_pre_gs(struct lower_gs_state *state, bool indexed, bool restart, if (!state->rasterizer_discard) { libagx_build_gs_draw( b, nir_load_geometry_param_buffer_agx(b), - previous_vertices(b, state, 0, unrolled_in_prims), - restart ? previous_primitives(b, state, 0, unrolled_in_prims) - : nir_imm_int(b, 0)); + nir_imul_imm(b, unrolled_in_prims, index_buffer_allocation)); } /* Determine the number of primitives generated in each stream */ @@ -1197,6 +1207,29 @@ agx_nir_lower_gs_instancing(nir_shader *gs) nir_metadata_control_flow, index); } +static unsigned +calculate_max_indices(enum mesa_prim prim, unsigned verts, signed static_verts, + signed static_prims) +{ + /* We always have a static max_vertices, but we might have a tighter bound. + * Use it if we have one + */ + if (static_verts >= 0) { + verts = MIN2(verts, static_verts); + } + + /* Points do not need primitive count added. Other topologies do. If we have + * a static primitive count, we use that. Otherwise, we use a worst case + * estimate that primitives are emitted one-by-one. + */ + if (prim == MESA_PRIM_POINTS) + return verts; + else if (static_prims >= 0) + return verts + static_prims; + else + return verts + (verts / mesa_vertices_per_prim(prim)); +} + bool agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count, nir_shader **gs_copy, nir_shader **pre_gs, @@ -1289,6 +1322,12 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count, } } + /* Using the gathered static counts, choose the index buffer stride. */ + gs_state.max_indices = calculate_max_indices( + gs->info.gs.output_primitive, gs->info.gs.vertices_out, + gs_state.static_count[GS_COUNTER_VERTICES][0], + gs_state.static_count[GS_COUNTER_PRIMITIVES][0]); + bool side_effects_for_rast = false; *gs_copy = agx_nir_create_gs_rast_shader(gs, &side_effects_for_rast); @@ -1388,7 +1427,7 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count, *pre_gs = agx_nir_create_pre_gs( &gs_state, true, gs->info.gs.output_primitive != MESA_PRIM_POINTS, gs->xfb_info, verts_in_output_prim(gs), gs->info.gs.active_stream_mask, - gs->info.gs.invocations); + gs->info.gs.invocations, gs_state.max_indices); /* Signal what primitive we want to draw the GS Copy VS with */ *out_mode = gs->info.gs.output_primitive; diff --git a/src/asahi/libagx/geometry.cl b/src/asahi/libagx/geometry.cl index 2becb6574a5..9f55947f72e 100644 --- a/src/asahi/libagx/geometry.cl +++ b/src/asahi/libagx/geometry.cl @@ -566,8 +566,7 @@ libagx_setup_xfb_buffer(global struct agx_geometry_params *p, uint i) */ void libagx_end_primitive(global int *index_buffer, uint total_verts, - uint verts_in_prim, uint total_prims, - uint invocation_vertex_base, uint invocation_prim_base, + uint verts_in_prim, uint total_prims, uint index_offs, uint geometry_base, bool restart) { /* Previous verts/prims are from previous invocations plus earlier @@ -575,14 +574,15 @@ libagx_end_primitive(global int *index_buffer, uint total_verts, * subtract the count for this prim from the inclusive sum NIR gives us. */ uint previous_verts_in_invoc = (total_verts - verts_in_prim); - uint previous_verts = invocation_vertex_base + previous_verts_in_invoc; - uint previous_prims = restart ? invocation_prim_base + (total_prims - 1) : 0; + uint previous_verts = previous_verts_in_invoc; + uint previous_prims = restart ? (total_prims - 1) : 0; /* The indices are encoded as: (unrolled ID * output vertices) + vertex. */ uint index_base = geometry_base + previous_verts_in_invoc; /* Index buffer contains 1 index for each vertex and 1 for each prim */ - global int *out = &index_buffer[previous_verts + previous_prims]; + global int *out = + &index_buffer[index_offs + previous_verts + previous_prims]; /* Write out indices for the strip */ for (uint i = 0; i < verts_in_prim; ++i) { @@ -594,15 +594,20 @@ libagx_end_primitive(global int *index_buffer, uint total_verts, } void -libagx_build_gs_draw(global struct agx_geometry_params *p, uint vertices, - uint primitives) +libagx_pad_index_gs(global int *index_buffer, uint total_verts, + uint total_prims, uint id, uint alloc) +{ + for (uint i = total_verts + total_prims; i < alloc; ++i) { + index_buffer[(id * alloc) + i] = -1; + } +} + +void +libagx_build_gs_draw(global struct agx_geometry_params *p, uint indices) { global uint *descriptor = p->indirect_desc; global struct agx_geometry_state *state = p->state; - /* Setup the indirect draw descriptor */ - uint indices = vertices + primitives; /* includes restart indices */ - /* Allocate the index buffer */ uint index_buffer_offset_B = state->heap_bottom; p->output_index_buffer = @@ -610,6 +615,7 @@ libagx_build_gs_draw(global struct agx_geometry_params *p, uint vertices, state->heap_bottom += (indices * 4); assert(state->heap_bottom < state->heap_size); + /* Setup the indirect draw descriptor */ descriptor[0] = indices; /* count */ descriptor[1] = 1; /* instance count */ descriptor[2] = index_buffer_offset_B / 4; /* start */ diff --git a/src/asahi/vulkan/hk_cmd_draw.c b/src/asahi/vulkan/hk_cmd_draw.c index 9ded93a807e..0aa27c0f8db 100644 --- a/src/asahi/vulkan/hk_cmd_draw.c +++ b/src/asahi/vulkan/hk_cmd_draw.c @@ -1487,10 +1487,8 @@ hk_launch_gs_prerast(struct hk_cmd_buffer *cmd, struct hk_cs *cs, /* Pre-rast geometry shader */ hk_dispatch_with_local_size(cmd, cs, main, grid_gs, agx_workgroup(1, 1, 1)); - bool restart = cmd->state.gfx.topology != AGX_PRIMITIVE_POINTS; return agx_draw_indexed_indirect(cmd->geom_indirect, dev->heap->va->addr, - dev->heap->size, AGX_INDEX_SIZE_U32, - restart); + dev->heap->size, AGX_INDEX_SIZE_U32, true); } static struct agx_draw diff --git a/src/gallium/drivers/asahi/agx_state.c b/src/gallium/drivers/asahi/agx_state.c index 479b4ac15fd..cfe6713173a 100644 --- a/src/gallium/drivers/asahi/agx_state.c +++ b/src/gallium/drivers/asahi/agx_state.c @@ -5076,7 +5076,7 @@ agx_draw_vbo(struct pipe_context *pctx, const struct pipe_draw_info *info, info_gs = (struct pipe_draw_info){ .mode = ctx->gs->gs_output_mode, .index_size = 4, - .primitive_restart = ctx->gs->gs_output_mode != MESA_PRIM_POINTS, + .primitive_restart = true, .restart_index = ~0, .index.resource = ctx->heap, .instance_count = 1,