mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-01-06 13:10:10 +01:00
nir/spirv: Add support for a bunch of ALU operations
This commit is contained in:
parent
d2a7972557
commit
ff828749ea
1 changed files with 195 additions and 7 deletions
|
|
@ -27,6 +27,7 @@
|
|||
|
||||
#include "nir_spirv.h"
|
||||
#include "nir_vla.h"
|
||||
#include "nir_builder.h"
|
||||
#include "spirv.h"
|
||||
|
||||
struct vtn_decoration;
|
||||
|
|
@ -81,6 +82,8 @@ struct vtn_decoration {
|
|||
};
|
||||
|
||||
struct vtn_builder {
|
||||
nir_builder nb;
|
||||
|
||||
nir_shader *shader;
|
||||
nir_function_impl *impl;
|
||||
struct exec_list *cf_list;
|
||||
|
|
@ -705,11 +708,192 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
|
|||
unreachable("Unhandled opcode");
|
||||
}
|
||||
|
||||
static void
|
||||
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count)
|
||||
{
|
||||
unreachable("Matrix math not handled");
|
||||
}
|
||||
|
||||
static void
|
||||
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count)
|
||||
{
|
||||
unreachable("Unhandled opcode");
|
||||
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
|
||||
val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
|
||||
|
||||
/* Collect the various SSA sources */
|
||||
unsigned num_inputs = count - 3;
|
||||
nir_ssa_def *src[4];
|
||||
for (unsigned i = 0; i < num_inputs; i++)
|
||||
src[i] = vtn_ssa_value(b, w[i + 3]);
|
||||
|
||||
/* We use the builder for some of the instructions. Go ahead and
|
||||
* initialize it with the current cf_list.
|
||||
*/
|
||||
nir_builder_insert_after_cf_list(&b->nb, b->cf_list);
|
||||
|
||||
/* Indicates that the first two arguments should be swapped. This is
|
||||
* used for implementing greater-than and less-than-or-equal.
|
||||
*/
|
||||
bool swap = false;
|
||||
|
||||
nir_op op;
|
||||
switch (opcode) {
|
||||
/* Basic ALU operations */
|
||||
case SpvOpSNegate: op = nir_op_ineg; break;
|
||||
case SpvOpFNegate: op = nir_op_fneg; break;
|
||||
case SpvOpNot: op = nir_op_inot; break;
|
||||
|
||||
case SpvOpAny:
|
||||
switch (src[0]->num_components) {
|
||||
case 1: op = nir_op_imov; break;
|
||||
case 2: op = nir_op_bany2; break;
|
||||
case 3: op = nir_op_bany3; break;
|
||||
case 4: op = nir_op_bany4; break;
|
||||
}
|
||||
break;
|
||||
|
||||
case SpvOpAll:
|
||||
switch (src[0]->num_components) {
|
||||
case 1: op = nir_op_imov; break;
|
||||
case 2: op = nir_op_ball2; break;
|
||||
case 3: op = nir_op_ball3; break;
|
||||
case 4: op = nir_op_ball4; break;
|
||||
}
|
||||
break;
|
||||
|
||||
case SpvOpIAdd: op = nir_op_iadd; break;
|
||||
case SpvOpFAdd: op = nir_op_fadd; break;
|
||||
case SpvOpISub: op = nir_op_isub; break;
|
||||
case SpvOpFSub: op = nir_op_fsub; break;
|
||||
case SpvOpIMul: op = nir_op_imul; break;
|
||||
case SpvOpFMul: op = nir_op_fmul; break;
|
||||
case SpvOpUDiv: op = nir_op_udiv; break;
|
||||
case SpvOpSDiv: op = nir_op_idiv; break;
|
||||
case SpvOpFDiv: op = nir_op_fdiv; break;
|
||||
case SpvOpUMod: op = nir_op_umod; break;
|
||||
case SpvOpSMod: op = nir_op_umod; break; /* FIXME? */
|
||||
case SpvOpFMod: op = nir_op_fmod; break;
|
||||
|
||||
case SpvOpDot:
|
||||
assert(src[0]->num_components == src[1]->num_components);
|
||||
switch (src[0]->num_components) {
|
||||
case 1: op = nir_op_fmul; break;
|
||||
case 2: op = nir_op_fdot2; break;
|
||||
case 3: op = nir_op_fdot3; break;
|
||||
case 4: op = nir_op_fdot4; break;
|
||||
}
|
||||
break;
|
||||
|
||||
case SpvOpShiftRightLogical: op = nir_op_ushr; break;
|
||||
case SpvOpShiftRightArithmetic: op = nir_op_ishr; break;
|
||||
case SpvOpShiftLeftLogical: op = nir_op_ishl; break;
|
||||
case SpvOpLogicalOr: op = nir_op_ior; break;
|
||||
case SpvOpLogicalXor: op = nir_op_ixor; break;
|
||||
case SpvOpLogicalAnd: op = nir_op_iand; break;
|
||||
case SpvOpBitwiseOr: op = nir_op_ior; break;
|
||||
case SpvOpBitwiseXor: op = nir_op_ixor; break;
|
||||
case SpvOpBitwiseAnd: op = nir_op_iand; break;
|
||||
case SpvOpSelect: op = nir_op_bcsel; break;
|
||||
case SpvOpIEqual: op = nir_op_ieq; break;
|
||||
|
||||
/* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
|
||||
case SpvOpFOrdEqual: op = nir_op_feq; break;
|
||||
case SpvOpFUnordEqual: op = nir_op_feq; break;
|
||||
case SpvOpINotEqual: op = nir_op_ine; break;
|
||||
case SpvOpFOrdNotEqual: op = nir_op_fne; break;
|
||||
case SpvOpFUnordNotEqual: op = nir_op_fne; break;
|
||||
case SpvOpULessThan: op = nir_op_ult; break;
|
||||
case SpvOpSLessThan: op = nir_op_ilt; break;
|
||||
case SpvOpFOrdLessThan: op = nir_op_flt; break;
|
||||
case SpvOpFUnordLessThan: op = nir_op_flt; break;
|
||||
case SpvOpUGreaterThan: op = nir_op_ult; swap = true; break;
|
||||
case SpvOpSGreaterThan: op = nir_op_ilt; swap = true; break;
|
||||
case SpvOpFOrdGreaterThan: op = nir_op_flt; swap = true; break;
|
||||
case SpvOpFUnordGreaterThan: op = nir_op_flt; swap = true; break;
|
||||
case SpvOpULessThanEqual: op = nir_op_uge; swap = true; break;
|
||||
case SpvOpSLessThanEqual: op = nir_op_ige; swap = true; break;
|
||||
case SpvOpFOrdLessThanEqual: op = nir_op_fge; swap = true; break;
|
||||
case SpvOpFUnordLessThanEqual: op = nir_op_fge; swap = true; break;
|
||||
case SpvOpUGreaterThanEqual: op = nir_op_uge; break;
|
||||
case SpvOpSGreaterThanEqual: op = nir_op_ige; break;
|
||||
case SpvOpFOrdGreaterThanEqual: op = nir_op_fge; break;
|
||||
case SpvOpFUnordGreaterThanEqual:op = nir_op_fge; break;
|
||||
|
||||
/* Conversions: */
|
||||
case SpvOpConvertFToU: op = nir_op_f2u; break;
|
||||
case SpvOpConvertFToS: op = nir_op_f2i; break;
|
||||
case SpvOpConvertSToF: op = nir_op_i2f; break;
|
||||
case SpvOpConvertUToF: op = nir_op_u2f; break;
|
||||
case SpvOpBitcast: op = nir_op_imov; break;
|
||||
case SpvOpUConvert:
|
||||
case SpvOpSConvert:
|
||||
op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */
|
||||
break;
|
||||
case SpvOpFConvert:
|
||||
op = nir_op_fmov;
|
||||
break;
|
||||
|
||||
/* Derivatives: */
|
||||
case SpvOpDPdx: op = nir_op_fddx; break;
|
||||
case SpvOpDPdy: op = nir_op_fddy; break;
|
||||
case SpvOpDPdxFine: op = nir_op_fddx_fine; break;
|
||||
case SpvOpDPdyFine: op = nir_op_fddy_fine; break;
|
||||
case SpvOpDPdxCoarse: op = nir_op_fddx_coarse; break;
|
||||
case SpvOpDPdyCoarse: op = nir_op_fddy_coarse; break;
|
||||
case SpvOpFwidth:
|
||||
val->ssa = nir_fadd(&b->nb,
|
||||
nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
|
||||
nir_fabs(&b->nb, nir_fddx(&b->nb, src[1])));
|
||||
return;
|
||||
case SpvOpFwidthFine:
|
||||
val->ssa = nir_fadd(&b->nb,
|
||||
nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
|
||||
nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1])));
|
||||
return;
|
||||
case SpvOpFwidthCoarse:
|
||||
val->ssa = nir_fadd(&b->nb,
|
||||
nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
|
||||
nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1])));
|
||||
return;
|
||||
|
||||
case SpvOpVectorTimesScalar:
|
||||
/* The builder will take care of splatting for us. */
|
||||
val->ssa = nir_fmul(&b->nb, src[0], src[1]);
|
||||
return;
|
||||
|
||||
case SpvOpSRem:
|
||||
case SpvOpFRem:
|
||||
unreachable("No NIR equivalent");
|
||||
|
||||
case SpvOpIsNan:
|
||||
case SpvOpIsInf:
|
||||
case SpvOpIsFinite:
|
||||
case SpvOpIsNormal:
|
||||
case SpvOpSignBitSet:
|
||||
case SpvOpLessOrGreater:
|
||||
case SpvOpOrdered:
|
||||
case SpvOpUnordered:
|
||||
default:
|
||||
unreachable("Unhandled opcode");
|
||||
}
|
||||
|
||||
if (swap) {
|
||||
nir_ssa_def *tmp = src[0];
|
||||
src[0] = src[1];
|
||||
src[1] = tmp;
|
||||
}
|
||||
|
||||
nir_alu_instr *instr = nir_alu_instr_create(b->shader, op);
|
||||
nir_ssa_dest_init(&instr->instr, &instr->dest.dest,
|
||||
glsl_get_vector_elements(val->type), val->name);
|
||||
val->ssa = &instr->dest.dest.ssa;
|
||||
|
||||
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++)
|
||||
instr->src[i].src = nir_src_for_ssa(src[i]);
|
||||
|
||||
nir_instr_insert_after_cf_list(b->cf_list, &instr->instr);
|
||||
}
|
||||
|
||||
static bool
|
||||
|
|
@ -993,7 +1177,6 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpPtrCastToGeneric:
|
||||
case SpvOpGenericCastToPtr:
|
||||
case SpvOpBitcast:
|
||||
case SpvOpTranspose:
|
||||
case SpvOpIsNan:
|
||||
case SpvOpIsInf:
|
||||
case SpvOpIsFinite:
|
||||
|
|
@ -1017,11 +1200,6 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpFRem:
|
||||
case SpvOpFMod:
|
||||
case SpvOpVectorTimesScalar:
|
||||
case SpvOpMatrixTimesScalar:
|
||||
case SpvOpVectorTimesMatrix:
|
||||
case SpvOpMatrixTimesVector:
|
||||
case SpvOpMatrixTimesMatrix:
|
||||
case SpvOpOuterProduct:
|
||||
case SpvOpDot:
|
||||
case SpvOpShiftRightLogical:
|
||||
case SpvOpShiftRightArithmetic:
|
||||
|
|
@ -1067,6 +1245,15 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||
vtn_handle_alu(b, opcode, w, count);
|
||||
break;
|
||||
|
||||
case SpvOpTranspose:
|
||||
case SpvOpOuterProduct:
|
||||
case SpvOpMatrixTimesScalar:
|
||||
case SpvOpVectorTimesMatrix:
|
||||
case SpvOpMatrixTimesVector:
|
||||
case SpvOpMatrixTimesMatrix:
|
||||
vtn_handle_matrix_alu(b, opcode, w, count);
|
||||
break;
|
||||
|
||||
default:
|
||||
unreachable("Unhandled opcode");
|
||||
}
|
||||
|
|
@ -1163,6 +1350,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
|
|||
|
||||
foreach_list_typed(struct vtn_function, func, node, &b->functions) {
|
||||
b->impl = nir_function_impl_create(func->overload);
|
||||
nir_builder_init(&b->nb, b->impl);
|
||||
b->cf_list = &b->impl->body;
|
||||
vtn_walk_blocks(b, func->start_block, NULL);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue