aco/optimizer: use new helper functions to create med3

Foz-DB Navi48:
Totals from 9659 (11.72% of 82419) affected shaders:
Instrs: 17301747 -> 17301735 (-0.00%); split: -0.00%, +0.00%
CodeSize: 93378108 -> 93378184 (+0.00%); split: -0.00%, +0.00%
Latency: 145441784 -> 145441791 (+0.00%); split: -0.00%, +0.00%
InvThroughput: 25768777 -> 25768778 (+0.00%)
Copies: 1370123 -> 1370124 (+0.00%)
VALU: 9705655 -> 9705656 (+0.00%)

Foz-DB Navi21:
Totals from 22 (0.03% of 82387) affected shaders:
Instrs: 27433 -> 27406 (-0.10%)
CodeSize: 146440 -> 146352 (-0.06%); split: -0.06%, +0.00%
Latency: 305857 -> 305806 (-0.02%); split: -0.02%, +0.00%
InvThroughput: 63634 -> 63580 (-0.08%)
VALU: 19109 -> 19082 (-0.14%)

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38150>
This commit is contained in:
Georg Lehmann 2024-12-13 20:41:55 +01:00 committed by Marge Bot
parent 6fc250fc06
commit d21734e024

View file

@ -3930,184 +3930,6 @@ combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
return false;
}
bool
get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
aco_opcode* med3, aco_opcode* minmax, bool* some_gfx9_only)
{
switch (op) {
#define MINMAX(type, gfx9) \
case aco_opcode::v_min_##type: \
case aco_opcode::v_max_##type: \
*min = aco_opcode::v_min_##type; \
*max = aco_opcode::v_max_##type; \
*med3 = aco_opcode::v_med3_##type; \
*min3 = aco_opcode::v_min3_##type; \
*max3 = aco_opcode::v_max3_##type; \
*minmax = op == *min ? aco_opcode::v_maxmin_##type : aco_opcode::v_minmax_##type; \
*some_gfx9_only = gfx9; \
return true;
#define MINMAX_INT16(type, gfx9) \
case aco_opcode::v_min_##type: \
case aco_opcode::v_max_##type: \
*min = aco_opcode::v_min_##type; \
*max = aco_opcode::v_max_##type; \
*med3 = aco_opcode::v_med3_##type; \
*min3 = aco_opcode::v_min3_##type; \
*max3 = aco_opcode::v_max3_##type; \
*minmax = aco_opcode::num_opcodes; \
*some_gfx9_only = gfx9; \
return true;
#define MINMAX_INT16_E64(type, gfx9) \
case aco_opcode::v_min_##type##_e64: \
case aco_opcode::v_max_##type##_e64: \
*min = aco_opcode::v_min_##type##_e64; \
*max = aco_opcode::v_max_##type##_e64; \
*med3 = aco_opcode::v_med3_##type; \
*min3 = aco_opcode::v_min3_##type; \
*max3 = aco_opcode::v_max3_##type; \
*minmax = aco_opcode::num_opcodes; \
*some_gfx9_only = gfx9; \
return true;
MINMAX(f32, false)
MINMAX(u32, false)
MINMAX(i32, false)
MINMAX(f16, true)
MINMAX_INT16(u16, true)
MINMAX_INT16(i16, true)
MINMAX_INT16_E64(u16, true)
MINMAX_INT16_E64(i16, true)
#undef MINMAX_INT16_E64
#undef MINMAX_INT16
#undef MINMAX
default: return false;
}
}
/* when ub > lb:
* v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
* v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
*/
bool
combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
aco_opcode med)
{
/* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
* FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
* minVal > maxVal, which means we can always select it to a v_med3_f32 */
aco_opcode other_op;
if (instr->opcode == min)
other_op = max;
else if (instr->opcode == max)
other_op = min;
else
return false;
for (unsigned swap = 0; swap < 2; swap++) {
Operand operands[3];
bool clamp, precise;
bitarray8 opsel = 0, neg = 0, abs = 0;
uint8_t omod = 0;
if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
/* max(min(src, upper), lower) returns upper if src is NaN, but
* med3(src, lower, upper) returns lower.
*/
if (precise && instr->opcode != min &&
(min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32))
continue;
int const0_idx = -1, const1_idx = -1;
uint32_t const0 = 0, const1 = 0;
for (int i = 0; i < 3; i++) {
uint32_t val;
bool hi16 = opsel & (1 << i);
if (operands[i].isConstant()) {
val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
} else if (operands[i].isTemp() && ctx.info[operands[i].tempId()].is_constant()) {
val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
} else {
continue;
}
if (const0_idx >= 0) {
const1_idx = i;
const1 = val;
} else {
const0_idx = i;
const0 = val;
}
}
if (const0_idx < 0 || const1_idx < 0)
continue;
int lower_idx = const0_idx;
switch (min) {
case aco_opcode::v_min_f32:
case aco_opcode::v_min_f16: {
float const0_f, const1_f;
if (min == aco_opcode::v_min_f32) {
memcpy(&const0_f, &const0, 4);
memcpy(&const1_f, &const1, 4);
} else {
const0_f = _mesa_half_to_float(const0);
const1_f = _mesa_half_to_float(const1);
}
if (abs[const0_idx])
const0_f = fabsf(const0_f);
if (abs[const1_idx])
const1_f = fabsf(const1_f);
if (neg[const0_idx])
const0_f = -const0_f;
if (neg[const1_idx])
const1_f = -const1_f;
lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
break;
}
case aco_opcode::v_min_u32: {
lower_idx = const0 < const1 ? const0_idx : const1_idx;
break;
}
case aco_opcode::v_min_u16:
case aco_opcode::v_min_u16_e64: {
lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
break;
}
case aco_opcode::v_min_i32: {
int32_t const0_i =
const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
int32_t const1_i =
const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
break;
}
case aco_opcode::v_min_i16:
case aco_opcode::v_min_i16_e64: {
int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
break;
}
default: break;
}
int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
if (instr->opcode == min) {
if (upper_idx != 0 || lower_idx == 0)
return false;
} else {
if (upper_idx == 0 || lower_idx != 0)
return false;
}
ctx.uses[instr->operands[swap].tempId()]--;
create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
return true;
}
}
return false;
}
bool
interp_can_become_fma(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
@ -4780,6 +4602,41 @@ create_fma_cb(opt_ctx& ctx, alu_opt_info& info)
return false;
}
template <bool max_first>
bool
create_med3_cb(opt_ctx& ctx, alu_opt_info& info)
{
aco_type type = instr_info.alu_opcode_infos[(int)info.opcode].def_types[0];
/* NaN correctness needs max first, then min. */
if (!max_first && type.base_type == aco_base_type_float && info.defs[0].isPrecise())
return false;
uint64_t upper = 0;
uint64_t lower = 0;
if (!op_info_get_constant(ctx, info.operands[0], type, &upper))
return false;
if (!op_info_get_constant(ctx, info.operands[1], type, &lower) &&
!op_info_get_constant(ctx, info.operands[2], type, &lower))
return false;
if (!max_first)
std::swap(upper, lower);
switch (info.opcode) {
case aco_opcode::v_med3_f32: return uif(lower) <= uif(upper);
case aco_opcode::v_med3_f16: return _mesa_half_to_float(lower) <= _mesa_half_to_float(upper);
case aco_opcode::v_med3_u32: return uint32_t(lower) <= uint32_t(upper);
case aco_opcode::v_med3_u16: return uint16_t(lower) <= uint16_t(upper);
case aco_opcode::v_med3_i32: return int32_t(lower) <= int32_t(upper);
case aco_opcode::v_med3_i16: return int16_t(lower) <= int16_t(upper);
default: UNREACHABLE("invalid clamp");
}
return false;
}
bool
is_mul(Instruction* instr)
{
@ -5003,14 +4860,6 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
combine_sabsdiff(ctx, instr);
} else if (instr->opcode == aco_opcode::v_and_b32) {
combine_v_andor_not(ctx, instr);
} else {
aco_opcode min, max, min3, max3, med3, minmax;
bool some_gfx9_only;
if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &minmax,
&some_gfx9_only) &&
(!some_gfx9_only || ctx.program->gfx_level >= GFX9)) {
combine_clamp(ctx, instr, min, max, med3);
}
}
alu_opt_info info;
@ -5060,50 +4909,74 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
add_opt(v_max_f32, v_max3_f32, 0x3, "120", nullptr, true);
if (ctx.program->gfx_level >= GFX11)
add_opt(v_min_f32, v_minmax_f32, 0x3, "120", nullptr, true);
else
add_opt(v_min_f32, v_med3_f32, 0x3, "012", create_med3_cb<false>, true);
} else if (info.opcode == aco_opcode::v_min_f32) {
add_opt(v_min_f32, v_min3_f32, 0x3, "120", nullptr, true);
if (ctx.program->gfx_level >= GFX11)
add_opt(v_max_f32, v_maxmin_f32, 0x3, "120", nullptr, true);
else
add_opt(v_max_f32, v_med3_f32, 0x3, "012", create_med3_cb<true>, true);
} else if (info.opcode == aco_opcode::v_max_u32) {
add_opt(v_max_u32, v_max3_u32, 0x3, "120", nullptr, true);
if (ctx.program->gfx_level >= GFX11)
add_opt(v_min_u32, v_minmax_u32, 0x3, "120", nullptr, true);
else
add_opt(v_min_u32, v_med3_u32, 0x3, "012", create_med3_cb<false>, true);
} else if (info.opcode == aco_opcode::v_min_u32) {
add_opt(v_min_u32, v_min3_u32, 0x3, "120", nullptr, true);
if (ctx.program->gfx_level >= GFX11)
add_opt(v_max_u32, v_maxmin_u32, 0x3, "120", nullptr, true);
else
add_opt(v_max_u32, v_med3_u32, 0x3, "012", create_med3_cb<true>, true);
} else if (info.opcode == aco_opcode::v_max_i32) {
add_opt(v_max_i32, v_max3_i32, 0x3, "120", nullptr, true);
if (ctx.program->gfx_level >= GFX11)
add_opt(v_min_i32, v_minmax_i32, 0x3, "120", nullptr, true);
else
add_opt(v_min_i32, v_med3_i32, 0x3, "012", create_med3_cb<false>, true);
} else if (info.opcode == aco_opcode::v_min_i32) {
add_opt(v_min_i32, v_min3_i32, 0x3, "120", nullptr, true);
if (ctx.program->gfx_level >= GFX11)
add_opt(v_max_i32, v_maxmin_i32, 0x3, "120", nullptr, true);
else
add_opt(v_max_i32, v_med3_i32, 0x3, "012", create_med3_cb<true>, true);
} else if (info.opcode == aco_opcode::v_max_f16 && ctx.program->gfx_level >= GFX9) {
add_opt(v_max_f16, v_max3_f16, 0x3, "120", nullptr, true);
if (ctx.program->gfx_level >= GFX11)
add_opt(v_min_f16, v_minmax_f16, 0x3, "120", nullptr, true);
else
add_opt(v_min_f16, v_med3_f16, 0x3, "012", create_med3_cb<false>, true);
} else if (info.opcode == aco_opcode::v_min_f16 && ctx.program->gfx_level >= GFX9) {
add_opt(v_min_f16, v_min3_f16, 0x3, "120", nullptr, true);
if (ctx.program->gfx_level >= GFX11)
add_opt(v_max_f16, v_maxmin_f16, 0x3, "120", nullptr, true);
else
add_opt(v_max_f16, v_med3_f16, 0x3, "012", create_med3_cb<true>, true);
} else if (info.opcode == aco_opcode::v_max_u16 && ctx.program->gfx_level >= GFX9) {
add_opt(v_max_u16, v_max3_u16, 0x3, "120", nullptr, true);
add_opt(v_min_u16, v_med3_u16, 0x3, "012", create_med3_cb<false>, true);
} else if (info.opcode == aco_opcode::v_min_u16 && ctx.program->gfx_level >= GFX9) {
add_opt(v_min_u16, v_min3_u16, 0x3, "120", nullptr, true);
add_opt(v_max_u16, v_med3_u16, 0x3, "012", create_med3_cb<true>, true);
} else if (info.opcode == aco_opcode::v_max_i16 && ctx.program->gfx_level >= GFX9) {
add_opt(v_max_i16, v_max3_i16, 0x3, "120", nullptr, true);
add_opt(v_min_i16, v_med3_i16, 0x3, "012", create_med3_cb<false>, true);
} else if (info.opcode == aco_opcode::v_min_i16 && ctx.program->gfx_level >= GFX9) {
add_opt(v_min_i16, v_min3_i16, 0x3, "120", nullptr, true);
add_opt(v_max_i16, v_med3_i16, 0x3, "012", create_med3_cb<true>, true);
} else if (info.opcode == aco_opcode::v_max_u16_e64) {
add_opt(v_max_u16_e64, v_max3_u16, 0x3, "120", nullptr, true);
add_opt(v_min_u16_e64, v_med3_u16, 0x3, "012", create_med3_cb<false>, true);
} else if (info.opcode == aco_opcode::v_min_u16_e64) {
add_opt(v_min_u16_e64, v_min3_u16, 0x3, "120", nullptr, true);
add_opt(v_max_u16_e64, v_med3_u16, 0x3, "012", create_med3_cb<true>, true);
} else if (info.opcode == aco_opcode::v_max_i16_e64) {
add_opt(v_max_i16_e64, v_max3_i16, 0x3, "120", nullptr, true);
add_opt(v_min_i16_e64, v_med3_i16, 0x3, "012", create_med3_cb<false>, true);
} else if (info.opcode == aco_opcode::v_min_i16_e64) {
add_opt(v_min_i16_e64, v_min3_i16, 0x3, "120", nullptr, true);
add_opt(v_max_i16_e64, v_med3_i16, 0x3, "012", create_med3_cb<true>, true);
}
if (match_and_apply_patterns(ctx, info, patterns)) {