lavapipe: Implement clustered reductions

Replaces the runtime loop with a compile time loop and restarts the scan
on multiples of cluster_size.

Reviewed-By: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31250>
This commit is contained in:
Konstantin Seurer 2024-09-19 10:10:15 +02:00 committed by Marge Bot
parent 1f3b8bb918
commit dfa314e805
2 changed files with 101 additions and 66 deletions

View file

@ -2437,16 +2437,21 @@ static void emit_reduce(struct lp_build_nir_context *bld_base, LLVMValueRef src,
uint32_t bit_size = nir_src_bit_size(instr->src[0]);
/* can't use llvm reduction intrinsics because of exec_mask */
LLVMValueRef exec_mask = mask_vec(bld_base);
struct lp_build_loop_state loop_state;
nir_op reduction_op = nir_intrinsic_reduction_op(instr);
uint32_t cluster_size = 0;
if (instr->intrinsic == nir_intrinsic_reduce)
cluster_size = nir_intrinsic_cluster_size(instr);
if (cluster_size == 0)
cluster_size = bld_base->int_bld.type.length;
LLVMValueRef res_store = NULL;
LLVMValueRef scan_store;
struct lp_build_context *int_bld = get_int_bld(bld_base, true, bit_size);
if (instr->intrinsic != nir_intrinsic_reduce)
res_store = lp_build_alloca(gallivm, int_bld->vec_type, "");
res_store = lp_build_alloca(gallivm, int_bld->vec_type, "");
scan_store = lp_build_alloca(gallivm, int_bld->elem_type, "");
struct lp_build_context elem_bld;
@ -2578,75 +2583,104 @@ static void emit_reduce(struct lp_build_nir_context *bld_base, LLVMValueRef src,
LLVMValueRef outer_cond = LLVMBuildICmp(builder, LLVMIntNE, exec_mask, bld_base->uint_bld.zero, "");
lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
for (uint32_t i = 0; i < bld_base->uint_bld.type.length; i++) {
LLVMValueRef counter = lp_build_const_int32(gallivm, i);
struct lp_build_if_state ifthen;
LLVMValueRef if_cond = LLVMBuildExtractElement(gallivm->builder, outer_cond, loop_state.counter, "");
lp_build_if(&ifthen, gallivm, if_cond);
LLVMValueRef value = LLVMBuildExtractElement(gallivm->builder, src, loop_state.counter, "");
struct lp_build_if_state ifthen;
LLVMValueRef if_cond = LLVMBuildExtractElement(gallivm->builder, outer_cond, counter, "");
lp_build_if(&ifthen, gallivm, if_cond);
LLVMValueRef res = NULL;
LLVMValueRef scan_val = LLVMBuildLoad2(gallivm->builder, int_bld->elem_type, scan_store, "");
if (instr->intrinsic != nir_intrinsic_reduce)
res = LLVMBuildLoad2(gallivm->builder, int_bld->vec_type, res_store, "");
LLVMValueRef value = LLVMBuildExtractElement(gallivm->builder, src, counter, "");
if (instr->intrinsic == nir_intrinsic_exclusive_scan)
res = LLVMBuildInsertElement(builder, res, scan_val, loop_state.counter, "");
LLVMValueRef res = NULL;
LLVMValueRef scan_val = LLVMBuildLoad2(gallivm->builder, int_bld->elem_type, scan_store, "");
if (is_flt) {
scan_val = LLVMBuildBitCast(builder, scan_val, elem_bld.elem_type, "");
value = LLVMBuildBitCast(builder, value, elem_bld.elem_type, "");
}
switch (reduction_op) {
case nir_op_fadd:
case nir_op_iadd:
scan_val = lp_build_add(&elem_bld, value, scan_val);
break;
case nir_op_fmul:
case nir_op_imul:
scan_val = lp_build_mul(&elem_bld, value, scan_val);
break;
case nir_op_imin:
case nir_op_umin:
case nir_op_fmin:
scan_val = lp_build_min(&elem_bld, value, scan_val);
break;
case nir_op_imax:
case nir_op_umax:
case nir_op_fmax:
scan_val = lp_build_max(&elem_bld, value, scan_val);
break;
case nir_op_iand:
scan_val = lp_build_and(&elem_bld, value, scan_val);
break;
case nir_op_ior:
scan_val = lp_build_or(&elem_bld, value, scan_val);
break;
case nir_op_ixor:
scan_val = lp_build_xor(&elem_bld, value, scan_val);
break;
default:
assert(0);
break;
}
if (is_flt)
scan_val = LLVMBuildBitCast(builder, scan_val, int_bld->elem_type, "");
LLVMBuildStore(builder, scan_val, scan_store);
if (instr->intrinsic != nir_intrinsic_reduce)
res = LLVMBuildLoad2(gallivm->builder, int_bld->vec_type, res_store, "");
if (instr->intrinsic == nir_intrinsic_inclusive_scan) {
res = LLVMBuildInsertElement(builder, res, scan_val, loop_state.counter, "");
if (instr->intrinsic == nir_intrinsic_exclusive_scan)
res = LLVMBuildInsertElement(builder, res, scan_val, counter, "");
if (is_flt) {
scan_val = LLVMBuildBitCast(builder, scan_val, elem_bld.elem_type, "");
value = LLVMBuildBitCast(builder, value, elem_bld.elem_type, "");
}
switch (reduction_op) {
case nir_op_fadd:
case nir_op_iadd:
scan_val = lp_build_add(&elem_bld, value, scan_val);
break;
case nir_op_fmul:
case nir_op_imul:
scan_val = lp_build_mul(&elem_bld, value, scan_val);
break;
case nir_op_imin:
case nir_op_umin:
case nir_op_fmin:
scan_val = lp_build_min(&elem_bld, value, scan_val);
break;
case nir_op_imax:
case nir_op_umax:
case nir_op_fmax:
scan_val = lp_build_max(&elem_bld, value, scan_val);
break;
case nir_op_iand:
scan_val = lp_build_and(&elem_bld, value, scan_val);
break;
case nir_op_ior:
scan_val = lp_build_or(&elem_bld, value, scan_val);
break;
case nir_op_ixor:
scan_val = lp_build_xor(&elem_bld, value, scan_val);
break;
default:
assert(0);
break;
}
if (is_flt)
scan_val = LLVMBuildBitCast(builder, scan_val, int_bld->elem_type, "");
LLVMBuildStore(builder, scan_val, scan_store);
if (instr->intrinsic == nir_intrinsic_inclusive_scan)
res = LLVMBuildInsertElement(builder, res, scan_val, counter, "");
if (instr->intrinsic != nir_intrinsic_reduce)
LLVMBuildStore(builder, res, res_store);
lp_build_endif(&ifthen);
if (instr->intrinsic == nir_intrinsic_reduce && (i % cluster_size) == (cluster_size - 1)) {
res = LLVMBuildLoad2(gallivm->builder, int_bld->vec_type, res_store, "");
scan_val = LLVMBuildLoad2(gallivm->builder, int_bld->elem_type, scan_store, "");
if (store_val)
LLVMBuildStore(builder, store_val, scan_store);
else
LLVMBuildStore(builder, LLVMConstNull(int_bld->elem_type), scan_store);
LLVMValueRef cluster_index = lp_build_const_int32(gallivm, i / cluster_size);
res = LLVMBuildInsertElement(builder, res, scan_val, cluster_index, "");
LLVMBuildStore(builder, res, res_store);
}
}
if (instr->intrinsic != nir_intrinsic_reduce)
LLVMBuildStore(builder, res, res_store);
lp_build_endif(&ifthen);
LLVMValueRef res = LLVMBuildLoad2(gallivm->builder, int_bld->vec_type, res_store, "");
if (instr->intrinsic == nir_intrinsic_reduce) {
LLVMValueRef swizzle[LP_MAX_VECTOR_LENGTH];
for (uint32_t i = 0; i < bld_base->int_bld.type.length; i++)
swizzle[i] = lp_build_const_int32(gallivm, i / cluster_size);
LLVMValueRef undef = LLVMGetUndef(int_bld->vec_type);
result[0] = LLVMBuildShuffleVector(
builder, res, undef, LLVMConstVector(swizzle, bld_base->int_bld.type.length), "");
} else {
result[0] = res;
}
lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, bld_base->uint_bld.type.length),
NULL, LLVMIntUGE);
if (instr->intrinsic == nir_intrinsic_reduce)
result[0] = lp_build_broadcast_scalar(int_bld, LLVMBuildLoad2(builder, int_bld->elem_type, scan_store, ""));
else
result[0] = LLVMBuildLoad2(builder, int_bld->vec_type, res_store, "");
}
static void emit_read_invocation(struct lp_build_nir_context *bld_base,

View file

@ -1204,7 +1204,8 @@ lvp_get_properties(const struct lvp_physical_device *device, struct vk_propertie
memset(p->deviceLUID, 0, VK_LUID_SIZE);
#if LLVM_VERSION_MAJOR >= 10
p->subgroupSupportedOperations |= VK_SUBGROUP_FEATURE_SHUFFLE_BIT | VK_SUBGROUP_FEATURE_SHUFFLE_RELATIVE_BIT | VK_SUBGROUP_FEATURE_QUAD_BIT;
p->subgroupSupportedOperations |= VK_SUBGROUP_FEATURE_SHUFFLE_BIT | VK_SUBGROUP_FEATURE_SHUFFLE_RELATIVE_BIT | VK_SUBGROUP_FEATURE_QUAD_BIT |
VK_SUBGROUP_FEATURE_CLUSTERED_BIT;
#endif
/* Vulkan 1.2 */