diff --git a/src/compiler/nir/nir_lower_gs_intrinsics.c b/src/compiler/nir/nir_lower_gs_intrinsics.c index a9c2ffcf3e9..1d8b6ee6099 100644 --- a/src/compiler/nir/nir_lower_gs_intrinsics.c +++ b/src/compiler/nir/nir_lower_gs_intrinsics.c @@ -63,6 +63,7 @@ struct state { bool count_prims; bool count_vtx_per_prim; bool overwrite_incomplete; + bool is_points; bool progress; }; @@ -89,6 +90,8 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state) if (state->count_vtx_per_prim) count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]); + else if (state->is_points) + count_per_primitive = nir_imm_int(b, 0); else count_per_primitive = nir_ssa_undef(b, 1, 32); @@ -208,6 +211,8 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state) if (state->count_vtx_per_prim) count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]); + else if (state->is_points) + count_per_primitive = nir_imm_int(b, 0); else count_per_primitive = nir_ssa_undef(b, count->num_components, count->bit_size); @@ -290,7 +295,7 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state) if (state->per_stream && !(shader->info.gs.active_stream_mask & (1 << stream))) { /* Inactive stream: vertex count is 0, primitive count is 0 or undef. */ vtx_cnt = nir_imm_int(b, 0); - prim_cnt = state->count_prims + prim_cnt = state->count_prims || state->is_points ? nir_imm_int(b, 0) : nir_ssa_undef(b, 1, 32); } else { @@ -298,9 +303,16 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state) overwrite_incomplete_primitives(state, stream); vtx_cnt = nir_load_var(b, state->vertex_count_vars[stream]); - prim_cnt = state->count_prims - ? nir_load_var(b, state->primitive_count_vars[stream]) - : nir_ssa_undef(b, 1, 32); + + if (state->count_prims) + prim_cnt = nir_load_var(b, state->primitive_count_vars[stream]); + else if (state->is_points) + /* EndPrimitive does not affect primitive count for points, + * just use vertex count instead + */ + prim_cnt = vtx_cnt; + else + prim_cnt = nir_ssa_undef(b, 1, 32); } nir_set_vertex_and_primitive_count(b, vtx_cnt, prim_cnt, stream); @@ -361,12 +373,23 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option overwrite_incomplete || (options & nir_lower_gs_intrinsics_count_vertices_per_primitive); + bool is_points = shader->info.gs.output_primitive == SHADER_PRIM_POINTS; + /* points are always complete primitives with a single vertex, so these are + * not needed when primitive is points. + */ + if (is_points) { + count_primitives = false; + overwrite_incomplete = false; + count_vtx_per_prim = false; + } + struct state state; state.progress = false; state.count_prims = count_primitives; state.count_vtx_per_prim = count_vtx_per_prim; state.overwrite_incomplete = overwrite_incomplete; state.per_stream = per_stream; + state.is_points = is_points; nir_function_impl *impl = nir_shader_get_entrypoint(shader); assert(impl);