diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.c b/src/gallium/auxiliary/gallivm/lp_bld_nir.c index d41be6ba0fc..aaa88cdf494 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.c @@ -1789,6 +1789,11 @@ static void visit_intrinsic(struct lp_build_nir_context *bld_base, case nir_intrinsic_elect: bld_base->elect(bld_base, result); break; + case nir_intrinsic_reduce: + case nir_intrinsic_inclusive_scan: + case nir_intrinsic_exclusive_scan: + bld_base->reduce(bld_base, cast_type(bld_base, get_src(bld_base, instr->src[0]), nir_type_int, nir_src_bit_size(instr->src[0])), instr, result); + break; case nir_intrinsic_interp_deref_at_offset: case nir_intrinsic_interp_deref_at_centroid: case nir_intrinsic_interp_deref_at_sample: diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.h b/src/gallium/auxiliary/gallivm/lp_bld_nir.h index e396db71eb8..d65309e6b0a 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir.h +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.h @@ -186,6 +186,7 @@ struct lp_build_nir_context void (*vote)(struct lp_build_nir_context *bld_base, LLVMValueRef src, nir_intrinsic_instr *instr, LLVMValueRef dst[4]); void (*elect)(struct lp_build_nir_context *bld_base, LLVMValueRef dst[4]); + void (*reduce)(struct lp_build_nir_context *bld_base, LLVMValueRef src, nir_intrinsic_instr *instr, LLVMValueRef dst[4]); void (*helper_invocation)(struct lp_build_nir_context *bld_base, LLVMValueRef *dst); void (*interp_at)(struct lp_build_nir_context *bld_base, diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c index 6c122fc18ba..5965f863f86 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c @@ -1920,6 +1920,156 @@ static void emit_elect(struct lp_build_nir_context *bld_base, LLVMValueRef resul ""); } +static void emit_reduce(struct lp_build_nir_context *bld_base, LLVMValueRef src, + nir_intrinsic_instr *instr, LLVMValueRef result[4]) +{ + struct gallivm_state *gallivm = bld_base->base.gallivm; + LLVMBuilderRef builder = gallivm->builder; + uint32_t bit_size = nir_src_bit_size(instr->src[0]); + /* can't use llvm reduction intrinsics because of exec_mask */ + LLVMValueRef exec_mask = mask_vec(bld_base); + struct lp_build_loop_state loop_state; + nir_op reduction_op = nir_intrinsic_reduction_op(instr); + + LLVMValueRef res_store = NULL; + LLVMValueRef scan_store; + struct lp_build_context *int_bld = get_int_bld(bld_base, true, bit_size); + + if (instr->intrinsic != nir_intrinsic_reduce) + res_store = lp_build_alloca(gallivm, int_bld->vec_type, ""); + + scan_store = lp_build_alloca(gallivm, int_bld->elem_type, ""); + + struct lp_build_context elem_bld; + bool is_flt = reduction_op == nir_op_fadd || + reduction_op == nir_op_fmul || + reduction_op == nir_op_fmin || + reduction_op == nir_op_fmax; + bool is_unsigned = reduction_op == nir_op_umin || + reduction_op == nir_op_umax; + + struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld_base, bit_size) : + get_int_bld(bld_base, is_unsigned, bit_size); + + lp_build_context_init(&elem_bld, gallivm, lp_elem_type(vec_bld->type)); + + LLVMValueRef store_val = NULL; + /* + * Put the identity value for the operation into the storage + */ + switch (reduction_op) { + case nir_op_fmin: { + LLVMValueRef flt_max = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), INFINITY) : + lp_build_const_float(gallivm, INFINITY); + store_val = LLVMBuildBitCast(builder, flt_max, int_bld->elem_type, ""); + break; + } + case nir_op_fmax: { + LLVMValueRef flt_min = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), -INFINITY) : + lp_build_const_float(gallivm, -INFINITY); + store_val = LLVMBuildBitCast(builder, flt_min, int_bld->elem_type, ""); + break; + } + case nir_op_fmul: { + LLVMValueRef flt_one = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), 1.0) : + lp_build_const_float(gallivm, 1.0); + store_val = LLVMBuildBitCast(builder, flt_one, int_bld->elem_type, ""); + break; + } + case nir_op_umin: + store_val = lp_build_const_int32(gallivm, UINT_MAX); + break; + case nir_op_imin: + store_val = lp_build_const_int32(gallivm, INT_MAX); + break; + case nir_op_imax: + store_val = lp_build_const_int32(gallivm, INT_MIN); + break; + case nir_op_imul: + store_val = lp_build_const_int32(gallivm, 1); + break; + case nir_op_iand: + store_val = lp_build_const_int32(gallivm, 0xffffffff); + break; + default: + break; + } + if (store_val) + LLVMBuildStore(builder, store_val, scan_store); + + LLVMValueRef outer_cond = LLVMBuildICmp(builder, LLVMIntNE, exec_mask, bld_base->uint_bld.zero, ""); + + lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0)); + + struct lp_build_if_state ifthen; + LLVMValueRef if_cond = LLVMBuildExtractElement(gallivm->builder, outer_cond, loop_state.counter, ""); + lp_build_if(&ifthen, gallivm, if_cond); + LLVMValueRef value = LLVMBuildExtractElement(gallivm->builder, src, loop_state.counter, ""); + + LLVMValueRef res = NULL; + LLVMValueRef scan_val = LLVMBuildLoad(gallivm->builder, scan_store, ""); + if (instr->intrinsic != nir_intrinsic_reduce) + res = LLVMBuildLoad(gallivm->builder, res_store, ""); + + if (instr->intrinsic == nir_intrinsic_exclusive_scan) + res = LLVMBuildInsertElement(builder, res, scan_val, loop_state.counter, ""); + + if (is_flt) { + scan_val = LLVMBuildBitCast(builder, scan_val, elem_bld.elem_type, ""); + value = LLVMBuildBitCast(builder, value, elem_bld.elem_type, ""); + } + switch (reduction_op) { + case nir_op_fadd: + case nir_op_iadd: + scan_val = lp_build_add(&elem_bld, value, scan_val); + break; + case nir_op_fmul: + case nir_op_imul: + scan_val = lp_build_mul(&elem_bld, value, scan_val); + break; + case nir_op_imin: + case nir_op_umin: + case nir_op_fmin: + scan_val = lp_build_min(&elem_bld, value, scan_val); + break; + case nir_op_imax: + case nir_op_umax: + case nir_op_fmax: + scan_val = lp_build_max(&elem_bld, value, scan_val); + break; + case nir_op_iand: + scan_val = lp_build_and(&elem_bld, value, scan_val); + break; + case nir_op_ior: + scan_val = lp_build_or(&elem_bld, value, scan_val); + break; + case nir_op_ixor: + scan_val = lp_build_xor(&elem_bld, value, scan_val); + break; + default: + assert(0); + break; + } + if (is_flt) + scan_val = LLVMBuildBitCast(builder, scan_val, int_bld->elem_type, ""); + LLVMBuildStore(builder, scan_val, scan_store); + + if (instr->intrinsic == nir_intrinsic_inclusive_scan) { + res = LLVMBuildInsertElement(builder, res, scan_val, loop_state.counter, ""); + } + + if (instr->intrinsic != nir_intrinsic_reduce) + LLVMBuildStore(builder, res, res_store); + lp_build_endif(&ifthen); + + lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, bld_base->uint_bld.type.length), + NULL, LLVMIntUGE); + if (instr->intrinsic == nir_intrinsic_reduce) + result[0] = lp_build_broadcast_scalar(int_bld, LLVMBuildLoad(builder, scan_store, "")); + else + result[0] = LLVMBuildLoad(builder, res_store, ""); +} + static void emit_interp_at(struct lp_build_nir_context *bld_base, unsigned num_components, @@ -2166,6 +2316,7 @@ void lp_build_nir_soa(struct gallivm_state *gallivm, bld.bld_base.image_size = emit_image_size; bld.bld_base.vote = emit_vote; bld.bld_base.elect = emit_elect; + bld.bld_base.reduce = emit_reduce; bld.bld_base.helper_invocation = emit_helper_invocation; bld.bld_base.interp_at = emit_interp_at; bld.bld_base.load_scratch = emit_load_scratch;