nir/spirv: Add support for a bunch of ALU operations

This commit is contained in:
Jason Ekstrand 2015-05-04 12:04:02 -07:00
parent d2a7972557
commit ff828749ea

View file

@ -27,6 +27,7 @@
#include "nir_spirv.h"
#include "nir_vla.h"
#include "nir_builder.h"
#include "spirv.h"
struct vtn_decoration;
@ -81,6 +82,8 @@ struct vtn_decoration {
};
struct vtn_builder {
nir_builder nb;
nir_shader *shader;
nir_function_impl *impl;
struct exec_list *cf_list;
@ -705,11 +708,192 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
unreachable("Unhandled opcode");
}
static void
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
unreachable("Matrix math not handled");
}
static void
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
unreachable("Unhandled opcode");
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
/* Collect the various SSA sources */
unsigned num_inputs = count - 3;
nir_ssa_def *src[4];
for (unsigned i = 0; i < num_inputs; i++)
src[i] = vtn_ssa_value(b, w[i + 3]);
/* We use the builder for some of the instructions. Go ahead and
* initialize it with the current cf_list.
*/
nir_builder_insert_after_cf_list(&b->nb, b->cf_list);
/* Indicates that the first two arguments should be swapped. This is
* used for implementing greater-than and less-than-or-equal.
*/
bool swap = false;
nir_op op;
switch (opcode) {
/* Basic ALU operations */
case SpvOpSNegate: op = nir_op_ineg; break;
case SpvOpFNegate: op = nir_op_fneg; break;
case SpvOpNot: op = nir_op_inot; break;
case SpvOpAny:
switch (src[0]->num_components) {
case 1: op = nir_op_imov; break;
case 2: op = nir_op_bany2; break;
case 3: op = nir_op_bany3; break;
case 4: op = nir_op_bany4; break;
}
break;
case SpvOpAll:
switch (src[0]->num_components) {
case 1: op = nir_op_imov; break;
case 2: op = nir_op_ball2; break;
case 3: op = nir_op_ball3; break;
case 4: op = nir_op_ball4; break;
}
break;
case SpvOpIAdd: op = nir_op_iadd; break;
case SpvOpFAdd: op = nir_op_fadd; break;
case SpvOpISub: op = nir_op_isub; break;
case SpvOpFSub: op = nir_op_fsub; break;
case SpvOpIMul: op = nir_op_imul; break;
case SpvOpFMul: op = nir_op_fmul; break;
case SpvOpUDiv: op = nir_op_udiv; break;
case SpvOpSDiv: op = nir_op_idiv; break;
case SpvOpFDiv: op = nir_op_fdiv; break;
case SpvOpUMod: op = nir_op_umod; break;
case SpvOpSMod: op = nir_op_umod; break; /* FIXME? */
case SpvOpFMod: op = nir_op_fmod; break;
case SpvOpDot:
assert(src[0]->num_components == src[1]->num_components);
switch (src[0]->num_components) {
case 1: op = nir_op_fmul; break;
case 2: op = nir_op_fdot2; break;
case 3: op = nir_op_fdot3; break;
case 4: op = nir_op_fdot4; break;
}
break;
case SpvOpShiftRightLogical: op = nir_op_ushr; break;
case SpvOpShiftRightArithmetic: op = nir_op_ishr; break;
case SpvOpShiftLeftLogical: op = nir_op_ishl; break;
case SpvOpLogicalOr: op = nir_op_ior; break;
case SpvOpLogicalXor: op = nir_op_ixor; break;
case SpvOpLogicalAnd: op = nir_op_iand; break;
case SpvOpBitwiseOr: op = nir_op_ior; break;
case SpvOpBitwiseXor: op = nir_op_ixor; break;
case SpvOpBitwiseAnd: op = nir_op_iand; break;
case SpvOpSelect: op = nir_op_bcsel; break;
case SpvOpIEqual: op = nir_op_ieq; break;
/* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
case SpvOpFOrdEqual: op = nir_op_feq; break;
case SpvOpFUnordEqual: op = nir_op_feq; break;
case SpvOpINotEqual: op = nir_op_ine; break;
case SpvOpFOrdNotEqual: op = nir_op_fne; break;
case SpvOpFUnordNotEqual: op = nir_op_fne; break;
case SpvOpULessThan: op = nir_op_ult; break;
case SpvOpSLessThan: op = nir_op_ilt; break;
case SpvOpFOrdLessThan: op = nir_op_flt; break;
case SpvOpFUnordLessThan: op = nir_op_flt; break;
case SpvOpUGreaterThan: op = nir_op_ult; swap = true; break;
case SpvOpSGreaterThan: op = nir_op_ilt; swap = true; break;
case SpvOpFOrdGreaterThan: op = nir_op_flt; swap = true; break;
case SpvOpFUnordGreaterThan: op = nir_op_flt; swap = true; break;
case SpvOpULessThanEqual: op = nir_op_uge; swap = true; break;
case SpvOpSLessThanEqual: op = nir_op_ige; swap = true; break;
case SpvOpFOrdLessThanEqual: op = nir_op_fge; swap = true; break;
case SpvOpFUnordLessThanEqual: op = nir_op_fge; swap = true; break;
case SpvOpUGreaterThanEqual: op = nir_op_uge; break;
case SpvOpSGreaterThanEqual: op = nir_op_ige; break;
case SpvOpFOrdGreaterThanEqual: op = nir_op_fge; break;
case SpvOpFUnordGreaterThanEqual:op = nir_op_fge; break;
/* Conversions: */
case SpvOpConvertFToU: op = nir_op_f2u; break;
case SpvOpConvertFToS: op = nir_op_f2i; break;
case SpvOpConvertSToF: op = nir_op_i2f; break;
case SpvOpConvertUToF: op = nir_op_u2f; break;
case SpvOpBitcast: op = nir_op_imov; break;
case SpvOpUConvert:
case SpvOpSConvert:
op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */
break;
case SpvOpFConvert:
op = nir_op_fmov;
break;
/* Derivatives: */
case SpvOpDPdx: op = nir_op_fddx; break;
case SpvOpDPdy: op = nir_op_fddy; break;
case SpvOpDPdxFine: op = nir_op_fddx_fine; break;
case SpvOpDPdyFine: op = nir_op_fddy_fine; break;
case SpvOpDPdxCoarse: op = nir_op_fddx_coarse; break;
case SpvOpDPdyCoarse: op = nir_op_fddy_coarse; break;
case SpvOpFwidth:
val->ssa = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
nir_fabs(&b->nb, nir_fddx(&b->nb, src[1])));
return;
case SpvOpFwidthFine:
val->ssa = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1])));
return;
case SpvOpFwidthCoarse:
val->ssa = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1])));
return;
case SpvOpVectorTimesScalar:
/* The builder will take care of splatting for us. */
val->ssa = nir_fmul(&b->nb, src[0], src[1]);
return;
case SpvOpSRem:
case SpvOpFRem:
unreachable("No NIR equivalent");
case SpvOpIsNan:
case SpvOpIsInf:
case SpvOpIsFinite:
case SpvOpIsNormal:
case SpvOpSignBitSet:
case SpvOpLessOrGreater:
case SpvOpOrdered:
case SpvOpUnordered:
default:
unreachable("Unhandled opcode");
}
if (swap) {
nir_ssa_def *tmp = src[0];
src[0] = src[1];
src[1] = tmp;
}
nir_alu_instr *instr = nir_alu_instr_create(b->shader, op);
nir_ssa_dest_init(&instr->instr, &instr->dest.dest,
glsl_get_vector_elements(val->type), val->name);
val->ssa = &instr->dest.dest.ssa;
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++)
instr->src[i].src = nir_src_for_ssa(src[i]);
nir_instr_insert_after_cf_list(b->cf_list, &instr->instr);
}
static bool
@ -993,7 +1177,6 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpPtrCastToGeneric:
case SpvOpGenericCastToPtr:
case SpvOpBitcast:
case SpvOpTranspose:
case SpvOpIsNan:
case SpvOpIsInf:
case SpvOpIsFinite:
@ -1017,11 +1200,6 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpFRem:
case SpvOpFMod:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
case SpvOpOuterProduct:
case SpvOpDot:
case SpvOpShiftRightLogical:
case SpvOpShiftRightArithmetic:
@ -1067,6 +1245,15 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
vtn_handle_alu(b, opcode, w, count);
break;
case SpvOpTranspose:
case SpvOpOuterProduct:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
vtn_handle_matrix_alu(b, opcode, w, count);
break;
default:
unreachable("Unhandled opcode");
}
@ -1163,6 +1350,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
foreach_list_typed(struct vtn_function, func, node, &b->functions) {
b->impl = nir_function_impl_create(func->overload);
nir_builder_init(&b->nb, b->impl);
b->cf_list = &b->impl->body;
vtn_walk_blocks(b, func->start_block, NULL);
}