brw: Fix single patch thread dispatch masks in NIR

Arguably a little more code but it brings us a bit closer to not
needing separate per-stage "run" functions.

Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40328>
This commit is contained in:
Kenneth Graunke 2026-03-09 00:58:58 -07:00 committed by Marge Bot
parent 4a9aa3ecc4
commit 66fbfe7bf3

View file

@ -138,7 +138,6 @@ run_tcs(brw_shader &s)
assert(s.stage == MESA_SHADER_TESS_CTRL);
struct brw_vue_prog_data *vue_prog_data = brw_vue_prog_data(s.prog_data);
const brw_builder bld = brw_builder(&s);
assert(vue_prog_data->dispatch_mode == INTEL_DISPATCH_MODE_TCS_SINGLE_PATCH ||
vue_prog_data->dispatch_mode == INTEL_DISPATCH_MODE_TCS_MULTI_PATCH);
@ -148,23 +147,8 @@ run_tcs(brw_shader &s)
/* Initialize gl_InvocationID */
brw_set_tcs_invocation_id(s);
const bool fix_dispatch_mask =
vue_prog_data->dispatch_mode == INTEL_DISPATCH_MODE_TCS_SINGLE_PATCH &&
(s.nir->info.tess.tcs_vertices_out % 8) != 0;
/* Fix the disptach mask */
if (fix_dispatch_mask) {
bld.CMP(bld.null_reg_ud(), s.invocation_id,
brw_imm_ud(s.nir->info.tess.tcs_vertices_out), BRW_CONDITIONAL_L);
bld.IF(BRW_PREDICATE_NORMAL);
}
brw_from_nir(&s);
if (fix_dispatch_mask) {
bld.emit(BRW_OPCODE_ENDIF);
}
if (s.failed)
return false;
@ -187,6 +171,30 @@ run_tcs(brw_shader &s)
return !s.failed;
}
static bool
fix_single_patch_thread_dispatch(nir_shader *nir)
{
if ((nir->info.tess.tcs_vertices_out % 8) == 0)
return false;
/* Wrap shader inside if (InvocationID < VerticesOut) */
nir_function_impl *impl = nir_shader_get_entrypoint(nir);
if (nir_cf_list_is_empty_block(&impl->body))
return false;
nir_cf_list body;
nir_cf_list_extract(&body, &impl->body);
nir_builder b = nir_builder_at(nir_after_impl(impl));
nir_push_if(&b, nir_ilt_imm(&b, nir_load_invocation_id(&b),
nir->info.tess.tcs_vertices_out));
nir_cf_reinsert(&body, b.cursor);
nir_pop_if(&b, NULL);
return true;
}
extern "C" const unsigned *
brw_compile_tcs(const struct brw_compiler *compiler,
struct brw_compile_tcs_params *params)
@ -234,6 +242,9 @@ brw_compile_tcs(const struct brw_compiler *compiler,
brw_nir_opt_vectorize_urb(pt);
BRW_NIR_PASS(intel_nir_lower_patch_vertices_in, key->input_vertices);
if (!intel_use_tcs_multi_patch(devinfo))
BRW_NIR_PASS(fix_single_patch_thread_dispatch);
brw_postprocess_nir(pt, debug_enabled, key->base.robust_flags);
bool has_primitive_id =