zink: move bo load offset adjustment to compiler passes

Reviewed-by: Dave Airlie <airlied@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13484>
This commit is contained in:
Mike Blumenkrantz 2021-09-21 17:11:05 -04:00 committed by Marge Bot
parent 14f7eb9d4c
commit ee22cd619f
2 changed files with 35 additions and 10 deletions

View file

@ -1941,8 +1941,6 @@ emit_load_bo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
/* destination type for the load */
SpvId type = get_dest_uvec_type(ctx, &intr->dest);
/* an id of an array member in bytes */
SpvId uint_size = emit_uint_const(ctx, 32, MIN2(bit_size, 32) / 8);
/* we grab a single array member at a time, so it's a pointer to a uint */
SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
@ -1955,18 +1953,13 @@ emit_load_bo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
* uint base[array_size];
* };
*
* where 'array_size' is set as though every member of the ubo takes up a vec4,
* even if it's only a vec2 or a float.
*
* first, access 'base'
*/
SpvId member = emit_uint_const(ctx, 32, 0);
/* this is the offset (in bytes) that we're accessing:
/* this is the array member we're accessing:
* it may be a const value or it may be dynamic in the shader
*/
SpvId offset = get_src(ctx, &intr->src[1]);
/* calculate the byte offset in the array */
SpvId vec_offset = emit_binop(ctx, SpvOpUDiv, uint_type, offset, uint_size);
/* OpAccessChain takes an array of indices that drill into a hierarchy based on the type:
* index 0 is accessing 'base'
* index 1 is accessing 'base[index 1]'
@ -1976,7 +1969,7 @@ emit_load_bo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
* (composite|vector)_extract both take literals
*/
for (unsigned i = 0; i < num_components; i++) {
SpvId indices[2] = { member, vec_offset };
SpvId indices[2] = { member, offset };
SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
bo, indices,
ARRAY_SIZE(indices));
@ -1985,8 +1978,9 @@ emit_load_bo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
constituents[i] = emit_atomic(ctx, SpvOpAtomicLoad, uint_type, ptr, 0, 0);
else
constituents[i] = spirv_builder_emit_load(&ctx->builder, uint_type, ptr);
/* increment to the next member index for the next load */
vec_offset = emit_binop(ctx, SpvOpIAdd, uint_type, vec_offset, one);
offset = emit_binop(ctx, SpvOpIAdd, uint_type, offset, one);
}
/* if we're loading a 64bit value, we have to reassemble all the u32 values we've loaded into u64 values

View file

@ -637,6 +637,32 @@ decompose_attribs(nir_shader *nir, uint32_t decomposed_attrs, uint32_t decompose
return true;
}
static bool
rewrite_bo_access_instr(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
switch (intr->intrinsic) {
case nir_intrinsic_load_ssbo:
case nir_intrinsic_load_ubo:
case nir_intrinsic_load_ubo_vec4:
b->cursor = nir_before_instr(instr);
nir_instr_rewrite_src_ssa(instr, &intr->src[1], nir_udiv_imm(b, intr->src[1].ssa, MIN2(nir_dest_bit_size(intr->dest), 32) / 8));
return true;
case nir_intrinsic_store_ssbo:
default:
break;
}
return false;
}
static bool
rewrite_bo_access(nir_shader *shader)
{
return nir_shader_instructions_pass(shader, rewrite_bo_access_instr, nir_metadata_dominance, NULL);
}
static void
assign_producer_var_io(gl_shader_stage stage, nir_variable *var, unsigned *reserved, unsigned char *slot_map)
{
@ -872,6 +898,8 @@ zink_shader_compile(struct zink_screen *screen, struct zink_shader *zs, nir_shad
default: break;
}
}
if (screen->driconf.inline_uniforms)
NIR_PASS_V(nir, rewrite_bo_access);
if (inlined_uniforms) {
optimize_nir(nir);
@ -1384,6 +1412,9 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
nir->info.fs.color_is_dual_source ? 1 : 8);
NIR_PASS_V(nir, lower_64bit_vertex_attribs);
NIR_PASS_V(nir, unbreak_bos);
/* run in compile if there could be inlined uniforms */
if (!screen->driconf.inline_uniforms)
NIR_PASS_V(nir, rewrite_bo_access);
if (zink_debug & ZINK_DEBUG_NIR) {
fprintf(stderr, "NIR shader:\n---8<---\n");