nir: remove workgroup_id_zero_base

This removes the need for drivers to handle both versions. The base will
get added once in nir_lower_system_values when converting from deref to
intrinsic and will be replaced by a zero for users not supporting it.

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Signed-off-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26800>
This commit is contained in:
Karol Herbst 2024-03-21 00:15:48 +01:00 committed by Marge Bot
parent 3217838fef
commit d22f936019
12 changed files with 22 additions and 30 deletions

View file

@ -3659,7 +3659,6 @@ ntq_emit_intrinsic(struct v3d_compile *c, nir_intrinsic_instr *instr)
}
break;
case nir_intrinsic_load_workgroup_id_zero_base:
case nir_intrinsic_load_workgroup_id: {
struct qreg x = vir_AND(c, c->cs_payload[0],
vir_uniform_ui(c, 0xffff));

View file

@ -401,7 +401,6 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_load_workgroup_index:
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base:
assert(gl_shader_stage_uses_workgroup(stage));
if (stage == MESA_SHADER_COMPUTE)
is_divergent |= (options & nir_divergence_multiple_workgroup_per_compute_subgroup);

View file

@ -886,10 +886,8 @@ system_value("tess_level_inner_default", 2)
system_value("patch_vertices_in", 1)
system_value("local_invocation_id", 3)
system_value("local_invocation_index", 1)
# zero_base indicates it starts from 0 for the current dispatch
# non-zero_base indicates the base is included
# workgroup_id does not include the base_workgroup_id
system_value("workgroup_id", 3)
system_value("workgroup_id_zero_base", 3)
# The workgroup_index is intended for situations when a 3 dimensional
# workgroup_id is not available on the HW, but a 1 dimensional index is.
system_value("workgroup_index", 1)
@ -927,9 +925,7 @@ system_value("num_subgroups", 1)
system_value("subgroup_id", 1)
system_value("workgroup_size", 3)
# note: the definition of global_invocation_id is based on
# (workgroup_id * workgroup_size) + local_invocation_id.
# it is *not* based on workgroup_id_zero_base, meaning the work group
# base is already accounted for, and the global base is additive on top of that
# ((workgroup_id + base_workgroup_id) * workgroup_size) + local_invocation_id.
system_value("global_invocation_id", 3, bit_sizes=[32, 64])
system_value("base_global_invocation_id", 3, bit_sizes=[32, 64])
system_value("global_invocation_index", 1, bit_sizes=[32, 64])

View file

@ -207,6 +207,11 @@ lower_system_value_instr(nir_builder *b, nir_instr *instr, void *_state)
nir_load_base_global_invocation_id(b, bit_size));
}
case SYSTEM_VALUE_WORKGROUP_ID: {
return nir_iadd(b, nir_u2uN(b, nir_load_workgroup_id(b), bit_size),
nir_load_base_workgroup_id(b, bit_size));
}
case SYSTEM_VALUE_SUBGROUP_EQ_MASK:
case SYSTEM_VALUE_SUBGROUP_GE_MASK:
case SYSTEM_VALUE_SUBGROUP_GT_MASK:
@ -683,10 +688,12 @@ lower_compute_system_value_instr(nir_builder *b,
!b->shader->options->has_cs_global_id) {
nir_def *group_size = nir_load_workgroup_size(b);
nir_def *group_id = nir_load_workgroup_id(b);
nir_def *base_group_id = nir_load_base_workgroup_id(b, bit_size);
nir_def *local_id = nir_load_local_invocation_id(b);
return nir_iadd(b, nir_imul(b, nir_u2uN(b, group_id, bit_size),
nir_u2uN(b, group_size, bit_size)),
return nir_iadd(b, nir_imul(b, nir_iadd(b, nir_u2uN(b, group_id, bit_size),
base_group_id),
nir_u2uN(b, group_size, bit_size)),
nir_u2uN(b, local_id, bit_size));
} else {
return NULL;
@ -699,6 +706,12 @@ lower_compute_system_value_instr(nir_builder *b,
return NULL;
}
case nir_intrinsic_load_base_workgroup_id: {
if (options && !options->has_base_workgroup_id)
return nir_imm_zero(b, 3, bit_size);
return NULL;
}
case nir_intrinsic_load_global_invocation_index: {
/* OpenCL's global_linear_id explicitly ignores the global offset */
assert(b->shader->info.stage == MESA_SHADER_KERNEL);
@ -716,10 +729,7 @@ lower_compute_system_value_instr(nir_builder *b,
}
case nir_intrinsic_load_workgroup_id: {
if (options && options->has_base_workgroup_id)
return nir_iadd(b, nir_u2uN(b, nir_load_workgroup_id_zero_base(b), bit_size),
nir_load_base_workgroup_id(b, bit_size));
else if (options && options->lower_workgroup_id_to_index) {
if (options && options->lower_workgroup_id_to_index) {
nir_def *wg_idx = nir_load_workgroup_index(b);
nir_def *val =

View file

@ -2483,7 +2483,6 @@ emit_intrinsic(struct ir3_context *ctx, nir_intrinsic_instr *intr)
ir3_split_dest(b, dst, ctx->local_invocation_id, 0, 3);
break;
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base:
if (ctx->compiler->has_shared_regfile) {
if (!ctx->work_group_id) {
ctx->work_group_id =

View file

@ -287,7 +287,6 @@ emit_system_values_block(nir_to_brw_state &ntb, nir_block *block)
break;
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base:
if (gl_shader_stage_is_mesh(s.stage))
unreachable("should be lowered by nir_lower_compute_system_values().");
assert(gl_shader_stage_is_compute(s.stage));
@ -4370,8 +4369,7 @@ fs_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
bld.MOV(offset(dest, bld, i), s.cs_payload().local_invocation_id[i]);
break;
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base: {
case nir_intrinsic_load_workgroup_id: {
fs_reg val = ntb.system_values[SYSTEM_VALUE_WORKGROUP_ID];
assert(val.file != BAD_FILE);
dest.type = val.type;

View file

@ -456,7 +456,7 @@ brw_nir_create_raygen_trampoline(const struct brw_compiler *compiler,
nir_def *raygen_bsr_addr =
nir_if_phi(&b, raygen_indirect_bsr_addr, raygen_param_bsr_addr);
nir_def *global_id = nir_load_workgroup_id_zero_base(&b);
nir_def *global_id = nir_load_workgroup_id(&b);
nir_def *simd_channel = nir_load_subgroup_invocation(&b);
nir_def *local_x =
nir_ubfe(&b, simd_channel, nir_imm_int(&b, 0),

View file

@ -240,7 +240,6 @@ emit_system_values_block(nir_to_elk_state &ntb, nir_block *block)
break;
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base:
assert(gl_shader_stage_is_compute(s.stage));
reg = &ntb.system_values[SYSTEM_VALUE_WORKGROUP_ID];
if (reg->file == BAD_FILE)
@ -4062,8 +4061,7 @@ fs_nir_emit_cs_intrinsic(nir_to_elk_state &ntb,
s.cs_payload().load_subgroup_id(bld, dest);
break;
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base: {
case nir_intrinsic_load_workgroup_id: {
elk_fs_reg val = ntb.system_values[SYSTEM_VALUE_WORKGROUP_ID];
assert(val.file != BAD_FILE);
dest.type = val.type;

View file

@ -317,7 +317,6 @@ clc_lower_64bit_semantics(nir_shader *nir)
case nir_intrinsic_load_base_global_invocation_id:
case nir_intrinsic_load_local_invocation_id:
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base:
case nir_intrinsic_load_base_workgroup_id:
case nir_intrinsic_load_num_workgroups:
break;

View file

@ -4792,7 +4792,6 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
case nir_intrinsic_load_local_invocation_index:
return emit_load_local_invocation_index(ctx, intr);
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base:
return emit_load_local_workgroup_id(ctx, intr);
case nir_intrinsic_load_ssbo:
return emit_load_ssbo(ctx, intr);

View file

@ -1615,7 +1615,6 @@ Converter::convert(nir_intrinsic_op intr)
case nir_intrinsic_load_vertex_id:
return SV_VERTEX_ID;
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base:
return SV_CTAID;
case nir_intrinsic_load_work_dim:
return SV_WORK_DIM;
@ -1903,7 +1902,6 @@ Converter::visit(nir_intrinsic_instr *insn)
case nir_intrinsic_load_tess_level_outer:
case nir_intrinsic_load_vertex_id:
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base:
case nir_intrinsic_load_work_dim: {
const DataType dType = getDType(insn);
SVSemantic sv = convert(op);

View file

@ -538,11 +538,8 @@ nak_nir_lower_system_value_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
case nir_intrinsic_load_helper_invocation:
case nir_intrinsic_load_invocation_id:
case nir_intrinsic_load_local_invocation_id:
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_workgroup_id_zero_base: {
case nir_intrinsic_load_workgroup_id: {
const gl_system_value sysval =
intrin->intrinsic == nir_intrinsic_load_workgroup_id_zero_base ?
SYSTEM_VALUE_WORKGROUP_ID :
nir_system_value_from_intrinsic(intrin->intrinsic);
const uint32_t idx = nak_sysval_sysval_idx(sysval);
nir_def *comps[3];