spirv: Implement SPV_KHR_untyped_pointers

The untyped pointer types only have a storage class associated, and the
operations using them would carry the necessary "data type" information.

Untyped pointers themselves are identified by "vtn_type::pointed" being
NULL.  For the NIR lowering the operations will have explicit casts
before them when applicable and the nir_derefs representing untyped
pointers will use the "void" glsl_type.

Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36427>
This commit is contained in:
Caio Oliveira 2024-09-06 16:37:17 -07:00 committed by Marge Bot
parent 8eaf1dced0
commit dde6fa5728
4 changed files with 190 additions and 24 deletions

View file

@ -203,6 +203,7 @@ static const struct spirv_capabilities implemented_capabilities = {
.UniformDecoration = true,
.UniformTexelBufferArrayDynamicIndexingEXT = true,
.UniformTexelBufferArrayNonUniformIndexingEXT = true,
.UntypedPointersKHR = true,
.VariablePointers = true,
.VariablePointersStorageBuffer = true,
.Vector16 = true,
@ -1348,6 +1349,9 @@ const struct glsl_type *
vtn_type_get_nir_type(struct vtn_builder *b, struct vtn_type *type,
enum vtn_variable_mode mode)
{
if (type == NULL)
return glsl_void_type();
if (mode == vtn_variable_mode_atomic_counter) {
vtn_fail_if(glsl_without_array(type->type) != glsl_uint_type(),
"Variables in the AtomicCounter storage class should be "
@ -1462,6 +1466,7 @@ array_stride_decoration_cb(struct vtn_builder *b,
if (dec->decoration == SpvDecorationArrayStride) {
if (type->base_type == vtn_base_type_pointer &&
type->pointed != NULL &&
(type->pointed->block || type->pointed->buffer_block)) {
vtn_warn("A pointer to a structure decorated with *Block* or "
"*BufferBlock* must not have an *ArrayStride* decoration");
@ -2110,6 +2115,26 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
break;
}
case SpvOpTypeUntypedPointerKHR: {
SpvStorageClass storage_class = w[2];
val->type->base_type = vtn_base_type_pointer;
val->type->storage_class = storage_class;
/* For untyped pointers, storage class alone should be sufficient to
* identify the right variable_mode (and glsl_type). The special cases
* are either handling legacy stuff or classes not used with untyped
* pointers.
*/
enum vtn_variable_mode mode = vtn_storage_class_to_mode(
b, storage_class, NULL, NULL);
val->type->type = nir_address_format_to_glsl_type(
vtn_mode_to_address_format(b, mode));
vtn_foreach_decoration(b, val, array_stride_decoration_cb, NULL);
break;
}
case SpvOpTypePointer:
case SpvOpTypeForwardPointer: {
/* We can't blindly push the value because it might be a forward
@ -5868,6 +5893,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpTypeAccelerationStructureKHR:
case SpvOpTypeRayQueryKHR:
case SpvOpTypeCooperativeMatrixKHR:
case SpvOpTypeUntypedPointerKHR:
vtn_handle_type(b, opcode, w, count);
break;
@ -5889,6 +5915,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpUndef:
case SpvOpVariable:
case SpvOpConstantSampler:
case SpvOpUntypedVariableKHR:
vtn_handle_variables(b, opcode, w, count);
break;
@ -6029,10 +6056,26 @@ vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode,
switch (opcode) {
case SpvOpPtrDiff: {
/* OpPtrDiff returns the difference in number of elements (not byte offset). */
/* OpPtrDiff returns the difference in number of elements (not byte offset).
*
* SPV_KHR_untyped_pointers extension adds
*
* "The types of Operand 1 and Operand 2 must be the same
* OpTypePointer or OpTypeUntypedPointerKHR."
*/
unsigned elem_size, elem_align;
glsl_get_natural_size_align_bytes(type1->pointed->type,
&elem_size, &elem_align);
if (type1->pointed != NULL) {
vtn_assert(type2->pointed != NULL);
glsl_get_natural_size_align_bytes(type1->pointed->type,
&elem_size, &elem_align);
} else {
vtn_assert(type2->pointed == NULL);
/* If 'Operand 1' and 'Operand 2' are OpTypeUntypedPointerKHR,
* the array is interpreted as an array of 8-bit integers.
*/
elem_size = 1;
elem_align = 1;
}
def = nir_build_addr_isub(&b->nb,
vtn_get_nir_ssa(b, w[3]),
@ -6421,6 +6464,11 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpSubgroupBlockReadINTEL:
case SpvOpSubgroupBlockWriteINTEL:
case SpvOpConvertUToAccelerationStructureKHR:
case SpvOpUntypedAccessChainKHR:
case SpvOpUntypedPtrAccessChainKHR:
case SpvOpUntypedInBoundsAccessChainKHR:
case SpvOpUntypedInBoundsPtrAccessChainKHR:
case SpvOpUntypedArrayLengthKHR:
vtn_handle_variables(b, opcode, w, count);
break;

View file

@ -84,6 +84,18 @@ vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id)
return deref;
}
static struct vtn_pointer *
vtn_cast_pointer_to_byte_pointer(struct vtn_builder *b, struct vtn_pointer *p)
{
assert(!p->type->pointed);
struct vtn_type *t = vtn_zalloc(b, struct vtn_type);
t->base_type = vtn_base_type_scalar;
t->type = glsl_uint8_t_type();
t->length = 1;
return vtn_cast_pointer(b, p, t);
}
void
vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@ -94,6 +106,10 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
/* Untyped pointers are effectively used as byte pointers. */
if (!src->type->pointed)
src = vtn_cast_pointer_to_byte_pointer(b, src);
const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]);
nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32);
@ -116,6 +132,10 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer);
struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
/* Untyped pointers are effectively used as byte pointers. */
if (!dest->type->pointed)
dest = vtn_cast_pointer_to_byte_pointer(b, dest);
const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]);
nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32);

View file

@ -364,7 +364,9 @@ struct vtn_type {
/* Members for pointer types */
struct {
/* For pointers, the vtn_type of the object pointed to. */
/* For regular pointers, the vtn_type of the object pointed to;
* for untyped pointers it must be NULL.
*/
struct vtn_type *pointed;
/* Storage class for pointers */
@ -922,6 +924,10 @@ nir_def *
vtn_pointer_to_offset(struct vtn_builder *b, struct vtn_pointer *ptr,
nir_def **index_out);
struct vtn_pointer *
vtn_cast_pointer(struct vtn_builder *b, struct vtn_pointer *p,
struct vtn_type *pointed);
nir_deref_instr *
vtn_get_call_payload_for_location(struct vtn_builder *b, uint32_t location_id);

View file

@ -331,6 +331,7 @@ vtn_create_internal_pointer_type(struct vtn_builder *b,
t->pointed = pointed;
t->storage_class = original->storage_class;
t->type = original->type;
t->stride = original->stride;
return t;
}
@ -1958,6 +1959,10 @@ vtn_pointer_ssa_is_desc_index(struct vtn_builder *b,
if (ptr->mode == vtn_variable_mode_phys_ssbo)
return false;
/* Untyped pointers are never in desc_index form. */
if (ptr->type->pointed == NULL)
return false;
return vtn_pointer_is_external_block(b, ptr) &&
vtn_type_is_block_array(b, ptr->type->pointed);
}
@ -1990,8 +1995,10 @@ vtn_pointer_from_ssa(struct vtn_builder *b, nir_def *ssa,
{
vtn_assert(ptr_type->base_type == vtn_base_type_pointer);
const bool untyped = !ptr_type->pointed;
struct vtn_pointer *ptr = vtn_zalloc(b, struct vtn_pointer);
struct vtn_type *without_array =
struct vtn_type *without_array = untyped ? NULL :
vtn_type_without_array(ptr_type->pointed);
nir_variable_mode nir_mode;
@ -2639,6 +2646,42 @@ ptr_nonuniform_workaround_cb(struct vtn_builder *b, struct vtn_value *val,
}
}
struct vtn_pointer *
vtn_cast_pointer(struct vtn_builder *b, struct vtn_pointer *p,
struct vtn_type *pointed)
{
assert(pointed);
struct vtn_pointer *casted = vtn_zalloc(b, struct vtn_pointer);
*casted = *p;
casted->type = vtn_create_internal_pointer_type(b, p->type, pointed);
vtn_assert(pointed == casted->type->pointed);
if (p->deref) {
casted->deref = nir_build_deref_cast(&b->nb, &p->deref->def,
p->deref->modes,
pointed->type, 0);
} else if (p->desc_index != NULL) {
/* Nothing to do for descriptor index pointers. */
} else if (p->var != NULL) {
struct vtn_variable *var = p->var;
if (b->options->environment == NIR_SPIRV_VULKAN &&
vtn_pointer_is_external_block(b, casted)) {
casted->desc_index = vtn_variable_resource_index(b, var, NULL);
} else {
vtn_assert(var->var);
nir_deref_instr *deref = nir_build_deref_var(&b->nb, var->var);
casted->deref = nir_build_deref_cast(&b->nb, &deref->def,
deref->modes, pointed->type, 0);
}
} else {
vtn_fail("Invalid pointer");
}
return casted;
}
void
vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@ -2651,9 +2694,12 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
break;
}
case SpvOpVariable: {
case SpvOpVariable:
case SpvOpUntypedVariableKHR: {
const bool untyped = opcode == SpvOpUntypedVariableKHR;
struct vtn_type *ptr_type = vtn_get_type(b, w[1]);
struct vtn_type *data_type = ptr_type->pointed;
struct vtn_type *data_type = untyped ? vtn_get_type(b, w[4]) : ptr_type->pointed;
SpvStorageClass storage_class = w[3];
@ -2672,7 +2718,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
}
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_pointer);
struct vtn_value *initializer = count > 4 ? vtn_untyped_value(b, w[4]) : NULL;
const unsigned init_idx = untyped ? 5 : 4;
struct vtn_value *initializer =
count > init_idx ? vtn_untyped_value(b, w[init_idx]) : NULL;
vtn_create_variable(b, val, ptr_type, data_type, storage_class, initializer);
@ -2707,14 +2756,27 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
case SpvOpAccessChain:
case SpvOpPtrAccessChain:
case SpvOpInBoundsAccessChain:
case SpvOpInBoundsPtrAccessChain: {
case SpvOpInBoundsPtrAccessChain:
case SpvOpUntypedAccessChainKHR:
case SpvOpUntypedPtrAccessChainKHR:
case SpvOpUntypedInBoundsAccessChainKHR:
case SpvOpUntypedInBoundsPtrAccessChainKHR: {
bool ptr_as_array = opcode == SpvOpPtrAccessChain ||
opcode == SpvOpInBoundsPtrAccessChain;
opcode == SpvOpInBoundsPtrAccessChain ||
opcode == SpvOpUntypedPtrAccessChainKHR ||
opcode == SpvOpUntypedInBoundsPtrAccessChainKHR;
const bool untyped = opcode == SpvOpUntypedAccessChainKHR ||
opcode == SpvOpUntypedInBoundsAccessChainKHR ||
opcode == SpvOpUntypedPtrAccessChainKHR ||
opcode == SpvOpUntypedInBoundsPtrAccessChainKHR;
struct vtn_type *ptr_type = vtn_get_type(b, w[1]);
struct vtn_pointer *base = vtn_pointer(b, w[3]);
struct vtn_pointer *base =
untyped ? vtn_cast_pointer(b, vtn_pointer(b, w[4]), vtn_get_type(b, w[3]))
: vtn_pointer(b, w[3]);
unsigned first_idx = 4;
unsigned first_idx = untyped ? 5 : 4;
/* The SPIR-V spec says
*
@ -2723,6 +2785,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
* then the behavior is undefined."
*/
if (ptr_as_array &&
base->type->pointed &&
(base->type->pointed->block || base->type->pointed->buffer_block)) {
struct vtn_value *val = vtn_untyped_value(b, w[first_idx]);
@ -2759,7 +2822,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
idx++;
}
chain->in_bounds = (opcode == SpvOpInBoundsAccessChain || opcode == SpvOpInBoundsPtrAccessChain);
chain->in_bounds = opcode == SpvOpInBoundsAccessChain ||
opcode == SpvOpInBoundsPtrAccessChain ||
opcode == SpvOpUntypedInBoundsAccessChainKHR ||
opcode == SpvOpUntypedInBoundsPtrAccessChainKHR;
/* Workaround for https://gitlab.freedesktop.org/mesa/mesa/-/issues/3406 */
access |= base->access & ACCESS_NON_UNIFORM;
@ -2780,8 +2846,15 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
vtn_assert_types_equal(b, opcode, dest_val->type->pointed,
src_val->type->pointed);
/* At least one must be a regular (typed) pointer. */
vtn_assert(dest->type->pointed || src->type->pointed);
if (!dest->type->pointed)
dest = vtn_cast_pointer(b, dest, src->type->pointed);
else if (!src->type->pointed)
src = vtn_cast_pointer(b, src, dest->type->pointed);
vtn_assert_types_equal(b, opcode, src->type->pointed, dest->type->pointed);
unsigned idx = 3, dest_alignment, src_alignment;
SpvMemoryAccessMask dest_access, src_access;
@ -2844,7 +2917,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
vtn_assert_types_equal(b, opcode, res_type, src_val->type->pointed);
if (!src->type->pointed)
src = vtn_cast_pointer(b, src, res_type);
vtn_assert_types_equal(b, opcode, res_type, src->type->pointed);
unsigned idx = 4, alignment;
SpvMemoryAccessMask access;
@ -2863,6 +2939,9 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
struct vtn_value *src_val = vtn_untyped_value(b, w[2]);
if (!dest->type->pointed)
dest = vtn_cast_pointer(b, dest, src_val->type);
/* OpStore requires us to actually have a storage type */
vtn_fail_if(dest->type->pointed->type == NULL,
"Invalid destination type for OpStore");
@ -2886,7 +2965,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
break;
}
vtn_assert_types_equal(b, opcode, dest_val->type->pointed, src_val->type);
vtn_assert_types_equal(b, opcode, dest->type->pointed, src_val->type);
unsigned idx = 3, alignment;
SpvMemoryAccessMask access;
@ -2901,14 +2980,27 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
break;
}
case SpvOpArrayLength: {
struct vtn_pointer *ptr = vtn_pointer(b, w[3]);
const uint32_t field = w[4];
case SpvOpArrayLength:
case SpvOpUntypedArrayLengthKHR: {
const bool untyped = opcode == SpvOpUntypedArrayLengthKHR;
vtn_fail_if(ptr->type->pointed->base_type != vtn_base_type_struct,
unsigned idx = 3;
struct vtn_pointer *ptr;
struct vtn_type *struct_type;
if (untyped) {
struct_type = vtn_get_type(b, w[idx++]);
ptr = vtn_cast_pointer(b, vtn_pointer(b, w[idx++]), struct_type);
} else {
ptr = vtn_pointer(b, w[idx++]);
struct_type = ptr->type->pointed;
}
const uint32_t field = w[idx];
vtn_fail_if(struct_type->base_type != vtn_base_type_struct,
"OpArrayLength must take a pointer to a structure type");
vtn_fail_if(field != ptr->type->pointed->length - 1 ||
ptr->type->pointed->members[field]->base_type != vtn_base_type_array,
vtn_fail_if(field != struct_type->length - 1 ||
struct_type->members[field]->base_type != vtn_base_type_array,
"OpArrayLength must reference the last member of the "
"structure and that must be an array");
@ -2923,7 +3015,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
nir_def *array_length =
nir_deref_buffer_array_length(&b->nb, 32,
vtn_pointer_to_ssa(b, array),
.access=ptr->access | ptr->type->pointed->access);
.access=ptr->access | struct_type->access);
vtn_push_nir_ssa(b, w[2], array_length);
break;