diff --git a/src/gallium/frontends/clover/nir/invocation.cpp b/src/gallium/frontends/clover/nir/invocation.cpp index d493201eb54..c0bf64652f6 100644 --- a/src/gallium/frontends/clover/nir/invocation.cpp +++ b/src/gallium/frontends/clover/nir/invocation.cpp @@ -76,6 +76,7 @@ static void debug_function(void *private_data, struct clover_lower_nir_state { std::vector &args; uint32_t global_dims; + nir_variable *constant_var; nir_variable *offset_vars[3]; }; @@ -123,15 +124,34 @@ clover_lower_nir_instr(nir_builder *b, nir_instr *instr, void *_state) return nir_u2u(b, nir_vec(b, loads, state->global_dims), nir_dest_bit_size(intrinsic->dest)); } + case nir_intrinsic_load_constant_base_ptr: { + return nir_load_var(b, state->constant_var); + } + default: return NULL; } } static bool -clover_lower_nir(nir_shader *nir, std::vector &args, uint32_t dims) +clover_lower_nir(nir_shader *nir, std::vector &args, + uint32_t dims, uint32_t pointer_bit_size) { - clover_lower_nir_state state = { args, dims }; + nir_variable *constant_var = NULL; + if (nir->constant_data_size) { + const glsl_type *type = pointer_bit_size == 64 ? glsl_uint64_t_type() : glsl_uint_type(); + + constant_var = nir_variable_create(nir, nir_var_uniform, type, + "constant_buffer_addr"); + constant_var->data.location = args.size(); + + args.emplace_back(module::argument::global, + pointer_bit_size / 8, pointer_bit_size / 8, pointer_bit_size / 8, + module::argument::zero_ext, + module::argument::constant_buffer); + } + + clover_lower_nir_state state = { args, dims, constant_var }; return nir_shader_lower_instructions(nir, clover_lower_nir_filter, clover_lower_nir_instr, &state); } @@ -306,11 +326,16 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev, sysval_options.has_base_global_invocation_id = true; NIR_PASS_V(nir, nir_lower_compute_system_values, &sysval_options); - auto args = sym.args; - NIR_PASS_V(nir, clover_lower_nir, args, dev.max_block_size().size()); - + NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_mem_constant, NULL); NIR_PASS_V(nir, nir_lower_mem_constant_vars, glsl_get_cl_type_size_align); + NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant, + spirv_options.constant_addr_format); + + auto args = sym.args; + NIR_PASS_V(nir, clover_lower_nir, args, dev.max_block_size().size(), + dev.address_bits()); + NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_uniform | nir_var_mem_shared | nir_var_mem_global | nir_var_function_temp, @@ -342,6 +367,19 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev, NIR_PASS_V(nir, nir_opt_dce); + if (nir->constant_data_size) { + const char *ptr = reinterpret_cast(nir->constant_data); + const module::section constants { + section_id, + module::section::data_constant, + nir->constant_data_size, + { ptr, ptr + nir->constant_data_size } + }; + nir->constant_data = NULL; + nir->constant_data_size = 0; + m.secs.push_back(constants); + } + struct blob blob; blob_init(&blob); nir_serialize(&blob, nir, false);