From c6b74705dd4c4fa21435ceb023ab689b8d36883c Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Sun, 19 Oct 2025 16:51:55 +0200 Subject: [PATCH] aco/optimizer: support fma_mix with rtz MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_optimizer.cpp | 41 +++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 00f82aa5532..c8ec3b0222e 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -364,6 +364,7 @@ struct alu_opt_info { uint8_t omod = 0; bool clamp = false; bool f32_to_f16 = false; + bool f32_to_f16_rtz = false; SubdwordSel insert = SubdwordSel::dword; bool try_swap_operands(unsigned idx0, unsigned idx1) @@ -600,7 +601,8 @@ format_is(Format f1, Format f2) bool try_vinterp_inreg(opt_ctx& ctx, alu_opt_info& info) { - if (ctx.program->gfx_level < GFX11 || info.opcode != aco_opcode::v_fma_f32 || info.omod) + if (ctx.program->gfx_level < GFX11 || info.opcode != aco_opcode::v_fma_f32 || info.omod || + info.f32_to_f16_rtz) return false; bool fp16 = info.f32_to_f16; @@ -691,8 +693,6 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info) case aco_opcode::v_fma_legacy_f32: case aco_opcode::v_fma_f16: case aco_opcode::v_fma_legacy_f16: - case aco_opcode::v_fma_mix_f32: - case aco_opcode::v_fma_mixlo_f16: case aco_opcode::v_pk_mul_f16: case aco_opcode::v_pk_fma_f16: case aco_opcode::s_mul_f32: @@ -834,6 +834,10 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info) } } + assert(!info.f32_to_f16_rtz || info.f32_to_f16); + if (info.f32_to_f16_rtz && ctx.fp_mode.round16_64 == fp_round_tz) + info.f32_to_f16_rtz = false; + /* convert to VINTERP_INREG */ try_vinterp_inreg(ctx, info); @@ -871,8 +875,13 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info) default: return false; } - info.opcode = info.f32_to_f16 ? aco_opcode::v_fma_mixlo_f16 : aco_opcode::v_fma_mix_f32; info.format = Format::VOP3P; + if (info.f32_to_f16_rtz) + info.opcode = aco_opcode::p_v_fma_mixlo_f16_rtz; + else if (info.f32_to_f16) + info.opcode = aco_opcode::v_fma_mixlo_f16; + else + info.opcode = aco_opcode::v_fma_mix_f32; } /* remove negate modifiers by converting to subtract */ @@ -1126,8 +1135,9 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info) info.defs[0].setPrecolored(vcc); } } else if (format_is(info.format, Format::VOP3P)) { - bool fmamix = - info.opcode == aco_opcode::v_fma_mix_f32 || info.opcode == aco_opcode::v_fma_mixlo_f16; + bool fmamix = info.opcode == aco_opcode::v_fma_mix_f32 || + info.opcode == aco_opcode::v_fma_mixlo_f16 || + info.opcode == aco_opcode::p_v_fma_mixlo_f16_rtz; bool dot2_f32 = info.opcode == aco_opcode::v_dot2_f32_f16 || info.opcode == aco_opcode::v_dot2_f32_bf16; bool supports_dpp = (fmamix || dot2_f32) && ctx.program->gfx_level >= GFX11; @@ -1253,9 +1263,11 @@ alu_opt_gather_info(opt_ctx& ctx, Instruction* instr, alu_opt_info& info) opsel = 0; } - if (instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16) { + if (instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) { info.opcode = ctx.program->dev.fused_mad_mix ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32; - info.f32_to_f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16; + info.f32_to_f16_rtz = instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz; + info.f32_to_f16 = info.f32_to_f16_rtz || instr->opcode == aco_opcode::v_fma_mixlo_f16; } if (instr->isSDWA()) @@ -1268,7 +1280,8 @@ alu_opt_gather_info(opt_ctx& ctx, Instruction* instr, alu_opt_info& info) alu_opt_op op_info = {}; op_info.op = instr->operands[i]; if (instr->opcode == aco_opcode::v_fma_mix_f32 || - instr->opcode == aco_opcode::v_fma_mixlo_f16) { + instr->opcode == aco_opcode::v_fma_mixlo_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) { op_info.neg[0] = instr->valu().neg[i]; op_info.abs[0] = instr->valu().abs[i]; if (instr->valu().opsel_hi[i]) { @@ -1490,7 +1503,8 @@ alu_opt_info_to_instr(opt_ctx& ctx, alu_opt_info& info, Instruction* old_instr) for (unsigned i = 0; i < info.operands.size(); i++) { instr->operands[i] = info.operands[i].op; if (instr->opcode == aco_opcode::v_fma_mix_f32 || - instr->opcode == aco_opcode::v_fma_mixlo_f16) { + instr->opcode == aco_opcode::v_fma_mixlo_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) { instr->valu().neg[i] = info.operands[i].neg[0]; instr->valu().abs[i] = info.operands[i].abs[0]; instr->valu().opsel_hi[i] = info.operands[i].f16_to_f32; @@ -5207,7 +5221,8 @@ static void opt_fma_mix_acc(opt_ctx& ctx, aco_ptr& instr) { /* fma_mix is only dual issued on gfx11 if dst and acc type match */ - bool f2f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16; + bool f2f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz; if (instr->valu().opsel_hi[2] == f2f16 || instr->isDPP()) return; @@ -5358,7 +5373,9 @@ apply_literals(opt_ctx& ctx, aco_ptr& instr) if (instr->isSOPC() && ctx.program->gfx_level < GFX12) try_convert_sopc_to_sopk(instr); - if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || instr->opcode == aco_opcode::v_fma_mix_f32) + if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || + instr->opcode == aco_opcode::v_fma_mix_f32) opt_fma_mix_acc(ctx, instr); if (instr->opcode == aco_opcode::v_mul_f64 || instr->opcode == aco_opcode::v_mul_f64_e64)