gallivm: let reduce ops use llvm intrinsics

As part of coopmat, I want to make reductions faster as I need
them to implement coopmat.

The intrinsics can't be used directly as we have to take into
account the exec_mask, but it can be done by picking the
a value to insert into the disabled lanes, then calling
the LLVM intrinsic.

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39225>
This commit is contained in:
Dave Airlie 2025-12-17 10:41:41 +10:00 committed by Marge Bot
parent 1e59dbf66b
commit 12bceb228a

View file

@ -2315,15 +2315,31 @@ static void emit_elect(struct lp_build_nir_soa_context *bld, LLVMValueRef result
result[0] = LLVMBuildICmp(builder, LLVMIntNE, result[0], lp_build_const_int_vec(gallivm, bld->int_bld.type, 0), "");
}
static LLVMValueRef build_reduction_identity_val(struct gallivm_state *gallivm,
struct lp_build_context *int_bld,
nir_op reduction_op,
unsigned bit_size)
{
nir_const_value const_val = nir_alu_binop_identity(reduction_op, bit_size);
return lp_build_const_int_vec(gallivm, lp_elem_type(int_bld->type),
nir_const_value_as_uint(const_val, bit_size));
}
static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src,
nir_intrinsic_instr *instr, LLVMValueRef result[4])
{
struct gallivm_state *gallivm = bld->base.gallivm;
LLVMBuilderRef builder = gallivm->builder;
uint32_t bit_size = nir_src_bit_size(instr->src[0]);
/* can't use llvm reduction intrinsics because of exec_mask */
LLVMValueRef exec_mask = group_op_mask_vec(bld);
nir_op reduction_op = nir_intrinsic_reduction_op(instr);
bool is_flt = reduction_op == nir_op_fadd ||
reduction_op == nir_op_fmul ||
reduction_op == nir_op_fmin ||
reduction_op == nir_op_fmax;
bool is_unsigned = reduction_op == nir_op_umin ||
reduction_op == nir_op_umax;
uint32_t cluster_size = 0;
@ -2338,137 +2354,101 @@ static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src,
src = LLVMBuildZExt(builder, src, bld->uint8_bld.vec_type, "");
}
struct lp_build_context *int_bld = get_int_bld(bld, true, bit_size, true);
struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld, bit_size, true) :
get_int_bld(bld, is_unsigned, bit_size, true);
/*
* For a reduce operation with the correct cluster size, the llvm
* intrinsics can be used as long as the exec_mask is taken into account.
* Values are defaulted in disabled lanes depending on the operation.
*/
if (instr->intrinsic == nir_intrinsic_reduce &&
cluster_size == bld->int_bld.type.length) {
char intrinsic[64];
uint32_t length = vec_bld->type.length;
uint32_t src_width = bit_size;
src = LLVMBuildBitCast(builder, src, int_bld->vec_type, "");
if (bit_size < 32)
exec_mask = LLVMBuildTrunc(builder, exec_mask, int_bld->vec_type, "");
if (bit_size > 32)
exec_mask = LLVMBuildSExt(builder, exec_mask, int_bld->vec_type, "");
LLVMValueRef masked_val = lp_build_and(int_bld, src, exec_mask);
const char *opname;
switch (reduction_op) {
case nir_op_iadd: opname = "add"; break;
case nir_op_iand: opname = "and"; break;
case nir_op_ior: opname = "or"; break;
case nir_op_ixor: opname = "xor"; break;
case nir_op_imul: opname = "mul"; break;
case nir_op_fadd: opname = "fadd"; break;
case nir_op_fmul: opname = "fmul"; break;
case nir_op_imin: opname = "smin"; break;
case nir_op_umin: opname = "umin"; break;
case nir_op_fmin: opname = "fmin"; break;
case nir_op_imax: opname = "smax"; break;
case nir_op_umax: opname = "umax"; break;
case nir_op_fmax: opname = "fmax"; break;
default:
assert(0);
};
snprintf(intrinsic, sizeof intrinsic, "llvm.vector.reduce.%s.v%u%s%u",
opname,
length, is_flt ? "f" : "i" , src_width);
LLVMValueRef init_val = build_reduction_identity_val(gallivm,
int_bld,
reduction_op,
bit_size);
if (init_val) {
init_val = lp_build_broadcast_scalar(int_bld, init_val);
init_val = lp_build_andnot(int_bld, init_val, exec_mask);
masked_val = lp_build_or(int_bld, masked_val, init_val);
}
if (is_flt)
masked_val = LLVMBuildBitCast(builder, masked_val, vec_bld->vec_type, "");
LLVMValueRef args[2];
int num_args = 1;
if (reduction_op == nir_op_fadd ||
reduction_op == nir_op_fmul) {
if (reduction_op == nir_op_fmul) {
args[0] = lp_build_const_elem(gallivm, vec_bld->type, 1);
} else {
args[0] = lp_build_const_elem(gallivm, vec_bld->type, -0.0);
}
args[1] = masked_val;
num_args++;
} else {
args[0] = masked_val;
}
LLVMValueRef res = lp_build_intrinsic(builder, intrinsic, vec_bld->elem_type, args, num_args, 0);
result[0] = lp_build_broadcast(gallivm, vec_bld->vec_type, res);
if (instr->def.bit_size == 1)
result[0] = LLVMBuildICmp(builder, LLVMIntNE, result[0], int_bld->zero, "");
return;
}
LLVMValueRef res_store = NULL;
LLVMValueRef scan_store;
struct lp_build_context *int_bld = get_int_bld(bld, true, bit_size, true);
res_store = lp_build_alloca(gallivm, int_bld->vec_type, "");
scan_store = lp_build_alloca(gallivm, int_bld->elem_type, "");
struct lp_build_context elem_bld;
bool is_flt = reduction_op == nir_op_fadd ||
reduction_op == nir_op_fmul ||
reduction_op == nir_op_fmin ||
reduction_op == nir_op_fmax;
bool is_unsigned = reduction_op == nir_op_umin ||
reduction_op == nir_op_umax;
struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld, bit_size, true) :
get_int_bld(bld, is_unsigned, bit_size, true);
lp_build_context_init(&elem_bld, gallivm, lp_elem_type(vec_bld->type));
LLVMValueRef store_val = NULL;
LLVMValueRef store_val = build_reduction_identity_val(gallivm, int_bld, reduction_op, bit_size);
/*
* Put the identity value for the operation into the storage
*/
switch (reduction_op) {
case nir_op_fmin: {
LLVMValueRef flt_max = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), INFINITY) :
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), INFINITY) : lp_build_const_float(gallivm, INFINITY));
store_val = LLVMBuildBitCast(builder, flt_max, int_bld->elem_type, "");
break;
}
case nir_op_fmax: {
LLVMValueRef flt_min = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), -INFINITY) :
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), -INFINITY) : lp_build_const_float(gallivm, -INFINITY));
store_val = LLVMBuildBitCast(builder, flt_min, int_bld->elem_type, "");
break;
}
case nir_op_fmul: {
LLVMValueRef flt_one = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), 1.0) :
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), 1.0) : lp_build_const_float(gallivm, 1.0));
store_val = LLVMBuildBitCast(builder, flt_one, int_bld->elem_type, "");
break;
}
case nir_op_umin:
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), UINT8_MAX, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), UINT16_MAX, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, UINT_MAX);
break;
case 64:
store_val = lp_build_const_int64(gallivm, UINT64_MAX);
break;
}
break;
case nir_op_imin:
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), INT8_MAX, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), INT16_MAX, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, INT_MAX);
break;
case 64:
store_val = lp_build_const_int64(gallivm, INT64_MAX);
break;
}
break;
case nir_op_imax:
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), INT8_MIN, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), INT16_MIN, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, INT_MIN);
break;
case 64:
store_val = lp_build_const_int64(gallivm, INT64_MIN);
break;
}
break;
case nir_op_imul:
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 1, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 1, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, 1);
break;
case 64:
store_val = lp_build_const_int64(gallivm, 1);
break;
}
break;
case nir_op_iand:
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 0xff, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 0xffff, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, 0xffffffff);
break;
case 64:
store_val = lp_build_const_int64(gallivm, 0xffffffffffffffffLL);
break;
}
break;
default:
break;
}
if (store_val)
LLVMBuildStore(builder, store_val, scan_store);