microsoft/compiler: Emit const accesses as load_deref

There's a few changes in here that are very inter-related.

First, we stop lowering load_deref on shader_temp to load_ptr_dxil,
and just leave it as load_deref. In order for that to work, we need
the derefs to be in a shape that's acceptable to DXIL, so the only
current producer of shader_temp loads (the CLC frontend) needs to
run some lowering passes on them first.

The DXIL backend is augmented to just write out deref indices while
walking a deref chain, which will get combined in the load op into
a GEP instruction. For non-mesh/raytracing shaders, these are required
to be single-level scalar arrays, but the complexity here is preparation
for when we don't need to do that anymore.

Additionally, the const lookups are changed from using a hash table
to just putting an index on the variable.

All of this together is enough to enable the authored-forever-ago test
which uses indirect array access into a const packed struct. The
load_ptr_dxil handling didn't deal with packed structs / unaligned
accesses, but now that we're in a logical address space with derefs
instead of physical, there's no alignment to deal with anymore and
the fact that it's packed goes out the window.

This removes one custom DXIL intrinsic.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23173>
This commit is contained in:
Jesse Natalie 2023-05-19 08:57:03 -07:00 committed by Marge Bot
parent 572e02a3b7
commit f9b0382faf
5 changed files with 163 additions and 235 deletions

View file

@ -1269,8 +1269,6 @@ intrinsic("store_scratch_dxil", [1, 1])
load("shared_dxil", [1], [], [CAN_ELIMINATE])
# src[] = { index }.
load("scratch_dxil", [1], [], [CAN_ELIMINATE])
# src[] = { deref_var, offset }
load("ptr_dxil", [1, 1], [], [])
# src[] = { index, 16-byte-based-offset }
load("ubo_dxil", [1, 1], [], [CAN_ELIMINATE, CAN_REORDER])

View file

@ -885,7 +885,17 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
}
NIR_PASS_V(nir, nir_lower_memcpy);
// Attempt to preserve derefs to constants by moving them to shader_temp
NIR_PASS_V(nir, dxil_nir_lower_constant_to_temp);
// While inserting new var derefs for our "logical" addressing mode, temporarily
// switch the pointer size to 32-bit.
nir->info.cs.ptr_size = 32;
NIR_PASS_V(nir, nir_split_struct_vars, nir_var_shader_temp);
NIR_PASS_V(nir, dxil_nir_flatten_var_arrays, nir_var_shader_temp);
NIR_PASS_V(nir, dxil_nir_lower_var_bit_size, nir_var_shader_temp,
(supported_int_sizes & 16) ? 16 : 32, (supported_int_sizes & 64) ? 64 : 32);
nir->info.cs.ptr_size = 64;
NIR_PASS_V(nir, clc_lower_constant_to_ssbo, out_dxil->kernel, &uav_id);
NIR_PASS_V(nir, clc_lower_global_to_ssbo);
@ -895,7 +905,7 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
NIR_PASS_V(nir, dxil_nir_lower_deref_ssbo);
NIR_PASS_V(nir, dxil_nir_split_unaligned_loads_stores, nir_var_all);
NIR_PASS_V(nir, dxil_nir_split_unaligned_loads_stores, nir_var_all & ~nir_var_shader_temp);
assert(nir->info.cs.ptr_size == 64);
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ssbo,

View file

@ -2143,7 +2143,7 @@ TEST_F(ComputeTest, packed_struct_local)
}
}
TEST_F(ComputeTest, DISABLED_packed_struct_const)
TEST_F(ComputeTest, packed_struct_const)
{
#pragma pack(push, 1)
struct s { uint8_t uc; uint64_t ul; uint16_t us; };

View file

@ -67,72 +67,6 @@ load_comps_to_vec(nir_builder *b, unsigned src_bit_size,
return nir_vec(b, dst_comps, num_dst_comps);
}
static nir_ssa_def *
build_load_ptr_dxil(nir_builder *b, nir_deref_instr *deref, nir_ssa_def *idx)
{
return nir_load_ptr_dxil(b, 1, 32, &deref->dest.ssa, idx);
}
static bool
lower_load_deref(nir_builder *b, nir_intrinsic_instr *intr)
{
assert(intr->dest.is_ssa);
b->cursor = nir_before_instr(&intr->instr);
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
if (!nir_deref_mode_is(deref, nir_var_shader_temp))
return false;
nir_ssa_def *ptr = nir_u2u32(b, nir_build_deref_offset(b, deref, cl_type_size_align));
nir_ssa_def *offset = nir_iand(b, ptr, nir_inot(b, nir_imm_int(b, 3)));
assert(intr->dest.is_ssa);
unsigned num_components = nir_dest_num_components(intr->dest);
unsigned bit_size = nir_dest_bit_size(intr->dest);
unsigned load_size = MAX2(32, bit_size);
unsigned num_bits = num_components * bit_size;
nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
unsigned comp_idx = 0;
nir_deref_path path;
nir_deref_path_init(&path, deref, NULL);
nir_ssa_def *base_idx = nir_ishr(b, offset, nir_imm_int(b, 2 /* log2(32 / 8) */));
/* Split loads into 32-bit chunks */
for (unsigned i = 0; i < num_bits; i += load_size) {
unsigned subload_num_bits = MIN2(num_bits - i, load_size);
nir_ssa_def *idx = nir_iadd(b, base_idx, nir_imm_int(b, i / 32));
nir_ssa_def *vec32 = build_load_ptr_dxil(b, path.path[0], idx);
if (load_size == 64) {
idx = nir_iadd(b, idx, nir_imm_int(b, 1));
vec32 = nir_vec2(b, vec32,
build_load_ptr_dxil(b, path.path[0], idx));
}
/* If we have 2 bytes or less to load we need to adjust the u32 value so
* we can always extract the LSB.
*/
if (subload_num_bits <= 16) {
nir_ssa_def *shift = nir_imul(b, nir_iand(b, ptr, nir_imm_int(b, 3)),
nir_imm_int(b, 8));
vec32 = nir_ushr(b, vec32, shift);
}
/* And now comes the pack/unpack step to match the original type. */
nir_ssa_def *temp_vec = nir_extract_bits(b, &vec32, 1, 0, subload_num_bits / bit_size, bit_size);
for (unsigned comp = 0; comp < subload_num_bits / bit_size; ++comp, ++comp_idx)
comps[comp_idx] = nir_channel(b, temp_vec, comp);
}
nir_deref_path_finish(&path);
assert(comp_idx == num_components);
nir_ssa_def *result = nir_vec(b, comps, num_components);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
nir_instr_remove(&intr->instr);
return true;
}
static nir_ssa_def *
ubo_load_select_32b_comps(nir_builder *b, nir_ssa_def *vec32,
nir_ssa_def *offset, unsigned alignment)
@ -612,12 +546,6 @@ dxil_nir_lower_constant_to_temp(nir_shader *nir)
/* Change the variable mode. */
var->data.mode = nir_var_shader_temp;
/* Make sure the variable has a name.
* DXIL variables must have names.
*/
if (!var->name)
var->name = ralloc_asprintf(nir, "global_%d", exec_list_length(&nir->variables));
progress = true;
}
@ -1034,9 +962,6 @@ dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
switch (intr->intrinsic) {
case nir_intrinsic_load_deref:
progress |= lower_load_deref(&b, intr);
break;
case nir_intrinsic_load_shared:
case nir_intrinsic_load_scratch:
progress |= lower_32b_offset_load(&b, intr);

View file

@ -32,6 +32,7 @@
#include "dxil_signature.h"
#include "nir/nir_builder.h"
#include "nir_deref.h"
#include "util/ralloc.h"
#include "util/u_debug.h"
#include "util/u_dynarray.h"
@ -583,7 +584,7 @@ struct ntd_context {
const struct dxil_value *sharedvars;
const struct dxil_value *scratchvars;
struct hash_table *consts;
const struct dxil_value **consts;
nir_variable *ps_front_face;
nir_variable *system_value[SYSTEM_VALUE_MAX];
@ -1405,114 +1406,124 @@ emit_uav_var(struct ntd_context *ctx, nir_variable *var, unsigned count)
res_kind, name);
}
static void
var_fill_const_array_with_vector_or_scalar(struct ntd_context *ctx,
const struct nir_constant *c,
const struct glsl_type *type,
void *const_vals,
unsigned int offset)
static const struct dxil_value *
get_value_for_const(struct dxil_module *mod, nir_const_value *c, const struct dxil_type *type)
{
assert(glsl_type_is_vector_or_scalar(type));
unsigned int components = glsl_get_vector_elements(type);
unsigned bit_size = glsl_get_bit_size(type);
unsigned int increment = bit_size / 8;
for (unsigned int comp = 0; comp < components; comp++) {
uint8_t *dst = (uint8_t *)const_vals + offset;
switch (bit_size) {
case 64:
memcpy(dst, &c->values[comp].u64, sizeof(c->values[0].u64));
break;
case 32:
memcpy(dst, &c->values[comp].u32, sizeof(c->values[0].u32));
break;
case 16:
memcpy(dst, &c->values[comp].u16, sizeof(c->values[0].u16));
break;
case 8:
assert(glsl_base_type_is_integer(glsl_get_base_type(type)));
memcpy(dst, &c->values[comp].u8, sizeof(c->values[0].u8));
break;
default:
unreachable("unexpeted bit-size");
}
offset += increment;
if (type == mod->int1_type) return dxil_module_get_int1_const(mod, c->b);
if (type == mod->float32_type) return dxil_module_get_float_const(mod, c->f32);
if (type == mod->int32_type) return dxil_module_get_int32_const(mod, c->i32);
if (type == mod->int16_type) {
mod->feats.min_precision = true;
return dxil_module_get_int16_const(mod, c->i16);
}
if (type == mod->int64_type) {
mod->feats.int64_ops = true;
return dxil_module_get_int64_const(mod, c->i64);
}
if (type == mod->float16_type) {
mod->feats.min_precision = true;
return dxil_module_get_float16_const(mod, c->u16);
}
if (type == mod->float64_type) {
mod->feats.doubles = true;
return dxil_module_get_double_const(mod, c->f64);
}
unreachable("Invalid type");
}
static void
var_fill_const_array(struct ntd_context *ctx, const struct nir_constant *c,
const struct glsl_type *type, void *const_vals,
unsigned int offset)
static const struct dxil_type *
get_type_for_glsl_base_type(struct dxil_module *mod, enum glsl_base_type type)
{
assert(!glsl_type_is_interface(type));
uint32_t bit_size = glsl_base_type_bit_size(type);
if (nir_alu_type_get_base_type(nir_get_nir_type_for_glsl_base_type(type)) == nir_type_float)
return dxil_module_get_float_type(mod, bit_size);
return dxil_module_get_int_type(mod, bit_size);
}
static const struct dxil_type *
get_type_for_glsl_type(struct dxil_module *mod, const struct glsl_type *type)
{
if (glsl_type_is_scalar(type))
return get_type_for_glsl_base_type(mod, glsl_get_base_type(type));
if (glsl_type_is_vector(type))
return dxil_module_get_vector_type(mod, get_type_for_glsl_base_type(mod, glsl_get_base_type(type)),
glsl_get_vector_elements(type));
if (glsl_type_is_array(type))
return dxil_module_get_array_type(mod, get_type_for_glsl_type(mod, glsl_get_array_element(type)),
glsl_array_size(type));
assert(glsl_type_is_struct(type));
uint32_t size = glsl_get_length(type);
const struct dxil_type **fields = calloc(sizeof(const struct dxil_type *), size);
for (uint32_t i = 0; i < size; ++i)
fields[i] = get_type_for_glsl_type(mod, glsl_get_struct_field(type, i));
const struct dxil_type *ret = dxil_module_get_struct_type(mod, glsl_get_type_name(type), fields, size);
free((void *)fields);
return ret;
}
static const struct dxil_value *
get_value_for_const_aggregate(struct dxil_module *mod, nir_constant *c, const struct glsl_type *type)
{
const struct dxil_type *dxil_type = get_type_for_glsl_type(mod, type);
if (glsl_type_is_vector_or_scalar(type)) {
var_fill_const_array_with_vector_or_scalar(ctx, c, type,
const_vals,
offset);
} else if (glsl_type_is_array(type)) {
assert(!glsl_type_is_unsized_array(type));
const struct glsl_type *without = glsl_get_array_element(type);
unsigned stride = glsl_get_explicit_stride(type);
const struct dxil_type *element_type = get_type_for_glsl_base_type(mod, glsl_get_base_type(type));
const struct dxil_value *elements[NIR_MAX_VEC_COMPONENTS];
for (uint32_t i = 0; i < glsl_get_vector_elements(type); ++i)
elements[i] = get_value_for_const(mod, &c->values[i], element_type);
if (glsl_type_is_scalar(type))
return elements[0];
return dxil_module_get_vector_const(mod, dxil_type, elements);
}
for (unsigned elt = 0; elt < glsl_get_length(type); elt++) {
var_fill_const_array(ctx, c->elements[elt], without,
const_vals, offset);
offset += stride;
}
} else if (glsl_type_is_struct(type)) {
for (unsigned int elt = 0; elt < glsl_get_length(type); elt++) {
const struct glsl_type *elt_type = glsl_get_struct_field(type, elt);
unsigned field_offset = glsl_get_struct_field_offset(type, elt);
var_fill_const_array(ctx, c->elements[elt],
elt_type, const_vals,
offset + field_offset);
}
} else
unreachable("unknown GLSL type in var_fill_const_array");
uint32_t num_values = glsl_get_length(type);
assert(num_values == c->num_elements);
const struct dxil_value **values = calloc(sizeof(const struct dxil_value *), num_values);
const struct dxil_value *ret;
if (glsl_type_is_array(type)) {
const struct glsl_type *element_type = glsl_get_array_element(type);
for (uint32_t i = 0; i < num_values; ++i)
values[i] = get_value_for_const_aggregate(mod, c->elements[i], element_type);
ret = dxil_module_get_array_const(mod, dxil_type, values);
} else {
for (uint32_t i = 0; i < num_values; ++i)
values[i] = get_value_for_const_aggregate(mod, c->elements[i], glsl_get_struct_field(type, i));
ret = dxil_module_get_struct_const(mod, dxil_type, values);
}
free((void *)values);
return ret;
}
static bool
emit_global_consts(struct ntd_context *ctx)
{
uint32_t index = 0;
nir_foreach_variable_with_modes(var, ctx->shader, nir_var_shader_temp) {
assert(var->constant_initializer);
var->data.driver_location = index++;
}
unsigned int num_members = DIV_ROUND_UP(glsl_get_cl_size(var->type), 4);
uint32_t *const_ints = ralloc_array(ctx->ralloc_ctx, uint32_t, num_members);
var_fill_const_array(ctx, var->constant_initializer, var->type,
const_ints, 0);
const struct dxil_value **const_vals =
ralloc_array(ctx->ralloc_ctx, const struct dxil_value *, num_members);
if (!const_vals)
return false;
for (int i = 0; i < num_members; i++)
const_vals[i] = dxil_module_get_int32_const(&ctx->mod, const_ints[i]);
ctx->consts = ralloc_array(ctx->ralloc_ctx, const struct dxil_value *, index);
nir_foreach_variable_with_modes(var, ctx->shader, nir_var_shader_temp) {
if (!var->name)
var->name = ralloc_asprintf(var, "const_%d", var->data.driver_location);
const struct dxil_type *elt_type = dxil_module_get_int_type(&ctx->mod, 32);
if (!elt_type)
return false;
const struct dxil_type *type =
dxil_module_get_array_type(&ctx->mod, elt_type, num_members);
if (!type)
return false;
const struct dxil_value *agg_vals =
dxil_module_get_array_const(&ctx->mod, type, const_vals);
get_value_for_const_aggregate(&ctx->mod, var->constant_initializer, var->type);
if (!agg_vals)
return false;
const struct dxil_value *gvar = dxil_add_global_ptr_var(&ctx->mod, var->name, type,
DXIL_AS_DEFAULT, 4,
const struct dxil_value *gvar = dxil_add_global_ptr_var(&ctx->mod, var->name,
dxil_value_get_type(agg_vals),
DXIL_AS_DEFAULT, 16,
agg_vals);
if (!gvar)
return false;
if (!_mesa_hash_table_insert(ctx->consts, var, (void *)gvar))
return false;
ctx->consts[var->data.driver_location] = gvar;
}
return true;
@ -3273,24 +3284,6 @@ get_int32_undef(struct dxil_module *m)
return dxil_module_get_undef(m, int32_type);
}
static const struct dxil_value *
emit_gep_for_index(struct ntd_context *ctx, const nir_variable *var,
const struct dxil_value *index)
{
assert(var->data.mode == nir_var_shader_temp);
struct hash_entry *he = _mesa_hash_table_search(ctx->consts, var);
assert(he != NULL);
const struct dxil_value *ptr = he->data;
const struct dxil_value *zero = dxil_module_get_int32_const(&ctx->mod, 0);
if (!zero)
return NULL;
const struct dxil_value *ops[] = { ptr, zero, index };
return dxil_emit_gep_inbounds(&ctx->mod, ops, ARRAY_SIZE(ops));
}
static const struct dxil_value *
get_resource_handle(struct ntd_context *ctx, nir_src *src, enum dxil_resource_class class,
enum dxil_resource_kind kind)
@ -4013,23 +4006,37 @@ emit_load_interpolated_input(struct ntd_context *ctx, nir_intrinsic_instr *intr)
return true;
}
static bool
emit_load_ptr(struct ntd_context *ctx, nir_intrinsic_instr *intr)
static const struct dxil_value *
deref_to_gep(struct ntd_context *ctx, nir_deref_instr *deref)
{
struct nir_variable *var =
nir_deref_instr_get_variable(nir_src_as_deref(intr->src[0]));
nir_deref_path path;
nir_deref_path_init(&path, deref, ctx->ralloc_ctx);
assert(path.path[0]->deref_type == nir_deref_type_var);
uint32_t count = 0;
while (path.path[count])
++count;
const struct dxil_value **gep_indices = ralloc_array(ctx->ralloc_ctx,
const struct dxil_value *,
count + 1);
nir_variable *var = path.path[0]->var;
gep_indices[0] = ctx->consts[var->data.driver_location];
const struct dxil_value *index =
get_src(ctx, &intr->src[1], 0, nir_type_uint);
if (!index)
return false;
for (uint32_t i = 0; i < count; ++i)
gep_indices[i + 1] = get_src_ssa(ctx, &path.path[i]->dest.ssa, 0);
const struct dxil_value *ptr = emit_gep_for_index(ctx, var, index);
return dxil_emit_gep_inbounds(&ctx->mod, gep_indices, count + 1);
}
static bool
emit_load_deref(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
const struct dxil_value *ptr = deref_to_gep(ctx, nir_src_as_deref(intr->src[0]));
if (!ptr)
return false;
const struct dxil_value *retval =
dxil_emit_load(&ctx->mod, ptr, 4, false);
dxil_emit_load(&ctx->mod, ptr, nir_dest_bit_size(intr->dest) / 8, false);
if (!retval)
return false;
@ -5014,8 +5021,8 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
return emit_store_shared(ctx, intr);
case nir_intrinsic_store_scratch_dxil:
return emit_store_scratch(ctx, intr);
case nir_intrinsic_load_ptr_dxil:
return emit_load_ptr(ctx, intr);
case nir_intrinsic_load_deref:
return emit_load_deref(ctx, intr);
case nir_intrinsic_load_ubo:
return emit_load_ubo(ctx, intr);
case nir_intrinsic_load_ubo_dxil:
@ -5171,31 +5178,6 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
}
}
static const struct dxil_value *
get_value_for_const(struct dxil_module *mod, nir_const_value *c, const struct dxil_type *type)
{
if (type == mod->int1_type) return dxil_module_get_int1_const(mod, c->b);
if (type == mod->float32_type) return dxil_module_get_float_const(mod, c->f32);
if (type == mod->int32_type) return dxil_module_get_int32_const(mod, c->i32);
if (type == mod->int16_type) {
mod->feats.min_precision = true;
return dxil_module_get_int16_const(mod, c->i16);
}
if (type == mod->int64_type) {
mod->feats.int64_ops = true;
return dxil_module_get_int64_const(mod, c->i64);
}
if (type == mod->float16_type) {
mod->feats.min_precision = true;
return dxil_module_get_float16_const(mod, c->u16);
}
if (type == mod->float64_type) {
mod->feats.doubles = true;
return dxil_module_get_double_const(mod, c->f64);
}
unreachable("Invalid type");
}
static const struct dxil_type *
dxil_type_for_const(struct ntd_context *ctx, nir_ssa_def *def)
{
@ -5218,8 +5200,36 @@ emit_load_const(struct ntd_context *ctx, nir_load_const_instr *load_const)
static bool
emit_deref(struct ntd_context* ctx, nir_deref_instr* instr)
{
assert(instr->deref_type == nir_deref_type_var ||
instr->deref_type == nir_deref_type_array);
/* There's two possible reasons we might be walking through derefs:
* 1. Computing an index to be used for a texture/sampler/image binding, which
* can only do array indexing and should compute the indices along the way with
* array-of-array sizes.
* 2. Storing an index to be used in a GEP for access to a variable.
*/
nir_variable *var = nir_deref_instr_get_variable(instr);
assert(var);
bool is_aoa_size =
glsl_type_is_sampler(glsl_without_array(var->type)) ||
glsl_type_is_image(glsl_without_array(var->type)) ||
glsl_type_is_texture(glsl_without_array(var->type));
if (!is_aoa_size) {
/* Just store the values, we'll use these to build a GEP in the load or store */
switch (instr->deref_type) {
case nir_deref_type_var:
store_dest(ctx, &instr->dest, 0, dxil_module_get_int_const(&ctx->mod, 0, instr->dest.ssa.bit_size));
return true;
case nir_deref_type_array:
store_dest(ctx, &instr->dest, 0, get_src(ctx, &instr->arr.index, 0, nir_type_int));
return true;
case nir_deref_type_struct:
store_dest(ctx, &instr->dest, 0, dxil_module_get_int_const(&ctx->mod, instr->strct.index, 32));
return true;
default:
unreachable("Other deref types not supported");
}
}
/* In the CL environment, there's nothing to emit. Any references to
* derefs will emit the necessary logic to handle scratch/shared GEP addressing
@ -5227,18 +5237,6 @@ emit_deref(struct ntd_context* ctx, nir_deref_instr* instr)
if (ctx->opts->environment == DXIL_ENVIRONMENT_CL)
return true;
/* In the Vulkan environment, we don't have cached handles for textures or
* samplers, so let's use the opportunity of walking through the derefs to
* emit those.
*/
nir_variable *var = nir_deref_instr_get_variable(instr);
assert(var);
if (!glsl_type_is_sampler(glsl_without_array(var->type)) &&
!glsl_type_is_image(glsl_without_array(var->type)) &&
!glsl_type_is_texture(glsl_without_array(var->type)))
return true;
const struct glsl_type *type = instr->type;
const struct dxil_value *binding;
unsigned binding_val = ctx->opts->environment == DXIL_ENVIRONMENT_GL ?
@ -6274,9 +6272,6 @@ emit_module(struct ntd_context *ctx, const struct nir_to_dxil_options *opts)
if (!emit_globals(ctx, opts->num_kernel_globals))
return false;
ctx->consts = _mesa_pointer_hash_table_create(ctx->ralloc_ctx);
if (!ctx->consts)
return false;
if (!emit_global_consts(ctx))
return false;
} else if (ctx->opts->environment == DXIL_ENVIRONMENT_VULKAN) {