diff --git a/src/compiler/nir/tests/loop_analyze_tests.cpp b/src/compiler/nir/tests/loop_analyze_tests.cpp index 4de9863fe70..5ffb08e9fc3 100644 --- a/src/compiler/nir/tests/loop_analyze_tests.cpp +++ b/src/compiler/nir/tests/loop_analyze_tests.cpp @@ -201,6 +201,108 @@ loop_builder_invert(nir_builder *b, loop_builder_invert_param p) return loop; } +struct loop_builder_two_terminators_param { + bool j_from_uniform; + bool j_from_input; + uint32_t i_init; + uint32_t j_init_const; + uint32_t i_limit; + uint32_t j_limit; + uint32_t j_incr; + nir_def *(*cond_instr)(nir_builder *, + nir_def *, + nir_def *); +}; + +static nir_loop * +loop_builder_two_terminators(nir_builder *b, + loop_builder_two_terminators_param p) +{ + /* Create IR: + * + * uint i = i_init; + * uint j = (uniform/input/or const); + * + * while (true) { + * if (i >= i_limit) + * break; + * + * if (j >= j_limit) + * break; + * + * i++; + * j += j_incr; + * } + */ + + nir_def *i0 = nir_imm_int(b, p.i_init); + nir_def *j0; + + if (p.j_from_uniform) { + nir_def *one = nir_imm_int(b, 1); + nir_def *twelve = nir_imm_int(b, 12); + j0 = nir_load_ubo(b, 1, 32, one, twelve, (gl_access_qualifier)0, 0, 0, 0, 16); + } else if (p.j_from_input) { + nir_def *zero = nir_imm_int(b, 0); + j0 = nir_load_input(b, 1, 32, zero); + } else { + j0 = nir_imm_int(b, p.j_init_const); + } + + nir_def *i_limit = nir_imm_int(b, p.i_limit); + nir_def *j_limit = nir_imm_int(b, p.j_limit); + nir_def *j_step = nir_imm_int(b, p.j_incr); + nir_def *one = nir_imm_int(b, 1); + + nir_phi_instr *phi_i = nir_phi_instr_create(b->shader); + nir_phi_instr *phi_j = nir_phi_instr_create(b->shader); + + nir_loop *loop = nir_push_loop(b); + { + nir_def_init(&phi_i->instr, &phi_i->def, 1, 32); + nir_def_init(&phi_j->instr, &phi_j->def, 1, 32); + + nir_phi_instr_add_src(phi_i, nir_def_block(i0), i0); + nir_phi_instr_add_src(phi_j, nir_def_block(j0), j0); + + nir_def *i = &phi_i->def; + nir_def *j = &phi_j->def; + + nir_def *cond_i = p.cond_instr(b, i, i_limit); + + nir_if *if_i = nir_push_if(b, cond_i); + { + nir_jump_instr *jump = + nir_jump_instr_create(b->shader, nir_jump_break); + nir_builder_instr_insert(b, &jump->instr); + } + nir_pop_if(b, if_i); + + nir_def *cond_j = p.cond_instr(b, j, j_limit); + + nir_if *if_j = nir_push_if(b, cond_j); + { + nir_jump_instr *jump = + nir_jump_instr_create(b->shader, nir_jump_break); + nir_builder_instr_insert(b, &jump->instr); + } + nir_pop_if(b, if_j); + + nir_def *i_next = nir_iadd(b, i, one); + nir_def *j_next = nir_iadd(b, j, j_step); + + nir_phi_instr_add_src(phi_i, nir_def_block(i_next), i_next); + nir_phi_instr_add_src(phi_j, nir_def_block(j_next), j_next); + } + nir_pop_loop(b, loop); + + b->cursor = nir_before_block(nir_loop_first_block(loop)); + nir_builder_instr_insert(b, &phi_i->instr); + nir_builder_instr_insert(b, &phi_j->instr); + + return loop; +} + TEST_F(nir_loop_analyze_test, one_iteration_fneu) { /* Create IR: @@ -558,6 +660,49 @@ INOT_COMPARE(ilt_imin_rev) EXPECT_FALSE(loop->info->exact_trip_count_known); \ } +#define TWO_TERMINATOR_INEXACT_COUNT_TEST(_name, _j_uniform, _j_input, _term_trip_known, _other_term_trip_known, cond, _i_limit, _j_limit, _j_step, _count) \ + TEST_F(nir_loop_analyze_test, two_terminator_##_name) \ + { \ + nir_loop *loop = \ + loop_builder_two_terminators(&b, { \ + .j_from_uniform = _j_uniform, \ + .j_from_input = _j_input, \ + .i_init = 0, \ + .j_init_const = 0, \ + .i_limit = _i_limit, \ + .j_limit = _j_limit, \ + .j_incr = _j_step, \ + .cond_instr = nir_ ## cond, \ + }); \ + \ + nir_validate_shader(b.shader, "input"); \ + \ + nir_loop_analyze_impl(b.impl, nir_var_all, false); \ + \ + ASSERT_NE((void *)0, loop->info); \ + EXPECT_NE((void *)0, loop->info->limiting_terminator); \ + EXPECT_EQ(_count, loop->info->max_trip_count); \ + EXPECT_FALSE(loop->info->exact_trip_count_known); \ + if (!_term_trip_known) \ + EXPECT_TRUE(loop->info->limiting_terminator->exact_trip_count_unknown); \ + else \ + EXPECT_FALSE(loop->info->limiting_terminator->exact_trip_count_unknown); \ + \ + list_for_each_entry_safe(nir_loop_terminator, t, \ + &loop->info->loop_terminator_list, \ + loop_terminator_link) { \ + if (t != loop->info->limiting_terminator) { \ + if (!_other_term_trip_known) \ + EXPECT_TRUE(t->exact_trip_count_unknown); \ + else \ + EXPECT_FALSE(t->exact_trip_count_unknown); \ + } \ + } \ + \ + ASSERT_NE((void *)0, loop->info->induction_vars); \ + EXPECT_GE(_mesa_hash_table_num_entries(loop->info->induction_vars), 2); \ + } + /* float i = 0.0; * while (true) { * if (i == 0.9) @@ -1736,3 +1881,131 @@ INEXACT_COUNT_TEST_UNKNOWN_INIT(0x00000004, 0x00000006, uge, iadd, 1, 1, 1) * } */ INEXACT_COUNT_TEST_UNKNOWN_INIT(0x00000004, 0x00000006, uge, iadd, 1, 0, 0) + +/* uniform uint x; + * uint i = 0; + * uint j = x; + * while (true) { + * if (i >= 4) + * break; + * + * if (j >= 12) + * break; + * + * i++; + * j += 6; + * } + */ +TWO_TERMINATOR_INEXACT_COUNT_TEST(uniform_j_limit, true, false, false, true, uge, 4, 12, 6, 2) + +/* uniform uint x; + * uint i = 0; + * uint j = x; + * while (true) { + * if (i >= 4) + * break; + * + * if (j >= 30) + * break; + * + * i++; + * j += 6; + * } + */ +TWO_TERMINATOR_INEXACT_COUNT_TEST(uniform_j_limit2, true, false, true, false, uge, 4, 30, 6, 4) + +/* uniform int x; + * int i = 0; + * int j = x; + * while (true) { + * if (i >= 4) + * break; + * + * if (j >= 12) + * break; + * + * i++; + * j += 6; + * } + */ +TWO_TERMINATOR_INEXACT_COUNT_TEST(uniform_signed_j_limit, true, false, true, false, ige, 4, 12, 6, 4) + +/* uniform int x; + * int i = 0; + * int j = x; + * while (true) { + * if (i >= 4) + * break; + * + * if (j >= 30) + * break; + * + * i++; + * j += 6; + * } + */ +TWO_TERMINATOR_INEXACT_COUNT_TEST(uniform_signed_j_limit2, true, false, true, false, ige, 4, 30, 6, 4) + +/* in int x; + * int i = 0; + * int j = x; + * while (true) { + * if (i >= 4) + * break; + * + * if (j >= 30) + * break; + * + * i++; + * j += 6; + * } + */ +TWO_TERMINATOR_INEXACT_COUNT_TEST(input_signed_j_limit, false, true, true, false, ige, 4, 30, 6, 4) + +/* in int x; + * int i = 0; + * int j = x; + * while (true) { + * if (i >= 4) + * break; + * + * if (j >= 12) + * break; + * + * i++; + * j += 6; + * } + */ +TWO_TERMINATOR_INEXACT_COUNT_TEST(input_signed_limit_j_limit2, false, true, true, false, ige, 4, 12, 6, 4) + +/* in uint x; + * uint i = 0; + * uint j = x; + * while (true) { + * if (i >= 4) + * break; + * + * if (j >= 12) + * break; + * + * i++; + * j += 6; + * } + */ +TWO_TERMINATOR_INEXACT_COUNT_TEST(input_unsigned_limit_j_limit, false, true, false, true, uge, 4, 12, 6, 2) + +/* in uint x; + * uint i = 0; + * uint j = x; + * while (true) { + * if (i >= 4) + * break; + * + * if (j >= 30) + * break; + * + * i++; + * j += 6; + * } + */ +TWO_TERMINATOR_INEXACT_COUNT_TEST(input_unsigned_j_limit2, false, true, true, false, uge, 4, 30, 6, 4)