gallivm/nir/soa: Select more IO to gather/scatter intrinsics

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32963>
This commit is contained in:
Konstantin Seurer 2024-12-07 14:59:48 +01:00 committed by Marge Bot
parent 29a4886cc8
commit 2208379628

View file

@ -1278,7 +1278,7 @@ mem_access_base_pointer(struct lp_build_nir_soa_context *bld,
static void emit_load_mem(struct lp_build_nir_soa_context *bld,
unsigned nc,
unsigned bit_size,
bool index_and_offset_are_uniform,
bool index_uniform, bool offset_uniform,
bool payload,
LLVMValueRef index,
LLVMValueRef offset,
@ -1300,7 +1300,7 @@ static void emit_load_mem(struct lp_build_nir_soa_context *bld,
* though, since those don't do bounds checking and we could use an invalid
* offset if exec_mask == 0.
*/
if (index_and_offset_are_uniform && (invocation_0_must_be_active(bld) || index)) {
if (index_uniform && offset_uniform && (invocation_0_must_be_active(bld) || index)) {
LLVMValueRef ssbo_limit;
LLVMValueRef first_active = first_active_invocation(bld);
LLVMValueRef mem_ptr = mem_access_base_pointer(bld, load_bld, bit_size, payload, index,
@ -1329,14 +1329,39 @@ static void emit_load_mem(struct lp_build_nir_soa_context *bld,
return;
}
/* although the index is dynamically uniform that doesn't count if exec mask isn't set, so read the one-by-one */
LLVMValueRef gather_mask = mask_vec_with_helpers(bld);
LLVMValueRef gather_cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, gather_mask, uint_bld->zero, "");
if (index_uniform) {
LLVMValueRef limit = NULL;
LLVMValueRef mem_ptr = mem_access_base_pointer(bld, load_bld, bit_size, payload, index,
first_active_invocation(bld), &limit);
if (limit) {
limit = lp_build_broadcast_scalar(uint_bld, limit);
} else {
if (payload)
limit = lp_build_const_int_vec(gallivm, uint_bld->type, bld->shader->info.task_payload_size >> shift_val);
else
limit = lp_build_const_int_vec(gallivm, uint_bld->type, bld->shader->info.shared_size >> shift_val);
}
for (unsigned c = 0; c < nc; c++) {
LLVMValueRef channel_offset = LLVMBuildAdd(builder, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c), "channel_offset");
LLVMValueRef oob_cmp = LLVMBuildICmp(builder, LLVMIntULT, channel_offset, limit, "oob_cmp");
LLVMValueRef channel_ptr = LLVMBuildGEP2(builder, load_bld->elem_type, mem_ptr, &channel_offset, 1, "channel_ptr");
LLVMValueRef mask = LLVMBuildAnd(builder, gather_cond, oob_cmp, "mask");
outval[c] = lp_build_masked_gather(gallivm, load_bld->type.length, load_bld->type.width, load_bld->vec_type,
channel_ptr, mask);
}
return;
}
LLVMValueRef result[NIR_MAX_VEC_COMPONENTS];
for (unsigned c = 0; c < nc; c++)
result[c] = lp_build_alloca(gallivm, load_bld->vec_type, "");
LLVMValueRef gather_mask = mask_vec_with_helpers(bld);
LLVMValueRef gather_cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, gather_mask, uint_bld->zero, "");
for (unsigned i = 0; i < uint_bld->type.length; i++) {
LLVMValueRef counter = lp_build_const_int32(gallivm, i);
LLVMValueRef element_gather_cond = LLVMBuildExtractElement(gallivm->builder, gather_cond, counter, "");
@ -1388,7 +1413,7 @@ static void emit_store_mem(struct lp_build_nir_soa_context *bld,
unsigned writemask,
unsigned nc,
unsigned bit_size,
bool index_and_offset_are_uniform,
bool index_uniform, bool offset_uniform,
bool payload,
LLVMValueRef index,
LLVMValueRef offset,
@ -1411,7 +1436,7 @@ static void emit_store_mem(struct lp_build_nir_soa_context *bld,
* don't use first_active_uniform(), since we aren't guaranteed that there is
* actually an active invocation.
*/
if (index_and_offset_are_uniform && invocation_0_must_be_active(bld)) {
if (index_uniform && offset_uniform && invocation_0_must_be_active(bld)) {
cond = LLVMBuildBitCast(builder, cond, LLVMIntTypeInContext(gallivm->context, bld->base.type.length), "exec_bitmask");
cond = LLVMBuildZExt(builder, cond, bld->int_bld.elem_type, "");
@ -1449,6 +1474,36 @@ static void emit_store_mem(struct lp_build_nir_soa_context *bld,
return;
}
if (index_uniform) {
LLVMValueRef limit = NULL;
LLVMValueRef mem_ptr = mem_access_base_pointer(bld, store_bld, bit_size, payload, index,
first_active_invocation(bld), &limit);
if (limit) {
limit = lp_build_broadcast_scalar(uint_bld, limit);
} else {
if (payload)
limit = lp_build_const_int_vec(gallivm, uint_bld->type, bld->shader->info.task_payload_size >> shift_val);
else
limit = lp_build_const_int_vec(gallivm, uint_bld->type, bld->shader->info.shared_size >> shift_val);
}
for (unsigned c = 0; c < nc; c++) {
if (!(writemask & (1u << c)))
continue;
LLVMValueRef channel_offset = LLVMBuildAdd(builder, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c), "channel_offset");
LLVMValueRef oob_cmp = LLVMBuildICmp(builder, LLVMIntULT, channel_offset, limit, "oob_cmp");
LLVMValueRef channel_ptr = LLVMBuildGEP2(builder, store_bld->elem_type, mem_ptr, &channel_offset, 1, "channel_ptr");
LLVMValueRef mask = LLVMBuildAnd(builder, cond, oob_cmp, "mask");
LLVMValueRef value = (nc == 1) ? dst : LLVMBuildExtractValue(builder, dst, c, "");
value = LLVMBuildBitCast(gallivm->builder, value, store_bld->vec_type, "");
lp_build_masked_scatter(gallivm, store_bld->type.length, store_bld->type.width, channel_ptr, value, mask);
}
return;
}
for (unsigned i = 0; i < uint_bld->type.length; i++) {
LLVMValueRef counter = lp_build_const_int32(gallivm, i);
LLVMValueRef loop_cond = LLVMBuildExtractElement(gallivm->builder, cond, counter, "");
@ -3872,10 +3927,15 @@ visit_load_reg(struct lp_build_nir_soa_context *bld,
LLVMValueRef max_index = lp_build_const_int_vec(gallivm, uint_bld->type, num_array_elems - 1);
indirect_val = LLVMBuildAdd(builder, indirect_val, indir_src, "");
indirect_val = lp_build_min(uint_bld, indirect_val, max_index);
reg_storage = LLVMBuildBitCast(builder, reg_storage, LLVMPointerType(reg_bld->elem_type, 0), "");
reg_storage = LLVMBuildBitCast(builder, reg_storage, LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), "");
for (unsigned i = 0; i < nc; i++) {
LLVMValueRef indirect_offset = get_soa_array_offsets(uint_bld, indirect_val, nc, i, true);
result[i] = build_gather(bld, reg_bld, reg_bld->elem_type, reg_storage, indirect_offset, NULL, NULL);
indirect_offset = LLVMBuildMul(
builder, indirect_offset,
lp_build_const_int_vec(gallivm, uint_bld->type, reg_bld->type.width / 8), "indirect_offset");
result[i] = lp_build_gather(
gallivm, reg_bld->type.length, reg_bld->type.width, lp_elem_type(reg_bld->type),
true, reg_storage, indirect_offset, false);
}
} else {
for (unsigned i = 0; i < nc; i++) {
@ -4100,12 +4160,12 @@ visit_load_ssbo(struct lp_build_nir_soa_context *bld,
idx = cast_type(bld, idx, nir_type_uint, 32);
LLVMValueRef offset = get_src(bld, instr->src[1]);
bool index_and_offset_are_uniform =
nir_src_is_always_uniform(instr->src[0]) &&
nir_src_is_always_uniform(instr->src[1]);
emit_load_mem(bld, instr->def.num_components,
instr->def.bit_size,
index_and_offset_are_uniform, false, idx, offset, result);
nir_src_is_always_uniform(instr->src[0]),
nir_src_is_always_uniform(instr->src[1]),
false, idx, offset, result);
}
static void
@ -4119,14 +4179,13 @@ visit_store_ssbo(struct lp_build_nir_soa_context *bld,
idx = cast_type(bld, idx, nir_type_uint, 32);
LLVMValueRef offset = get_src(bld, instr->src[2]);
bool index_and_offset_are_uniform =
nir_src_is_always_uniform(instr->src[1]) &&
nir_src_is_always_uniform(instr->src[2]);
int writemask = instr->const_index[0];
int nc = nir_src_num_components(instr->src[0]);
int bitsize = nir_src_bit_size(instr->src[0]);
emit_store_mem(bld, writemask, nc, bitsize,
index_and_offset_are_uniform, false, idx, offset, val);
nir_src_is_always_uniform(instr->src[1]),
nir_src_is_always_uniform(instr->src[2]),
false, idx, offset, val);
}
static void
@ -4435,7 +4494,7 @@ visit_shared_load(struct lp_build_nir_soa_context *bld,
LLVMValueRef offset = get_src(bld, instr->src[0]);
bool offset_is_uniform = nir_src_is_always_uniform(instr->src[0]);
emit_load_mem(bld, instr->def.num_components,
instr->def.bit_size,
instr->def.bit_size, true,
offset_is_uniform, false, NULL, offset, result);
}
@ -4449,7 +4508,7 @@ visit_shared_store(struct lp_build_nir_soa_context *bld,
int writemask = instr->const_index[1];
int nc = nir_src_num_components(instr->src[0]);
int bitsize = nir_src_bit_size(instr->src[0]);
emit_store_mem(bld, writemask, nc, bitsize,
emit_store_mem(bld, writemask, nc, bitsize, true,
offset_is_uniform, false, NULL, offset, val);
}
@ -4873,7 +4932,7 @@ visit_payload_load(struct lp_build_nir_soa_context *bld,
LLVMValueRef offset = get_src(bld, instr->src[0]);
bool offset_is_uniform = nir_src_is_always_uniform(instr->src[0]);
emit_load_mem(bld, instr->def.num_components,
instr->def.bit_size,
instr->def.bit_size, true,
offset_is_uniform, true, NULL, offset, result);
}
@ -4887,7 +4946,7 @@ visit_payload_store(struct lp_build_nir_soa_context *bld,
int writemask = instr->const_index[1];
int nc = nir_src_num_components(instr->src[0]);
int bitsize = nir_src_bit_size(instr->src[0]);
emit_store_mem(bld, writemask, nc, bitsize,
emit_store_mem(bld, writemask, nc, bitsize, true,
offset_is_uniform, true, NULL, offset, val);
}