From dde6fa57286e59e6ecdc73f38d3dcb92ab0f5e6f Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Fri, 6 Sep 2024 16:37:17 -0700 Subject: [PATCH] spirv: Implement SPV_KHR_untyped_pointers The untyped pointer types only have a storage class associated, and the operations using them would carry the necessary "data type" information. Untyped pointers themselves are identified by "vtn_type::pointed" being NULL. For the NIR lowering the operations will have explicit casts before them when applicable and the nir_derefs representing untyped pointers will use the "void" glsl_type. Reviewed-by: Faith Ekstrand Part-of: --- src/compiler/spirv/spirv_to_nir.c | 54 +++++++++++- src/compiler/spirv/vtn_cmat.c | 20 +++++ src/compiler/spirv/vtn_private.h | 8 +- src/compiler/spirv/vtn_variables.c | 132 ++++++++++++++++++++++++----- 4 files changed, 190 insertions(+), 24 deletions(-) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index abf302665ac..c7031ef8e2c 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -203,6 +203,7 @@ static const struct spirv_capabilities implemented_capabilities = { .UniformDecoration = true, .UniformTexelBufferArrayDynamicIndexingEXT = true, .UniformTexelBufferArrayNonUniformIndexingEXT = true, + .UntypedPointersKHR = true, .VariablePointers = true, .VariablePointersStorageBuffer = true, .Vector16 = true, @@ -1348,6 +1349,9 @@ const struct glsl_type * vtn_type_get_nir_type(struct vtn_builder *b, struct vtn_type *type, enum vtn_variable_mode mode) { + if (type == NULL) + return glsl_void_type(); + if (mode == vtn_variable_mode_atomic_counter) { vtn_fail_if(glsl_without_array(type->type) != glsl_uint_type(), "Variables in the AtomicCounter storage class should be " @@ -1462,6 +1466,7 @@ array_stride_decoration_cb(struct vtn_builder *b, if (dec->decoration == SpvDecorationArrayStride) { if (type->base_type == vtn_base_type_pointer && + type->pointed != NULL && (type->pointed->block || type->pointed->buffer_block)) { vtn_warn("A pointer to a structure decorated with *Block* or " "*BufferBlock* must not have an *ArrayStride* decoration"); @@ -2110,6 +2115,26 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode, break; } + case SpvOpTypeUntypedPointerKHR: { + SpvStorageClass storage_class = w[2]; + val->type->base_type = vtn_base_type_pointer; + val->type->storage_class = storage_class; + + /* For untyped pointers, storage class alone should be sufficient to + * identify the right variable_mode (and glsl_type). The special cases + * are either handling legacy stuff or classes not used with untyped + * pointers. + */ + enum vtn_variable_mode mode = vtn_storage_class_to_mode( + b, storage_class, NULL, NULL); + val->type->type = nir_address_format_to_glsl_type( + vtn_mode_to_address_format(b, mode)); + + vtn_foreach_decoration(b, val, array_stride_decoration_cb, NULL); + + break; + } + case SpvOpTypePointer: case SpvOpTypeForwardPointer: { /* We can't blindly push the value because it might be a forward @@ -5868,6 +5893,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpTypeAccelerationStructureKHR: case SpvOpTypeRayQueryKHR: case SpvOpTypeCooperativeMatrixKHR: + case SpvOpTypeUntypedPointerKHR: vtn_handle_type(b, opcode, w, count); break; @@ -5889,6 +5915,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpUndef: case SpvOpVariable: case SpvOpConstantSampler: + case SpvOpUntypedVariableKHR: vtn_handle_variables(b, opcode, w, count); break; @@ -6029,10 +6056,26 @@ vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpPtrDiff: { - /* OpPtrDiff returns the difference in number of elements (not byte offset). */ + /* OpPtrDiff returns the difference in number of elements (not byte offset). + * + * SPV_KHR_untyped_pointers extension adds + * + * "The types of Operand 1 and Operand 2 must be the same + * OpTypePointer or OpTypeUntypedPointerKHR." + */ unsigned elem_size, elem_align; - glsl_get_natural_size_align_bytes(type1->pointed->type, - &elem_size, &elem_align); + if (type1->pointed != NULL) { + vtn_assert(type2->pointed != NULL); + glsl_get_natural_size_align_bytes(type1->pointed->type, + &elem_size, &elem_align); + } else { + vtn_assert(type2->pointed == NULL); + /* If 'Operand 1' and 'Operand 2' are OpTypeUntypedPointerKHR, + * the array is interpreted as an array of 8-bit integers. + */ + elem_size = 1; + elem_align = 1; + } def = nir_build_addr_isub(&b->nb, vtn_get_nir_ssa(b, w[3]), @@ -6421,6 +6464,11 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpSubgroupBlockReadINTEL: case SpvOpSubgroupBlockWriteINTEL: case SpvOpConvertUToAccelerationStructureKHR: + case SpvOpUntypedAccessChainKHR: + case SpvOpUntypedPtrAccessChainKHR: + case SpvOpUntypedInBoundsAccessChainKHR: + case SpvOpUntypedInBoundsPtrAccessChainKHR: + case SpvOpUntypedArrayLengthKHR: vtn_handle_variables(b, opcode, w, count); break; diff --git a/src/compiler/spirv/vtn_cmat.c b/src/compiler/spirv/vtn_cmat.c index f0ab5ba95c2..eb66d893ffb 100644 --- a/src/compiler/spirv/vtn_cmat.c +++ b/src/compiler/spirv/vtn_cmat.c @@ -84,6 +84,18 @@ vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id) return deref; } +static struct vtn_pointer * +vtn_cast_pointer_to_byte_pointer(struct vtn_builder *b, struct vtn_pointer *p) +{ + assert(!p->type->pointed); + + struct vtn_type *t = vtn_zalloc(b, struct vtn_type); + t->base_type = vtn_base_type_scalar; + t->type = glsl_uint8_t_type(); + t->length = 1; + return vtn_cast_pointer(b, p, t); +} + void vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -94,6 +106,10 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode, struct vtn_pointer *src = vtn_value_to_pointer(b, src_val); struct vtn_type *dst_type = vtn_get_type(b, w[1]); + /* Untyped pointers are effectively used as byte pointers. */ + if (!src->type->pointed) + src = vtn_cast_pointer_to_byte_pointer(b, src); + const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]); nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32); @@ -116,6 +132,10 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer); struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val); + /* Untyped pointers are effectively used as byte pointers. */ + if (!dest->type->pointed) + dest = vtn_cast_pointer_to_byte_pointer(b, dest); + const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]); nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32); diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 702e2e6c87b..6a026794b9e 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -364,7 +364,9 @@ struct vtn_type { /* Members for pointer types */ struct { - /* For pointers, the vtn_type of the object pointed to. */ + /* For regular pointers, the vtn_type of the object pointed to; + * for untyped pointers it must be NULL. + */ struct vtn_type *pointed; /* Storage class for pointers */ @@ -922,6 +924,10 @@ nir_def * vtn_pointer_to_offset(struct vtn_builder *b, struct vtn_pointer *ptr, nir_def **index_out); +struct vtn_pointer * +vtn_cast_pointer(struct vtn_builder *b, struct vtn_pointer *p, + struct vtn_type *pointed); + nir_deref_instr * vtn_get_call_payload_for_location(struct vtn_builder *b, uint32_t location_id); diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index 3fd157ddf26..2ae7ca3ccf7 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -331,6 +331,7 @@ vtn_create_internal_pointer_type(struct vtn_builder *b, t->pointed = pointed; t->storage_class = original->storage_class; t->type = original->type; + t->stride = original->stride; return t; } @@ -1958,6 +1959,10 @@ vtn_pointer_ssa_is_desc_index(struct vtn_builder *b, if (ptr->mode == vtn_variable_mode_phys_ssbo) return false; + /* Untyped pointers are never in desc_index form. */ + if (ptr->type->pointed == NULL) + return false; + return vtn_pointer_is_external_block(b, ptr) && vtn_type_is_block_array(b, ptr->type->pointed); } @@ -1990,8 +1995,10 @@ vtn_pointer_from_ssa(struct vtn_builder *b, nir_def *ssa, { vtn_assert(ptr_type->base_type == vtn_base_type_pointer); + const bool untyped = !ptr_type->pointed; + struct vtn_pointer *ptr = vtn_zalloc(b, struct vtn_pointer); - struct vtn_type *without_array = + struct vtn_type *without_array = untyped ? NULL : vtn_type_without_array(ptr_type->pointed); nir_variable_mode nir_mode; @@ -2639,6 +2646,42 @@ ptr_nonuniform_workaround_cb(struct vtn_builder *b, struct vtn_value *val, } } +struct vtn_pointer * +vtn_cast_pointer(struct vtn_builder *b, struct vtn_pointer *p, + struct vtn_type *pointed) +{ + assert(pointed); + + struct vtn_pointer *casted = vtn_zalloc(b, struct vtn_pointer); + *casted = *p; + casted->type = vtn_create_internal_pointer_type(b, p->type, pointed); + vtn_assert(pointed == casted->type->pointed); + + if (p->deref) { + casted->deref = nir_build_deref_cast(&b->nb, &p->deref->def, + p->deref->modes, + pointed->type, 0); + } else if (p->desc_index != NULL) { + /* Nothing to do for descriptor index pointers. */ + } else if (p->var != NULL) { + struct vtn_variable *var = p->var; + + if (b->options->environment == NIR_SPIRV_VULKAN && + vtn_pointer_is_external_block(b, casted)) { + casted->desc_index = vtn_variable_resource_index(b, var, NULL); + } else { + vtn_assert(var->var); + nir_deref_instr *deref = nir_build_deref_var(&b->nb, var->var); + casted->deref = nir_build_deref_cast(&b->nb, &deref->def, + deref->modes, pointed->type, 0); + } + } else { + vtn_fail("Invalid pointer"); + } + + return casted; +} + void vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -2651,9 +2694,12 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, break; } - case SpvOpVariable: { + case SpvOpVariable: + case SpvOpUntypedVariableKHR: { + const bool untyped = opcode == SpvOpUntypedVariableKHR; + struct vtn_type *ptr_type = vtn_get_type(b, w[1]); - struct vtn_type *data_type = ptr_type->pointed; + struct vtn_type *data_type = untyped ? vtn_get_type(b, w[4]) : ptr_type->pointed; SpvStorageClass storage_class = w[3]; @@ -2672,7 +2718,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, } struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_pointer); - struct vtn_value *initializer = count > 4 ? vtn_untyped_value(b, w[4]) : NULL; + + const unsigned init_idx = untyped ? 5 : 4; + struct vtn_value *initializer = + count > init_idx ? vtn_untyped_value(b, w[init_idx]) : NULL; vtn_create_variable(b, val, ptr_type, data_type, storage_class, initializer); @@ -2707,14 +2756,27 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, case SpvOpAccessChain: case SpvOpPtrAccessChain: case SpvOpInBoundsAccessChain: - case SpvOpInBoundsPtrAccessChain: { + case SpvOpInBoundsPtrAccessChain: + case SpvOpUntypedAccessChainKHR: + case SpvOpUntypedPtrAccessChainKHR: + case SpvOpUntypedInBoundsAccessChainKHR: + case SpvOpUntypedInBoundsPtrAccessChainKHR: { bool ptr_as_array = opcode == SpvOpPtrAccessChain || - opcode == SpvOpInBoundsPtrAccessChain; + opcode == SpvOpInBoundsPtrAccessChain || + opcode == SpvOpUntypedPtrAccessChainKHR || + opcode == SpvOpUntypedInBoundsPtrAccessChainKHR; + + const bool untyped = opcode == SpvOpUntypedAccessChainKHR || + opcode == SpvOpUntypedInBoundsAccessChainKHR || + opcode == SpvOpUntypedPtrAccessChainKHR || + opcode == SpvOpUntypedInBoundsPtrAccessChainKHR; struct vtn_type *ptr_type = vtn_get_type(b, w[1]); - struct vtn_pointer *base = vtn_pointer(b, w[3]); + struct vtn_pointer *base = + untyped ? vtn_cast_pointer(b, vtn_pointer(b, w[4]), vtn_get_type(b, w[3])) + : vtn_pointer(b, w[3]); - unsigned first_idx = 4; + unsigned first_idx = untyped ? 5 : 4; /* The SPIR-V spec says * @@ -2723,6 +2785,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, * then the behavior is undefined." */ if (ptr_as_array && + base->type->pointed && (base->type->pointed->block || base->type->pointed->buffer_block)) { struct vtn_value *val = vtn_untyped_value(b, w[first_idx]); @@ -2759,7 +2822,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, idx++; } - chain->in_bounds = (opcode == SpvOpInBoundsAccessChain || opcode == SpvOpInBoundsPtrAccessChain); + chain->in_bounds = opcode == SpvOpInBoundsAccessChain || + opcode == SpvOpInBoundsPtrAccessChain || + opcode == SpvOpUntypedInBoundsAccessChainKHR || + opcode == SpvOpUntypedInBoundsPtrAccessChainKHR; /* Workaround for https://gitlab.freedesktop.org/mesa/mesa/-/issues/3406 */ access |= base->access & ACCESS_NON_UNIFORM; @@ -2780,8 +2846,15 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val); struct vtn_pointer *src = vtn_value_to_pointer(b, src_val); - vtn_assert_types_equal(b, opcode, dest_val->type->pointed, - src_val->type->pointed); + /* At least one must be a regular (typed) pointer. */ + vtn_assert(dest->type->pointed || src->type->pointed); + + if (!dest->type->pointed) + dest = vtn_cast_pointer(b, dest, src->type->pointed); + else if (!src->type->pointed) + src = vtn_cast_pointer(b, src, dest->type->pointed); + + vtn_assert_types_equal(b, opcode, src->type->pointed, dest->type->pointed); unsigned idx = 3, dest_alignment, src_alignment; SpvMemoryAccessMask dest_access, src_access; @@ -2844,7 +2917,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer); struct vtn_pointer *src = vtn_value_to_pointer(b, src_val); - vtn_assert_types_equal(b, opcode, res_type, src_val->type->pointed); + if (!src->type->pointed) + src = vtn_cast_pointer(b, src, res_type); + + vtn_assert_types_equal(b, opcode, res_type, src->type->pointed); unsigned idx = 4, alignment; SpvMemoryAccessMask access; @@ -2863,6 +2939,9 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val); struct vtn_value *src_val = vtn_untyped_value(b, w[2]); + if (!dest->type->pointed) + dest = vtn_cast_pointer(b, dest, src_val->type); + /* OpStore requires us to actually have a storage type */ vtn_fail_if(dest->type->pointed->type == NULL, "Invalid destination type for OpStore"); @@ -2886,7 +2965,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, break; } - vtn_assert_types_equal(b, opcode, dest_val->type->pointed, src_val->type); + vtn_assert_types_equal(b, opcode, dest->type->pointed, src_val->type); unsigned idx = 3, alignment; SpvMemoryAccessMask access; @@ -2901,14 +2980,27 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, break; } - case SpvOpArrayLength: { - struct vtn_pointer *ptr = vtn_pointer(b, w[3]); - const uint32_t field = w[4]; + case SpvOpArrayLength: + case SpvOpUntypedArrayLengthKHR: { + const bool untyped = opcode == SpvOpUntypedArrayLengthKHR; - vtn_fail_if(ptr->type->pointed->base_type != vtn_base_type_struct, + unsigned idx = 3; + struct vtn_pointer *ptr; + struct vtn_type *struct_type; + if (untyped) { + struct_type = vtn_get_type(b, w[idx++]); + ptr = vtn_cast_pointer(b, vtn_pointer(b, w[idx++]), struct_type); + } else { + ptr = vtn_pointer(b, w[idx++]); + struct_type = ptr->type->pointed; + } + + const uint32_t field = w[idx]; + + vtn_fail_if(struct_type->base_type != vtn_base_type_struct, "OpArrayLength must take a pointer to a structure type"); - vtn_fail_if(field != ptr->type->pointed->length - 1 || - ptr->type->pointed->members[field]->base_type != vtn_base_type_array, + vtn_fail_if(field != struct_type->length - 1 || + struct_type->members[field]->base_type != vtn_base_type_array, "OpArrayLength must reference the last member of the " "structure and that must be an array"); @@ -2923,7 +3015,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, nir_def *array_length = nir_deref_buffer_array_length(&b->nb, 32, vtn_pointer_to_ssa(b, array), - .access=ptr->access | ptr->type->pointed->access); + .access=ptr->access | struct_type->access); vtn_push_nir_ssa(b, w[2], array_length); break;