ir3: Late lowering of fmul+fadd to ffma

Since we know our mad.f16/mad.f32 is unfused, we can also apply this opt
in the exact case.

Signed-off-by: Rob Clark <rob.clark@oss.qualcomm.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40271>
This commit is contained in:
Rob Clark 2026-03-11 11:00:34 -07:00 committed by Marge Bot
parent ec18b2d28a
commit 92d2671af6
6 changed files with 84 additions and 3 deletions

View file

@ -135,7 +135,9 @@ static const nir_shader_compiler_options ir3_base_options = {
* SPIRV, and NIR don't require either fused or unfused behavior from
* fma, and we'll turn mul+adds back into nir_op_ffma (again, implemented
* as unfused) during nir_opt_algebraic_late() (assuming it's not
* decorated with GLSL's precise, or SPIRV's NoContraction).
* decorated with GLSL's precise, or SPIRV's NoContraction), or
* ir3_nir_opt_algebraic_late (if it is, since ir3's unfused mul-add is
* precise).
*/
.lower_ffma16 = true,
.lower_ffma32 = true,

View file

@ -90,6 +90,7 @@ ir3_context_init(struct ir3_compiler *compiler, struct ir3_shader *shader,
/* nir_opt_algebraic() above would have unfused our ffmas, re-fuse them. */
if (needs_late_alg) {
NIR_PASS(progress, ctx->s, nir_opt_algebraic_late);
NIR_PASS(progress, ctx->s, ir3_nir_opt_algebraic_late);
NIR_PASS(progress, ctx->s, nir_opt_dce);
}

View file

@ -1639,7 +1639,8 @@ ir3_nir_lower_variant(struct ir3_shader_variant *so,
*/
bool more_late_algebraic = true;
while (more_late_algebraic) {
more_late_algebraic = OPT(s, nir_opt_algebraic_late);
more_late_algebraic = OPT(s, nir_opt_algebraic_late) ||
OPT(s, ir3_nir_opt_algebraic_late);
if (!more_late_algebraic && so->compiler->gen >= 5) {
/* Lowers texture operations that have only f2f16 or u2u16 called on
* them to have a 16-bit destination. Also, lower 16-bit texture

View file

@ -62,6 +62,7 @@ nir_mem_access_size_align ir3_mem_access_size_align(
bool ir3_nir_opt_branch_and_or_not(nir_shader *nir);
bool ir3_nir_opt_triops_bitwise(nir_shader *nir);
bool ir3_nir_opt_algebraic_late(nir_shader *nir);
struct ir3_optimize_options {
nir_opt_uub_options opt_uub_options;

View file

@ -0,0 +1,57 @@
#
# Copyright © 2016 Intel Corporation
#
# SPDX-License-Identifier: MIT
import argparse
import sys
# fuse fadd+fmul late to get something we can turn into mad.f32/f16. The
# common nir_opt_algebraic_late pass only does this for non-exact patterns.
# Since for us, mad is not fused, we don't have this restriction.
late_optimizations = []
a = 'a'
b = 'b'
c = 'c'
for sz in [16, 32]:
# Fuse the correct fmul. Only consider fmuls where the only users are fadd
# (or fneg/fabs which are assumed to be propagated away), as a heuristic to
# avoid fusing in cases where it's harmful.
fmul = 'fmul(is_only_used_by_fadd)'
ffma = 'ffma'
fadd = 'fadd@{}'.format(sz)
late_optimizations.extend([
((fadd, (fmul, a, b), c), (ffma, a, b, c)),
((fadd, ('fneg(is_only_used_by_fadd)', (fmul, a, b)), c),
(ffma, ('fneg', a), b, c)),
((fadd, ('fabs(is_only_used_by_fadd)', (fmul, a, b)), c),
(ffma, ('fabs', a), ('fabs', b), c)),
((fadd, ('fneg(is_only_used_by_fadd)', ('fabs', (fmul, a, b))), c),
(ffma, ('fneg', ('fabs', a)), ('fabs', b), c)),
])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--import-path', required=True)
args = parser.parse_args()
sys.path.insert(0, args.import_path)
run()
def run():
import nir_algebraic # pylint: disable=import-error
print('#include "ir3_nir.h"')
print(nir_algebraic.AlgebraicPass("ir3_nir_opt_algebraic_late",
late_optimizations).render())
if __name__ == '__main__':
main()

View file

@ -45,6 +45,17 @@ ir3_nir_triop_bitwise_c = custom_target(
depend_files : nir_algebraic_depends,
)
ir3_nir_opt_algebraic_late_c = custom_target(
'ir3_nir_opt_algebraic_late.c',
input : 'ir3_nir_opt_algebraic_late.py',
output : 'ir3_nir_opt_algebraic_late.c',
command : [
prog_python, '@INPUT@', '-p', dir_compiler_nir,
],
capture : true,
depend_files : nir_algebraic_depends,
)
ir3_parser = custom_target(
'ir3_parser.[ch]',
input: 'ir3_parser.y',
@ -134,7 +145,15 @@ libfreedreno_ir3_files = files(
libfreedreno_ir3 = static_library(
'freedreno_ir3',
[libfreedreno_ir3_files, ir3_nir_trig_c, ir3_nir_imul_c, ir3_nir_branch_and_or_not_c, ir3_nir_triop_bitwise_c, ir3_parser[0], ir3_parser[1], ir3_lexer],
[libfreedreno_ir3_files,
ir3_nir_trig_c,
ir3_nir_imul_c,
ir3_nir_branch_and_or_not_c,
ir3_nir_triop_bitwise_c,
ir3_nir_opt_algebraic_late_c,
ir3_parser[0], ir3_parser[1],
ir3_lexer,
],
include_directories : [inc_freedreno, inc_include, inc_src],
c_args : [no_override_init_args],
gnu_symbol_visibility : 'hidden',