aco/gfx11.5: select s_(ceil|floor|trunc|rndne)

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29245>
This commit is contained in:
Georg Lehmann 2023-09-21 19:53:12 +02:00 committed by Marge Bot
parent 33a719b3e2
commit 1efb7754fc
2 changed files with 55 additions and 33 deletions

View file

@ -2680,6 +2680,13 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
emit_vop1_instruction(ctx, instr, aco_opcode::v_fract_f32, dst);
} else if (dst.regClass() == v2) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_fract_f64, dst);
} else if (dst.regClass() == s1) {
Temp src = get_alu_src(ctx, instr->src[0]);
aco_opcode op =
instr->def.bit_size == 16 ? aco_opcode::s_floor_f16 : aco_opcode::s_floor_f32;
Temp floor = bld.sop1(op, bld.def(s1), src);
op = instr->def.bit_size == 16 ? aco_opcode::s_sub_f16 : aco_opcode::s_sub_f32;
bld.sop2(op, Definition(dst), src, floor);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
@ -2693,6 +2700,11 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
} else if (dst.regClass() == v2) {
Temp src = get_alu_src(ctx, instr->src[0]);
emit_floor_f64(ctx, bld, Definition(dst), src);
} else if (dst.regClass() == s1) {
Temp src = get_alu_src(ctx, instr->src[0]);
aco_opcode op =
instr->def.bit_size == 16 ? aco_opcode::s_floor_f16 : aco_opcode::s_floor_f32;
bld.sop1(op, Definition(dst), src);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
@ -2725,6 +2737,11 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
bld.copy(bld.def(v1), Operand::zero()), add);
bld.vop3(aco_opcode::v_add_f64_e64, Definition(dst), trunc, add);
}
} else if (dst.regClass() == s1) {
Temp src = get_alu_src(ctx, instr->src[0]);
aco_opcode op =
instr->def.bit_size == 16 ? aco_opcode::s_ceil_f16 : aco_opcode::s_ceil_f32;
bld.sop1(op, Definition(dst), src);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
@ -2738,6 +2755,11 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
} else if (dst.regClass() == v2) {
Temp src = get_alu_src(ctx, instr->src[0]);
emit_trunc_f64(ctx, bld, Definition(dst), src);
} else if (dst.regClass() == s1) {
Temp src = get_alu_src(ctx, instr->src[0]);
aco_opcode op =
instr->def.bit_size == 16 ? aco_opcode::s_trunc_f16 : aco_opcode::s_trunc_f32;
bld.sop1(op, Definition(dst), src);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
@ -2786,6 +2808,11 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), dst0, dst1);
}
} else if (dst.regClass() == s1) {
Temp src = get_alu_src(ctx, instr->src[0]);
aco_opcode op =
instr->def.bit_size == 16 ? aco_opcode::s_rndne_f16 : aco_opcode::s_rndne_f32;
bld.sop1(op, Definition(dst), src);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}

View file

@ -310,7 +310,24 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_instr_type_alu: {
nir_alu_instr* alu_instr = nir_instr_as_alu(instr);
RegType type = alu_instr->def.divergent ? RegType::vgpr : RegType::sgpr;
/* packed 16bit instructions have to be VGPR */
if (alu_instr->def.num_components == 2 &&
nir_op_infos[alu_instr->op].output_size == 0)
type = RegType::vgpr;
switch (alu_instr->op) {
case nir_op_f2i16:
case nir_op_f2u16:
case nir_op_f2i32:
case nir_op_f2u32:
case nir_op_b2i8:
case nir_op_b2i16:
case nir_op_b2i32:
case nir_op_b2b32:
case nir_op_b2f16:
case nir_op_b2f32:
case nir_op_mov: break;
case nir_op_fmul:
case nir_op_fmulz:
case nir_op_fadd:
@ -328,11 +345,6 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_fsqrt:
case nir_op_fexp2:
case nir_op_flog2:
case nir_op_ffract:
case nir_op_ffloor:
case nir_op_fceil:
case nir_op_ftrunc:
case nir_op_fround_even:
case nir_op_fsin_amd:
case nir_op_fcos_amd:
case nir_op_f2f16:
@ -377,35 +389,18 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_sdot_2x16_iadd:
case nir_op_udot_2x16_uadd_sat:
case nir_op_sdot_2x16_iadd_sat: type = RegType::vgpr; break;
case nir_op_f2i16:
case nir_op_f2u16:
case nir_op_f2i32:
case nir_op_f2u32:
case nir_op_b2i8:
case nir_op_b2i16:
case nir_op_b2i32:
case nir_op_b2b32:
case nir_op_b2f16:
case nir_op_b2f32:
case nir_op_mov: break;
case nir_op_iabs:
case nir_op_iadd:
case nir_op_iadd_sat:
case nir_op_uadd_sat:
case nir_op_isub:
case nir_op_isub_sat:
case nir_op_usub_sat:
case nir_op_imul:
case nir_op_imin:
case nir_op_imax:
case nir_op_umin:
case nir_op_umax:
case nir_op_ishl:
case nir_op_ishr:
case nir_op_ushr:
/* packed 16bit instructions have to be VGPR */
type = alu_instr->def.num_components == 2 ? RegType::vgpr : type;
case nir_op_ffract:
case nir_op_ffloor:
case nir_op_fceil:
case nir_op_ftrunc:
case nir_op_fround_even: {
if (ctx->program->gfx_level < GFX11_5 ||
alu_instr->src[0].src.ssa->bit_size > 32) {
type = RegType::vgpr;
break;
}
FALLTHROUGH;
}
default:
for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
if (regclasses[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr)