mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-07 02:48:06 +02:00
gallivm: let reduce ops use llvm intrinsics
As part of coopmat, I want to make reductions faster as I need them to implement coopmat. The intrinsics can't be used directly as we have to take into account the exec_mask, but it can be done by picking the a value to insert into the disabled lanes, then calling the LLVM intrinsic. Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39225>
This commit is contained in:
parent
1e59dbf66b
commit
12bceb228a
1 changed files with 99 additions and 119 deletions
|
|
@ -2315,15 +2315,31 @@ static void emit_elect(struct lp_build_nir_soa_context *bld, LLVMValueRef result
|
||||||
result[0] = LLVMBuildICmp(builder, LLVMIntNE, result[0], lp_build_const_int_vec(gallivm, bld->int_bld.type, 0), "");
|
result[0] = LLVMBuildICmp(builder, LLVMIntNE, result[0], lp_build_const_int_vec(gallivm, bld->int_bld.type, 0), "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LLVMValueRef build_reduction_identity_val(struct gallivm_state *gallivm,
|
||||||
|
struct lp_build_context *int_bld,
|
||||||
|
nir_op reduction_op,
|
||||||
|
unsigned bit_size)
|
||||||
|
{
|
||||||
|
nir_const_value const_val = nir_alu_binop_identity(reduction_op, bit_size);
|
||||||
|
|
||||||
|
return lp_build_const_int_vec(gallivm, lp_elem_type(int_bld->type),
|
||||||
|
nir_const_value_as_uint(const_val, bit_size));
|
||||||
|
}
|
||||||
|
|
||||||
static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src,
|
static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src,
|
||||||
nir_intrinsic_instr *instr, LLVMValueRef result[4])
|
nir_intrinsic_instr *instr, LLVMValueRef result[4])
|
||||||
{
|
{
|
||||||
struct gallivm_state *gallivm = bld->base.gallivm;
|
struct gallivm_state *gallivm = bld->base.gallivm;
|
||||||
LLVMBuilderRef builder = gallivm->builder;
|
LLVMBuilderRef builder = gallivm->builder;
|
||||||
uint32_t bit_size = nir_src_bit_size(instr->src[0]);
|
uint32_t bit_size = nir_src_bit_size(instr->src[0]);
|
||||||
/* can't use llvm reduction intrinsics because of exec_mask */
|
|
||||||
LLVMValueRef exec_mask = group_op_mask_vec(bld);
|
LLVMValueRef exec_mask = group_op_mask_vec(bld);
|
||||||
nir_op reduction_op = nir_intrinsic_reduction_op(instr);
|
nir_op reduction_op = nir_intrinsic_reduction_op(instr);
|
||||||
|
bool is_flt = reduction_op == nir_op_fadd ||
|
||||||
|
reduction_op == nir_op_fmul ||
|
||||||
|
reduction_op == nir_op_fmin ||
|
||||||
|
reduction_op == nir_op_fmax;
|
||||||
|
bool is_unsigned = reduction_op == nir_op_umin ||
|
||||||
|
reduction_op == nir_op_umax;
|
||||||
|
|
||||||
uint32_t cluster_size = 0;
|
uint32_t cluster_size = 0;
|
||||||
|
|
||||||
|
|
@ -2338,137 +2354,101 @@ static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src,
|
||||||
src = LLVMBuildZExt(builder, src, bld->uint8_bld.vec_type, "");
|
src = LLVMBuildZExt(builder, src, bld->uint8_bld.vec_type, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct lp_build_context *int_bld = get_int_bld(bld, true, bit_size, true);
|
||||||
|
struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld, bit_size, true) :
|
||||||
|
get_int_bld(bld, is_unsigned, bit_size, true);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* For a reduce operation with the correct cluster size, the llvm
|
||||||
|
* intrinsics can be used as long as the exec_mask is taken into account.
|
||||||
|
* Values are defaulted in disabled lanes depending on the operation.
|
||||||
|
*/
|
||||||
|
if (instr->intrinsic == nir_intrinsic_reduce &&
|
||||||
|
cluster_size == bld->int_bld.type.length) {
|
||||||
|
char intrinsic[64];
|
||||||
|
uint32_t length = vec_bld->type.length;
|
||||||
|
uint32_t src_width = bit_size;
|
||||||
|
|
||||||
|
src = LLVMBuildBitCast(builder, src, int_bld->vec_type, "");
|
||||||
|
if (bit_size < 32)
|
||||||
|
exec_mask = LLVMBuildTrunc(builder, exec_mask, int_bld->vec_type, "");
|
||||||
|
if (bit_size > 32)
|
||||||
|
exec_mask = LLVMBuildSExt(builder, exec_mask, int_bld->vec_type, "");
|
||||||
|
LLVMValueRef masked_val = lp_build_and(int_bld, src, exec_mask);
|
||||||
|
const char *opname;
|
||||||
|
|
||||||
|
switch (reduction_op) {
|
||||||
|
case nir_op_iadd: opname = "add"; break;
|
||||||
|
case nir_op_iand: opname = "and"; break;
|
||||||
|
case nir_op_ior: opname = "or"; break;
|
||||||
|
case nir_op_ixor: opname = "xor"; break;
|
||||||
|
case nir_op_imul: opname = "mul"; break;
|
||||||
|
case nir_op_fadd: opname = "fadd"; break;
|
||||||
|
case nir_op_fmul: opname = "fmul"; break;
|
||||||
|
case nir_op_imin: opname = "smin"; break;
|
||||||
|
case nir_op_umin: opname = "umin"; break;
|
||||||
|
case nir_op_fmin: opname = "fmin"; break;
|
||||||
|
case nir_op_imax: opname = "smax"; break;
|
||||||
|
case nir_op_umax: opname = "umax"; break;
|
||||||
|
case nir_op_fmax: opname = "fmax"; break;
|
||||||
|
default:
|
||||||
|
assert(0);
|
||||||
|
};
|
||||||
|
snprintf(intrinsic, sizeof intrinsic, "llvm.vector.reduce.%s.v%u%s%u",
|
||||||
|
opname,
|
||||||
|
length, is_flt ? "f" : "i" , src_width);
|
||||||
|
|
||||||
|
LLVMValueRef init_val = build_reduction_identity_val(gallivm,
|
||||||
|
int_bld,
|
||||||
|
reduction_op,
|
||||||
|
bit_size);
|
||||||
|
if (init_val) {
|
||||||
|
init_val = lp_build_broadcast_scalar(int_bld, init_val);
|
||||||
|
init_val = lp_build_andnot(int_bld, init_val, exec_mask);
|
||||||
|
masked_val = lp_build_or(int_bld, masked_val, init_val);
|
||||||
|
}
|
||||||
|
if (is_flt)
|
||||||
|
masked_val = LLVMBuildBitCast(builder, masked_val, vec_bld->vec_type, "");
|
||||||
|
|
||||||
|
LLVMValueRef args[2];
|
||||||
|
int num_args = 1;
|
||||||
|
|
||||||
|
if (reduction_op == nir_op_fadd ||
|
||||||
|
reduction_op == nir_op_fmul) {
|
||||||
|
if (reduction_op == nir_op_fmul) {
|
||||||
|
args[0] = lp_build_const_elem(gallivm, vec_bld->type, 1);
|
||||||
|
} else {
|
||||||
|
args[0] = lp_build_const_elem(gallivm, vec_bld->type, -0.0);
|
||||||
|
}
|
||||||
|
args[1] = masked_val;
|
||||||
|
num_args++;
|
||||||
|
} else {
|
||||||
|
args[0] = masked_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
LLVMValueRef res = lp_build_intrinsic(builder, intrinsic, vec_bld->elem_type, args, num_args, 0);
|
||||||
|
|
||||||
|
result[0] = lp_build_broadcast(gallivm, vec_bld->vec_type, res);
|
||||||
|
|
||||||
|
if (instr->def.bit_size == 1)
|
||||||
|
result[0] = LLVMBuildICmp(builder, LLVMIntNE, result[0], int_bld->zero, "");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
LLVMValueRef res_store = NULL;
|
LLVMValueRef res_store = NULL;
|
||||||
LLVMValueRef scan_store;
|
LLVMValueRef scan_store;
|
||||||
struct lp_build_context *int_bld = get_int_bld(bld, true, bit_size, true);
|
|
||||||
|
|
||||||
res_store = lp_build_alloca(gallivm, int_bld->vec_type, "");
|
res_store = lp_build_alloca(gallivm, int_bld->vec_type, "");
|
||||||
scan_store = lp_build_alloca(gallivm, int_bld->elem_type, "");
|
scan_store = lp_build_alloca(gallivm, int_bld->elem_type, "");
|
||||||
|
|
||||||
struct lp_build_context elem_bld;
|
struct lp_build_context elem_bld;
|
||||||
bool is_flt = reduction_op == nir_op_fadd ||
|
|
||||||
reduction_op == nir_op_fmul ||
|
|
||||||
reduction_op == nir_op_fmin ||
|
|
||||||
reduction_op == nir_op_fmax;
|
|
||||||
bool is_unsigned = reduction_op == nir_op_umin ||
|
|
||||||
reduction_op == nir_op_umax;
|
|
||||||
|
|
||||||
struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld, bit_size, true) :
|
|
||||||
get_int_bld(bld, is_unsigned, bit_size, true);
|
|
||||||
|
|
||||||
lp_build_context_init(&elem_bld, gallivm, lp_elem_type(vec_bld->type));
|
lp_build_context_init(&elem_bld, gallivm, lp_elem_type(vec_bld->type));
|
||||||
|
|
||||||
LLVMValueRef store_val = NULL;
|
LLVMValueRef store_val = build_reduction_identity_val(gallivm, int_bld, reduction_op, bit_size);
|
||||||
/*
|
/*
|
||||||
* Put the identity value for the operation into the storage
|
* Put the identity value for the operation into the storage
|
||||||
*/
|
*/
|
||||||
switch (reduction_op) {
|
|
||||||
case nir_op_fmin: {
|
|
||||||
LLVMValueRef flt_max = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), INFINITY) :
|
|
||||||
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), INFINITY) : lp_build_const_float(gallivm, INFINITY));
|
|
||||||
store_val = LLVMBuildBitCast(builder, flt_max, int_bld->elem_type, "");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case nir_op_fmax: {
|
|
||||||
LLVMValueRef flt_min = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), -INFINITY) :
|
|
||||||
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), -INFINITY) : lp_build_const_float(gallivm, -INFINITY));
|
|
||||||
store_val = LLVMBuildBitCast(builder, flt_min, int_bld->elem_type, "");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case nir_op_fmul: {
|
|
||||||
LLVMValueRef flt_one = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), 1.0) :
|
|
||||||
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), 1.0) : lp_build_const_float(gallivm, 1.0));
|
|
||||||
store_val = LLVMBuildBitCast(builder, flt_one, int_bld->elem_type, "");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case nir_op_umin:
|
|
||||||
switch (bit_size) {
|
|
||||||
case 8:
|
|
||||||
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), UINT8_MAX, 0);
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), UINT16_MAX, 0);
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
default:
|
|
||||||
store_val = lp_build_const_int32(gallivm, UINT_MAX);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
store_val = lp_build_const_int64(gallivm, UINT64_MAX);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case nir_op_imin:
|
|
||||||
switch (bit_size) {
|
|
||||||
case 8:
|
|
||||||
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), INT8_MAX, 0);
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), INT16_MAX, 0);
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
default:
|
|
||||||
store_val = lp_build_const_int32(gallivm, INT_MAX);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
store_val = lp_build_const_int64(gallivm, INT64_MAX);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case nir_op_imax:
|
|
||||||
switch (bit_size) {
|
|
||||||
case 8:
|
|
||||||
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), INT8_MIN, 0);
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), INT16_MIN, 0);
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
default:
|
|
||||||
store_val = lp_build_const_int32(gallivm, INT_MIN);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
store_val = lp_build_const_int64(gallivm, INT64_MIN);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case nir_op_imul:
|
|
||||||
switch (bit_size) {
|
|
||||||
case 8:
|
|
||||||
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 1, 0);
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 1, 0);
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
default:
|
|
||||||
store_val = lp_build_const_int32(gallivm, 1);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
store_val = lp_build_const_int64(gallivm, 1);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case nir_op_iand:
|
|
||||||
switch (bit_size) {
|
|
||||||
case 8:
|
|
||||||
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 0xff, 0);
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 0xffff, 0);
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
default:
|
|
||||||
store_val = lp_build_const_int32(gallivm, 0xffffffff);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
store_val = lp_build_const_int64(gallivm, 0xffffffffffffffffLL);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (store_val)
|
if (store_val)
|
||||||
LLVMBuildStore(builder, store_val, scan_store);
|
LLVMBuildStore(builder, store_val, scan_store);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue