lavapipe: Lower push constants in NIR

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32963>
This commit is contained in:
Konstantin Seurer 2024-12-14 11:20:14 +01:00 committed by Marge Bot
parent a57e8b2e97
commit d9db40208d
2 changed files with 71 additions and 83 deletions

View file

@ -1223,70 +1223,6 @@ lp_offset_in_range(struct lp_build_nir_soa_context *bld,
return LLVMBuildAnd(gallivm->builder, fetch_in_bounds, fetch_non_negative, "");
}
static void emit_load_ubo(struct lp_build_nir_soa_context *bld,
unsigned nc,
unsigned bit_size,
LLVMValueRef index,
LLVMValueRef offset,
LLVMValueRef result[NIR_MAX_VEC_COMPONENTS])
{
struct gallivm_state *gallivm = bld->base.gallivm;
LLVMBuilderRef builder = gallivm->builder;
struct lp_build_context *uint_bld = get_int_bld(bld, true, 32, lp_value_is_divergent(offset));
struct lp_build_context *bld_broad = get_int_bld(bld, true, bit_size, lp_value_is_divergent(offset));
LLVMValueRef consts_ptr = lp_llvm_buffer_base(gallivm, bld->consts_ptr, index, LP_MAX_TGSI_CONST_BUFFERS);
LLVMValueRef num_consts = lp_llvm_buffer_num_elements(gallivm, bld->consts_ptr, index, LP_MAX_TGSI_CONST_BUFFERS);
unsigned size_shift = bit_size_to_shift_size(bit_size);
if (size_shift)
offset = lp_build_shr(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, size_shift));
LLVMTypeRef ptr_type = LLVMPointerType(bld_broad->elem_type, 0);
consts_ptr = LLVMBuildBitCast(builder, consts_ptr, ptr_type, "");
if (!lp_value_is_divergent(offset)) {
struct lp_build_context *load_bld = get_int_bld(bld, true, bit_size, false);
switch (bit_size) {
case 8:
num_consts = LLVMBuildShl(gallivm->builder, num_consts, lp_build_const_int32(gallivm, 2), "");
break;
case 16:
num_consts = LLVMBuildShl(gallivm->builder, num_consts, lp_build_const_int32(gallivm, 1), "");
break;
case 64:
num_consts = LLVMBuildLShr(gallivm->builder, num_consts, lp_build_const_int32(gallivm, 1), "");
break;
default: break;
}
for (unsigned c = 0; c < nc; c++) {
LLVMValueRef chan_offset = LLVMBuildAdd(builder, offset, lp_build_const_int32(gallivm, c), "");
LLVMValueRef in_range = lp_offset_in_range(bld, chan_offset, num_consts);
LLVMValueRef ptr = LLVMBuildGEP2(builder, bld_broad->elem_type, consts_ptr, &chan_offset, 1, "");
LLVMValueRef null_ptr = LLVMBuildBitCast(builder, bld->null_qword_ptr, LLVMTypeOf(ptr), "");
ptr = LLVMBuildSelect(builder, in_range, ptr, null_ptr, "");
result[c] = LLVMBuildLoad2(builder, load_bld->elem_type, ptr, "");
}
} else {
LLVMValueRef overflow_mask;
num_consts = lp_build_broadcast_scalar(uint_bld, num_consts);
if (bit_size == 64)
num_consts = lp_build_shr_imm(uint_bld, num_consts, 1);
else if (bit_size == 16)
num_consts = lp_build_shl_imm(uint_bld, num_consts, 1);
else if (bit_size == 8)
num_consts = lp_build_shl_imm(uint_bld, num_consts, 2);
for (unsigned c = 0; c < nc; c++) {
LLVMValueRef this_offset = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c));
overflow_mask = lp_build_compare(gallivm, uint_bld->type, PIPE_FUNC_GEQUAL,
this_offset, num_consts);
result[c] = build_gather(bld, bld_broad, bld_broad->elem_type, consts_ptr, this_offset, overflow_mask, NULL);
}
}
}
static LLVMValueRef
load_ubo_base_addr(struct lp_build_nir_soa_context *bld, LLVMValueRef index)
{
@ -4037,26 +3973,67 @@ visit_load_ubo(struct lp_build_nir_soa_context *bld,
nir_intrinsic_instr *instr,
LLVMValueRef result[NIR_MAX_VEC_COMPONENTS])
{
LLVMValueRef idx = get_src(bld, &instr->src[0], 0);
struct gallivm_state *gallivm = bld->base.gallivm;
LLVMBuilderRef builder = gallivm->builder;
LLVMValueRef index = get_src(bld, &instr->src[0], 0);
LLVMValueRef offset = get_src(bld, &instr->src[1], 0);
emit_load_ubo(bld, instr->def.num_components,
instr->def.bit_size,
idx, offset, result);
}
struct lp_build_context *uint_bld = get_int_bld(bld, true, 32, lp_value_is_divergent(offset));
struct lp_build_context *bld_broad = get_int_bld(bld, true, instr->def.bit_size, lp_value_is_divergent(offset));
static void
visit_load_push_constant(struct lp_build_nir_soa_context *bld,
nir_intrinsic_instr *instr,
LLVMValueRef result[NIR_MAX_VEC_COMPONENTS])
{
struct gallivm_state *gallivm = bld->base.gallivm;
LLVMValueRef offset = get_src(bld, &instr->src[0], 0);
LLVMValueRef idx = lp_build_const_int32(gallivm, 0);
LLVMValueRef consts_ptr = lp_llvm_buffer_base(gallivm, bld->consts_ptr, index, LP_MAX_TGSI_CONST_BUFFERS);
LLVMValueRef num_consts = lp_llvm_buffer_num_elements(gallivm, bld->consts_ptr, index, LP_MAX_TGSI_CONST_BUFFERS);
emit_load_ubo(bld, instr->def.num_components,
instr->def.bit_size,
idx, offset, result);
unsigned size_shift = bit_size_to_shift_size(instr->def.bit_size);
if (size_shift)
offset = lp_build_shr(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, size_shift));
LLVMTypeRef ptr_type = LLVMPointerType(bld_broad->elem_type, 0);
consts_ptr = LLVMBuildBitCast(builder, consts_ptr, ptr_type, "");
if (!lp_value_is_divergent(offset)) {
struct lp_build_context *load_bld = get_int_bld(bld, true, instr->def.bit_size, false);
switch (instr->def.bit_size) {
case 8:
num_consts = LLVMBuildShl(gallivm->builder, num_consts, lp_build_const_int32(gallivm, 2), "");
break;
case 16:
num_consts = LLVMBuildShl(gallivm->builder, num_consts, lp_build_const_int32(gallivm, 1), "");
break;
case 64:
num_consts = LLVMBuildLShr(gallivm->builder, num_consts, lp_build_const_int32(gallivm, 1), "");
break;
default: break;
}
for (unsigned c = 0; c < instr->def.num_components; c++) {
LLVMValueRef chan_offset = LLVMBuildAdd(builder, offset, lp_build_const_int32(gallivm, c), "");
LLVMValueRef in_range = lp_offset_in_range(bld, chan_offset, num_consts);
LLVMValueRef ptr = LLVMBuildGEP2(builder, bld_broad->elem_type, consts_ptr, &chan_offset, 1, "");
LLVMValueRef null_ptr = LLVMBuildBitCast(builder, bld->null_qword_ptr, LLVMTypeOf(ptr), "");
ptr = LLVMBuildSelect(builder, in_range, ptr, null_ptr, "");
result[c] = LLVMBuildLoad2(builder, load_bld->elem_type, ptr, "");
}
} else {
LLVMValueRef overflow_mask;
num_consts = lp_build_broadcast_scalar(uint_bld, num_consts);
if (instr->def.bit_size == 64)
num_consts = lp_build_shr_imm(uint_bld, num_consts, 1);
else if (instr->def.bit_size == 16)
num_consts = lp_build_shl_imm(uint_bld, num_consts, 1);
else if (instr->def.bit_size == 8)
num_consts = lp_build_shl_imm(uint_bld, num_consts, 2);
for (unsigned c = 0; c < instr->def.num_components; c++) {
LLVMValueRef this_offset = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c));
overflow_mask = lp_build_compare(gallivm, uint_bld->type, PIPE_FUNC_GEQUAL,
this_offset, num_consts);
result[c] = build_gather(bld, bld_broad, bld_broad->elem_type, consts_ptr, this_offset, overflow_mask, NULL);
}
}
}
static void
@ -4896,9 +4873,6 @@ visit_intrinsic(struct lp_build_nir_soa_context *bld,
case nir_intrinsic_load_ubo:
visit_load_ubo(bld, instr, result);
break;
case nir_intrinsic_load_push_constant:
visit_load_push_constant(bld, instr, result);
break;
case nir_intrinsic_load_ssbo:
visit_load_ssbo(bld, instr, result);
break;

View file

@ -179,6 +179,16 @@ lower_load_ubo(nir_builder *b, nir_intrinsic_instr *intrin, void *data_cb)
return true;
}
static void
lower_push_constant(nir_builder *b, nir_intrinsic_instr *intrin, void *data_cb)
{
nir_def *load = nir_load_ubo(b, intrin->def.num_components, intrin->def.bit_size,
nir_imm_int(b, 0), intrin->src[0].ssa,
.range = nir_intrinsic_range(intrin));
nir_def_rewrite_uses(&intrin->def, load);
nir_instr_remove(&intrin->instr);
}
static bool
lower_vri_instr(struct nir_builder *b, nir_instr *instr, void *data_cb)
{
@ -232,6 +242,10 @@ lower_vri_instr(struct nir_builder *b, nir_instr *instr, void *data_cb)
case nir_intrinsic_image_deref_samples:
lower_image_intrinsic(b, intrin, data_cb);
return true;
case nir_intrinsic_load_push_constant:
lower_push_constant(b, intrin, data_cb);
return true;
default:
return false;