mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-25 08:40:11 +01:00
aco/optimizer: use new helper functions to create med3
Foz-DB Navi48: Totals from 9659 (11.72% of 82419) affected shaders: Instrs: 17301747 -> 17301735 (-0.00%); split: -0.00%, +0.00% CodeSize: 93378108 -> 93378184 (+0.00%); split: -0.00%, +0.00% Latency: 145441784 -> 145441791 (+0.00%); split: -0.00%, +0.00% InvThroughput: 25768777 -> 25768778 (+0.00%) Copies: 1370123 -> 1370124 (+0.00%) VALU: 9705655 -> 9705656 (+0.00%) Foz-DB Navi21: Totals from 22 (0.03% of 82387) affected shaders: Instrs: 27433 -> 27406 (-0.10%) CodeSize: 146440 -> 146352 (-0.06%); split: -0.06%, +0.00% Latency: 305857 -> 305806 (-0.02%); split: -0.02%, +0.00% InvThroughput: 63634 -> 63580 (-0.08%) VALU: 19109 -> 19082 (-0.14%) Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38150>
This commit is contained in:
parent
6fc250fc06
commit
d21734e024
1 changed files with 59 additions and 186 deletions
|
|
@ -3930,184 +3930,6 @@ combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
|||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
|
||||
aco_opcode* med3, aco_opcode* minmax, bool* some_gfx9_only)
|
||||
{
|
||||
switch (op) {
|
||||
#define MINMAX(type, gfx9) \
|
||||
case aco_opcode::v_min_##type: \
|
||||
case aco_opcode::v_max_##type: \
|
||||
*min = aco_opcode::v_min_##type; \
|
||||
*max = aco_opcode::v_max_##type; \
|
||||
*med3 = aco_opcode::v_med3_##type; \
|
||||
*min3 = aco_opcode::v_min3_##type; \
|
||||
*max3 = aco_opcode::v_max3_##type; \
|
||||
*minmax = op == *min ? aco_opcode::v_maxmin_##type : aco_opcode::v_minmax_##type; \
|
||||
*some_gfx9_only = gfx9; \
|
||||
return true;
|
||||
#define MINMAX_INT16(type, gfx9) \
|
||||
case aco_opcode::v_min_##type: \
|
||||
case aco_opcode::v_max_##type: \
|
||||
*min = aco_opcode::v_min_##type; \
|
||||
*max = aco_opcode::v_max_##type; \
|
||||
*med3 = aco_opcode::v_med3_##type; \
|
||||
*min3 = aco_opcode::v_min3_##type; \
|
||||
*max3 = aco_opcode::v_max3_##type; \
|
||||
*minmax = aco_opcode::num_opcodes; \
|
||||
*some_gfx9_only = gfx9; \
|
||||
return true;
|
||||
#define MINMAX_INT16_E64(type, gfx9) \
|
||||
case aco_opcode::v_min_##type##_e64: \
|
||||
case aco_opcode::v_max_##type##_e64: \
|
||||
*min = aco_opcode::v_min_##type##_e64; \
|
||||
*max = aco_opcode::v_max_##type##_e64; \
|
||||
*med3 = aco_opcode::v_med3_##type; \
|
||||
*min3 = aco_opcode::v_min3_##type; \
|
||||
*max3 = aco_opcode::v_max3_##type; \
|
||||
*minmax = aco_opcode::num_opcodes; \
|
||||
*some_gfx9_only = gfx9; \
|
||||
return true;
|
||||
MINMAX(f32, false)
|
||||
MINMAX(u32, false)
|
||||
MINMAX(i32, false)
|
||||
MINMAX(f16, true)
|
||||
MINMAX_INT16(u16, true)
|
||||
MINMAX_INT16(i16, true)
|
||||
MINMAX_INT16_E64(u16, true)
|
||||
MINMAX_INT16_E64(i16, true)
|
||||
#undef MINMAX_INT16_E64
|
||||
#undef MINMAX_INT16
|
||||
#undef MINMAX
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
/* when ub > lb:
|
||||
* v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
|
||||
* v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
|
||||
*/
|
||||
bool
|
||||
combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
|
||||
aco_opcode med)
|
||||
{
|
||||
/* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
|
||||
* FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
|
||||
* minVal > maxVal, which means we can always select it to a v_med3_f32 */
|
||||
aco_opcode other_op;
|
||||
if (instr->opcode == min)
|
||||
other_op = max;
|
||||
else if (instr->opcode == max)
|
||||
other_op = min;
|
||||
else
|
||||
return false;
|
||||
|
||||
for (unsigned swap = 0; swap < 2; swap++) {
|
||||
Operand operands[3];
|
||||
bool clamp, precise;
|
||||
bitarray8 opsel = 0, neg = 0, abs = 0;
|
||||
uint8_t omod = 0;
|
||||
if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
|
||||
abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
|
||||
/* max(min(src, upper), lower) returns upper if src is NaN, but
|
||||
* med3(src, lower, upper) returns lower.
|
||||
*/
|
||||
if (precise && instr->opcode != min &&
|
||||
(min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32))
|
||||
continue;
|
||||
|
||||
int const0_idx = -1, const1_idx = -1;
|
||||
uint32_t const0 = 0, const1 = 0;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
uint32_t val;
|
||||
bool hi16 = opsel & (1 << i);
|
||||
if (operands[i].isConstant()) {
|
||||
val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
|
||||
} else if (operands[i].isTemp() && ctx.info[operands[i].tempId()].is_constant()) {
|
||||
val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
if (const0_idx >= 0) {
|
||||
const1_idx = i;
|
||||
const1 = val;
|
||||
} else {
|
||||
const0_idx = i;
|
||||
const0 = val;
|
||||
}
|
||||
}
|
||||
if (const0_idx < 0 || const1_idx < 0)
|
||||
continue;
|
||||
|
||||
int lower_idx = const0_idx;
|
||||
switch (min) {
|
||||
case aco_opcode::v_min_f32:
|
||||
case aco_opcode::v_min_f16: {
|
||||
float const0_f, const1_f;
|
||||
if (min == aco_opcode::v_min_f32) {
|
||||
memcpy(&const0_f, &const0, 4);
|
||||
memcpy(&const1_f, &const1, 4);
|
||||
} else {
|
||||
const0_f = _mesa_half_to_float(const0);
|
||||
const1_f = _mesa_half_to_float(const1);
|
||||
}
|
||||
if (abs[const0_idx])
|
||||
const0_f = fabsf(const0_f);
|
||||
if (abs[const1_idx])
|
||||
const1_f = fabsf(const1_f);
|
||||
if (neg[const0_idx])
|
||||
const0_f = -const0_f;
|
||||
if (neg[const1_idx])
|
||||
const1_f = -const1_f;
|
||||
lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
|
||||
break;
|
||||
}
|
||||
case aco_opcode::v_min_u32: {
|
||||
lower_idx = const0 < const1 ? const0_idx : const1_idx;
|
||||
break;
|
||||
}
|
||||
case aco_opcode::v_min_u16:
|
||||
case aco_opcode::v_min_u16_e64: {
|
||||
lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
|
||||
break;
|
||||
}
|
||||
case aco_opcode::v_min_i32: {
|
||||
int32_t const0_i =
|
||||
const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
|
||||
int32_t const1_i =
|
||||
const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
|
||||
lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
|
||||
break;
|
||||
}
|
||||
case aco_opcode::v_min_i16:
|
||||
case aco_opcode::v_min_i16_e64: {
|
||||
int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
|
||||
int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
|
||||
lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
|
||||
|
||||
if (instr->opcode == min) {
|
||||
if (upper_idx != 0 || lower_idx == 0)
|
||||
return false;
|
||||
} else {
|
||||
if (upper_idx == 0 || lower_idx != 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
ctx.uses[instr->operands[swap].tempId()]--;
|
||||
create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
interp_can_become_fma(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
{
|
||||
|
|
@ -4780,6 +4602,41 @@ create_fma_cb(opt_ctx& ctx, alu_opt_info& info)
|
|||
return false;
|
||||
}
|
||||
|
||||
template <bool max_first>
|
||||
bool
|
||||
create_med3_cb(opt_ctx& ctx, alu_opt_info& info)
|
||||
{
|
||||
aco_type type = instr_info.alu_opcode_infos[(int)info.opcode].def_types[0];
|
||||
|
||||
/* NaN correctness needs max first, then min. */
|
||||
if (!max_first && type.base_type == aco_base_type_float && info.defs[0].isPrecise())
|
||||
return false;
|
||||
|
||||
uint64_t upper = 0;
|
||||
uint64_t lower = 0;
|
||||
|
||||
if (!op_info_get_constant(ctx, info.operands[0], type, &upper))
|
||||
return false;
|
||||
|
||||
if (!op_info_get_constant(ctx, info.operands[1], type, &lower) &&
|
||||
!op_info_get_constant(ctx, info.operands[2], type, &lower))
|
||||
return false;
|
||||
|
||||
if (!max_first)
|
||||
std::swap(upper, lower);
|
||||
|
||||
switch (info.opcode) {
|
||||
case aco_opcode::v_med3_f32: return uif(lower) <= uif(upper);
|
||||
case aco_opcode::v_med3_f16: return _mesa_half_to_float(lower) <= _mesa_half_to_float(upper);
|
||||
case aco_opcode::v_med3_u32: return uint32_t(lower) <= uint32_t(upper);
|
||||
case aco_opcode::v_med3_u16: return uint16_t(lower) <= uint16_t(upper);
|
||||
case aco_opcode::v_med3_i32: return int32_t(lower) <= int32_t(upper);
|
||||
case aco_opcode::v_med3_i16: return int16_t(lower) <= int16_t(upper);
|
||||
default: UNREACHABLE("invalid clamp");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
is_mul(Instruction* instr)
|
||||
{
|
||||
|
|
@ -5003,14 +4860,6 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
|||
combine_sabsdiff(ctx, instr);
|
||||
} else if (instr->opcode == aco_opcode::v_and_b32) {
|
||||
combine_v_andor_not(ctx, instr);
|
||||
} else {
|
||||
aco_opcode min, max, min3, max3, med3, minmax;
|
||||
bool some_gfx9_only;
|
||||
if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &minmax,
|
||||
&some_gfx9_only) &&
|
||||
(!some_gfx9_only || ctx.program->gfx_level >= GFX9)) {
|
||||
combine_clamp(ctx, instr, min, max, med3);
|
||||
}
|
||||
}
|
||||
|
||||
alu_opt_info info;
|
||||
|
|
@ -5060,50 +4909,74 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
|||
add_opt(v_max_f32, v_max3_f32, 0x3, "120", nullptr, true);
|
||||
if (ctx.program->gfx_level >= GFX11)
|
||||
add_opt(v_min_f32, v_minmax_f32, 0x3, "120", nullptr, true);
|
||||
else
|
||||
add_opt(v_min_f32, v_med3_f32, 0x3, "012", create_med3_cb<false>, true);
|
||||
} else if (info.opcode == aco_opcode::v_min_f32) {
|
||||
add_opt(v_min_f32, v_min3_f32, 0x3, "120", nullptr, true);
|
||||
if (ctx.program->gfx_level >= GFX11)
|
||||
add_opt(v_max_f32, v_maxmin_f32, 0x3, "120", nullptr, true);
|
||||
else
|
||||
add_opt(v_max_f32, v_med3_f32, 0x3, "012", create_med3_cb<true>, true);
|
||||
} else if (info.opcode == aco_opcode::v_max_u32) {
|
||||
add_opt(v_max_u32, v_max3_u32, 0x3, "120", nullptr, true);
|
||||
if (ctx.program->gfx_level >= GFX11)
|
||||
add_opt(v_min_u32, v_minmax_u32, 0x3, "120", nullptr, true);
|
||||
else
|
||||
add_opt(v_min_u32, v_med3_u32, 0x3, "012", create_med3_cb<false>, true);
|
||||
} else if (info.opcode == aco_opcode::v_min_u32) {
|
||||
add_opt(v_min_u32, v_min3_u32, 0x3, "120", nullptr, true);
|
||||
if (ctx.program->gfx_level >= GFX11)
|
||||
add_opt(v_max_u32, v_maxmin_u32, 0x3, "120", nullptr, true);
|
||||
else
|
||||
add_opt(v_max_u32, v_med3_u32, 0x3, "012", create_med3_cb<true>, true);
|
||||
} else if (info.opcode == aco_opcode::v_max_i32) {
|
||||
add_opt(v_max_i32, v_max3_i32, 0x3, "120", nullptr, true);
|
||||
if (ctx.program->gfx_level >= GFX11)
|
||||
add_opt(v_min_i32, v_minmax_i32, 0x3, "120", nullptr, true);
|
||||
else
|
||||
add_opt(v_min_i32, v_med3_i32, 0x3, "012", create_med3_cb<false>, true);
|
||||
} else if (info.opcode == aco_opcode::v_min_i32) {
|
||||
add_opt(v_min_i32, v_min3_i32, 0x3, "120", nullptr, true);
|
||||
if (ctx.program->gfx_level >= GFX11)
|
||||
add_opt(v_max_i32, v_maxmin_i32, 0x3, "120", nullptr, true);
|
||||
else
|
||||
add_opt(v_max_i32, v_med3_i32, 0x3, "012", create_med3_cb<true>, true);
|
||||
} else if (info.opcode == aco_opcode::v_max_f16 && ctx.program->gfx_level >= GFX9) {
|
||||
add_opt(v_max_f16, v_max3_f16, 0x3, "120", nullptr, true);
|
||||
if (ctx.program->gfx_level >= GFX11)
|
||||
add_opt(v_min_f16, v_minmax_f16, 0x3, "120", nullptr, true);
|
||||
else
|
||||
add_opt(v_min_f16, v_med3_f16, 0x3, "012", create_med3_cb<false>, true);
|
||||
} else if (info.opcode == aco_opcode::v_min_f16 && ctx.program->gfx_level >= GFX9) {
|
||||
add_opt(v_min_f16, v_min3_f16, 0x3, "120", nullptr, true);
|
||||
if (ctx.program->gfx_level >= GFX11)
|
||||
add_opt(v_max_f16, v_maxmin_f16, 0x3, "120", nullptr, true);
|
||||
else
|
||||
add_opt(v_max_f16, v_med3_f16, 0x3, "012", create_med3_cb<true>, true);
|
||||
} else if (info.opcode == aco_opcode::v_max_u16 && ctx.program->gfx_level >= GFX9) {
|
||||
add_opt(v_max_u16, v_max3_u16, 0x3, "120", nullptr, true);
|
||||
add_opt(v_min_u16, v_med3_u16, 0x3, "012", create_med3_cb<false>, true);
|
||||
} else if (info.opcode == aco_opcode::v_min_u16 && ctx.program->gfx_level >= GFX9) {
|
||||
add_opt(v_min_u16, v_min3_u16, 0x3, "120", nullptr, true);
|
||||
add_opt(v_max_u16, v_med3_u16, 0x3, "012", create_med3_cb<true>, true);
|
||||
} else if (info.opcode == aco_opcode::v_max_i16 && ctx.program->gfx_level >= GFX9) {
|
||||
add_opt(v_max_i16, v_max3_i16, 0x3, "120", nullptr, true);
|
||||
add_opt(v_min_i16, v_med3_i16, 0x3, "012", create_med3_cb<false>, true);
|
||||
} else if (info.opcode == aco_opcode::v_min_i16 && ctx.program->gfx_level >= GFX9) {
|
||||
add_opt(v_min_i16, v_min3_i16, 0x3, "120", nullptr, true);
|
||||
add_opt(v_max_i16, v_med3_i16, 0x3, "012", create_med3_cb<true>, true);
|
||||
} else if (info.opcode == aco_opcode::v_max_u16_e64) {
|
||||
add_opt(v_max_u16_e64, v_max3_u16, 0x3, "120", nullptr, true);
|
||||
add_opt(v_min_u16_e64, v_med3_u16, 0x3, "012", create_med3_cb<false>, true);
|
||||
} else if (info.opcode == aco_opcode::v_min_u16_e64) {
|
||||
add_opt(v_min_u16_e64, v_min3_u16, 0x3, "120", nullptr, true);
|
||||
add_opt(v_max_u16_e64, v_med3_u16, 0x3, "012", create_med3_cb<true>, true);
|
||||
} else if (info.opcode == aco_opcode::v_max_i16_e64) {
|
||||
add_opt(v_max_i16_e64, v_max3_i16, 0x3, "120", nullptr, true);
|
||||
add_opt(v_min_i16_e64, v_med3_i16, 0x3, "012", create_med3_cb<false>, true);
|
||||
} else if (info.opcode == aco_opcode::v_min_i16_e64) {
|
||||
add_opt(v_min_i16_e64, v_min3_i16, 0x3, "120", nullptr, true);
|
||||
add_opt(v_max_i16_e64, v_med3_i16, 0x3, "012", create_med3_cb<true>, true);
|
||||
}
|
||||
|
||||
if (match_and_apply_patterns(ctx, info, patterns)) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue