diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.c b/src/gallium/auxiliary/gallivm/lp_bld_nir.c index 5292798fdd2..3cf44bfa7e1 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.c @@ -1554,8 +1554,10 @@ visit_load_ssbo(struct lp_build_nir_context *bld_base, nir_intrinsic_instr *instr, LLVMValueRef result[NIR_MAX_VEC_COMPONENTS]) { - LLVMValueRef idx = cast_type(bld_base, get_src(bld_base, instr->src[0]), - nir_type_uint, 32); + LLVMValueRef idx = get_src(bld_base, instr->src[0]); + if (nir_src_num_components(instr->src[0]) == 1) + idx = cast_type(bld_base, idx, nir_type_uint, 32); + LLVMValueRef offset = get_src(bld_base, instr->src[1]); bool index_and_offset_are_uniform = nir_src_is_always_uniform(instr->src[0]) && @@ -1571,8 +1573,11 @@ visit_store_ssbo(struct lp_build_nir_context *bld_base, nir_intrinsic_instr *instr) { LLVMValueRef val = get_src(bld_base, instr->src[0]); - LLVMValueRef idx = cast_type(bld_base, get_src(bld_base, instr->src[1]), - nir_type_uint, 32); + + LLVMValueRef idx = get_src(bld_base, instr->src[1]); + if (nir_src_num_components(instr->src[1]) == 1) + idx = cast_type(bld_base, idx, nir_type_uint, 32); + LLVMValueRef offset = get_src(bld_base, instr->src[2]); bool index_and_offset_are_uniform = nir_src_is_always_uniform(instr->src[1]) && @@ -1590,9 +1595,10 @@ visit_get_ssbo_size(struct lp_build_nir_context *bld_base, nir_intrinsic_instr *instr, LLVMValueRef result[NIR_MAX_VEC_COMPONENTS]) { - LLVMValueRef idx = cast_type(bld_base, - get_src(bld_base, instr->src[0]), - nir_type_uint, 32); + LLVMValueRef idx = get_src(bld_base, instr->src[0]); + if (nir_src_num_components(instr->src[0]) == 1) + idx = cast_type(bld_base, idx, nir_type_uint, 32); + result[0] = bld_base->get_ssbo_size(bld_base, idx); } @@ -1602,8 +1608,10 @@ visit_ssbo_atomic(struct lp_build_nir_context *bld_base, nir_intrinsic_instr *instr, LLVMValueRef result[NIR_MAX_VEC_COMPONENTS]) { - LLVMValueRef idx = cast_type(bld_base, get_src(bld_base, instr->src[0]), - nir_type_uint, 32); + LLVMValueRef idx = get_src(bld_base, instr->src[0]); + if (nir_src_num_components(instr->src[0]) == 1) + idx = cast_type(bld_base, idx, nir_type_uint, 32); + LLVMValueRef offset = get_src(bld_base, instr->src[1]); LLVMValueRef val = get_src(bld_base, instr->src[2]); LLVMValueRef val2 = NULL; diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c index 7d11b1d98e1..437a24422d2 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c @@ -1312,9 +1312,30 @@ ssbo_base_pointer(struct lp_build_nir_context *bld_base, struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; uint32_t shift_val = bit_size_to_shift_size(bit_size); - LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, invocation, ""); - LLVMValueRef ssbo_size_ptr = lp_llvm_buffer_num_elements(gallivm, bld->ssbo_ptr, ssbo_idx, LP_MAX_TGSI_SHADER_BUFFERS); - LLVMValueRef ssbo_ptr = lp_llvm_buffer_base(gallivm, bld->ssbo_ptr, ssbo_idx, LP_MAX_TGSI_SHADER_BUFFERS); + LLVMValueRef ssbo_idx; + LLVMValueRef buffers; + uint32_t buffers_limit; + if (LLVMGetTypeKind(LLVMTypeOf(index)) == LLVMArrayTypeKind) { + LLVMValueRef set = LLVMBuildExtractValue(gallivm->builder, index, 0, ""); + set = LLVMBuildExtractElement(gallivm->builder, set, invocation, ""); + + LLVMValueRef binding = LLVMBuildExtractValue(gallivm->builder, index, 1, ""); + binding = LLVMBuildExtractElement(gallivm->builder, binding, invocation, ""); + + LLVMValueRef components[2] = { set, binding }; + ssbo_idx = lp_nir_array_build_gather_values(gallivm->builder, components, 2); + + buffers = bld->consts_ptr; + buffers_limit = LP_MAX_TGSI_CONST_BUFFERS; + } else { + ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, invocation, ""); + + buffers = bld->ssbo_ptr; + buffers_limit = LP_MAX_TGSI_SHADER_BUFFERS; + } + + LLVMValueRef ssbo_size_ptr = lp_llvm_buffer_num_elements(gallivm, buffers, ssbo_idx, buffers_limit); + LLVMValueRef ssbo_ptr = lp_llvm_buffer_base(gallivm, buffers, ssbo_idx, buffers_limit); if (bounds) *bounds = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), ""); @@ -1667,17 +1688,12 @@ static void emit_barrier(struct lp_build_nir_context *bld_base) static LLVMValueRef emit_get_ssbo_size(struct lp_build_nir_context *bld_base, LLVMValueRef index) { - struct gallivm_state *gallivm = bld_base->base.gallivm; - struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; - LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder; struct lp_build_context *bld_broad = &bld_base->uint_bld; - LLVMValueRef ssbo_index = LLVMBuildExtractElement(builder, index, - first_active_invocation(bld_base), ""); - LLVMValueRef size_ptr = lp_llvm_buffer_num_elements(gallivm, bld->ssbo_ptr, - ssbo_index, - LP_MAX_TGSI_SHADER_BUFFERS); - return lp_build_broadcast_scalar(bld_broad, size_ptr); + LLVMValueRef size; + ssbo_base_pointer(bld_base, 8, index, first_active_invocation(bld_base), &size); + + return lp_build_broadcast_scalar(bld_broad, size); } static void emit_image_op(struct lp_build_nir_context *bld_base,