diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index 1032e74d021..acfc1722396 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -1644,6 +1644,234 @@ deref_is_matrix(nir_deref_instr *deref) return NULL; } +/* rewrite all input/output variables using 32bit types and load/stores */ +static bool +lower_64bit_vars_function(nir_shader *shader, nir_function *function, nir_variable *var, struct hash_table *derefs, struct set *deletes) +{ + bool func_progress = false; + if (!function->impl) + return false; + nir_builder b; + nir_builder_init(&b, function->impl); + nir_foreach_block(block, function->impl) { + nir_foreach_instr_safe(instr, block) { + switch (instr->type) { + case nir_instr_type_deref: { + nir_deref_instr *deref = nir_instr_as_deref(instr); + if (!(deref->modes & var->data.mode)) + continue; + if (nir_deref_instr_get_variable(deref) != var) + continue; + + /* matrix types are special: store the original deref type for later use */ + const struct glsl_type *matrix = deref_is_matrix(deref); + nir_deref_instr *parent = nir_deref_instr_parent(deref); + if (!matrix) { + /* if this isn't a direct matrix deref, it's maybe a matrix row deref */ + hash_table_foreach(derefs, he) { + /* propagate parent matrix type to row deref */ + if (he->key == parent) + matrix = he->data; + } + } + if (matrix) + _mesa_hash_table_insert(derefs, deref, (void*)matrix); + if (deref->deref_type == nir_deref_type_var) + deref->type = var->type; + else + deref->type = rewrite_64bit_type(shader, deref->type, var); + } + break; + case nir_instr_type_intrinsic: { + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + if (intr->intrinsic != nir_intrinsic_store_deref && + intr->intrinsic != nir_intrinsic_load_deref) + break; + if (nir_intrinsic_get_var(intr, 0) != var) + break; + if ((intr->intrinsic == nir_intrinsic_store_deref && intr->src[1].ssa->bit_size != 64) || + (intr->intrinsic == nir_intrinsic_load_deref && intr->dest.ssa.bit_size != 64)) + break; + b.cursor = nir_before_instr(instr); + nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); + unsigned num_components = intr->num_components * 2; + nir_ssa_def *comp[NIR_MAX_VEC_COMPONENTS]; + /* this is the stored matrix type from the deref */ + struct hash_entry *he = _mesa_hash_table_search(derefs, deref); + const struct glsl_type *matrix = he ? he->data : NULL; + func_progress = true; + if (intr->intrinsic == nir_intrinsic_store_deref) { + /* first, unpack the src data to 32bit vec2 components */ + for (unsigned i = 0; i < intr->num_components; i++) { + nir_ssa_def *ssa = nir_unpack_64_2x32(&b, nir_channel(&b, intr->src[1].ssa, i)); + comp[i * 2] = nir_channel(&b, ssa, 0); + comp[i * 2 + 1] = nir_channel(&b, ssa, 1); + } + unsigned wrmask = nir_intrinsic_write_mask(intr); + unsigned mask = 0; + /* expand writemask for doubled components */ + for (unsigned i = 0; i < intr->num_components; i++) { + if (wrmask & BITFIELD_BIT(i)) + mask |= BITFIELD_BIT(i * 2) | BITFIELD_BIT(i * 2 + 1); + } + if (matrix) { + /* matrix types always come from array (row) derefs */ + assert(deref->deref_type == nir_deref_type_array); + nir_deref_instr *var_deref = nir_deref_instr_parent(deref); + /* let optimization clean up consts later */ + nir_ssa_def *index = deref->arr.index.ssa; + /* this might be an indirect array index: + * - iterate over matrix columns + * - add if blocks for each column + * - perform the store in the block + */ + for (unsigned idx = 0; idx < glsl_get_matrix_columns(matrix); idx++) { + nir_push_if(&b, nir_ieq_imm(&b, index, idx)); + unsigned vec_components = glsl_get_vector_elements(matrix); + /* always clamp dvec3 to 4 components */ + if (vec_components == 3) + vec_components = 4; + unsigned start_component = idx * vec_components * 2; + /* struct member */ + unsigned member = start_component / 4; + /* number of components remaining */ + unsigned remaining = num_components; + for (unsigned i = 0; i < num_components; member++) { + if (!(mask & BITFIELD_BIT(i))) + continue; + assert(member < glsl_get_length(var_deref->type)); + /* deref the rewritten struct to the appropriate vec4/vec2 */ + nir_deref_instr *strct = nir_build_deref_struct(&b, var_deref, member); + unsigned incr = MIN2(remaining, 4); + /* assemble the write component vec */ + nir_ssa_def *val = nir_vec(&b, &comp[i], incr); + /* use the number of components being written as the writemask */ + if (glsl_get_vector_elements(strct->type) > val->num_components) + val = nir_pad_vector(&b, val, glsl_get_vector_elements(strct->type)); + nir_store_deref(&b, strct, val, BITFIELD_MASK(incr)); + remaining -= incr; + i += incr; + } + nir_pop_if(&b, NULL); + } + _mesa_set_add(deletes, &deref->instr); + } else if (num_components <= 4) { + /* simple store case: just write out the components */ + nir_ssa_def *dest = nir_vec(&b, comp, num_components); + nir_store_deref(&b, deref, dest, mask); + } else { + /* writing > 4 components: access the struct and write to the appropriate vec4 members */ + for (unsigned i = 0; num_components; i++, num_components -= MIN2(num_components, 4)) { + if (!(mask & BITFIELD_MASK(4))) + continue; + nir_deref_instr *strct = nir_build_deref_struct(&b, deref, i); + nir_ssa_def *dest = nir_vec(&b, &comp[i * 4], MIN2(num_components, 4)); + if (glsl_get_vector_elements(strct->type) > dest->num_components) + dest = nir_pad_vector(&b, dest, glsl_get_vector_elements(strct->type)); + nir_store_deref(&b, strct, dest, mask & BITFIELD_MASK(4)); + mask >>= 4; + } + } + } else { + nir_ssa_def *dest = NULL; + if (matrix) { + /* matrix types always come from array (row) derefs */ + assert(deref->deref_type == nir_deref_type_array); + nir_deref_instr *var_deref = nir_deref_instr_parent(deref); + /* let optimization clean up consts later */ + nir_ssa_def *index = deref->arr.index.ssa; + /* this might be an indirect array index: + * - iterate over matrix columns + * - add if blocks for each column + * - phi the loads using the array index + */ + unsigned cols = glsl_get_matrix_columns(matrix); + nir_ssa_def *dests[4]; + for (unsigned idx = 0; idx < cols; idx++) { + /* don't add an if for the final row: this will be handled in the else */ + if (idx < cols - 1) + nir_push_if(&b, nir_ieq_imm(&b, index, idx)); + unsigned vec_components = glsl_get_vector_elements(matrix); + /* always clamp dvec3 to 4 components */ + if (vec_components == 3) + vec_components = 4; + unsigned start_component = idx * vec_components * 2; + /* struct member */ + unsigned member = start_component / 4; + /* number of components remaining */ + unsigned remaining = num_components; + /* component index */ + unsigned comp_idx = 0; + for (unsigned i = 0; i < num_components; member++) { + assert(member < glsl_get_length(var_deref->type)); + nir_deref_instr *strct = nir_build_deref_struct(&b, var_deref, member); + nir_ssa_def *load = nir_load_deref(&b, strct); + unsigned incr = MIN2(remaining, 4); + /* repack the loads to 64bit */ + for (unsigned c = 0; c < incr / 2; c++, comp_idx++) + comp[comp_idx] = nir_pack_64_2x32(&b, nir_channels(&b, load, BITFIELD_RANGE(c * 2, 2))); + remaining -= incr; + i += incr; + } + dest = dests[idx] = nir_vec(&b, comp, intr->num_components); + if (idx < cols - 1) + nir_push_else(&b, NULL); + } + /* loop over all the if blocks that were made, pop them, and phi the loaded+packed results */ + for (unsigned idx = cols - 1; idx >= 1; idx--) { + nir_pop_if(&b, NULL); + dest = nir_if_phi(&b, dests[idx - 1], dest); + } + _mesa_set_add(deletes, &deref->instr); + } else if (num_components <= 4) { + /* simple load case */ + nir_ssa_def *load = nir_load_deref(&b, deref); + /* pack 32bit loads into 64bit: this will automagically get optimized out later */ + for (unsigned i = 0; i < intr->num_components; i++) { + comp[i] = nir_pack_64_2x32(&b, nir_channels(&b, load, BITFIELD_RANGE(i * 2, 2))); + } + dest = nir_vec(&b, comp, intr->num_components); + } else { + /* writing > 4 components: access the struct and load the appropriate vec4 members */ + for (unsigned i = 0; i < 2; i++, num_components -= 4) { + nir_deref_instr *strct = nir_build_deref_struct(&b, deref, i); + nir_ssa_def *load = nir_load_deref(&b, strct); + comp[i * 2] = nir_pack_64_2x32(&b, nir_channels(&b, load, BITFIELD_MASK(2))); + if (num_components > 2) + comp[i * 2 + 1] = nir_pack_64_2x32(&b, nir_channels(&b, load, BITFIELD_RANGE(2, 2))); + } + dest = nir_vec(&b, comp, intr->num_components); + } + nir_ssa_def_rewrite_uses_after(&intr->dest.ssa, dest, instr); + } + _mesa_set_add(deletes, instr); + break; + } + break; + default: break; + } + } + } + if (func_progress) + nir_metadata_preserve(function->impl, nir_metadata_none); + /* derefs must be queued for deletion to avoid deleting the same deref repeatedly */ + set_foreach_remove(deletes, he) + nir_instr_remove((void*)he->key); + return func_progress; +} + +static bool +lower_64bit_vars_loop(nir_shader *shader, nir_variable *var, struct hash_table *derefs, struct set *deletes) +{ + if (!glsl_type_contains_64bit(var->type)) + return false; + var->type = rewrite_64bit_type(shader, var->type, var); + /* once type is rewritten, rewrite all loads and stores */ + nir_foreach_function(function, shader) + lower_64bit_vars_function(shader, function, var, derefs, deletes); + return true; +} + /* rewrite all input/output variables using 32bit types and load/stores */ static bool lower_64bit_vars(nir_shader *shader) @@ -1651,224 +1879,8 @@ lower_64bit_vars(nir_shader *shader) bool progress = false; struct hash_table *derefs = _mesa_hash_table_create(NULL, _mesa_hash_pointer, _mesa_key_pointer_equal); struct set *deletes = _mesa_set_create(NULL, _mesa_hash_pointer, _mesa_key_pointer_equal); - nir_foreach_variable_with_modes(var, shader, nir_var_shader_in | nir_var_shader_out) { - if (!glsl_type_contains_64bit(var->type)) - continue; - var->type = rewrite_64bit_type(shader, var->type, var); - /* once type is rewritten, rewrite all loads and stores */ - nir_foreach_function(function, shader) { - bool func_progress = false; - if (!function->impl) - continue; - nir_builder b; - nir_builder_init(&b, function->impl); - nir_foreach_block(block, function->impl) { - nir_foreach_instr_safe(instr, block) { - switch (instr->type) { - case nir_instr_type_deref: { - nir_deref_instr *deref = nir_instr_as_deref(instr); - if (!(deref->modes & (nir_var_shader_in | nir_var_shader_out))) - continue; - if (nir_deref_instr_get_variable(deref) != var) - continue; - - /* matrix types are special: store the original deref type for later use */ - const struct glsl_type *matrix = deref_is_matrix(deref); - nir_deref_instr *parent = nir_deref_instr_parent(deref); - if (!matrix) { - /* if this isn't a direct matrix deref, it's maybe a matrix row deref */ - hash_table_foreach(derefs, he) { - /* propagate parent matrix type to row deref */ - if (he->key == parent) - matrix = he->data; - } - } - if (matrix) - _mesa_hash_table_insert(derefs, deref, (void*)matrix); - if (deref->deref_type == nir_deref_type_var) - deref->type = var->type; - else - deref->type = rewrite_64bit_type(shader, deref->type, var); - } - break; - case nir_instr_type_intrinsic: { - nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); - if (intr->intrinsic != nir_intrinsic_store_deref && - intr->intrinsic != nir_intrinsic_load_deref) - break; - if (nir_intrinsic_get_var(intr, 0) != var) - break; - if ((intr->intrinsic == nir_intrinsic_store_deref && intr->src[1].ssa->bit_size != 64) || - (intr->intrinsic == nir_intrinsic_load_deref && intr->dest.ssa.bit_size != 64)) - break; - b.cursor = nir_before_instr(instr); - nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); - unsigned num_components = intr->num_components * 2; - nir_ssa_def *comp[NIR_MAX_VEC_COMPONENTS]; - /* this is the stored matrix type from the deref */ - struct hash_entry *he = _mesa_hash_table_search(derefs, deref); - const struct glsl_type *matrix = he ? he->data : NULL; - func_progress = true; - if (intr->intrinsic == nir_intrinsic_store_deref) { - /* first, unpack the src data to 32bit vec2 components */ - for (unsigned i = 0; i < intr->num_components; i++) { - nir_ssa_def *ssa = nir_unpack_64_2x32(&b, nir_channel(&b, intr->src[1].ssa, i)); - comp[i * 2] = nir_channel(&b, ssa, 0); - comp[i * 2 + 1] = nir_channel(&b, ssa, 1); - } - unsigned wrmask = nir_intrinsic_write_mask(intr); - unsigned mask = 0; - /* expand writemask for doubled components */ - for (unsigned i = 0; i < intr->num_components; i++) { - if (wrmask & BITFIELD_BIT(i)) - mask |= BITFIELD_BIT(i * 2) | BITFIELD_BIT(i * 2 + 1); - } - if (matrix) { - /* matrix types always come from array (row) derefs */ - assert(deref->deref_type == nir_deref_type_array); - nir_deref_instr *var_deref = nir_deref_instr_parent(deref); - /* let optimization clean up consts later */ - nir_ssa_def *index = deref->arr.index.ssa; - /* this might be an indirect array index: - * - iterate over matrix columns - * - add if blocks for each column - * - perform the store in the block - */ - for (unsigned idx = 0; idx < glsl_get_matrix_columns(matrix); idx++) { - nir_push_if(&b, nir_ieq_imm(&b, index, idx)); - unsigned vec_components = glsl_get_vector_elements(matrix); - /* always clamp dvec3 to 4 components */ - if (vec_components == 3) - vec_components = 4; - unsigned start_component = idx * vec_components * 2; - /* struct member */ - unsigned member = start_component / 4; - /* number of components remaining */ - unsigned remaining = num_components; - for (unsigned i = 0; i < num_components; member++) { - if (!(mask & BITFIELD_BIT(i))) - continue; - assert(member < glsl_get_length(var_deref->type)); - /* deref the rewritten struct to the appropriate vec4/vec2 */ - nir_deref_instr *strct = nir_build_deref_struct(&b, var_deref, member); - unsigned incr = MIN2(remaining, 4); - /* assemble the write component vec */ - nir_ssa_def *val = nir_vec(&b, &comp[i], incr); - /* use the number of components being written as the writemask */ - if (glsl_get_vector_elements(strct->type) > val->num_components) - val = nir_pad_vector(&b, val, glsl_get_vector_elements(strct->type)); - nir_store_deref(&b, strct, val, BITFIELD_MASK(incr)); - remaining -= incr; - i += incr; - } - nir_pop_if(&b, NULL); - } - _mesa_set_add(deletes, &deref->instr); - } else if (num_components <= 4) { - /* simple store case: just write out the components */ - nir_ssa_def *dest = nir_vec(&b, comp, num_components); - nir_store_deref(&b, deref, dest, mask); - } else { - /* writing > 4 components: access the struct and write to the appropriate vec4 members */ - for (unsigned i = 0; num_components; i++, num_components -= MIN2(num_components, 4)) { - if (!(mask & BITFIELD_MASK(4))) - continue; - nir_deref_instr *strct = nir_build_deref_struct(&b, deref, i); - nir_ssa_def *dest = nir_vec(&b, &comp[i * 4], MIN2(num_components, 4)); - if (glsl_get_vector_elements(strct->type) > dest->num_components) - dest = nir_pad_vector(&b, dest, glsl_get_vector_elements(strct->type)); - nir_store_deref(&b, strct, dest, mask & BITFIELD_MASK(4)); - mask >>= 4; - } - } - } else { - nir_ssa_def *dest = NULL; - if (matrix) { - /* matrix types always come from array (row) derefs */ - assert(deref->deref_type == nir_deref_type_array); - nir_deref_instr *var_deref = nir_deref_instr_parent(deref); - /* let optimization clean up consts later */ - nir_ssa_def *index = deref->arr.index.ssa; - /* this might be an indirect array index: - * - iterate over matrix columns - * - add if blocks for each column - * - phi the loads using the array index - */ - unsigned cols = glsl_get_matrix_columns(matrix); - nir_ssa_def *dests[4]; - for (unsigned idx = 0; idx < cols; idx++) { - /* don't add an if for the final row: this will be handled in the else */ - if (idx < cols - 1) - nir_push_if(&b, nir_ieq_imm(&b, index, idx)); - unsigned vec_components = glsl_get_vector_elements(matrix); - /* always clamp dvec3 to 4 components */ - if (vec_components == 3) - vec_components = 4; - unsigned start_component = idx * vec_components * 2; - /* struct member */ - unsigned member = start_component / 4; - /* number of components remaining */ - unsigned remaining = num_components; - /* component index */ - unsigned comp_idx = 0; - for (unsigned i = 0; i < num_components; member++) { - assert(member < glsl_get_length(var_deref->type)); - nir_deref_instr *strct = nir_build_deref_struct(&b, var_deref, member); - nir_ssa_def *load = nir_load_deref(&b, strct); - unsigned incr = MIN2(remaining, 4); - /* repack the loads to 64bit */ - for (unsigned c = 0; c < incr / 2; c++, comp_idx++) - comp[comp_idx] = nir_pack_64_2x32(&b, nir_channels(&b, load, BITFIELD_RANGE(c * 2, 2))); - remaining -= incr; - i += incr; - } - dest = dests[idx] = nir_vec(&b, comp, intr->num_components); - if (idx < cols - 1) - nir_push_else(&b, NULL); - } - /* loop over all the if blocks that were made, pop them, and phi the loaded+packed results */ - for (unsigned idx = cols - 1; idx >= 1; idx--) { - nir_pop_if(&b, NULL); - dest = nir_if_phi(&b, dests[idx - 1], dest); - } - _mesa_set_add(deletes, &deref->instr); - } else if (num_components <= 4) { - /* simple load case */ - nir_ssa_def *load = nir_load_deref(&b, deref); - /* pack 32bit loads into 64bit: this will automagically get optimized out later */ - for (unsigned i = 0; i < intr->num_components; i++) { - comp[i] = nir_pack_64_2x32(&b, nir_channels(&b, load, BITFIELD_RANGE(i * 2, 2))); - } - dest = nir_vec(&b, comp, intr->num_components); - } else { - /* writing > 4 components: access the struct and load the appropriate vec4 members */ - for (unsigned i = 0; i < 2; i++, num_components -= 4) { - nir_deref_instr *strct = nir_build_deref_struct(&b, deref, i); - nir_ssa_def *load = nir_load_deref(&b, strct); - comp[i * 2] = nir_pack_64_2x32(&b, nir_channels(&b, load, BITFIELD_MASK(2))); - if (num_components > 2) - comp[i * 2 + 1] = nir_pack_64_2x32(&b, nir_channels(&b, load, BITFIELD_RANGE(2, 2))); - } - dest = nir_vec(&b, comp, intr->num_components); - } - nir_ssa_def_rewrite_uses_after(&intr->dest.ssa, dest, instr); - } - _mesa_set_add(deletes, instr); - break; - } - break; - default: break; - } - } - } - if (func_progress) - nir_metadata_preserve(function->impl, nir_metadata_none); - /* derefs must be queued for deletion to avoid deleting the same deref repeatedly */ - set_foreach_remove(deletes, he) - nir_instr_remove((void*)he->key); - } - progress = true; - } + nir_foreach_variable_with_modes(var, shader, nir_var_shader_in | nir_var_shader_out) + progress |= lower_64bit_vars_loop(shader, var, derefs, deletes); ralloc_free(deletes); ralloc_free(derefs); if (progress) {