diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index 5ff7dc68205..20ae7203489 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -1564,13 +1564,13 @@ zink_compiler_assign_io(nir_shader *producer, nir_shader *consumer) /* all types that hit this function contain something that is 64bit */ static const struct glsl_type * -rewrite_64bit_type(nir_shader *nir, const struct glsl_type *type, nir_variable *var) +rewrite_64bit_type(nir_shader *nir, const struct glsl_type *type, nir_variable *var, bool doubles_only) { if (glsl_type_is_array(type)) { const struct glsl_type *child = glsl_get_array_element(type); unsigned elements = glsl_array_size(type); unsigned stride = glsl_get_explicit_stride(type); - return glsl_array_type(rewrite_64bit_type(nir, child, var), elements, stride); + return glsl_array_type(rewrite_64bit_type(nir, child, var, doubles_only), elements, stride); } /* rewrite structs recursively */ if (glsl_type_is_struct_or_ifc(type)) { @@ -1582,15 +1582,18 @@ rewrite_64bit_type(nir_shader *nir, const struct glsl_type *type, nir_variable * fields[i] = *f; xfb_offset += glsl_get_component_slots(fields[i].type) * 4; if (i < nmembers - 1 && xfb_offset % 8 && - glsl_type_contains_64bit(glsl_get_struct_field(type, i + 1))) { + (glsl_contains_double(glsl_get_struct_field(type, i + 1)) || + (glsl_type_contains_64bit(glsl_get_struct_field(type, i + 1)) && !doubles_only))) { var->data.is_xfb = true; } - fields[i].type = rewrite_64bit_type(nir, f->type, var); + fields[i].type = rewrite_64bit_type(nir, f->type, var, doubles_only); } return glsl_struct_type(fields, nmembers, glsl_get_type_name(type), glsl_struct_type_is_packed(type)); } - if (!glsl_type_is_64bit(type)) + if (!glsl_type_is_64bit(type) || (!glsl_contains_double(type) && doubles_only)) return type; + if (doubles_only && glsl_type_is_vector_or_scalar(type)) + return glsl_vector_type(GLSL_TYPE_UINT64, glsl_get_vector_elements(type)); enum glsl_base_type base_type; switch (glsl_get_base_type(type)) { case GLSL_TYPE_UINT64: @@ -1646,7 +1649,8 @@ deref_is_matrix(nir_deref_instr *deref) } static bool -lower_64bit_vars_function(nir_shader *shader, nir_function *function, nir_variable *var, struct hash_table *derefs, struct set *deletes) +lower_64bit_vars_function(nir_shader *shader, nir_function *function, nir_variable *var, + struct hash_table *derefs, struct set *deletes, bool doubles_only) { bool func_progress = false; if (!function->impl) @@ -1679,7 +1683,7 @@ lower_64bit_vars_function(nir_shader *shader, nir_function *function, nir_variab if (deref->deref_type == nir_deref_type_var) deref->type = var->type; else - deref->type = rewrite_64bit_type(shader, deref->type, var); + deref->type = rewrite_64bit_type(shader, deref->type, var, doubles_only); } break; case nir_instr_type_intrinsic: { @@ -1699,6 +1703,8 @@ lower_64bit_vars_function(nir_shader *shader, nir_function *function, nir_variab /* this is the stored matrix type from the deref */ struct hash_entry *he = _mesa_hash_table_search(derefs, deref); const struct glsl_type *matrix = he ? he->data : NULL; + if (doubles_only && !matrix) + break; func_progress = true; if (intr->intrinsic == nir_intrinsic_store_deref) { /* first, unpack the src data to 32bit vec2 components */ @@ -1861,32 +1867,33 @@ lower_64bit_vars_function(nir_shader *shader, nir_function *function, nir_variab } static bool -lower_64bit_vars_loop(nir_shader *shader, nir_variable *var, struct hash_table *derefs, struct set *deletes) +lower_64bit_vars_loop(nir_shader *shader, nir_variable *var, struct hash_table *derefs, + struct set *deletes, bool doubles_only) { - if (!glsl_type_contains_64bit(var->type)) + if (!glsl_type_contains_64bit(var->type) || (doubles_only && !glsl_contains_double(var->type))) return false; - var->type = rewrite_64bit_type(shader, var->type, var); + var->type = rewrite_64bit_type(shader, var->type, var, doubles_only); /* once type is rewritten, rewrite all loads and stores */ nir_foreach_function(function, shader) - lower_64bit_vars_function(shader, function, var, derefs, deletes); + lower_64bit_vars_function(shader, function, var, derefs, deletes, doubles_only); return true; } /* rewrite all input/output variables using 32bit types and load/stores */ static bool -lower_64bit_vars(nir_shader *shader) +lower_64bit_vars(nir_shader *shader, bool doubles_only) { bool progress = false; struct hash_table *derefs = _mesa_hash_table_create(NULL, _mesa_hash_pointer, _mesa_key_pointer_equal); struct set *deletes = _mesa_set_create(NULL, _mesa_hash_pointer, _mesa_key_pointer_equal); nir_foreach_variable_with_modes(var, shader, nir_var_shader_in | nir_var_shader_out) - progress |= lower_64bit_vars_loop(shader, var, derefs, deletes); + progress |= lower_64bit_vars_loop(shader, var, derefs, deletes, doubles_only); nir_foreach_function(function, shader) { nir_foreach_function_temp_variable(var, function->impl) { - if (!glsl_type_contains_64bit(var->type)) + if (!glsl_type_contains_64bit(var->type) || (doubles_only && !glsl_contains_double(var->type))) continue; - var->type = rewrite_64bit_type(shader, var->type, var); - progress |= lower_64bit_vars_function(shader, function, var, derefs, deletes); + var->type = rewrite_64bit_type(shader, var->type, var, doubles_only); + progress |= lower_64bit_vars_function(shader, function, var, derefs, deletes, doubles_only); } } ralloc_free(deletes); @@ -3256,7 +3263,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, ret->bindless |= bindless_lowered; if (!screen->info.feats.features.shaderInt64 || !screen->info.feats.features.shaderFloat64) - NIR_PASS_V(nir, lower_64bit_vars); + NIR_PASS_V(nir, lower_64bit_vars, screen->info.feats.features.shaderInt64); NIR_PASS_V(nir, match_tex_dests); ret->nir = nir;