diff --git a/src/microsoft/clc/clc_compiler.c b/src/microsoft/clc/clc_compiler.c index d1127ca8b00..3aa16304909 100644 --- a/src/microsoft/clc/clc_compiler.c +++ b/src/microsoft/clc/clc_compiler.c @@ -645,6 +645,8 @@ clc_parse_spirv(const struct clc_binary *in_spirv, if (!clc_spirv_get_kernels_info(in_spirv, &out_data->kernels, &out_data->num_kernels, + &out_data->spec_constants, + &out_data->num_spec_constants, logger)) return false; diff --git a/src/microsoft/clc/clc_compiler.h b/src/microsoft/clc/clc_compiler.h index 0cf508deb48..cc957d0df83 100644 --- a/src/microsoft/clc/clc_compiler.h +++ b/src/microsoft/clc/clc_compiler.h @@ -108,9 +108,32 @@ struct clc_kernel_info { enum clc_vec_hint_type vec_hint_type; }; +enum clc_spec_constant_type { + CLC_SPEC_CONSTANT_UNKNOWN, + CLC_SPEC_CONSTANT_BOOL, + CLC_SPEC_CONSTANT_FLOAT, + CLC_SPEC_CONSTANT_DOUBLE, + CLC_SPEC_CONSTANT_INT8, + CLC_SPEC_CONSTANT_UINT8, + CLC_SPEC_CONSTANT_INT16, + CLC_SPEC_CONSTANT_UINT16, + CLC_SPEC_CONSTANT_INT32, + CLC_SPEC_CONSTANT_UINT32, + CLC_SPEC_CONSTANT_INT64, + CLC_SPEC_CONSTANT_UINT64, +}; + +struct clc_parsed_spec_constant { + uint32_t id; + enum clc_spec_constant_type type; +}; + struct clc_parsed_spirv { const struct clc_kernel_info *kernels; unsigned num_kernels; + + const struct clc_parsed_spec_constant *spec_constants; + unsigned num_spec_constants; }; #define CLC_MAX_CONSTS 32 diff --git a/src/microsoft/clc/clc_helpers.cpp b/src/microsoft/clc/clc_helpers.cpp index 62ae00282ae..68591bcf1a0 100644 --- a/src/microsoft/clc/clc_helpers.cpp +++ b/src/microsoft/clc/clc_helpers.cpp @@ -293,6 +293,18 @@ public: assert(op->type == SPV_OPERAND_TYPE_DECORATION); decoration = ins->words[op->offset]; + if (decoration == SpvDecorationSpecId) { + uint32_t spec_id = ins->words[ins->operands[2].offset]; + for (auto &c : specConstants) { + if (c.second.id == spec_id) { + assert(c.first == id); + return; + } + } + specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id }); + return; + } + for (auto &kernel : kernels) { for (auto &arg : kernel.args) { if (arg.id == id) { @@ -414,6 +426,104 @@ public: } } + void parseLiteralType(const spv_parsed_instruction_t *ins) + { + uint32_t typeId = ins->words[ins->operands[0].offset]; + auto& literalType = literalTypes[typeId]; + switch (ins->opcode) { + case SpvOpTypeBool: + literalType = CLC_SPEC_CONSTANT_BOOL; + break; + case SpvOpTypeFloat: { + uint32_t sizeInBits = ins->words[ins->operands[1].offset]; + switch (sizeInBits) { + case 32: + literalType = CLC_SPEC_CONSTANT_FLOAT; + break; + case 64: + literalType = CLC_SPEC_CONSTANT_DOUBLE; + break; + case 16: + /* Can't be used for a spec constant */ + break; + default: + unreachable("Unexpected float bit size"); + } + break; + } + case SpvOpTypeInt: { + uint32_t sizeInBits = ins->words[ins->operands[1].offset]; + bool isSigned = ins->words[ins->operands[2].offset]; + if (isSigned) { + switch (sizeInBits) { + case 8: + literalType = CLC_SPEC_CONSTANT_INT8; + break; + case 16: + literalType = CLC_SPEC_CONSTANT_INT16; + break; + case 32: + literalType = CLC_SPEC_CONSTANT_INT32; + break; + case 64: + literalType = CLC_SPEC_CONSTANT_INT64; + break; + default: + unreachable("Unexpected int bit size"); + } + } else { + switch (sizeInBits) { + case 8: + literalType = CLC_SPEC_CONSTANT_UINT8; + break; + case 16: + literalType = CLC_SPEC_CONSTANT_UINT16; + break; + case 32: + literalType = CLC_SPEC_CONSTANT_UINT32; + break; + case 64: + literalType = CLC_SPEC_CONSTANT_UINT64; + break; + default: + unreachable("Unexpected uint bit size"); + } + } + break; + } + default: + unreachable("Unexpected type opcode"); + } + } + + void parseSpecConstant(const spv_parsed_instruction_t *ins) + { + uint32_t id = ins->result_id; + for (auto& c : specConstants) { + if (c.first == id) { + auto& data = c.second; + switch (ins->opcode) { + case SpvOpSpecConstant: { + uint32_t typeId = ins->words[ins->operands[0].offset]; + + // This better be an integer or float type + auto typeIter = literalTypes.find(typeId); + assert(typeIter != literalTypes.end()); + + data.type = typeIter->second; + break; + } + case SpvOpSpecConstantFalse: + case SpvOpSpecConstantTrue: + data.type = CLC_SPEC_CONSTANT_BOOL; + break; + default: + unreachable("Composites and Ops are not directly specializable."); + } + } + } + } + static spv_result_t parseInstruction(void *data, const spv_parsed_instruction_t *ins) { @@ -454,6 +564,16 @@ public: case SpvOpExecutionMode: parser->parseExecutionMode(ins); break; + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + parser->parseLiteralType(ins); + break; + case SpvOpSpecConstant: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstantTrue: + parser->parseSpecConstant(ins); + break; default: break; } @@ -504,6 +624,8 @@ public: } std::vector kernels; + std::vector> specConstants; + std::map literalTypes; std::map> decorationGroups; SPIRVKernelInfo *curKernel; spv_context ctx; @@ -513,9 +635,12 @@ bool clc_spirv_get_kernels_info(const struct clc_binary *spvbin, const struct clc_kernel_info **out_kernels, unsigned *num_kernels, + const struct clc_parsed_spec_constant **out_spec_constants, + unsigned *num_spec_constants, const struct clc_logger *logger) { struct clc_kernel_info *kernels; + struct clc_parsed_spec_constant *spec_constants; SPIRVKernelParser parser; @@ -523,6 +648,7 @@ clc_spirv_get_kernels_info(const struct clc_binary *spvbin, return false; *num_kernels = parser.kernels.size(); + *num_spec_constants = parser.specConstants.size(); if (!*num_kernels) return false; @@ -553,7 +679,18 @@ clc_spirv_get_kernels_info(const struct clc_binary *spvbin, } } + if (*num_spec_constants) { + spec_constants = reinterpret_cast(calloc(*num_spec_constants, + sizeof(*spec_constants))); + assert(spec_constants); + + for (unsigned i = 0; i < parser.specConstants.size(); ++i) { + spec_constants[i] = parser.specConstants[i].second; + } + } + *out_kernels = kernels; + *out_spec_constants = spec_constants; return true; } diff --git a/src/microsoft/clc/clc_helpers.h b/src/microsoft/clc/clc_helpers.h index afb50f8e355..7a5cf0c9fa9 100644 --- a/src/microsoft/clc/clc_helpers.h +++ b/src/microsoft/clc/clc_helpers.h @@ -42,6 +42,8 @@ bool clc_spirv_get_kernels_info(const struct clc_binary *spvbin, const struct clc_kernel_info **kernels, unsigned *num_kernels, + const struct clc_parsed_spec_constant **spec_constants, + unsigned *num_spec_constants, const struct clc_logger *logger); void