diff --git a/src/gallium/auxiliary/gallivm/lp_bld_gather.c b/src/gallium/auxiliary/gallivm/lp_bld_gather.c index 2f2506803cf..b93251b4444 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_gather.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_gather.c @@ -598,3 +598,50 @@ lp_build_gather_values(struct gallivm_state * gallivm, } return vec; } + +LLVMValueRef +lp_build_masked_gather(struct gallivm_state *gallivm, + unsigned length, + unsigned bit_size, + LLVMTypeRef vec_type, + LLVMValueRef offset_ptr, + LLVMValueRef exec_mask) +{ + LLVMBuilderRef builder = gallivm->builder; + LLVMValueRef args[4]; + char intrin_name[64]; + + snprintf(intrin_name, 64, "llvm.masked.gather.v%ui%u.v%up0i%u", + length, bit_size, length, bit_size); + args[0] = offset_ptr; + args[1] = lp_build_const_int32(gallivm, bit_size / 8); + args[2] = LLVMBuildICmp(builder, LLVMIntNE, exec_mask, + LLVMConstNull(LLVMTypeOf(exec_mask)), ""); + args[3] = LLVMConstNull(vec_type); + return lp_build_intrinsic(builder, intrin_name, vec_type, + args, 4, 0); + +} + +void +lp_build_masked_scatter(struct gallivm_state *gallivm, + unsigned length, + unsigned bit_size, + LLVMValueRef offset_ptr, + LLVMValueRef value_vec, + LLVMValueRef exec_mask) +{ + LLVMBuilderRef builder = gallivm->builder; + LLVMValueRef args[4]; + char intrin_name[64]; + + snprintf(intrin_name, 64, "llvm.masked.scatter.v%ui%u.v%up0i%u", + length, bit_size, length, bit_size); + args[0] = value_vec; + args[1] = offset_ptr; + args[2] = lp_build_const_int32(gallivm, bit_size / 8); + args[3] = LLVMBuildICmp(builder, LLVMIntNE, exec_mask, + LLVMConstNull(LLVMTypeOf(exec_mask)), ""); + lp_build_intrinsic(builder, intrin_name, LLVMVoidTypeInContext(gallivm->context), + args, 4, 0); +} diff --git a/src/gallium/auxiliary/gallivm/lp_bld_gather.h b/src/gallium/auxiliary/gallivm/lp_bld_gather.h index 7930864e611..5fabed956ca 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_gather.h +++ b/src/gallium/auxiliary/gallivm/lp_bld_gather.h @@ -66,4 +66,20 @@ lp_build_gather_values(struct gallivm_state * gallivm, LLVMValueRef * values, unsigned value_count); +LLVMValueRef +lp_build_masked_gather(struct gallivm_state *gallivm, + unsigned length, + unsigned bit_size, + LLVMTypeRef vec_type, + LLVMValueRef offset_ptr, + LLVMValueRef exec_mask); + +void +lp_build_masked_scatter(struct gallivm_state *gallivm, + unsigned length, + unsigned bit_size, + LLVMValueRef offset_ptr, + LLVMValueRef value_vec, + LLVMValueRef exec_mask); + #endif /* LP_BLD_GATHER_H_ */ diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c index 443bf4fea8a..57c953a8d3b 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c @@ -822,6 +822,41 @@ static LLVMValueRef global_addr_to_ptr(struct gallivm_state *gallivm, LLVMValueR return addr_ptr; } +static LLVMValueRef global_addr_to_ptr_vec(struct gallivm_state *gallivm, LLVMValueRef addr_ptr, unsigned length, unsigned bit_size) +{ + LLVMBuilderRef builder = gallivm->builder; + switch (bit_size) { + case 8: + addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), length), ""); + break; + case 16: + addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt16TypeInContext(gallivm->context), 0), length), ""); + break; + case 32: + default: + addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt32TypeInContext(gallivm->context), 0), length), ""); + break; + case 64: + addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt64TypeInContext(gallivm->context), 0), length), ""); + break; + } + return addr_ptr; +} + +static LLVMValueRef lp_vec_add_offset_ptr(struct lp_build_nir_context *bld_base, + unsigned bit_size, + LLVMValueRef ptr, + LLVMValueRef offset) +{ + struct gallivm_state *gallivm = bld_base->base.gallivm; + LLVMBuilderRef builder = gallivm->builder; + struct lp_build_context *uint_bld = &bld_base->uint_bld; + LLVMValueRef result = LLVMBuildPtrToInt(builder, ptr, bld_base->uint64_bld.vec_type, ""); + offset = LLVMBuildZExt(builder, offset, bld_base->uint64_bld.vec_type, ""); + result = LLVMBuildAdd(builder, offset, result, ""); + return global_addr_to_ptr_vec(gallivm, result, uint_bld->type.length, bit_size); +} + static void emit_load_global(struct lp_build_nir_context *bld_base, unsigned nc, unsigned bit_size, @@ -855,30 +890,14 @@ static void emit_load_global(struct lp_build_nir_context *bld_base, } for (unsigned c = 0; c < nc; c++) { - LLVMValueRef result = lp_build_alloca(gallivm, res_bld->vec_type, ""); - struct lp_build_loop_state loop_state; - lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0)); + LLVMValueRef chan_offset = lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8)); - struct lp_build_if_state ifthen; - LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, ""); - cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, ""); - lp_build_if(&ifthen, gallivm, cond); - - LLVMValueRef addr_ptr = LLVMBuildExtractElement(gallivm->builder, addr, - loop_state.counter, ""); - addr_ptr = global_addr_to_ptr(gallivm, addr_ptr, bit_size); - - LLVMValueRef value_ptr = lp_build_pointer_get2(builder, res_bld->elem_type, - addr_ptr, lp_build_const_int32(gallivm, c)); - - LLVMValueRef temp_res; - temp_res = LLVMBuildLoad2(builder, res_bld->vec_type, result, ""); - temp_res = LLVMBuildInsertElement(builder, temp_res, value_ptr, loop_state.counter, ""); - LLVMBuildStore(builder, temp_res, result); - lp_build_endif(&ifthen); - lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length), - NULL, LLVMIntUGE); - outval[c] = LLVMBuildLoad2(builder, res_bld->vec_type, result, ""); + outval[c] = lp_build_masked_gather(gallivm, res_bld->type.length, + bit_size, + res_bld->vec_type, + lp_vec_add_offset_ptr(bld_base, bit_size, addr, chan_offset), + exec_mask); + outval[c] = LLVMBuildBitCast(builder, outval[c], res_bld->vec_type, ""); } } @@ -898,40 +917,14 @@ static void emit_store_global(struct lp_build_nir_context *bld_base, if (!(writemask & (1u << c))) continue; LLVMValueRef val = (nc == 1) ? dst : LLVMBuildExtractValue(builder, dst, c, ""); + LLVMValueRef chan_offset = lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8)); - struct lp_build_loop_state loop_state; - lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0)); - LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, val, - loop_state.counter, ""); - - LLVMValueRef addr_ptr = LLVMBuildExtractElement(gallivm->builder, addr, - loop_state.counter, ""); - addr_ptr = global_addr_to_ptr(gallivm, addr_ptr, bit_size); - switch (bit_size) { - case 8: - value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt8TypeInContext(gallivm->context), ""); - break; - case 16: - value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt16TypeInContext(gallivm->context), ""); - break; - case 32: - value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt32TypeInContext(gallivm->context), ""); - break; - case 64: - value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt64TypeInContext(gallivm->context), ""); - break; - default: - break; - } - struct lp_build_if_state ifthen; - - LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, ""); - cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, ""); - lp_build_if(&ifthen, gallivm, cond); - lp_build_pointer_set(builder, addr_ptr, lp_build_const_int32(gallivm, c), value_ptr); - lp_build_endif(&ifthen); - lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length), - NULL, LLVMIntUGE); + struct lp_build_context *out_bld = get_int_bld(bld_base, false, bit_size); + val = LLVMBuildBitCast(builder, val, out_bld->vec_type, ""); + lp_build_masked_scatter(gallivm, out_bld->type.length, bit_size, + lp_vec_add_offset_ptr(bld_base, bit_size, + addr, chan_offset), + val, exec_mask); } } @@ -2616,46 +2609,25 @@ emit_load_scratch(struct lp_build_nir_context *bld_base, struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; struct lp_build_context *uint_bld = &bld_base->uint_bld; struct lp_build_context *load_bld; - LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);; - uint32_t shift_val = bit_size_to_shift_size(bit_size); + LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size); LLVMValueRef exec_mask = mask_vec(bld_base); - + LLVMValueRef scratch_ptr_vec = lp_build_broadcast(gallivm, + LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), uint_bld->type.length), + bld->scratch_ptr); load_bld = get_int_bld(bld_base, true, bit_size); offset = lp_build_add(uint_bld, offset, thread_offsets); - offset = lp_build_shr_imm(uint_bld, offset, shift_val); + for (unsigned c = 0; c < nc; c++) { - LLVMValueRef loop_index = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c)); + LLVMValueRef chan_offset = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8))); - LLVMValueRef result = lp_build_alloca(gallivm, load_bld->vec_type, ""); - struct lp_build_loop_state loop_state; - lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0)); - - struct lp_build_if_state ifthen; - LLVMValueRef cond, temp_res; - - loop_index = LLVMBuildExtractElement(gallivm->builder, loop_index, - loop_state.counter, ""); - cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, ""); - cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, ""); - - lp_build_if(&ifthen, gallivm, cond); - LLVMValueRef scalar; - LLVMValueRef ptr2 = LLVMBuildBitCast(builder, bld->scratch_ptr, LLVMPointerType(load_bld->elem_type, 0), ""); - scalar = lp_build_pointer_get2(builder, load_bld->elem_type, ptr2, loop_index); - - temp_res = LLVMBuildLoad2(builder, load_bld->vec_type, result, ""); - temp_res = LLVMBuildInsertElement(builder, temp_res, scalar, loop_state.counter, ""); - LLVMBuildStore(builder, temp_res, result); - lp_build_else(&ifthen); - temp_res = LLVMBuildLoad2(builder, load_bld->vec_type, result, ""); - LLVMValueRef zero = lp_build_zero_bits(gallivm, bit_size, false); - temp_res = LLVMBuildInsertElement(builder, temp_res, zero, loop_state.counter, ""); - LLVMBuildStore(builder, temp_res, result); - lp_build_endif(&ifthen); - lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length), - NULL, LLVMIntUGE); - outval[c] = LLVMBuildLoad2(gallivm->builder, load_bld->vec_type, result, ""); + outval[c] = lp_build_masked_gather(gallivm, load_bld->type.length, bit_size, + load_bld->vec_type, + lp_vec_add_offset_ptr(bld_base, bit_size, + scratch_ptr_vec, + chan_offset), + exec_mask); + outval[c] = LLVMBuildBitCast(builder, outval[c], load_bld->vec_type, ""); } } @@ -2670,43 +2642,28 @@ emit_store_scratch(struct lp_build_nir_context *bld_base, struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; struct lp_build_context *uint_bld = &bld_base->uint_bld; struct lp_build_context *store_bld; - LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);; - uint32_t shift_val = bit_size_to_shift_size(bit_size); + LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size); + LLVMValueRef scratch_ptr_vec = lp_build_broadcast(gallivm, + LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), uint_bld->type.length), + bld->scratch_ptr); store_bld = get_int_bld(bld_base, true, bit_size); LLVMValueRef exec_mask = mask_vec(bld_base); offset = lp_build_add(uint_bld, offset, thread_offsets); - offset = lp_build_shr_imm(uint_bld, offset, shift_val); for (unsigned c = 0; c < nc; c++) { if (!(writemask & (1u << c))) continue; LLVMValueRef val = (nc == 1) ? dst : LLVMBuildExtractValue(builder, dst, c, ""); - LLVMValueRef loop_index = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c)); - struct lp_build_loop_state loop_state; - lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0)); + LLVMValueRef chan_offset = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8))); - LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, val, - loop_state.counter, ""); - value_ptr = LLVMBuildBitCast(gallivm->builder, value_ptr, store_bld->elem_type, ""); + val = LLVMBuildBitCast(builder, val, store_bld->vec_type, ""); - struct lp_build_if_state ifthen; - LLVMValueRef cond; - - loop_index = LLVMBuildExtractElement(gallivm->builder, loop_index, - loop_state.counter, ""); - - cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, ""); - cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, ""); - lp_build_if(&ifthen, gallivm, cond); - - LLVMValueRef ptr2 = LLVMBuildBitCast(builder, bld->scratch_ptr, LLVMPointerType(store_bld->elem_type, 0), ""); - lp_build_pointer_set(builder, ptr2, loop_index, value_ptr); - - lp_build_endif(&ifthen); - lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length), - NULL, LLVMIntUGE); + lp_build_masked_scatter(gallivm, store_bld->type.length, bit_size, + lp_vec_add_offset_ptr(bld_base, bit_size, + scratch_ptr_vec, chan_offset), + val, exec_mask); } }