aco/optimizer: support fma_mix with rtz

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38815>
This commit is contained in:
Georg Lehmann 2025-10-19 16:51:55 +02:00 committed by Marge Bot
parent 6b9d28ab9b
commit c6b74705dd

View file

@ -364,6 +364,7 @@ struct alu_opt_info {
uint8_t omod = 0;
bool clamp = false;
bool f32_to_f16 = false;
bool f32_to_f16_rtz = false;
SubdwordSel insert = SubdwordSel::dword;
bool try_swap_operands(unsigned idx0, unsigned idx1)
@ -600,7 +601,8 @@ format_is(Format f1, Format f2)
bool
try_vinterp_inreg(opt_ctx& ctx, alu_opt_info& info)
{
if (ctx.program->gfx_level < GFX11 || info.opcode != aco_opcode::v_fma_f32 || info.omod)
if (ctx.program->gfx_level < GFX11 || info.opcode != aco_opcode::v_fma_f32 || info.omod ||
info.f32_to_f16_rtz)
return false;
bool fp16 = info.f32_to_f16;
@ -691,8 +693,6 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info)
case aco_opcode::v_fma_legacy_f32:
case aco_opcode::v_fma_f16:
case aco_opcode::v_fma_legacy_f16:
case aco_opcode::v_fma_mix_f32:
case aco_opcode::v_fma_mixlo_f16:
case aco_opcode::v_pk_mul_f16:
case aco_opcode::v_pk_fma_f16:
case aco_opcode::s_mul_f32:
@ -834,6 +834,10 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info)
}
}
assert(!info.f32_to_f16_rtz || info.f32_to_f16);
if (info.f32_to_f16_rtz && ctx.fp_mode.round16_64 == fp_round_tz)
info.f32_to_f16_rtz = false;
/* convert to VINTERP_INREG */
try_vinterp_inreg(ctx, info);
@ -871,8 +875,13 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info)
default: return false;
}
info.opcode = info.f32_to_f16 ? aco_opcode::v_fma_mixlo_f16 : aco_opcode::v_fma_mix_f32;
info.format = Format::VOP3P;
if (info.f32_to_f16_rtz)
info.opcode = aco_opcode::p_v_fma_mixlo_f16_rtz;
else if (info.f32_to_f16)
info.opcode = aco_opcode::v_fma_mixlo_f16;
else
info.opcode = aco_opcode::v_fma_mix_f32;
}
/* remove negate modifiers by converting to subtract */
@ -1126,8 +1135,9 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info)
info.defs[0].setPrecolored(vcc);
}
} else if (format_is(info.format, Format::VOP3P)) {
bool fmamix =
info.opcode == aco_opcode::v_fma_mix_f32 || info.opcode == aco_opcode::v_fma_mixlo_f16;
bool fmamix = info.opcode == aco_opcode::v_fma_mix_f32 ||
info.opcode == aco_opcode::v_fma_mixlo_f16 ||
info.opcode == aco_opcode::p_v_fma_mixlo_f16_rtz;
bool dot2_f32 =
info.opcode == aco_opcode::v_dot2_f32_f16 || info.opcode == aco_opcode::v_dot2_f32_bf16;
bool supports_dpp = (fmamix || dot2_f32) && ctx.program->gfx_level >= GFX11;
@ -1253,9 +1263,11 @@ alu_opt_gather_info(opt_ctx& ctx, Instruction* instr, alu_opt_info& info)
opsel = 0;
}
if (instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16) {
if (instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) {
info.opcode = ctx.program->dev.fused_mad_mix ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
info.f32_to_f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16;
info.f32_to_f16_rtz = instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz;
info.f32_to_f16 = info.f32_to_f16_rtz || instr->opcode == aco_opcode::v_fma_mixlo_f16;
}
if (instr->isSDWA())
@ -1268,7 +1280,8 @@ alu_opt_gather_info(opt_ctx& ctx, Instruction* instr, alu_opt_info& info)
alu_opt_op op_info = {};
op_info.op = instr->operands[i];
if (instr->opcode == aco_opcode::v_fma_mix_f32 ||
instr->opcode == aco_opcode::v_fma_mixlo_f16) {
instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) {
op_info.neg[0] = instr->valu().neg[i];
op_info.abs[0] = instr->valu().abs[i];
if (instr->valu().opsel_hi[i]) {
@ -1490,7 +1503,8 @@ alu_opt_info_to_instr(opt_ctx& ctx, alu_opt_info& info, Instruction* old_instr)
for (unsigned i = 0; i < info.operands.size(); i++) {
instr->operands[i] = info.operands[i].op;
if (instr->opcode == aco_opcode::v_fma_mix_f32 ||
instr->opcode == aco_opcode::v_fma_mixlo_f16) {
instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) {
instr->valu().neg[i] = info.operands[i].neg[0];
instr->valu().abs[i] = info.operands[i].abs[0];
instr->valu().opsel_hi[i] = info.operands[i].f16_to_f32;
@ -5207,7 +5221,8 @@ static void
opt_fma_mix_acc(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
/* fma_mix is only dual issued on gfx11 if dst and acc type match */
bool f2f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16;
bool f2f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz;
if (instr->valu().opsel_hi[2] == f2f16 || instr->isDPP())
return;
@ -5358,7 +5373,9 @@ apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
if (instr->isSOPC() && ctx.program->gfx_level < GFX12)
try_convert_sopc_to_sopk(instr);
if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || instr->opcode == aco_opcode::v_fma_mix_f32)
if (instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
instr->opcode == aco_opcode::v_fma_mix_f32)
opt_fma_mix_acc(ctx, instr);
if (instr->opcode == aco_opcode::v_mul_f64 || instr->opcode == aco_opcode::v_mul_f64_e64)