diff --git a/src/intel/compiler/brw_nir_lower_cs_intrinsics.c b/src/intel/compiler/brw_nir_lower_cs_intrinsics.c index f96b98edd2d..2d33e13b473 100644 --- a/src/intel/compiler/brw_nir_lower_cs_intrinsics.c +++ b/src/intel/compiler/brw_nir_lower_cs_intrinsics.c @@ -27,23 +27,80 @@ struct lower_intrinsics_state { nir_shader *nir; nir_function_impl *impl; + enum gl_derivative_group derivative_group; bool progress; bool hw_generated_local_id; nir_builder builder; + + /* Per-block cached values. */ + bool computed; + nir_def *hw_index; + nir_def *local_index; + nir_def *local_id; }; static void -compute_local_index_id(nir_builder *b, - nir_shader *nir, - nir_def **local_index, - nir_def **local_id) +compute_local_index_id(struct lower_intrinsics_state *state, nir_intrinsic_instr *current) { - nir_def *subgroup_id = nir_load_subgroup_id(b); + assert(!state->computed); + state->hw_index = NULL; + state->local_index = NULL; + state->local_id = NULL; + state->computed = true; - nir_def *thread_local_id = - nir_imul(b, subgroup_id, nir_load_simd_width_intel(b)); - nir_def *channel = nir_load_subgroup_invocation(b); - nir_def *linear = nir_iadd(b, channel, thread_local_id); + nir_shader *nir = state->nir; + nir_builder *b = &state->builder; + + if (!nir->info.workgroup_size_variable) { + /* Don't calculate anything for a single invocation workgroup. */ + const uint16_t *ws = nir->info.workgroup_size; + if (ws[0] * ws[1] * ws[2] == 1) { + nir_def *zero = nir_imm_int(b, 0); + state->local_index = zero; + state->local_id = nir_replicate(b, zero, 3); + return; + } + + if (state->hw_generated_local_id) { + assert(state->derivative_group != DERIVATIVE_GROUP_QUADS); + + nir_def *local_id_vec = nir_load_local_invocation_id(b); + nir_def *local_id[3] = { nir_channel(b, local_id_vec, 0), + nir_channel(b, local_id_vec, 1), + nir_channel(b, local_id_vec, 2) }; + nir_def *size_x = nir_imm_int(b, nir->info.workgroup_size[0]); + nir_def *size_y = nir_imm_int(b, nir->info.workgroup_size[1]); + + nir_def *local_index = nir_imul(b, local_id[2], nir_imul(b, size_x, size_y)); + local_index = nir_iadd(b, local_index, nir_imul(b, local_id[1], size_x)); + local_index = nir_iadd(b, local_index, local_id[0]); + + state->local_index = local_index; + state->local_id = NULL; + return; + } + } + + /* Linear index. Depending on the heuristic or the derivative group, will + * need to be processed to become the actual local_index. + */ + nir_def *linear; + + if (nir->info.stage == MESA_SHADER_MESH || nir->info.stage == MESA_SHADER_TASK) { + /* Thread payload provides a linear index, keep track of it + * so it doesn't get removed. + */ + state->hw_index = + current->intrinsic == nir_intrinsic_load_local_invocation_index ? + ¤t->def : nir_load_local_invocation_index(b); + linear = state->hw_index; + } else { + nir_def *subgroup_id = nir_load_subgroup_id(b); + nir_def *thread_local_id = + nir_imul(b, subgroup_id, nir_load_simd_width_intel(b)); + nir_def *channel = nir_load_subgroup_invocation(b); + linear = nir_iadd(b, channel, thread_local_id); + } nir_def *size_x; nir_def *size_y; @@ -75,7 +132,7 @@ compute_local_index_id(nir_builder *b, */ nir_def *id_x, *id_y, *id_z; - switch (nir->info.cs.derivative_group) { + switch (state->derivative_group) { case DERIVATIVE_GROUP_NONE: if (nir->info.num_images == 0 && nir->info.num_textures == 0) { @@ -85,7 +142,7 @@ compute_local_index_id(nir_builder *b, */ id_x = nir_umod(b, linear, size_x); id_y = nir_umod(b, nir_udiv(b, linear, size_x), size_y); - *local_index = linear; + state->local_index = linear; } else if (!nir->info.workgroup_size_variable && nir->info.workgroup_size[1] % 4 == 0) { /* 1x4 block X-major lid order. Same as X-major except increments in @@ -116,11 +173,11 @@ compute_local_index_id(nir_builder *b, } id_z = nir_udiv(b, linear, size_xy); - *local_id = nir_vec3(b, id_x, id_y, id_z); - if (!*local_index) { - *local_index = nir_iadd(b, nir_iadd(b, id_x, - nir_imul(b, id_y, size_x)), - nir_imul(b, id_z, size_xy)); + state->local_id = nir_vec3(b, id_x, id_y, id_z); + if (!state->local_index) { + state->local_index = nir_iadd(b, nir_iadd(b, id_x, + nir_imul(b, id_y, size_x)), + nir_imul(b, id_z, size_xy)); } break; case DERIVATIVE_GROUP_LINEAR: @@ -130,8 +187,8 @@ compute_local_index_id(nir_builder *b, id_x = nir_umod(b, linear, size_x); id_y = nir_umod(b, nir_udiv(b, linear, size_x), size_y); id_z = nir_udiv(b, linear, size_xy); - *local_id = nir_vec3(b, id_x, id_y, id_z); - *local_index = linear; + state->local_id = nir_vec3(b, id_x, id_y, id_z); + state->local_index = linear; break; case DERIVATIVE_GROUP_QUADS: { /* For quads, first we figure out the 2x2 grid the invocation @@ -157,10 +214,10 @@ compute_local_index_id(nir_builder *b, nir_ishl(b, y_row_pairs, one), nir_iand(b, nir_ishr(b, row_pair_id, one), one)); - *local_id = nir_vec3(b, x, - nir_umod(b, y, size_y), - nir_udiv(b, y, size_y)); - *local_index = nir_iadd(b, x, nir_imul(b, y, size_x)); + state->local_id = nir_vec3(b, x, + nir_umod(b, y, size_y), + nir_udiv(b, y, size_y)); + state->local_index = nir_iadd(b, x, nir_imul(b, y, size_x)); break; } default: @@ -176,9 +233,8 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state, nir_builder *b = &state->builder; nir_shader *nir = state->nir; - /* Reuse calculated values inside the block. */ - nir_def *local_index = NULL; - nir_def *local_id = NULL; + /* Reset per-block definitions. */ + state->computed = false; nir_foreach_instr_safe(instr, block) { if (instr->type != nir_instr_type_intrinsic) @@ -190,56 +246,30 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state, nir_def *sysval; switch (intrinsic->intrinsic) { - case nir_intrinsic_load_local_invocation_id: - if (state->hw_generated_local_id) + case nir_intrinsic_load_local_invocation_id: { + if (!state->computed) + compute_local_index_id(state, intrinsic); + + if (!state->local_id) { + /* Will be lowered later by the backend. */ + assert(state->hw_generated_local_id); + continue; + } + + sysval = state->local_id; + break; + } + + case nir_intrinsic_load_local_invocation_index: { + if (!state->computed) + compute_local_index_id(state, intrinsic); + + /* Will be lowered later by the backend. */ + if (&intrinsic->def == state->hw_index) continue; - FALLTHROUGH; - case nir_intrinsic_load_local_invocation_index: { - if (!local_index && !nir->info.workgroup_size_variable) { - const uint16_t *ws = nir->info.workgroup_size; - if (ws[0] * ws[1] * ws[2] == 1) { - nir_def *zero = nir_imm_int(b, 0); - local_index = zero; - local_id = nir_replicate(b, zero, 3); - } - } - - if (!local_index) { - if (nir->info.stage == MESA_SHADER_TASK || - nir->info.stage == MESA_SHADER_MESH) { - /* Will be lowered by nir_emit_task_mesh_intrinsic() using - * information from the payload. - */ - continue; - } - - if (state->hw_generated_local_id) { - nir_def *local_id_vec = nir_load_local_invocation_id(b); - nir_def *local_id[3] = { nir_channel(b, local_id_vec, 0), - nir_channel(b, local_id_vec, 1), - nir_channel(b, local_id_vec, 2) }; - nir_def *size_x = nir_imm_int(b, nir->info.workgroup_size[0]); - nir_def *size_y = nir_imm_int(b, nir->info.workgroup_size[1]); - - sysval = nir_imul(b, local_id[2], nir_imul(b, size_x, size_y)); - sysval = nir_iadd(b, sysval, nir_imul(b, local_id[1], size_x)); - sysval = nir_iadd(b, sysval, local_id[0]); - local_index = sysval; - break; - } - - /* First time we are using those, so let's calculate them. */ - assert(!local_id); - compute_local_index_id(b, nir, &local_index, &local_id); - } - - assert(local_id); - assert(local_index); - if (intrinsic->intrinsic == nir_intrinsic_load_local_invocation_id) - sysval = local_id; - else - sysval = local_index; + assert(state->local_index); + sysval = state->local_index; break; } @@ -303,15 +333,16 @@ brw_nir_lower_cs_intrinsics(nir_shader *nir, struct lower_intrinsics_state state = { .nir = nir, .hw_generated_local_id = false, + .derivative_group = gl_shader_stage_is_compute(nir->info.stage) ? + nir->info.cs.derivative_group : DERIVATIVE_GROUP_NONE, }; /* Constraints from NV_compute_shader_derivatives. */ - if (gl_shader_stage_is_compute(nir->info.stage) && - !nir->info.workgroup_size_variable) { - if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_QUADS) { + if (!nir->info.workgroup_size_variable) { + if (state.derivative_group == DERIVATIVE_GROUP_QUADS) { assert(nir->info.workgroup_size[0] % 2 == 0); assert(nir->info.workgroup_size[1] % 2 == 0); - } else if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_LINEAR) { + } else if (state.derivative_group == DERIVATIVE_GROUP_LINEAR) { ASSERTED unsigned workgroup_size = nir->info.workgroup_size[0] * nir->info.workgroup_size[1] * @@ -322,7 +353,7 @@ brw_nir_lower_cs_intrinsics(nir_shader *nir, if (devinfo->verx10 >= 125 && prog_data && nir->info.stage == MESA_SHADER_COMPUTE && - nir->info.cs.derivative_group != DERIVATIVE_GROUP_QUADS && + state.derivative_group != DERIVATIVE_GROUP_QUADS && !nir->info.workgroup_size_variable && util_is_power_of_two_nonzero(nir->info.workgroup_size[0]) && util_is_power_of_two_nonzero(nir->info.workgroup_size[1])) {