diff --git a/src/glsl/nir/nir.h b/src/glsl/nir/nir.h index 7117113f04e..fa326819a69 100644 --- a/src/glsl/nir/nir.h +++ b/src/glsl/nir/nir.h @@ -1653,7 +1653,7 @@ void nir_lower_vars_to_ssa(nir_shader *shader); void nir_remove_dead_variables(nir_shader *shader); -void nir_lower_vec_to_movs(nir_shader *shader); +bool nir_lower_vec_to_movs(nir_shader *shader); void nir_lower_alu_to_scalar(nir_shader *shader); void nir_lower_load_const_to_scalar(nir_shader *shader); diff --git a/src/glsl/nir/nir_lower_vec_to_movs.c b/src/glsl/nir/nir_lower_vec_to_movs.c index 25a6f7d3ad9..3f4d39d71b6 100644 --- a/src/glsl/nir/nir_lower_vec_to_movs.c +++ b/src/glsl/nir/nir_lower_vec_to_movs.c @@ -32,6 +32,11 @@ * moves with partial writes. */ +struct vec_to_movs_state { + nir_function_impl *impl; + bool progress; +}; + static bool src_matches_dest_reg(nir_dest *dest, nir_src *src) { @@ -84,8 +89,12 @@ insert_mov(nir_alu_instr *vec, unsigned start_channel, } static bool -lower_vec_to_movs_block(nir_block *block, void *shader) +lower_vec_to_movs_block(nir_block *block, void *void_state) { + struct vec_to_movs_state *state = void_state; + nir_function_impl *impl = state->impl; + nir_shader *shader = impl->overload->function->shader; + nir_foreach_instr_safe(block, instr) { if (instr->type != nir_instr_type_alu) continue; @@ -134,24 +143,31 @@ lower_vec_to_movs_block(nir_block *block, void *shader) nir_instr_remove(&vec->instr); ralloc_free(vec); + state->progress = true; } return true; } -static void +static bool nir_lower_vec_to_movs_impl(nir_function_impl *impl) { - nir_shader *shader = impl->overload->function->shader; + struct vec_to_movs_state state = { impl, false }; - nir_foreach_block(impl, lower_vec_to_movs_block, shader); + nir_foreach_block(impl, lower_vec_to_movs_block, &state); + + return state.progress; } -void +bool nir_lower_vec_to_movs(nir_shader *shader) { + bool progress = false; + nir_foreach_overload(shader, overload) { if (overload->impl) - nir_lower_vec_to_movs_impl(overload->impl); + progress = nir_lower_vec_to_movs_impl(overload->impl) || progress; } + + return progress; }