spirv: use sdot_2x16 and udot_2x16 opcodes

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12617>
This commit is contained in:
Rhys Perry 2021-08-30 13:56:17 +01:00 committed by Marge Bot
parent 41ecef7855
commit 137974fabb

View file

@ -892,6 +892,7 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
spirv_op_to_string(opcode));
}
unsigned packed_bit_size = 8;
if (glsl_type_is_vector(vtn_src[0]->type)) {
/* FINISHME: Is this actually as good or better for platforms that don't
* have the special instructions (i.e., one or both of has_dot_4x8 or
@ -902,6 +903,14 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
glsl_get_bit_size(dest_type) <= 32) {
src[0] = nir_pack_32_4x8(&b->nb, src[0]);
src[1] = nir_pack_32_4x8(&b->nb, src[1]);
} else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
glsl_get_bit_size(vtn_src[0]->type) == 16 &&
glsl_get_bit_size(dest_type) <= 32 &&
opcode != SpvOpSUDotKHR &&
opcode != SpvOpSUDotAccSatKHR) {
src[0] = nir_pack_32_2x16(&b->nb, src[0]);
src[1] = nir_pack_32_2x16(&b->nb, src[1]);
packed_bit_size = 16;
}
} else if (glsl_type_is_scalar(vtn_src[0]->type) &&
glsl_type_is_32bit(vtn_src[0]->type)) {
@ -1012,53 +1021,64 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
bool is_signed;
bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
switch (opcode) {
case SpvOpSDotKHR:
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
is_signed = true;
break;
case SpvOpUDotKHR:
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
is_signed = false;
break;
case SpvOpSUDotKHR:
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
is_signed = true;
break;
case SpvOpSDotAccSatKHR:
if (dest_size == 32)
dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
else
if (packed_bit_size == 16) {
switch (opcode) {
case SpvOpSDotKHR:
dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
break;
case SpvOpUDotKHR:
dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
break;
case SpvOpSDotAccSatKHR:
if (dest_size == 32)
dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
else
dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
break;
case SpvOpUDotAccSatKHR:
if (dest_size == 32)
dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
else
dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
break;
default:
unreachable("Invalid opcode.");
}
} else {
switch (opcode) {
case SpvOpSDotKHR:
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
is_signed = true;
break;
case SpvOpUDotAccSatKHR:
if (dest_size == 32)
dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
else
break;
case SpvOpUDotKHR:
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
is_signed = false;
break;
case SpvOpSUDotAccSatKHR:
if (dest_size == 32)
dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
else
break;
case SpvOpSUDotKHR:
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
is_signed = true;
break;
default:
unreachable("Invalid opcode.");
break;
case SpvOpSDotAccSatKHR:
if (dest_size == 32)
dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
else
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
break;
case SpvOpUDotAccSatKHR:
if (dest_size == 32)
dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
else
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
break;
case SpvOpSUDotAccSatKHR:
if (dest_size == 32)
dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
else
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
break;
default:
unreachable("Invalid opcode.");
}
}
if (dest_size != 32) {