diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c index 2b10baba5ae..652b160db2d 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c @@ -35,6 +35,22 @@ #include "lp_bld_coro.h" #include "lp_bld_printf.h" #include "util/u_math.h" + +static int bit_size_to_shift_size(int bit_size) +{ + switch (bit_size) { + case 64: + return 3; + default: + case 32: + return 2; + case 16: + return 1; + case 8: + return 0; + } +} + /* * combine the execution mask if there is one with the current mask. */ @@ -709,14 +725,8 @@ static void emit_load_kernel_arg(struct lp_build_nir_context *bld_base, LLVMBuilderRef builder = gallivm->builder; struct lp_build_context *bld_broad = get_int_bld(bld_base, true, bit_size); LLVMValueRef kernel_args_ptr = bld->kernel_args_ptr; - unsigned size_shift = 0; + unsigned size_shift = bit_size_to_shift_size(bit_size); struct lp_build_context *bld_offset = get_int_bld(bld_base, true, offset_bit_size); - if (bit_size == 16) - size_shift = 1; - else if (bit_size == 32) - size_shift = 2; - else if (bit_size == 64) - size_shift = 3; if (size_shift) offset = lp_build_shr(bld_offset, offset, lp_build_const_int_vec(gallivm, bld_offset->type, size_shift)); @@ -945,11 +955,7 @@ static void emit_load_ubo(struct lp_build_nir_context *bld_base, struct lp_build_context *uint_bld = &bld_base->uint_bld; struct lp_build_context *bld_broad = bit_size == 64 ? &bld_base->dbl_bld : &bld_base->base; LLVMValueRef consts_ptr = lp_build_array_get(gallivm, bld->consts_ptr, index); - unsigned size_shift = 0; - if (bit_size == 32) - size_shift = 2; - else if (bit_size == 64) - size_shift = 3; + unsigned size_shift = bit_size_to_shift_size(bit_size); if (size_shift) offset = lp_build_shr(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, size_shift)); if (bit_size == 64) { @@ -993,19 +999,22 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base, LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder; LLVMValueRef ssbo_ptr = NULL; struct lp_build_context *uint_bld = &bld_base->uint_bld; - struct lp_build_context *uint64_bld = &bld_base->uint64_bld; LLVMValueRef ssbo_limit = NULL; + struct lp_build_context *load_bld; + uint32_t shift_val = bit_size_to_shift_size(bit_size); + + load_bld = get_int_bld(bld_base, true, bit_size); if (index) { LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, LLVMBuildExtractElement(builder, index, lp_build_const_int32(gallivm, 0), "")); - ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, bit_size == 64 ? 3 : 2), ""); + ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), ""); ssbo_limit = lp_build_broadcast_scalar(uint_bld, ssbo_limit); ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, LLVMBuildExtractElement(builder, index, lp_build_const_int32(gallivm, 0), "")); } else ssbo_ptr = bld->shared_ptr; - offset = LLVMBuildAShr(gallivm->builder, offset, lp_build_const_int_vec(gallivm, uint_bld->type, bit_size == 64 ? 3 : 2), ""); + offset = LLVMBuildAShr(gallivm->builder, offset, lp_build_const_int_vec(gallivm, uint_bld->type, shift_val), ""); for (unsigned c = 0; c < nc; c++) { LLVMValueRef loop_index = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c)); LLVMValueRef exec_mask = mask_vec(bld_base); @@ -1015,7 +1024,7 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base, exec_mask = LLVMBuildAnd(builder, exec_mask, ssbo_oob_cmp, ""); } - LLVMValueRef result = lp_build_alloca(gallivm, bit_size == 64 ? uint64_bld->vec_type : uint_bld->vec_type, ""); + LLVMValueRef result = lp_build_alloca(gallivm, load_bld->vec_type, ""); struct lp_build_loop_state loop_state; lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0)); @@ -1030,8 +1039,8 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base, lp_build_if(&ifthen, gallivm, cond); LLVMValueRef scalar; - if (bit_size == 64) { - LLVMValueRef ssbo_ptr2 = LLVMBuildBitCast(builder, ssbo_ptr, LLVMPointerType(uint64_bld->elem_type, 0), ""); + if (bit_size != 32) { + LLVMValueRef ssbo_ptr2 = LLVMBuildBitCast(builder, ssbo_ptr, LLVMPointerType(load_bld->elem_type, 0), ""); scalar = lp_build_pointer_get(builder, ssbo_ptr2, loop_index); } else scalar = lp_build_pointer_get(builder, ssbo_ptr, loop_index); @@ -1044,6 +1053,10 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base, LLVMValueRef zero; if (bit_size == 64) zero = LLVMConstInt(LLVMInt64TypeInContext(gallivm->context), 0, 0); + else if (bit_size == 16) + zero = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 0, 0); + else if (bit_size == 8) + zero = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 0, 0); else zero = lp_build_const_int32(gallivm, 0); temp_res = LLVMBuildInsertElement(builder, temp_res, zero, loop_state.counter, ""); @@ -1069,16 +1082,19 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base, LLVMValueRef ssbo_ptr; struct lp_build_context *uint_bld = &bld_base->uint_bld; LLVMValueRef ssbo_limit = NULL; + struct lp_build_context *store_bld; + uint32_t shift_val = bit_size_to_shift_size(bit_size); + store_bld = get_int_bld(bld_base, true, bit_size); if (index) { LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, LLVMBuildExtractElement(builder, index, lp_build_const_int32(gallivm, 0), "")); - ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, bit_size == 64 ? 3 : 2), ""); + ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), ""); ssbo_limit = lp_build_broadcast_scalar(uint_bld, ssbo_limit); ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, LLVMBuildExtractElement(builder, index, lp_build_const_int32(gallivm, 0), "")); } else ssbo_ptr = bld->shared_ptr; - offset = lp_build_shr_imm(uint_bld, offset, bit_size == 64 ? 3 : 2); + offset = lp_build_shr_imm(uint_bld, offset, shift_val); for (unsigned c = 0; c < nc; c++) { if (!(writemask & (1u << c))) continue; @@ -1095,10 +1111,7 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base, lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0)); LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, val, loop_state.counter, ""); - if (bit_size == 64) - value_ptr = LLVMBuildBitCast(gallivm->builder, value_ptr, bld_base->uint64_bld.elem_type, ""); - else - value_ptr = LLVMBuildBitCast(gallivm->builder, value_ptr, uint_bld->elem_type, ""); + value_ptr = LLVMBuildBitCast(gallivm->builder, value_ptr, store_bld->elem_type, ""); struct lp_build_if_state ifthen; LLVMValueRef cond; @@ -1107,8 +1120,8 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base, cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, ""); cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, ""); lp_build_if(&ifthen, gallivm, cond); - if (bit_size == 64) { - LLVMValueRef ssbo_ptr2 = LLVMBuildBitCast(builder, ssbo_ptr, LLVMPointerType(bld_base->uint64_bld.elem_type, 0), ""); + if (bit_size != 32) { + LLVMValueRef ssbo_ptr2 = LLVMBuildBitCast(builder, ssbo_ptr, LLVMPointerType(store_bld->elem_type, 0), ""); lp_build_pointer_set(builder, ssbo_ptr2, loop_index, value_ptr); } else lp_build_pointer_set(builder, ssbo_ptr, loop_index, value_ptr);