gallivm: use masked intrinsics for global and scratch access.

This seems to improve luxmark scores for me on the luxball scene
from numbers in the 4-500 range to 5-700 range.

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Tested-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19736>
This commit is contained in:
Dave Airlie 2022-11-14 18:00:10 +10:00 committed by Marge Bot
parent fda262fe64
commit 442d1fe5ad
3 changed files with 135 additions and 115 deletions

View file

@ -598,3 +598,50 @@ lp_build_gather_values(struct gallivm_state * gallivm,
}
return vec;
}
LLVMValueRef
lp_build_masked_gather(struct gallivm_state *gallivm,
unsigned length,
unsigned bit_size,
LLVMTypeRef vec_type,
LLVMValueRef offset_ptr,
LLVMValueRef exec_mask)
{
LLVMBuilderRef builder = gallivm->builder;
LLVMValueRef args[4];
char intrin_name[64];
snprintf(intrin_name, 64, "llvm.masked.gather.v%ui%u.v%up0i%u",
length, bit_size, length, bit_size);
args[0] = offset_ptr;
args[1] = lp_build_const_int32(gallivm, bit_size / 8);
args[2] = LLVMBuildICmp(builder, LLVMIntNE, exec_mask,
LLVMConstNull(LLVMTypeOf(exec_mask)), "");
args[3] = LLVMConstNull(vec_type);
return lp_build_intrinsic(builder, intrin_name, vec_type,
args, 4, 0);
}
void
lp_build_masked_scatter(struct gallivm_state *gallivm,
unsigned length,
unsigned bit_size,
LLVMValueRef offset_ptr,
LLVMValueRef value_vec,
LLVMValueRef exec_mask)
{
LLVMBuilderRef builder = gallivm->builder;
LLVMValueRef args[4];
char intrin_name[64];
snprintf(intrin_name, 64, "llvm.masked.scatter.v%ui%u.v%up0i%u",
length, bit_size, length, bit_size);
args[0] = value_vec;
args[1] = offset_ptr;
args[2] = lp_build_const_int32(gallivm, bit_size / 8);
args[3] = LLVMBuildICmp(builder, LLVMIntNE, exec_mask,
LLVMConstNull(LLVMTypeOf(exec_mask)), "");
lp_build_intrinsic(builder, intrin_name, LLVMVoidTypeInContext(gallivm->context),
args, 4, 0);
}

View file

@ -66,4 +66,20 @@ lp_build_gather_values(struct gallivm_state * gallivm,
LLVMValueRef * values,
unsigned value_count);
LLVMValueRef
lp_build_masked_gather(struct gallivm_state *gallivm,
unsigned length,
unsigned bit_size,
LLVMTypeRef vec_type,
LLVMValueRef offset_ptr,
LLVMValueRef exec_mask);
void
lp_build_masked_scatter(struct gallivm_state *gallivm,
unsigned length,
unsigned bit_size,
LLVMValueRef offset_ptr,
LLVMValueRef value_vec,
LLVMValueRef exec_mask);
#endif /* LP_BLD_GATHER_H_ */

View file

@ -822,6 +822,41 @@ static LLVMValueRef global_addr_to_ptr(struct gallivm_state *gallivm, LLVMValueR
return addr_ptr;
}
static LLVMValueRef global_addr_to_ptr_vec(struct gallivm_state *gallivm, LLVMValueRef addr_ptr, unsigned length, unsigned bit_size)
{
LLVMBuilderRef builder = gallivm->builder;
switch (bit_size) {
case 8:
addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), length), "");
break;
case 16:
addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt16TypeInContext(gallivm->context), 0), length), "");
break;
case 32:
default:
addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt32TypeInContext(gallivm->context), 0), length), "");
break;
case 64:
addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt64TypeInContext(gallivm->context), 0), length), "");
break;
}
return addr_ptr;
}
static LLVMValueRef lp_vec_add_offset_ptr(struct lp_build_nir_context *bld_base,
unsigned bit_size,
LLVMValueRef ptr,
LLVMValueRef offset)
{
struct gallivm_state *gallivm = bld_base->base.gallivm;
LLVMBuilderRef builder = gallivm->builder;
struct lp_build_context *uint_bld = &bld_base->uint_bld;
LLVMValueRef result = LLVMBuildPtrToInt(builder, ptr, bld_base->uint64_bld.vec_type, "");
offset = LLVMBuildZExt(builder, offset, bld_base->uint64_bld.vec_type, "");
result = LLVMBuildAdd(builder, offset, result, "");
return global_addr_to_ptr_vec(gallivm, result, uint_bld->type.length, bit_size);
}
static void emit_load_global(struct lp_build_nir_context *bld_base,
unsigned nc,
unsigned bit_size,
@ -855,30 +890,14 @@ static void emit_load_global(struct lp_build_nir_context *bld_base,
}
for (unsigned c = 0; c < nc; c++) {
LLVMValueRef result = lp_build_alloca(gallivm, res_bld->vec_type, "");
struct lp_build_loop_state loop_state;
lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
LLVMValueRef chan_offset = lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8));
struct lp_build_if_state ifthen;
LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, "");
cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, "");
lp_build_if(&ifthen, gallivm, cond);
LLVMValueRef addr_ptr = LLVMBuildExtractElement(gallivm->builder, addr,
loop_state.counter, "");
addr_ptr = global_addr_to_ptr(gallivm, addr_ptr, bit_size);
LLVMValueRef value_ptr = lp_build_pointer_get2(builder, res_bld->elem_type,
addr_ptr, lp_build_const_int32(gallivm, c));
LLVMValueRef temp_res;
temp_res = LLVMBuildLoad2(builder, res_bld->vec_type, result, "");
temp_res = LLVMBuildInsertElement(builder, temp_res, value_ptr, loop_state.counter, "");
LLVMBuildStore(builder, temp_res, result);
lp_build_endif(&ifthen);
lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length),
NULL, LLVMIntUGE);
outval[c] = LLVMBuildLoad2(builder, res_bld->vec_type, result, "");
outval[c] = lp_build_masked_gather(gallivm, res_bld->type.length,
bit_size,
res_bld->vec_type,
lp_vec_add_offset_ptr(bld_base, bit_size, addr, chan_offset),
exec_mask);
outval[c] = LLVMBuildBitCast(builder, outval[c], res_bld->vec_type, "");
}
}
@ -898,40 +917,14 @@ static void emit_store_global(struct lp_build_nir_context *bld_base,
if (!(writemask & (1u << c)))
continue;
LLVMValueRef val = (nc == 1) ? dst : LLVMBuildExtractValue(builder, dst, c, "");
LLVMValueRef chan_offset = lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8));
struct lp_build_loop_state loop_state;
lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, val,
loop_state.counter, "");
LLVMValueRef addr_ptr = LLVMBuildExtractElement(gallivm->builder, addr,
loop_state.counter, "");
addr_ptr = global_addr_to_ptr(gallivm, addr_ptr, bit_size);
switch (bit_size) {
case 8:
value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt8TypeInContext(gallivm->context), "");
break;
case 16:
value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt16TypeInContext(gallivm->context), "");
break;
case 32:
value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt32TypeInContext(gallivm->context), "");
break;
case 64:
value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt64TypeInContext(gallivm->context), "");
break;
default:
break;
}
struct lp_build_if_state ifthen;
LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, "");
cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, "");
lp_build_if(&ifthen, gallivm, cond);
lp_build_pointer_set(builder, addr_ptr, lp_build_const_int32(gallivm, c), value_ptr);
lp_build_endif(&ifthen);
lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length),
NULL, LLVMIntUGE);
struct lp_build_context *out_bld = get_int_bld(bld_base, false, bit_size);
val = LLVMBuildBitCast(builder, val, out_bld->vec_type, "");
lp_build_masked_scatter(gallivm, out_bld->type.length, bit_size,
lp_vec_add_offset_ptr(bld_base, bit_size,
addr, chan_offset),
val, exec_mask);
}
}
@ -2616,46 +2609,25 @@ emit_load_scratch(struct lp_build_nir_context *bld_base,
struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
struct lp_build_context *uint_bld = &bld_base->uint_bld;
struct lp_build_context *load_bld;
LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);;
uint32_t shift_val = bit_size_to_shift_size(bit_size);
LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);
LLVMValueRef exec_mask = mask_vec(bld_base);
LLVMValueRef scratch_ptr_vec = lp_build_broadcast(gallivm,
LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), uint_bld->type.length),
bld->scratch_ptr);
load_bld = get_int_bld(bld_base, true, bit_size);
offset = lp_build_add(uint_bld, offset, thread_offsets);
offset = lp_build_shr_imm(uint_bld, offset, shift_val);
for (unsigned c = 0; c < nc; c++) {
LLVMValueRef loop_index = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c));
LLVMValueRef chan_offset = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8)));
LLVMValueRef result = lp_build_alloca(gallivm, load_bld->vec_type, "");
struct lp_build_loop_state loop_state;
lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
struct lp_build_if_state ifthen;
LLVMValueRef cond, temp_res;
loop_index = LLVMBuildExtractElement(gallivm->builder, loop_index,
loop_state.counter, "");
cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, "");
cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, "");
lp_build_if(&ifthen, gallivm, cond);
LLVMValueRef scalar;
LLVMValueRef ptr2 = LLVMBuildBitCast(builder, bld->scratch_ptr, LLVMPointerType(load_bld->elem_type, 0), "");
scalar = lp_build_pointer_get2(builder, load_bld->elem_type, ptr2, loop_index);
temp_res = LLVMBuildLoad2(builder, load_bld->vec_type, result, "");
temp_res = LLVMBuildInsertElement(builder, temp_res, scalar, loop_state.counter, "");
LLVMBuildStore(builder, temp_res, result);
lp_build_else(&ifthen);
temp_res = LLVMBuildLoad2(builder, load_bld->vec_type, result, "");
LLVMValueRef zero = lp_build_zero_bits(gallivm, bit_size, false);
temp_res = LLVMBuildInsertElement(builder, temp_res, zero, loop_state.counter, "");
LLVMBuildStore(builder, temp_res, result);
lp_build_endif(&ifthen);
lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length),
NULL, LLVMIntUGE);
outval[c] = LLVMBuildLoad2(gallivm->builder, load_bld->vec_type, result, "");
outval[c] = lp_build_masked_gather(gallivm, load_bld->type.length, bit_size,
load_bld->vec_type,
lp_vec_add_offset_ptr(bld_base, bit_size,
scratch_ptr_vec,
chan_offset),
exec_mask);
outval[c] = LLVMBuildBitCast(builder, outval[c], load_bld->vec_type, "");
}
}
@ -2670,43 +2642,28 @@ emit_store_scratch(struct lp_build_nir_context *bld_base,
struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
struct lp_build_context *uint_bld = &bld_base->uint_bld;
struct lp_build_context *store_bld;
LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);;
uint32_t shift_val = bit_size_to_shift_size(bit_size);
LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);
LLVMValueRef scratch_ptr_vec = lp_build_broadcast(gallivm,
LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), uint_bld->type.length),
bld->scratch_ptr);
store_bld = get_int_bld(bld_base, true, bit_size);
LLVMValueRef exec_mask = mask_vec(bld_base);
offset = lp_build_add(uint_bld, offset, thread_offsets);
offset = lp_build_shr_imm(uint_bld, offset, shift_val);
for (unsigned c = 0; c < nc; c++) {
if (!(writemask & (1u << c)))
continue;
LLVMValueRef val = (nc == 1) ? dst : LLVMBuildExtractValue(builder, dst, c, "");
LLVMValueRef loop_index = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c));
struct lp_build_loop_state loop_state;
lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
LLVMValueRef chan_offset = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8)));
LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, val,
loop_state.counter, "");
value_ptr = LLVMBuildBitCast(gallivm->builder, value_ptr, store_bld->elem_type, "");
val = LLVMBuildBitCast(builder, val, store_bld->vec_type, "");
struct lp_build_if_state ifthen;
LLVMValueRef cond;
loop_index = LLVMBuildExtractElement(gallivm->builder, loop_index,
loop_state.counter, "");
cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, "");
cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, "");
lp_build_if(&ifthen, gallivm, cond);
LLVMValueRef ptr2 = LLVMBuildBitCast(builder, bld->scratch_ptr, LLVMPointerType(store_bld->elem_type, 0), "");
lp_build_pointer_set(builder, ptr2, loop_index, value_ptr);
lp_build_endif(&ifthen);
lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length),
NULL, LLVMIntUGE);
lp_build_masked_scatter(gallivm, store_bld->type.length, bit_size,
lp_vec_add_offset_ptr(bld_base, bit_size,
scratch_ptr_vec, chan_offset),
val, exec_mask);
}
}