nir: test loop analyze sets exact trip flags correctly

Introduces new test helper to create loop with multiple terminators
and tests some scenaros to make sure exact trip flags are set
correctly.

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32473>
This commit is contained in:
Timothy Arceri 2026-03-02 14:40:17 +11:00 committed by Marge Bot
parent 82b474c3fb
commit 06fc27b5a4

View file

@ -201,6 +201,108 @@ loop_builder_invert(nir_builder *b, loop_builder_invert_param p)
return loop;
}
struct loop_builder_two_terminators_param {
bool j_from_uniform;
bool j_from_input;
uint32_t i_init;
uint32_t j_init_const;
uint32_t i_limit;
uint32_t j_limit;
uint32_t j_incr;
nir_def *(*cond_instr)(nir_builder *,
nir_def *,
nir_def *);
};
static nir_loop *
loop_builder_two_terminators(nir_builder *b,
loop_builder_two_terminators_param p)
{
/* Create IR:
*
* uint i = i_init;
* uint j = (uniform/input/or const);
*
* while (true) {
* if (i >= i_limit)
* break;
*
* if (j >= j_limit)
* break;
*
* i++;
* j += j_incr;
* }
*/
nir_def *i0 = nir_imm_int(b, p.i_init);
nir_def *j0;
if (p.j_from_uniform) {
nir_def *one = nir_imm_int(b, 1);
nir_def *twelve = nir_imm_int(b, 12);
j0 = nir_load_ubo(b, 1, 32, one, twelve, (gl_access_qualifier)0, 0, 0, 0, 16);
} else if (p.j_from_input) {
nir_def *zero = nir_imm_int(b, 0);
j0 = nir_load_input(b, 1, 32, zero);
} else {
j0 = nir_imm_int(b, p.j_init_const);
}
nir_def *i_limit = nir_imm_int(b, p.i_limit);
nir_def *j_limit = nir_imm_int(b, p.j_limit);
nir_def *j_step = nir_imm_int(b, p.j_incr);
nir_def *one = nir_imm_int(b, 1);
nir_phi_instr *phi_i = nir_phi_instr_create(b->shader);
nir_phi_instr *phi_j = nir_phi_instr_create(b->shader);
nir_loop *loop = nir_push_loop(b);
{
nir_def_init(&phi_i->instr, &phi_i->def, 1, 32);
nir_def_init(&phi_j->instr, &phi_j->def, 1, 32);
nir_phi_instr_add_src(phi_i, nir_def_block(i0), i0);
nir_phi_instr_add_src(phi_j, nir_def_block(j0), j0);
nir_def *i = &phi_i->def;
nir_def *j = &phi_j->def;
nir_def *cond_i = p.cond_instr(b, i, i_limit);
nir_if *if_i = nir_push_if(b, cond_i);
{
nir_jump_instr *jump =
nir_jump_instr_create(b->shader, nir_jump_break);
nir_builder_instr_insert(b, &jump->instr);
}
nir_pop_if(b, if_i);
nir_def *cond_j = p.cond_instr(b, j, j_limit);
nir_if *if_j = nir_push_if(b, cond_j);
{
nir_jump_instr *jump =
nir_jump_instr_create(b->shader, nir_jump_break);
nir_builder_instr_insert(b, &jump->instr);
}
nir_pop_if(b, if_j);
nir_def *i_next = nir_iadd(b, i, one);
nir_def *j_next = nir_iadd(b, j, j_step);
nir_phi_instr_add_src(phi_i, nir_def_block(i_next), i_next);
nir_phi_instr_add_src(phi_j, nir_def_block(j_next), j_next);
}
nir_pop_loop(b, loop);
b->cursor = nir_before_block(nir_loop_first_block(loop));
nir_builder_instr_insert(b, &phi_i->instr);
nir_builder_instr_insert(b, &phi_j->instr);
return loop;
}
TEST_F(nir_loop_analyze_test, one_iteration_fneu)
{
/* Create IR:
@ -558,6 +660,49 @@ INOT_COMPARE(ilt_imin_rev)
EXPECT_FALSE(loop->info->exact_trip_count_known); \
}
#define TWO_TERMINATOR_INEXACT_COUNT_TEST(_name, _j_uniform, _j_input, _term_trip_known, _other_term_trip_known, cond, _i_limit, _j_limit, _j_step, _count) \
TEST_F(nir_loop_analyze_test, two_terminator_##_name) \
{ \
nir_loop *loop = \
loop_builder_two_terminators(&b, { \
.j_from_uniform = _j_uniform, \
.j_from_input = _j_input, \
.i_init = 0, \
.j_init_const = 0, \
.i_limit = _i_limit, \
.j_limit = _j_limit, \
.j_incr = _j_step, \
.cond_instr = nir_ ## cond, \
}); \
\
nir_validate_shader(b.shader, "input"); \
\
nir_loop_analyze_impl(b.impl, nir_var_all, false); \
\
ASSERT_NE((void *)0, loop->info); \
EXPECT_NE((void *)0, loop->info->limiting_terminator); \
EXPECT_EQ(_count, loop->info->max_trip_count); \
EXPECT_FALSE(loop->info->exact_trip_count_known); \
if (!_term_trip_known) \
EXPECT_TRUE(loop->info->limiting_terminator->exact_trip_count_unknown); \
else \
EXPECT_FALSE(loop->info->limiting_terminator->exact_trip_count_unknown); \
\
list_for_each_entry_safe(nir_loop_terminator, t, \
&loop->info->loop_terminator_list, \
loop_terminator_link) { \
if (t != loop->info->limiting_terminator) { \
if (!_other_term_trip_known) \
EXPECT_TRUE(t->exact_trip_count_unknown); \
else \
EXPECT_FALSE(t->exact_trip_count_unknown); \
} \
} \
\
ASSERT_NE((void *)0, loop->info->induction_vars); \
EXPECT_GE(_mesa_hash_table_num_entries(loop->info->induction_vars), 2); \
}
/* float i = 0.0;
* while (true) {
* if (i == 0.9)
@ -1736,3 +1881,131 @@ INEXACT_COUNT_TEST_UNKNOWN_INIT(0x00000004, 0x00000006, uge, iadd, 1, 1, 1)
* }
*/
INEXACT_COUNT_TEST_UNKNOWN_INIT(0x00000004, 0x00000006, uge, iadd, 1, 0, 0)
/* uniform uint x;
* uint i = 0;
* uint j = x;
* while (true) {
* if (i >= 4)
* break;
*
* if (j >= 12)
* break;
*
* i++;
* j += 6;
* }
*/
TWO_TERMINATOR_INEXACT_COUNT_TEST(uniform_j_limit, true, false, false, true, uge, 4, 12, 6, 2)
/* uniform uint x;
* uint i = 0;
* uint j = x;
* while (true) {
* if (i >= 4)
* break;
*
* if (j >= 30)
* break;
*
* i++;
* j += 6;
* }
*/
TWO_TERMINATOR_INEXACT_COUNT_TEST(uniform_j_limit2, true, false, true, false, uge, 4, 30, 6, 4)
/* uniform int x;
* int i = 0;
* int j = x;
* while (true) {
* if (i >= 4)
* break;
*
* if (j >= 12)
* break;
*
* i++;
* j += 6;
* }
*/
TWO_TERMINATOR_INEXACT_COUNT_TEST(uniform_signed_j_limit, true, false, true, false, ige, 4, 12, 6, 4)
/* uniform int x;
* int i = 0;
* int j = x;
* while (true) {
* if (i >= 4)
* break;
*
* if (j >= 30)
* break;
*
* i++;
* j += 6;
* }
*/
TWO_TERMINATOR_INEXACT_COUNT_TEST(uniform_signed_j_limit2, true, false, true, false, ige, 4, 30, 6, 4)
/* in int x;
* int i = 0;
* int j = x;
* while (true) {
* if (i >= 4)
* break;
*
* if (j >= 30)
* break;
*
* i++;
* j += 6;
* }
*/
TWO_TERMINATOR_INEXACT_COUNT_TEST(input_signed_j_limit, false, true, true, false, ige, 4, 30, 6, 4)
/* in int x;
* int i = 0;
* int j = x;
* while (true) {
* if (i >= 4)
* break;
*
* if (j >= 12)
* break;
*
* i++;
* j += 6;
* }
*/
TWO_TERMINATOR_INEXACT_COUNT_TEST(input_signed_limit_j_limit2, false, true, true, false, ige, 4, 12, 6, 4)
/* in uint x;
* uint i = 0;
* uint j = x;
* while (true) {
* if (i >= 4)
* break;
*
* if (j >= 12)
* break;
*
* i++;
* j += 6;
* }
*/
TWO_TERMINATOR_INEXACT_COUNT_TEST(input_unsigned_limit_j_limit, false, true, false, true, uge, 4, 12, 6, 2)
/* in uint x;
* uint i = 0;
* uint j = x;
* while (true) {
* if (i >= 4)
* break;
*
* if (j >= 30)
* break;
*
* i++;
* j += 6;
* }
*/
TWO_TERMINATOR_INEXACT_COUNT_TEST(input_unsigned_j_limit2, false, true, true, false, uge, 4, 30, 6, 4)