mesa/st: fix unlower_io_to_vars to work with mesh shaders

cc: mesa-stable

Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/15034
Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/15040

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37408>
This commit is contained in:
Mike Blumenkrantz 2025-09-16 10:06:46 -04:00 committed by Marge Bot
parent e604a8f617
commit 3dbb7e896d

View file

@ -7,6 +7,7 @@
struct io_desc {
bool is_per_vertex;
bool is_per_primitive;
bool is_output;
bool is_store;
bool is_indirect;
@ -28,10 +29,11 @@ static bool var_is_per_vertex(mesa_shader_stage stage, nir_variable *var)
return ((stage == MESA_SHADER_TESS_CTRL ||
stage == MESA_SHADER_GEOMETRY) &&
var->data.mode & nir_var_shader_in) ||
(stage == MESA_SHADER_MESH && var->data.mode & nir_var_shader_out) ||
(((stage == MESA_SHADER_TESS_CTRL && var->data.mode & nir_var_shader_out) ||
(stage == MESA_SHADER_TESS_EVAL && var->data.mode & nir_var_shader_in)) &&
!(var->data.location == VARYING_SLOT_TESS_LEVEL_INNER ||
var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER ||
!((stage != MESA_SHADER_MESH && var->data.location == VARYING_SLOT_TESS_LEVEL_INNER) ||
(stage != MESA_SHADER_MESH && var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER) ||
(var->data.location >= VARYING_SLOT_PATCH0 &&
var->data.location <= VARYING_SLOT_PATCH31)));
}
@ -84,6 +86,9 @@ parse_intrinsic(nir_shader *nir, nir_intrinsic_instr *intr,
case nir_intrinsic_load_output:
desc->is_output = true;
break;
case nir_intrinsic_load_per_primitive_input:
desc->is_per_primitive = true;
break;
case nir_intrinsic_load_per_vertex_output:
desc->is_output = true;
desc->is_per_vertex = true;
@ -92,6 +97,11 @@ parse_intrinsic(nir_shader *nir, nir_intrinsic_instr *intr,
desc->is_output = true;
desc->is_store = true;
break;
case nir_intrinsic_store_per_primitive_output:
desc->is_output = true;
desc->is_per_primitive = true;
desc->is_store = true;
break;
case nir_intrinsic_store_per_vertex_output:
desc->is_output = true;
desc->is_per_vertex = true;
@ -265,10 +275,33 @@ create_vars(nir_builder *b, nir_intrinsic_instr *intr, void *opaque)
switch (desc.sem.location) {
case VARYING_SLOT_TESS_LEVEL_OUTER:
var_type = glsl_array_type(glsl_float_type(), 4, sizeof(float));
if (nir->info.stage == MESA_SHADER_TESS_CTRL || nir->info.stage == MESA_SHADER_TESS_EVAL)
var_type = glsl_array_type(glsl_float_type(), 4, sizeof(float));
else
/* VARYING_SLOT_PRIMITIVE_COUNT */
var_type = glsl_uint_type();
break;
case VARYING_SLOT_TESS_LEVEL_INNER:
var_type = glsl_array_type(glsl_float_type(), 2, sizeof(float));
if (nir->info.stage == MESA_SHADER_TESS_CTRL || nir->info.stage == MESA_SHADER_TESS_EVAL) {
var_type = glsl_array_type(glsl_float_type(), 2, sizeof(float));
} else {
/* VARYING_SLOT_PRIMITIVE_INDICES */
unsigned num_components = 0;
switch (nir->info.mesh.primitive_type) {
case MESA_PRIM_POINTS:
num_components = 1;
break;
case MESA_PRIM_LINES:
num_components = 2;
break;
case MESA_PRIM_TRIANGLES:
num_components = 3;
break;
default:
UNREACHABLE("impossible prim type");
}
var_type = glsl_vector_type(GLSL_TYPE_UINT, num_components);
}
break;
case VARYING_SLOT_CLIP_DIST0:
case VARYING_SLOT_CLIP_DIST1:
@ -302,6 +335,9 @@ create_vars(nir_builder *b, nir_intrinsic_instr *intr, void *opaque)
num_components = 1;
break;
case VARYING_SLOT_TESS_LEVEL_INNER:
if (nir->info.stage == MESA_SHADER_MESH)
break;
FALLTHROUGH;
case VARYING_SLOT_PNTC:
num_components = 2;
break;
@ -327,20 +363,25 @@ create_vars(nir_builder *b, nir_intrinsic_instr *intr, void *opaque)
var_type = glsl_array_type(var_type, desc.sem.num_slots, 0);
}
unsigned num_vertices = 0;
unsigned num_array_elements = 0;
if (desc.is_per_vertex) {
if (nir->info.stage == MESA_SHADER_TESS_CTRL)
num_vertices = desc.is_output ? nir->info.tess.tcs_vertices_out : 32;
num_array_elements = desc.is_output ? nir->info.tess.tcs_vertices_out : 32;
else if (nir->info.stage == MESA_SHADER_TESS_EVAL && !desc.is_output)
num_vertices = 32;
num_array_elements = 32;
else if (nir->info.stage == MESA_SHADER_GEOMETRY && !desc.is_output)
num_vertices = mesa_vertices_per_prim(nir->info.gs.input_primitive);
num_array_elements = mesa_vertices_per_prim(nir->info.gs.input_primitive);
else if (nir->info.stage == MESA_SHADER_MESH && desc.is_output)
num_array_elements = nir->info.mesh.max_vertices_out;
else
UNREACHABLE("unexpected shader stage for per-vertex IO");
var_type = glsl_array_type(var_type, num_vertices, 0);
} else if (desc.is_per_primitive) {
if (nir->info.stage == MESA_SHADER_MESH && desc.is_output)
num_array_elements = nir->info.mesh.max_primitives_out;
}
if (num_array_elements)
var_type = glsl_array_type(var_type, num_array_elements, 0);
const char *name = intr->name;
if (!name) {
@ -360,6 +401,7 @@ create_vars(nir_builder *b, nir_intrinsic_instr *intr, void *opaque)
var->data.driver_location = nir_intrinsic_base(intr) -
(desc.sem.high_dvec2 ? 1 : 0);
var->data.compact = desc.is_compact;
var->data.per_primitive = desc.is_per_primitive;
var->data.precision = desc.sem.medium_precision ? GLSL_PRECISION_MEDIUM
: GLSL_PRECISION_HIGH;
var->data.index = desc.sem.dual_source_blend_index;
@ -419,8 +461,8 @@ create_vars(nir_builder *b, nir_intrinsic_instr *intr, void *opaque)
var->type = glsl_array_type(elem_type, new_num_slots, 0);
if (var_is_per_vertex(nir->info.stage, var)) {
assert(num_vertices);
var->type = glsl_array_type(var->type, num_vertices, 0);
assert(num_array_elements);
var->type = glsl_array_type(var->type, num_array_elements, 0);
}
}
@ -526,7 +568,8 @@ unlower_io_to_vars(nir_builder *b, nir_intrinsic_instr *intr, void *opaque)
assert(var);
nir_deref_instr *deref = nir_build_deref_var(b, var);
if (desc.is_per_vertex) {
/* fragment shaders can have non-arrayed per-primitive mesh inputs */
if (b->shader->info.stage != MESA_SHADER_FRAGMENT && (desc.is_per_vertex || desc.is_per_primitive)) {
deref = nir_build_deref_array(b, deref,
nir_get_io_arrayed_index_src(intr)->ssa);
}
@ -599,8 +642,9 @@ unlower_io_to_vars(nir_builder *b, nir_intrinsic_instr *intr, void *opaque)
* the GLSL compiler never vectorized them. Doing 1 store per bit of
* the writemask is enough to make virgl work.
*/
if (desc.sem.location == VARYING_SLOT_TESS_LEVEL_OUTER ||
desc.sem.location == VARYING_SLOT_TESS_LEVEL_INNER) {
if (b->shader->info.stage != MESA_SHADER_MESH &&
(desc.sem.location == VARYING_SLOT_TESS_LEVEL_OUTER ||
desc.sem.location == VARYING_SLOT_TESS_LEVEL_INNER)) {
u_foreach_bit(i, writemask) {
nir_build_store_deref(b, &deref->def, value,
.write_mask = BITFIELD_BIT(i),