spirv: add support for SPV_KHR_constant_data

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40722>
This commit is contained in:
Samuel Pitoiset 2026-03-19 13:20:11 +01:00 committed by Marge Bot
parent 378974588e
commit 34b8ce948a

View file

@ -51,6 +51,7 @@ static const struct spirv_capabilities implemented_capabilities = {
.ClipDistance = true,
.ComputeDerivativeGroupLinearKHR = true,
.ComputeDerivativeGroupQuadsKHR = true,
.ConstantDataKHR = true,
.CooperativeMatrixKHR = true,
.CooperativeMatrixConversionsNV = true,
.CooperativeMatrixReductionsNV = true,
@ -1856,6 +1857,10 @@ type_decoration_cb(struct vtn_builder *b,
/* User semantic decorations can safely be ignored by the driver. */
break;
case SpvDecorationUTFEncodedKHR:
/* Ignored. */
break;
default:
vtn_fail_with_decoration("Unhandled decoration", dec->decoration);
}
@ -2565,6 +2570,45 @@ spec_constant_decoration_cb(struct vtn_builder *b, UNUSED struct vtn_value *val,
}
}
struct spec_constant_data_cb_data {
nir_constant **elems;
unsigned num_elems;
unsigned elem_byte_size;
};
static void
spec_constant_data_decoration_cb(struct vtn_builder *b, UNUSED struct vtn_value *val,
ASSERTED int member,
const struct vtn_decoration *dec, void *data)
{
vtn_assert(member == -1);
if (dec->decoration != SpvDecorationSpecId)
return;
if (!b->specialization)
return;
const struct spec_constant_data_cb_data *ctx = data;
for (unsigned i = 0; i < b->specialization->num_entries; i++) {
const struct nir_spirv_specialization_entry *entry =
&b->specialization->entries[i];
if (entry->id != dec->operands[0])
continue;
if (entry->size == 0)
continue;
for (unsigned e = 0; e < ctx->num_elems; e++) {
unsigned offset = e * ctx->elem_byte_size;
assert(offset + ctx->elem_byte_size <= entry->size);
memcpy(&ctx->elems[e]->values[0], entry->data + offset, ctx->elem_byte_size);
}
return;
}
}
static void
handle_workgroup_size_decoration_cb(struct vtn_builder *b,
struct vtn_value *val,
@ -3072,6 +3116,40 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
break;
}
case SpvOpConstantDataKHR:
case SpvOpSpecConstantDataKHR: {
struct vtn_type *elem_type = val->type->array_element;
vtn_fail_if(val->type->base_type != vtn_base_type_array,
"Result type must be an array");
vtn_fail_if(!glsl_type_is_integer(elem_type->type),
"Element type must be a scalar integer");
unsigned num_elems = val->type->length;
unsigned bit_size = glsl_get_bit_size(elem_type->type);
unsigned byte_size = bit_size / 8;
const uint8_t *data = (const uint8_t *)&w[3];
nir_constant **elems = ralloc_array(b, nir_constant *, num_elems);
for (unsigned i = 0; i < num_elems; i++) {
nir_constant *c = rzalloc(b, nir_constant);
memcpy(&c->values[0], data + i * byte_size, byte_size);
elems[i] = c;
}
if (opcode == SpvOpSpecConstantDataKHR) {
struct spec_constant_data_cb_data ctx = { elems, num_elems, byte_size };
vtn_foreach_decoration(b, val, spec_constant_data_decoration_cb, &ctx);
}
ralloc_steal(val->constant, elems);
val->constant->num_elements = num_elems;
val->constant->elements = elems;
val->is_undef_constant = false;
break;
}
default:
vtn_fail_with_opcode("Unhandled opcode", opcode);
}
@ -6186,6 +6264,8 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpConstantComposite:
case SpvOpConstantCompositeReplicateEXT:
case SpvOpConstantNull:
case SpvOpConstantDataKHR:
case SpvOpSpecConstantDataKHR:
case SpvOpSpecConstantTrue:
case SpvOpSpecConstantFalse:
case SpvOpSpecConstant: