microsoft/compiler: Support lowering SSBO accesses to 16bit vectors

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21029>
This commit is contained in:
Jesse Natalie 2023-02-01 10:58:18 -08:00 committed by Marge Bot
parent 0f56fc09d9
commit 58e7acb0e2
5 changed files with 89 additions and 74 deletions

View file

@ -132,7 +132,8 @@ compile_nir(struct d3d12_context *ctx, struct d3d12_shader_selector *sel,
NIR_PASS_V(nir, d3d12_lower_load_draw_params);
NIR_PASS_V(nir, d3d12_lower_load_patch_vertices_in);
NIR_PASS_V(nir, d3d12_lower_state_vars, shader);
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil);
const struct dxil_nir_lower_loads_stores_options loads_stores_options = {};
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options);
NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
NIR_PASS_V(nir, dxil_nir_lower_double_math);

View file

@ -952,7 +952,10 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ubo,
nir_address_format_32bit_index_offset);
NIR_PASS_V(nir, clc_nir_lower_system_values, work_properties_var);
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil);
const struct dxil_nir_lower_loads_stores_options loads_stores_options = {
.use_16bit_ssbo = false,
};
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options);
NIR_PASS_V(nir, dxil_nir_opt_alu_deref_srcs);
NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all);

View file

@ -38,30 +38,31 @@ cl_type_size_align(const struct glsl_type *type, unsigned *size,
}
static nir_ssa_def *
load_comps_to_vec32(nir_builder *b, unsigned src_bit_size,
nir_ssa_def **src_comps, unsigned num_src_comps)
load_comps_to_vec(nir_builder *b, unsigned src_bit_size,
nir_ssa_def **src_comps, unsigned num_src_comps,
unsigned dst_bit_size)
{
if (src_bit_size == 32)
if (src_bit_size == dst_bit_size)
return nir_vec(b, src_comps, num_src_comps);
else if (src_bit_size > 32)
return nir_extract_bits(b, src_comps, num_src_comps, 0, src_bit_size * num_src_comps / 32, 32);
else if (src_bit_size > dst_bit_size)
return nir_extract_bits(b, src_comps, num_src_comps, 0, src_bit_size * num_src_comps / dst_bit_size, dst_bit_size);
unsigned num_vec32comps = DIV_ROUND_UP(num_src_comps * src_bit_size, 32);
unsigned comps_per32b = 32 / src_bit_size;
nir_ssa_def *vec32comps[4];
unsigned num_dst_comps = DIV_ROUND_UP(num_src_comps * src_bit_size, dst_bit_size);
unsigned comps_per_dst = dst_bit_size / src_bit_size;
nir_ssa_def *dst_comps[4];
for (unsigned i = 0; i < num_vec32comps; i++) {
unsigned src_offs = i * comps_per32b;
for (unsigned i = 0; i < num_dst_comps; i++) {
unsigned src_offs = i * comps_per_dst;
vec32comps[i] = nir_u2u32(b, src_comps[src_offs]);
for (unsigned j = 1; j < comps_per32b && src_offs + j < num_src_comps; j++) {
nir_ssa_def *tmp = nir_ishl(b, nir_u2u32(b, src_comps[src_offs + j]),
dst_comps[i] = nir_u2uN(b, src_comps[src_offs], dst_bit_size);
for (unsigned j = 1; j < comps_per_dst && src_offs + j < num_src_comps; j++) {
nir_ssa_def *tmp = nir_ishl(b, nir_u2uN(b, src_comps[src_offs + j], dst_bit_size),
nir_imm_int(b, j * src_bit_size));
vec32comps[i] = nir_ior(b, vec32comps[i], tmp);
dst_comps[i] = nir_ior(b, dst_comps[i], tmp);
}
}
return nir_vec(b, vec32comps, num_vec32comps);
return nir_vec(b, dst_comps, num_dst_comps);
}
static nir_ssa_def *
@ -214,7 +215,7 @@ build_load_ubo_dxil(nir_builder *b, nir_ssa_def *buffer,
}
static bool
lower_load_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
lower_load_ssbo(nir_builder *b, nir_intrinsic_instr *intr, unsigned min_bit_size)
{
assert(intr->dest.is_ssa);
assert(intr->src[0].is_ssa);
@ -222,46 +223,46 @@ lower_load_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
b->cursor = nir_before_instr(&intr->instr);
unsigned src_bit_size = nir_dest_bit_size(intr->dest);
unsigned store_bit_size = CLAMP(src_bit_size, min_bit_size, 32);
unsigned offset_mask = store_bit_size / 8 - 1;
nir_ssa_def *buffer = intr->src[0].ssa;
nir_ssa_def *offset = nir_iand(b, intr->src[1].ssa, nir_imm_int(b, ~3));
nir_ssa_def *offset = nir_iand(b, intr->src[1].ssa, nir_imm_int(b, ~offset_mask));
enum gl_access_qualifier access = nir_intrinsic_access(intr);
unsigned bit_size = nir_dest_bit_size(intr->dest);
unsigned num_components = nir_dest_num_components(intr->dest);
unsigned num_bits = num_components * bit_size;
unsigned num_bits = num_components * src_bit_size;
nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
unsigned comp_idx = 0;
/* We need to split loads in 16byte chunks because that's the optimal
* granularity of bufferLoad(). Minimum alignment is 4byte, which saves
* from us from extra complexity to extract >= 32 bit components.
/* We need to split loads in 4-component chunks because that's the optimal
* granularity of bufferLoad(). Minimum alignment is 2-byte.
*/
for (unsigned i = 0; i < num_bits; i += 4 * 32) {
/* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
* load.
*/
unsigned subload_num_bits = MIN2(num_bits - i, 4 * 32);
for (unsigned i = 0; i < num_bits; i += 4 * store_bit_size) {
/* For each 4-component chunk (or smaller) we generate a N-bit ssbo vec load. */
unsigned subload_num_bits = MIN2(num_bits - i, 4 * store_bit_size);
/* The number of components to store depends on the number of bytes. */
nir_ssa_def *vec32 =
nir_load_ssbo(b, DIV_ROUND_UP(subload_num_bits, 32), 32,
nir_ssa_def *result =
nir_load_ssbo(b, DIV_ROUND_UP(subload_num_bits, store_bit_size), store_bit_size,
buffer, nir_iadd(b, offset, nir_imm_int(b, i / 8)),
.align_mul = 4,
.align_mul = store_bit_size / 8,
.align_offset = 0,
.access = access);
/* If we have 2 bytes or less to load we need to adjust the u32 value so
/* If we have an unaligned load we need to adjust the result value so
* we can always extract the LSB.
*/
if (subload_num_bits <= 16) {
nir_ssa_def *shift = nir_imul(b, nir_iand(b, intr->src[1].ssa, nir_imm_int(b, 3)),
if (nir_intrinsic_align(intr) < store_bit_size / 8) {
nir_ssa_def *shift = nir_imul(b, nir_iand(b, intr->src[1].ssa, nir_imm_int(b, offset_mask)),
nir_imm_int(b, 8));
vec32 = nir_ushr(b, vec32, shift);
result = nir_ushr(b, result, nir_u2uN(b, shift, store_bit_size));
}
/* And now comes the pack/unpack step to match the original type. */
nir_ssa_def *temp_vec = nir_extract_bits(b, &vec32, 1, 0, subload_num_bits / bit_size, bit_size);
for (unsigned comp = 0; comp < subload_num_bits / bit_size; ++comp, ++comp_idx)
nir_ssa_def *temp_vec = nir_extract_bits(b, &result, 1, 0, subload_num_bits / src_bit_size, src_bit_size);
for (unsigned comp = 0; comp < subload_num_bits / src_bit_size; ++comp, ++comp_idx)
comps[comp_idx] = nir_channel(b, temp_vec, comp);
}
@ -273,7 +274,7 @@ lower_load_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
}
static bool
lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr, unsigned min_bit_size)
{
b->cursor = nir_before_instr(&intr->instr);
@ -283,11 +284,14 @@ lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
nir_ssa_def *val = intr->src[0].ssa;
nir_ssa_def *buffer = intr->src[1].ssa;
nir_ssa_def *offset = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, ~3));
unsigned bit_size = val->bit_size;
unsigned src_bit_size = val->bit_size;
unsigned store_bit_size = CLAMP(src_bit_size, min_bit_size, 32);
unsigned num_components = val->num_components;
unsigned num_bits = num_components * bit_size;
unsigned num_bits = num_components * src_bit_size;
unsigned offset_mask = store_bit_size / 8 - 1;
nir_ssa_def *offset = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, ~offset_mask));
nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS] = { 0 };
unsigned comp_idx = 0;
@ -297,75 +301,74 @@ lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
if (write_mask & (1 << i))
comps[i] = nir_channel(b, val, i);
/* We split stores in 16byte chunks because that's the optimal granularity
* of bufferStore(). Minimum alignment is 4byte, which saves from us from
* extra complexity to store >= 32 bit components.
*/
/* We split stores in 4-component chunks because that's the optimal granularity
* of bufferStore(). Minimum alignment is 2-byte. */
unsigned bit_offset = 0;
while (true) {
/* Skip over holes in the write mask */
while (comp_idx < num_components && comps[comp_idx] == NULL) {
comp_idx++;
bit_offset += bit_size;
bit_offset += src_bit_size;
}
if (comp_idx >= num_components)
break;
/* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
/* For each 4-component chunk (or smaller) we generate a ssbo vec
* store. If a component is skipped by the write mask, do a smaller
* sub-store
*/
unsigned num_src_comps_stored = 0, substore_num_bits = 0;
while(num_src_comps_stored + comp_idx < num_components &&
substore_num_bits + bit_offset < num_bits &&
substore_num_bits < 4 * 32 &&
substore_num_bits < 4 * store_bit_size &&
comps[comp_idx + num_src_comps_stored]) {
++num_src_comps_stored;
substore_num_bits += bit_size;
substore_num_bits += src_bit_size;
}
if (substore_num_bits == 48) {
/* Split this into two, one unmasked store of the first 32 bits,
if (substore_num_bits > store_bit_size &&
substore_num_bits % store_bit_size != 0) {
/* Split this into two, one unmasked store of the first bits,
* and then the second loop iteration will handle a masked store
* for the other 16. */
* for the rest. */
assert(num_src_comps_stored == 3);
--num_src_comps_stored;
substore_num_bits = 32;
substore_num_bits = store_bit_size;
}
nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, bit_offset / 8));
nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
num_src_comps_stored);
nir_ssa_def *store_vec = load_comps_to_vec(b, src_bit_size, &comps[comp_idx],
num_src_comps_stored, store_bit_size);
nir_intrinsic_instr *store;
if (substore_num_bits < 32) {
nir_ssa_def *mask = nir_imm_int(b, (1 << substore_num_bits) - 1);
if (substore_num_bits < store_bit_size) {
nir_ssa_def *mask = nir_imm_intN_t(b, (1 << substore_num_bits) - 1, store_bit_size);
/* If we have small alignments we need to place them correctly in the u32 component. */
if (nir_intrinsic_align(intr) <= 2) {
nir_ssa_def *pos = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, 3));
nir_ssa_def *shift = nir_imul_imm(b, pos, 8);
/* If we have small alignments we need to place them correctly in the component. */
if (nir_intrinsic_align(intr) <= store_bit_size / 8) {
nir_ssa_def *pos = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, offset_mask));
nir_ssa_def *shift = nir_u2uN(b, nir_imul_imm(b, pos, 8), store_bit_size);
vec32 = nir_ishl(b, vec32, shift);
store_vec = nir_ishl(b, store_vec, shift);
mask = nir_ishl(b, mask, shift);
}
store = nir_intrinsic_instr_create(b->shader,
nir_intrinsic_store_ssbo_masked_dxil);
store->src[0] = nir_src_for_ssa(vec32);
store->src[0] = nir_src_for_ssa(store_vec);
store->src[1] = nir_src_for_ssa(nir_inot(b, mask));
store->src[2] = nir_src_for_ssa(buffer);
store->src[3] = nir_src_for_ssa(local_offset);
} else {
store = nir_intrinsic_instr_create(b->shader,
nir_intrinsic_store_ssbo);
store->src[0] = nir_src_for_ssa(vec32);
store->src[0] = nir_src_for_ssa(store_vec);
store->src[1] = nir_src_for_ssa(buffer);
store->src[2] = nir_src_for_ssa(local_offset);
nir_intrinsic_set_align(store, 4, 0);
nir_intrinsic_set_align(store, store_bit_size / 8, 0);
}
/* The number of components to store depends on the number of bits. */
store->num_components = DIV_ROUND_UP(substore_num_bits, 32);
store->num_components = DIV_ROUND_UP(substore_num_bits, store_bit_size);
nir_builder_instr_insert(b, &store->instr);
comp_idx += num_src_comps_stored;
bit_offset += substore_num_bits;
@ -528,8 +531,8 @@ lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr)
/* For each 4byte chunk (or smaller) we generate a 32bit scalar store. */
unsigned substore_num_bits = MIN2(num_bits - i, step);
nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, i / 8));
nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
substore_num_bits / bit_size);
nir_ssa_def *vec32 = load_comps_to_vec(b, bit_size, &comps[comp_idx],
substore_num_bits / bit_size, 32);
nir_ssa_def *index = nir_ushr(b, local_offset, nir_imm_int(b, 2));
/* For anything less than 32bits we need to use the masked version of the
@ -687,7 +690,8 @@ lower_load_ubo(nir_builder *b, nir_intrinsic_instr *intr)
}
bool
dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir)
dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
const struct dxil_nir_lower_loads_stores_options *options)
{
bool progress = false;
@ -714,7 +718,7 @@ dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir)
progress |= lower_32b_offset_load(&b, intr);
break;
case nir_intrinsic_load_ssbo:
progress |= lower_load_ssbo(&b, intr);
progress |= lower_load_ssbo(&b, intr, options->use_16bit_ssbo ? 16 : 32);
break;
case nir_intrinsic_load_ubo:
progress |= lower_load_ubo(&b, intr);
@ -724,7 +728,7 @@ dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir)
progress |= lower_32b_offset_store(&b, intr);
break;
case nir_intrinsic_store_ssbo:
progress |= lower_store_ssbo(&b, intr);
progress |= lower_store_ssbo(&b, intr, options->use_16bit_ssbo ? 16 : 32);
break;
default:
break;

View file

@ -37,7 +37,11 @@ bool dxil_nir_lower_16bit_conv(nir_shader *shader);
bool dxil_nir_lower_x2b(nir_shader *shader);
bool dxil_nir_lower_fquantize2f16(nir_shader *shader);
bool dxil_nir_lower_ubo_to_temp(nir_shader *shader);
bool dxil_nir_lower_loads_stores_to_dxil(nir_shader *shader);
struct dxil_nir_lower_loads_stores_options {
bool use_16bit_ssbo;
};
bool dxil_nir_lower_loads_stores_to_dxil(nir_shader *shader,
const struct dxil_nir_lower_loads_stores_options *options);
bool dxil_nir_lower_atomics_to_dxil(nir_shader *shader);
bool dxil_nir_lower_deref_ssbo(nir_shader *shader);
bool dxil_nir_opt_alu_deref_srcs(nir_shader *shader);

View file

@ -1082,7 +1082,10 @@ dxil_spirv_nir_passes(nir_shader *nir,
NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
NIR_PASS_V(nir, dxil_nir_split_clip_cull_distance);
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil);
const struct dxil_nir_lower_loads_stores_options loads_stores_options = {
.use_16bit_ssbo = false,
};
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options);
NIR_PASS_V(nir, dxil_nir_split_typed_samplers);
NIR_PASS_V(nir, dxil_nir_lower_ubo_array_one_to_static);
NIR_PASS_V(nir, nir_opt_dce);