radv/rt: Use function call structure in NIR lowering

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29580>
This commit is contained in:
Natalie Vock 2025-02-17 18:42:47 +01:00 committed by Marge Bot
parent d3cb8b4046
commit c5d796c902
25 changed files with 1777 additions and 457 deletions

View file

@ -198,20 +198,6 @@ struct ac_shader_args {
struct ac_arg shader_addr; struct ac_arg shader_addr;
struct ac_arg shader_record; struct ac_arg shader_record;
struct ac_arg payload_offset; struct ac_arg payload_offset;
struct ac_arg ray_origin;
struct ac_arg ray_tmin;
struct ac_arg ray_direction;
struct ac_arg ray_tmax;
struct ac_arg cull_mask_and_flags;
struct ac_arg sbt_offset;
struct ac_arg sbt_stride;
struct ac_arg miss_index;
struct ac_arg accel_struct;
struct ac_arg primitive_id;
struct ac_arg instance_addr;
struct ac_arg primitive_addr;
struct ac_arg geometry_id_and_flags;
struct ac_arg hit_kind;
} rt; } rt;
}; };

View file

@ -26,4 +26,87 @@ enum aco_nir_parameter_attribs {
ACO_NIR_PARAM_ATTRIB_DISCARDABLE = 0x1, ACO_NIR_PARAM_ATTRIB_DISCARDABLE = 0x1,
}; };
enum aco_nir_call_system_args {
ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC,
ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC,
ACO_NIR_CALL_SYSTEM_ARG_COUNT,
};
enum aco_nir_rt_function_arg {
RT_ARG_LAUNCH_ID = 0,
RT_ARG_LAUNCH_SIZE,
RT_ARG_DESCRIPTORS,
RT_ARG_DYNAMIC_DESCRIPTORS,
RT_ARG_PUSH_CONSTANTS,
RT_ARG_SBT_DESCRIPTORS,
RT_ARG_COUNT,
};
enum aco_nir_raygen_function_arg {
RAYGEN_ARG_TRAVERSAL_ADDR = RT_ARG_COUNT,
RAYGEN_ARG_SHADER_RECORD_PTR,
RAYGEN_ARG_COUNT,
};
enum aco_nir_traversal_function_arg {
TRAVERSAL_ARG_TRAVERSAL_ADDR = RT_ARG_COUNT,
TRAVERSAL_ARG_SHADER_RECORD_PTR,
TRAVERSAL_ARG_ACCEL_STRUCT,
TRAVERSAL_ARG_CULL_MASK_AND_FLAGS,
TRAVERSAL_ARG_SBT_OFFSET,
TRAVERSAL_ARG_SBT_STRIDE,
TRAVERSAL_ARG_MISS_INDEX,
TRAVERSAL_ARG_RAY_ORIGIN,
TRAVERSAL_ARG_RAY_TMIN,
TRAVERSAL_ARG_RAY_DIRECTION,
TRAVERSAL_ARG_RAY_TMAX,
TRAVERSAL_ARG_PRIMITIVE_ADDR,
TRAVERSAL_ARG_PRIMITIVE_ID,
TRAVERSAL_ARG_INSTANCE_ADDR,
TRAVERSAL_ARG_GEOMETRY_ID_AND_FLAGS,
TRAVERSAL_ARG_HIT_KIND,
TRAVERSAL_ARG_PAYLOAD_BASE,
};
enum aco_nir_chit_miss_function_arg {
CHIT_MISS_ARG_TRAVERSAL_ADDR = RT_ARG_COUNT,
CHIT_MISS_ARG_SHADER_RECORD_PTR,
CHIT_MISS_ARG_ACCEL_STRUCT,
CHIT_MISS_ARG_CULL_MASK_AND_FLAGS,
CHIT_MISS_ARG_SBT_OFFSET,
CHIT_MISS_ARG_SBT_STRIDE,
CHIT_MISS_ARG_MISS_INDEX,
CHIT_MISS_ARG_RAY_ORIGIN,
CHIT_MISS_ARG_RAY_TMIN,
CHIT_MISS_ARG_RAY_DIRECTION,
CHIT_MISS_ARG_RAY_TMAX,
CHIT_MISS_ARG_PRIMITIVE_ADDR,
CHIT_MISS_ARG_PRIMITIVE_ID,
CHIT_MISS_ARG_INSTANCE_ADDR,
CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS,
CHIT_MISS_ARG_HIT_KIND,
CHIT_MISS_ARG_PAYLOAD_BASE,
};
/* aco_nir_cps_function_arg extends aco_nir_raygen_function_arg */
enum aco_nir_cps_function_arg {
CPS_ARG_PAYLOAD_SCRATCH_OFFSET = RAYGEN_ARG_COUNT,
CPS_ARG_STACK_PTR,
CPS_ARG_ACCEL_STRUCT,
CPS_ARG_CULL_MASK_AND_FLAGS,
CPS_ARG_SBT_OFFSET,
CPS_ARG_SBT_STRIDE,
CPS_ARG_MISS_INDEX,
CPS_ARG_RAY_ORIGIN,
CPS_ARG_RAY_TMIN,
CPS_ARG_RAY_DIRECTION,
CPS_ARG_RAY_TMAX,
CPS_ARG_PRIMITIVE_ID,
CPS_ARG_INSTANCE_ADDR,
CPS_ARG_GEOMETRY_ID_AND_FLAGS,
CPS_ARG_HIT_KIND,
CPS_ARG_PRIMITIVE_ADDR,
CPS_ARG_COUNT,
};
#endif /* ACO_NIR_CALL_ATTRIBS_H */ #endif /* ACO_NIR_CALL_ATTRIBS_H */

View file

@ -71,6 +71,7 @@ libradv_files = files(
'nir/radv_nir_apply_pipeline_layout.c', 'nir/radv_nir_apply_pipeline_layout.c',
'nir/radv_nir_export_multiview.c', 'nir/radv_nir_export_multiview.c',
'nir/radv_nir_lower_abi.c', 'nir/radv_nir_lower_abi.c',
'nir/radv_nir_lower_call_abi.c',
'nir/radv_nir_lower_cooperative_matrix.c', 'nir/radv_nir_lower_cooperative_matrix.c',
'nir/radv_nir_lower_fs_barycentric.c', 'nir/radv_nir_lower_fs_barycentric.c',
'nir/radv_nir_lower_fs_intrinsics.c', 'nir/radv_nir_lower_fs_intrinsics.c',
@ -90,6 +91,7 @@ libradv_files = files(
'nir/radv_nir_rt_common.c', 'nir/radv_nir_rt_common.c',
'nir/radv_nir_rt_stage_common.c', 'nir/radv_nir_rt_stage_common.c',
'nir/radv_nir_rt_stage_cps.c', 'nir/radv_nir_rt_stage_cps.c',
'nir/radv_nir_rt_stage_functions.c',
'nir/radv_nir_rt_stage_monolithic.c', 'nir/radv_nir_rt_stage_monolithic.c',
'nir/radv_nir_rt_traversal_shader.c', 'nir/radv_nir_rt_traversal_shader.c',
'nir/radv_nir_trim_fs_color_exports.c', 'nir/radv_nir_trim_fs_color_exports.c',

View file

@ -103,6 +103,10 @@ bool radv_nir_opt_fs_builtins(nir_shader *shader, const struct radv_graphics_sta
bool radv_nir_lower_immediate_samplers(nir_shader *shader, struct radv_device *device, bool radv_nir_lower_immediate_samplers(nir_shader *shader, struct radv_device *device,
const struct radv_shader_stage *stage); const struct radv_shader_stage *stage);
void radv_nir_lower_callee_signature(nir_function *function);
bool radv_nir_lower_call_abi(nir_shader *shader, unsigned wave_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -5,6 +5,7 @@
*/ */
#include "ac_descriptors.h" #include "ac_descriptors.h"
#include "ac_shader_util.h" #include "ac_shader_util.h"
#include "aco_nir_call_attribs.h"
#include "nir.h" #include "nir.h"
#include "nir_builder.h" #include "nir_builder.h"
#include "radv_descriptor_set.h" #include "radv_descriptor_set.h"
@ -37,6 +38,30 @@ get_scalar_arg(nir_builder *b, unsigned size, struct ac_arg arg)
return nir_load_scalar_arg_amd(b, size, .base = arg.arg_index); return nir_load_scalar_arg_amd(b, size, .base = arg.arg_index);
} }
static nir_def *
get_indirect_descriptors_addr(nir_builder *b, apply_layout_state *state)
{
if (mesa_shader_stage_is_rt(b->shader->info.stage))
return nir_load_param(b, RT_ARG_DESCRIPTORS);
return get_scalar_arg(b, 1, state->args->descriptors[0]);
}
static nir_def *
get_indirect_push_constants_addr(nir_builder *b, apply_layout_state *state)
{
if (mesa_shader_stage_is_rt(b->shader->info.stage))
return nir_load_param(b, RT_ARG_PUSH_CONSTANTS);
return get_scalar_arg(b, 1, state->args->ac.push_constants);
}
static nir_def *
get_dynamic_descriptors_addr(nir_builder *b, apply_layout_state *state)
{
if (mesa_shader_stage_is_rt(b->shader->info.stage))
return nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS);
return get_scalar_arg(b, 1, state->args->ac.dynamic_descriptors);
}
static nir_def * static nir_def *
convert_pointer_to_64_bit(nir_builder *b, apply_layout_state *state, nir_def *ptr) convert_pointer_to_64_bit(nir_builder *b, apply_layout_state *state, nir_def *ptr)
{ {
@ -47,8 +72,9 @@ static nir_def *
load_desc_ptr(nir_builder *b, apply_layout_state *state, unsigned set) load_desc_ptr(nir_builder *b, apply_layout_state *state, unsigned set)
{ {
const struct radv_userdata_locations *user_sgprs_locs = &state->info->user_sgprs_locs; const struct radv_userdata_locations *user_sgprs_locs = &state->info->user_sgprs_locs;
if (user_sgprs_locs->shader_data[AC_UD_INDIRECT_DESCRIPTORS].sgpr_idx != -1) { if (user_sgprs_locs->shader_data[AC_UD_INDIRECT_DESCRIPTORS].sgpr_idx != -1 ||
nir_def *addr = get_scalar_arg(b, 1, state->args->descriptors[0]); mesa_shader_stage_is_rt(b->shader->info.stage)) {
nir_def *addr = get_indirect_descriptors_addr(b, state);
addr = convert_pointer_to_64_bit(b, state, addr); addr = convert_pointer_to_64_bit(b, state, addr);
return ac_nir_load_smem(b, 1, addr, nir_imm_int(b, set * 4), 4, 0); return ac_nir_load_smem(b, 1, addr, nir_imm_int(b, set * 4), 4, 0);
} }
@ -69,7 +95,7 @@ visit_vulkan_resource_index(nir_builder *b, apply_layout_state *state, nir_intri
nir_def *set_ptr; nir_def *set_ptr;
if (vk_descriptor_type_is_dynamic(layout->binding[binding].type)) { if (vk_descriptor_type_is_dynamic(layout->binding[binding].type)) {
unsigned idx = state->layout->set[desc_set].dynamic_offset_start + layout->binding[binding].dynamic_offset_offset; unsigned idx = state->layout->set[desc_set].dynamic_offset_start + layout->binding[binding].dynamic_offset_offset;
set_ptr = get_scalar_arg(b, 1, state->args->ac.dynamic_descriptors); set_ptr = get_dynamic_descriptors_addr(b, state);
offset = idx * 16; offset = idx * 16;
stride = 16; stride = 16;
} else { } else {
@ -341,8 +367,10 @@ load_push_constant(nir_builder *b, apply_layout_state *state, nir_intrinsic_inst
continue; continue;
} }
if (!state->args->ac.push_constants.used) { if (!state->args->ac.push_constants.used && !mesa_shader_stage_is_rt(b->shader->info.stage)) {
/* Assume this is an inlined push constant load which was expanded to include dwords which are not inlined. */ /* Assume this is an inlined push constant load which was expanded to include dwords which are not inlined.
* RT stages use neither shader args nor inlined push constants, so skip this for RT shaders.
*/
assert(const_offset != -1); assert(const_offset != -1);
data[num_loads++] = nir_undef(b, 1, 32); data[num_loads++] = nir_undef(b, 1, 32);
start += 1; start += 1;
@ -350,7 +378,7 @@ load_push_constant(nir_builder *b, apply_layout_state *state, nir_intrinsic_inst
} }
if (!offset) { if (!offset) {
addr = get_scalar_arg(b, 1, state->args->ac.push_constants); addr = get_indirect_push_constants_addr(b, state);
addr = convert_pointer_to_64_bit(b, state, addr); addr = convert_pointer_to_64_bit(b, state, addr);
offset = nir_iadd_imm_nuw(b, intrin->src[0].ssa, base); offset = nir_iadd_imm_nuw(b, intrin->src[0].ssa, base);
} }

View file

@ -0,0 +1,439 @@
/*
* Copyright © 2023 Valve Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice (including the next
* paragraph) shall be included in all copies or substantial portions of the
* Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*/
#include "aco_nir_call_attribs.h"
#include "nir_builder.h"
#include "radv_nir.h"
void
radv_nir_lower_callee_signature(nir_function *function)
{
nir_parameter *old_params = function->params;
unsigned old_num_params = function->num_params;
function->num_params += ACO_NIR_CALL_SYSTEM_ARG_COUNT;
function->params = rzalloc_array_size(function->shader, function->num_params, sizeof(nir_parameter));
memcpy(function->params + ACO_NIR_CALL_SYSTEM_ARG_COUNT, old_params, old_num_params * sizeof(nir_parameter));
/* These are not return params, but each callee will modify these registers
* as part of the next callee selection. Make sure modification is allowed by
* marking the parameters as DISCARDABLE. Unlike other discardable parameters,
* ACO makes sure correct values are always written to them.
*/
function->params[ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC].num_components = 1;
function->params[ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC].bit_size = 64;
function->params[ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC].driver_attributes = ACO_NIR_PARAM_ATTRIB_DISCARDABLE;
function->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC].num_components = 1;
function->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC].bit_size = 64;
function->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC].is_uniform = true;
function->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC].driver_attributes = ACO_NIR_PARAM_ATTRIB_DISCARDABLE;
for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < function->num_params; ++i) {
if (!function->params[i].is_return)
continue;
function->params[i].bit_size = glsl_get_bit_size(function->params[i].type);
function->params[i].num_components = glsl_get_vector_elements(function->params[i].type);
}
nir_function_impl *impl = function->impl;
if (!impl)
return;
nir_foreach_block (block, impl) {
nir_foreach_instr_safe (instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
if (intr->intrinsic == nir_intrinsic_load_param)
nir_intrinsic_set_param_idx(intr, nir_intrinsic_param_idx(intr) + ACO_NIR_CALL_SYSTEM_ARG_COUNT);
}
}
}
/* Checks if caller can call callee using tail calls.
*
* If the ABIs mismatch, we might need to insert move instructions to move return values from callee return registers to
* caller return registers after the call. In that case, tail-calls are impossible to do correctly.
*/
static bool
is_tail_call_compatible(nir_function *caller, nir_function *callee)
{
/* If the caller doesn't return at all, we don't need to care if return params are compatible */
if (caller->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_NORETURN)
return true;
/* The same ABI can't mismatch */
if ((caller->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK) ==
(callee->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK))
return true;
/* The recursive shader ABI and the traversal shader ABI are built so that return parameters occupy exactly
* the same registers, to allow tail calls from the traversal shader. */
if ((caller->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK) == ACO_NIR_CALL_ABI_TRAVERSAL &&
(callee->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK) == ACO_NIR_CALL_ABI_RT_RECURSIVE)
return true;
return false;
}
static void
gather_tail_call_instrs_block(nir_function *caller, const struct nir_block *block, struct set *tail_calls)
{
nir_foreach_instr_reverse (instr, block) {
/* Making an instruction a tail call effectively moves it beyond the last block. If there are any instructions
* in the way, this reordering may be incorrect.
*/
if (instr->type != nir_instr_type_call)
return;
nir_call_instr *call = nir_instr_as_call(instr);
if (!is_tail_call_compatible(caller, call->callee))
return;
if (call->callee->num_params != caller->num_params)
return;
for (unsigned i = 0; i < call->num_params; ++i) {
if (call->callee->params[i].is_return != caller->params[i].is_return)
return;
/* We can only do tail calls if the caller returns exactly the callee return values */
if (caller->params[i].is_return) {
assert(nir_def_as_deref_or_null(call->params[i].ssa));
nir_deref_instr *deref_root = nir_def_as_deref(call->params[i].ssa);
while (nir_deref_instr_parent(deref_root))
deref_root = nir_deref_instr_parent(deref_root);
if (!deref_root->parent.ssa)
return;
nir_intrinsic_instr *intrin = nir_def_as_intrinsic_or_null(deref_root->parent.ssa);
if (!intrin || intrin->intrinsic != nir_intrinsic_load_param)
return;
/* The call parameters aren't lowered at this point, we need to add the call arg count here */
if (nir_intrinsic_param_idx(intrin) != i + ACO_NIR_CALL_SYSTEM_ARG_COUNT)
return;
}
if (call->callee->params[i].is_uniform != caller->params[i].is_uniform)
return;
if (call->callee->params[i].bit_size != caller->params[i].bit_size)
return;
if (call->callee->params[i].num_components != caller->params[i].num_components)
return;
}
_mesa_set_add(tail_calls, instr);
}
set_foreach (&block->predecessors, pred) {
gather_tail_call_instrs_block(caller, pred->key, tail_calls);
}
}
struct lower_param_info {
nir_def *return_deref;
nir_variable *param_var;
};
static void
rewrite_return_param_uses(nir_def *def, unsigned param_idx, struct lower_param_info *param_defs)
{
nir_foreach_use_safe (use, def) {
nir_instr *use_instr = nir_src_parent_instr(use);
if (use_instr->type == nir_instr_type_deref) {
assert(nir_instr_as_deref(use_instr)->deref_type == nir_deref_type_cast);
rewrite_return_param_uses(&nir_instr_as_deref(use_instr)->def, param_idx, param_defs);
nir_instr_remove(use_instr);
}
}
nir_def_rewrite_uses(def, param_defs[param_idx].return_deref);
}
static void
lower_call_abi_for_callee(nir_function *function, unsigned wave_size)
{
nir_function_impl *impl = function->impl;
nir_builder b = nir_builder_create(impl);
b.cursor = nir_before_impl(impl);
nir_variable *tail_call_pc =
nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "_tail_call_pc");
struct set *tail_call_instrs = _mesa_set_create(b.shader, _mesa_hash_pointer, _mesa_key_pointer_equal);
gather_tail_call_instrs_block(function, nir_impl_last_block(impl), tail_call_instrs);
/* guard the shader, so that only the correct invocations execute it */
nir_def *guard_condition = NULL;
nir_def *shader_addr;
nir_def *uniform_shader_addr;
if (function->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL) {
nir_cf_list list;
nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
b.cursor = nir_before_impl(impl);
shader_addr = nir_load_param(&b, ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC);
uniform_shader_addr = nir_load_param(&b, ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC);
nir_store_var(&b, tail_call_pc, shader_addr, 0x1);
guard_condition = nir_ieq(&b, uniform_shader_addr, shader_addr);
nir_if *shader_guard = nir_push_if(&b, guard_condition);
shader_guard->control = nir_selection_control_divergent_always_taken;
nir_store_var(&b, tail_call_pc, nir_imm_int64(&b, 0), 0x1);
nir_cf_reinsert(&list, b.cursor);
nir_pop_if(&b, shader_guard);
} else {
nir_store_var(&b, tail_call_pc, nir_imm_int64(&b, 0), 0x1);
}
b.cursor = nir_before_impl(impl);
struct lower_param_info *param_infos = ralloc_size(b.shader, function->num_params * sizeof(struct lower_param_info));
for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < function->num_params; ++i) {
param_infos[i].param_var = nir_local_variable_create(impl, function->params[i].type, "_param");
if (function->params[i].is_return) {
assert(!glsl_type_is_array(function->params[i].type) && !glsl_type_is_struct(function->params[i].type));
param_infos[i].return_deref = &nir_build_deref_var(&b, param_infos[i].param_var)->def;
} else {
param_infos[i].return_deref = NULL;
}
}
/* Lower everything related to call parameters and dispatch, particularly return parameters and tail calls.
*
* Return parameters in NIR are represented by having the parameter value actually be a deref. Callers pass
* a deref value to the call, and the callee can cast the parameter value back to a deref. Replace these deref_casts
* with a deref of a variable we declared further above, so the shader can be lowered to SSA.
* One simple example:
*
* %1 = load_param 0
* %2 = deref_cast %1
* %3 = load_deref %2
* %4 = inot %3
* store_deref %2, %4
*
* becomes
*
* decl_var _param;
*
* %1 = load_param 0
* store_var _param, %1
* %3 = load_var _param
* %4 = inot %3
* store_var _param, %4
* %5 = load_var _param
* store_param_amd %5
*
* If tail calls are detected, the call instruction is replaced with a sequence of writing the parameters to the new
* values they should have at callee entry, and updating the tail_call_pc value so that the callee is jumped to next.
*/
bool has_tail_call = false;
nir_foreach_block (block, impl) {
bool progress;
/* rewrite_return_param_uses may remove multiple instructions (not just the current one), which not even
* nir_foreach_instr_safe can safely iterate over. Therefore, if we made progress, we need to restart iteration.
*/
do {
progress = false;
nir_foreach_instr (instr, block) {
if (instr->type == nir_instr_type_call && _mesa_set_search(tail_call_instrs, instr)) {
nir_call_instr *call = nir_instr_as_call(instr);
b.cursor = nir_before_instr(instr);
for (unsigned i = 0; i < call->num_params; ++i) {
if (call->callee->params[i + ACO_NIR_CALL_SYSTEM_ARG_COUNT].is_return)
nir_store_var(&b, param_infos[i + ACO_NIR_CALL_SYSTEM_ARG_COUNT].param_var,
nir_load_deref(&b, nir_def_as_deref(call->params[i].ssa)),
(0x1 << glsl_get_vector_elements(
call->callee->params[i + ACO_NIR_CALL_SYSTEM_ARG_COUNT].type)) -
1);
else
nir_store_var(&b, param_infos[i + ACO_NIR_CALL_SYSTEM_ARG_COUNT].param_var, call->params[i].ssa,
(0x1 << call->params[i].ssa->num_components) - 1);
}
nir_store_var(&b, tail_call_pc, call->indirect_callee.ssa, 0x1);
nir_instr_remove(instr);
has_tail_call = true;
progress = true;
break;
}
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_param) {
unsigned param_idx = nir_intrinsic_param_idx(intr);
if (param_idx >= ACO_NIR_CALL_SYSTEM_ARG_COUNT && function->params[param_idx].is_return) {
rewrite_return_param_uses(&intr->def, param_idx, param_infos);
nir_instr_remove(instr);
progress = true;
break;
}
}
}
} while (progress);
}
b.cursor = nir_before_impl(impl);
for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < function->num_params; ++i) {
unsigned num_components = glsl_get_vector_elements(function->params[i].type);
nir_store_var(&b, param_infos[i].param_var, nir_load_param(&b, i), (0x1 << num_components) - 1);
}
/* Setup a jump to a different shader in the cases where there is a next shader to be called. */
if (!(function->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_NORETURN) ||
(function->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL) || has_tail_call) {
b.cursor = nir_after_impl(impl);
for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < function->num_params; ++i)
nir_store_param_amd(&b, nir_load_var(&b, param_infos[i].param_var), .param_idx = i);
shader_addr = nir_load_var(&b, tail_call_pc);
nir_def *ballot = nir_ballot(&b, 1, wave_size, nir_ine_imm(&b, shader_addr, 0));
nir_def *ballot_addr = nir_read_invocation(&b, shader_addr, nir_find_lsb(&b, ballot));
nir_def *no_next_shader = nir_ieq_imm(&b, ballot, 0);
nir_def *terminate_cond;
/* In functions marked noreturn, we don't need to bother checking the call return address. */
if (function->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_NORETURN) {
uniform_shader_addr = ballot_addr;
terminate_cond = no_next_shader;
} else {
nir_def *return_address = nir_load_call_return_address_amd(&b);
uniform_shader_addr = nir_bcsel(&b, no_next_shader, return_address, ballot_addr);
/* If the next shader address is zero for every invocation, return. */
terminate_cond = nir_ieq_imm(&b, uniform_shader_addr, 0);
}
nir_push_if(&b, terminate_cond);
nir_terminate(&b);
nir_pop_if(&b, NULL);
nir_set_next_call_pc_amd(&b, shader_addr, uniform_shader_addr);
}
}
static void
lower_call_abi_for_call(nir_builder *b, nir_call_instr *call, unsigned *cur_call_idx)
{
unsigned call_idx = (*cur_call_idx)++;
for (unsigned i = 0; i < call->num_params; ++i) {
unsigned callee_param_idx = i + ACO_NIR_CALL_SYSTEM_ARG_COUNT;
if (!call->callee->params[callee_param_idx].is_return)
continue;
b->cursor = nir_before_instr(&call->instr);
nir_src *old_src = &call->params[i];
assert(nir_def_as_deref_or_null(old_src->ssa));
nir_deref_instr *param_deref = nir_def_as_deref(old_src->ssa);
assert(param_deref->deref_type == nir_deref_type_var);
nir_src_rewrite(old_src, nir_load_deref(b, param_deref));
b->cursor = nir_after_instr(&call->instr);
unsigned num_components = glsl_get_vector_elements(param_deref->type);
nir_store_deref(
b, param_deref,
nir_load_return_param_amd(b, num_components, glsl_base_type_get_bit_size(param_deref->type->base_type),
.call_idx = call_idx, .param_idx = callee_param_idx),
(1u << num_components) - 1);
}
b->cursor = nir_before_instr(&call->instr);
nir_call_instr *new_call = nir_call_instr_create(b->shader, call->callee);
new_call->indirect_callee = nir_src_for_ssa(call->indirect_callee.ssa);
new_call->params[ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC] = nir_src_for_ssa(call->indirect_callee.ssa);
new_call->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC] =
nir_src_for_ssa(nir_read_first_invocation(b, call->indirect_callee.ssa));
for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < new_call->num_params; ++i)
new_call->params[i] = nir_src_for_ssa(call->params[i - ACO_NIR_CALL_SYSTEM_ARG_COUNT].ssa);
nir_builder_instr_insert(b, &new_call->instr);
nir_instr_remove(&call->instr);
}
static bool
lower_call_abi_for_caller(nir_function_impl *impl)
{
bool progress = false;
unsigned cur_call_idx = 0;
nir_foreach_block (block, impl) {
nir_foreach_instr_safe (instr, block) {
if (instr->type != nir_instr_type_call)
continue;
nir_call_instr *call = nir_instr_as_call(instr);
if (call->callee->impl)
continue;
nir_builder b = nir_builder_create(impl);
lower_call_abi_for_call(&b, call, &cur_call_idx);
progress = true;
}
}
return progress;
}
bool
radv_nir_lower_call_abi(nir_shader *shader, unsigned wave_size)
{
bool progress = false;
nir_foreach_function (function, shader) {
if (function->is_exported) {
radv_nir_lower_callee_signature(function);
if (function->impl)
progress |= nir_progress(true, function->impl, nir_metadata_none);
}
}
nir_foreach_function_with_impl (function, impl, shader) {
bool func_progress = false;
if (function->is_exported) {
lower_call_abi_for_callee(function, wave_size);
func_progress = true;
}
func_progress |= lower_call_abi_for_caller(impl);
progress |= nir_progress(func_progress, impl, nir_metadata_none);
}
return progress;
}

View file

@ -10,13 +10,19 @@
#include "radv_constants.h" #include "radv_constants.h"
#include "radv_nir.h" #include "radv_nir.h"
typedef nir_def *(*load_intrin_cb)(nir_builder *b, unsigned base);
typedef void (*store_intrin_cb)(nir_builder *b, nir_def *val, unsigned base);
struct lower_hit_attrib_deref_args { struct lower_hit_attrib_deref_args {
nir_variable_mode mode; nir_variable_mode mode;
uint32_t base_offset; uint32_t base_offset;
load_intrin_cb load_cb;
store_intrin_cb store_cb;
}; };
static bool static bool
lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) lower_rt_var_deref(nir_builder *b, nir_instr *instr, void *data)
{ {
if (instr->type != nir_instr_type_intrinsic) if (instr->type != nir_instr_type_intrinsic)
return false; return false;
@ -29,6 +35,8 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
if (!nir_deref_mode_is(deref, args->mode)) if (!nir_deref_mode_is(deref, args->mode))
return false; return false;
if (deref->deref_type == nir_deref_type_cast)
return false;
b->cursor = nir_after_instr(instr); b->cursor = nir_after_instr(instr);
@ -48,19 +56,16 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
uint32_t comp_offset = offset % 4; uint32_t comp_offset = offset % 4;
if (bit_size == 64) { if (bit_size == 64) {
components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base), components[comp] = nir_pack_64_2x32_split(b, args->load_cb(b, base), args->load_cb(b, base + 1));
nir_load_hit_attrib_amd(b, .base = base + 1));
} else if (bit_size == 32) { } else if (bit_size == 32) {
components[comp] = nir_load_hit_attrib_amd(b, .base = base); components[comp] = args->load_cb(b, base);
} else if (bit_size == 16) { } else if (bit_size == 16) {
components[comp] = components[comp] = nir_channel(b, nir_unpack_32_2x16(b, args->load_cb(b, base)), comp_offset / 2);
nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2);
} else if (bit_size == 8) { } else if (bit_size == 8) {
components[comp] = components[comp] = nir_channel(b, nir_unpack_bits(b, args->load_cb(b, base), 8), comp_offset);
nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset);
} else { } else {
assert(bit_size == 1); assert(bit_size == 1);
components[comp] = nir_i2b(b, nir_load_hit_attrib_amd(b, .base = base)); components[comp] = nir_i2b(b, args->load_cb(b, base));
} }
} }
@ -78,25 +83,25 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
nir_def *component = nir_channel(b, value, comp); nir_def *component = nir_channel(b, value, comp);
if (bit_size == 64) { if (bit_size == 64) {
nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base); args->store_cb(b, nir_unpack_64_2x32_split_x(b, component), base);
nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1); args->store_cb(b, nir_unpack_64_2x32_split_y(b, component), base + 1);
} else if (bit_size == 32) { } else if (bit_size == 32) {
nir_store_hit_attrib_amd(b, component, .base = base); args->store_cb(b, component, base);
} else if (bit_size == 16) { } else if (bit_size == 16) {
nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)); nir_def *prev = nir_unpack_32_2x16(b, args->load_cb(b, base));
nir_def *components[2]; nir_def *components[2];
for (uint32_t word = 0; word < 2; word++) for (uint32_t word = 0; word < 2; word++)
components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word); components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word);
nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base); args->store_cb(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), base);
} else if (bit_size == 8) { } else if (bit_size == 8) {
nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8); nir_def *prev = nir_unpack_bits(b, args->load_cb(b, base), 8);
nir_def *components[4]; nir_def *components[4];
for (uint32_t byte = 0; byte < 4; byte++) for (uint32_t byte = 0; byte < 4; byte++)
components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte); components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte);
nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base); args->store_cb(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), base);
} else { } else {
assert(bit_size == 1); assert(bit_size == 1);
nir_store_hit_attrib_amd(b, nir_b2i32(b, component), .base = base); args->store_cb(b, nir_b2i32(b, component), base);
} }
} }
} }
@ -123,23 +128,24 @@ radv_lower_payload_arg_to_offset(nir_builder *b, nir_intrinsic_instr *instr, voi
} }
static bool static bool
radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base_offset) radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, load_intrin_cb load_cb, store_intrin_cb store_cb,
uint32_t base_offset)
{ {
bool progress = false; bool progress = false;
progress |= nir_lower_indirect_derefs_to_if_else_trees(shader, mode, UINT32_MAX); progress |= nir_lower_indirect_derefs_to_if_else_trees(shader, mode, UINT32_MAX);
progress |= nir_lower_vars_to_explicit_types(shader, mode, glsl_get_natural_size_align_bytes);
if (shader->info.stage == MESA_SHADER_RAYGEN && mode == nir_var_function_temp) if (shader->info.stage == MESA_SHADER_RAYGEN && mode == nir_var_function_temp)
progress |= nir_shader_intrinsics_pass(shader, radv_lower_payload_arg_to_offset, nir_metadata_control_flow, NULL); progress |= nir_shader_intrinsics_pass(shader, radv_lower_payload_arg_to_offset, nir_metadata_control_flow, NULL);
struct lower_hit_attrib_deref_args args = { struct lower_hit_attrib_deref_args args = {
.mode = mode, .mode = mode,
.base_offset = base_offset, .base_offset = base_offset,
.load_cb = load_cb,
.store_cb = store_cb,
}; };
progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref, nir_metadata_control_flow, &args); progress |= nir_shader_instructions_pass(shader, lower_rt_var_deref, nir_metadata_control_flow, &args);
if (progress) { if (progress) {
nir_remove_dead_derefs(shader); nir_remove_dead_derefs(shader);
@ -149,16 +155,57 @@ radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base
return progress; return progress;
} }
static nir_def *
load_hit_attrib_cb(nir_builder *b, unsigned base)
{
return nir_load_hit_attrib_amd(b, .base = base);
}
static void
store_hit_attrib_cb(nir_builder *b, nir_def *val, unsigned base)
{
nir_store_hit_attrib_amd(b, val, .base = base);
}
bool bool
radv_nir_lower_hit_attrib_derefs(nir_shader *shader) radv_nir_lower_hit_attrib_derefs(nir_shader *shader)
{ {
return radv_nir_lower_rt_vars(shader, nir_var_ray_hit_attrib, 0); bool progress = false;
progress |= nir_lower_vars_to_explicit_types(shader, nir_var_ray_hit_attrib, glsl_get_natural_size_align_bytes);
progress |= radv_nir_lower_rt_vars(shader, nir_var_ray_hit_attrib, load_hit_attrib_cb, store_hit_attrib_cb, 0);
return progress;
}
static nir_def *
load_incoming_payload_cb(nir_builder *b, unsigned base)
{
return nir_load_incoming_ray_payload_amd(b, .base = base);
}
static void
store_incoming_payload_cb(nir_builder *b, nir_def *val, unsigned base)
{
nir_store_incoming_ray_payload_amd(b, val, .base = base);
}
static nir_def *
load_outgoing_payload_cb(nir_builder *b, unsigned base)
{
return nir_load_outgoing_ray_payload_amd(b, .base = base);
}
static void
store_outgoing_payload_cb(nir_builder *b, nir_def *val, unsigned base)
{
nir_store_outgoing_ray_payload_amd(b, val, .base = base);
} }
bool bool
radv_nir_lower_ray_payload_derefs(nir_shader *shader, uint32_t offset) radv_nir_lower_ray_payload_derefs(nir_shader *shader, uint32_t offset)
{ {
bool progress = radv_nir_lower_rt_vars(shader, nir_var_function_temp, RADV_MAX_HIT_ATTRIB_SIZE + offset); bool progress = radv_nir_lower_rt_vars(shader, nir_var_function_temp, load_outgoing_payload_cb,
progress |= radv_nir_lower_rt_vars(shader, nir_var_shader_call_data, RADV_MAX_HIT_ATTRIB_SIZE + offset); store_outgoing_payload_cb, offset);
progress |= radv_nir_lower_rt_vars(shader, nir_var_shader_call_data, load_incoming_payload_cb,
store_incoming_payload_cb, offset);
return progress; return progress;
} }

View file

@ -6,19 +6,21 @@
*/ */
#include "nir/radv_nir_rt_stage_common.h" #include "nir/radv_nir_rt_stage_common.h"
#include "nir/radv_nir_rt_stage_functions.h"
#include "aco_nir_call_attribs.h"
#include "nir_builder.h" #include "nir_builder.h"
struct radv_nir_sbt_data struct radv_nir_sbt_data
radv_nir_load_sbt_entry(nir_builder *b, nir_def *idx, enum radv_nir_sbt_type binding, enum radv_nir_sbt_entry offset) radv_nir_load_sbt_entry(nir_builder *b, nir_def *base, nir_def *idx, enum radv_nir_sbt_type binding,
enum radv_nir_sbt_entry offset)
{ {
struct radv_nir_sbt_data data; struct radv_nir_sbt_data data;
nir_def *desc_base_addr = nir_load_sbt_base_amd(b); nir_def *desc = nir_pack_64_2x32(b, ac_nir_load_smem(b, 2, base, nir_imm_int(b, binding), 4, 0));
nir_def *desc = nir_pack_64_2x32(b, ac_nir_load_smem(b, 2, desc_base_addr, nir_imm_int(b, binding), 4, 0));
nir_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16)); nir_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
nir_def *stride = ac_nir_load_smem(b, 1, desc_base_addr, stride_offset, 4, 0); nir_def *stride = ac_nir_load_smem(b, 1, base, stride_offset, 4, 0);
nir_def *addr = nir_iadd(b, desc, nir_u2u64(b, nir_iadd_imm(b, nir_imul(b, idx, stride), offset))); nir_def *addr = nir_iadd(b, desc, nir_u2u64(b, nir_iadd_imm(b, nir_imul(b, idx, stride), offset)));
@ -149,52 +151,14 @@ radv_visit_inlined_shaders(nir_builder *b, nir_def *sbt_idx, bool can_have_null_
free(cases); free(cases);
} }
bool /* Lowers RT I/O vars to registers or shared memory. If hit_attribs is NULL, attributes are
radv_nir_lower_rt_derefs(nir_shader *shader)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
bool progress = false;
nir_builder b;
nir_def *arg_offset = NULL;
nir_foreach_block (block, impl) {
nir_foreach_instr_safe (instr, block) {
if (instr->type != nir_instr_type_deref)
continue;
nir_deref_instr *deref = nir_instr_as_deref(instr);
if (!nir_deref_mode_is(deref, nir_var_shader_call_data))
continue;
deref->modes = nir_var_function_temp;
progress = true;
if (deref->deref_type == nir_deref_type_var) {
if (!arg_offset) {
b = nir_builder_at(nir_before_impl(impl));
arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
}
b.cursor = nir_before_instr(&deref->instr);
nir_deref_instr *replacement =
nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
nir_def_replace(&deref->def, &replacement->def);
}
}
}
return nir_progress(progress, impl, nir_metadata_control_flow);
}
/* Lowers hit attributes to registers or shared memory. If hit_attribs is NULL, attributes are
* lowered to shared memory. */ * lowered to shared memory. */
bool bool
radv_nir_lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workgroup_size) radv_nir_lower_rt_storage(nir_shader *shader, nir_variable **hit_attribs, nir_deref_instr **payload_in,
nir_variable **payload_out, uint32_t workgroup_size)
{ {
bool progress = false; bool progress = false;
nir_function_impl *impl = nir_shader_get_entrypoint(shader); nir_function_impl *impl = radv_get_rt_shader_entrypoint(shader);
nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib) { nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib) {
attrib->data.mode = nir_var_shader_temp; attrib->data.mode = nir_var_shader_temp;
@ -210,30 +174,52 @@ radv_nir_lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint3
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_load_hit_attrib_amd && if (intrin->intrinsic != nir_intrinsic_load_hit_attrib_amd &&
intrin->intrinsic != nir_intrinsic_store_hit_attrib_amd) intrin->intrinsic != nir_intrinsic_store_hit_attrib_amd &&
intrin->intrinsic != nir_intrinsic_load_incoming_ray_payload_amd &&
intrin->intrinsic != nir_intrinsic_store_incoming_ray_payload_amd &&
intrin->intrinsic != nir_intrinsic_load_outgoing_ray_payload_amd &&
intrin->intrinsic != nir_intrinsic_store_outgoing_ray_payload_amd)
continue; continue;
progress = true; progress = true;
b.cursor = nir_after_instr(instr); b.cursor = nir_after_instr(instr);
nir_def *offset; if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd ||
if (!hit_attribs) intrin->intrinsic == nir_intrinsic_store_hit_attrib_amd) {
offset = nir_imul_imm( nir_def *offset;
&b, nir_iadd_imm(&b, nir_load_subgroup_invocation(&b), nir_intrinsic_base(intrin) * workgroup_size), if (!hit_attribs)
sizeof(uint32_t)); offset = nir_imul_imm(
&b, nir_iadd_imm(&b, nir_load_subgroup_invocation(&b), nir_intrinsic_base(intrin) * workgroup_size),
sizeof(uint32_t));
if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) { if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) {
nir_def *ret; nir_def *ret;
if (hit_attribs) if (hit_attribs)
ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]); ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]);
else
ret = nir_load_shared(&b, 1, 32, offset, .base = 0, .align_mul = 4);
nir_def_rewrite_uses(nir_instr_def(instr), ret);
} else {
if (hit_attribs)
nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1);
else
nir_store_shared(&b, intrin->src->ssa, offset, .base = 0, .align_mul = 4);
}
} else if (intrin->intrinsic == nir_intrinsic_load_incoming_ray_payload_amd ||
intrin->intrinsic == nir_intrinsic_store_incoming_ray_payload_amd) {
if (!payload_in)
continue;
if (intrin->intrinsic == nir_intrinsic_load_incoming_ray_payload_amd)
nir_def_rewrite_uses(nir_instr_def(instr), nir_load_deref(&b, payload_in[nir_intrinsic_base(intrin)]));
else else
ret = nir_load_shared(&b, 1, 32, offset, .base = 0, .align_mul = 4); nir_store_deref(&b, payload_in[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1);
nir_def_rewrite_uses(nir_instr_def(instr), ret);
} else { } else {
if (hit_attribs) if (!payload_out)
nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1); continue;
if (intrin->intrinsic == nir_intrinsic_load_outgoing_ray_payload_amd)
nir_def_rewrite_uses(nir_instr_def(instr), nir_load_var(&b, payload_out[nir_intrinsic_base(intrin)]));
else else
nir_store_shared(&b, intrin->src->ssa, offset, .base = 0, .align_mul = 4); nir_store_var(&b, payload_out[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1);
} }
nir_instr_remove(instr); nir_instr_remove(instr);
} }
@ -244,3 +230,13 @@ radv_nir_lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint3
return nir_progress(progress, impl, nir_metadata_control_flow); return nir_progress(progress, impl, nir_metadata_control_flow);
} }
void
radv_nir_param_from_type(nir_parameter *param, const glsl_type *type, bool uniform, unsigned driver_attribs)
{
param->num_components = glsl_get_vector_elements(type);
param->bit_size = glsl_get_bit_size(type);
param->type = type;
param->is_uniform = uniform;
param->driver_attributes = driver_attribs;
}

View file

@ -15,6 +15,9 @@
#include "radv_pipeline_cache.h" #include "radv_pipeline_cache.h"
#include "radv_pipeline_rt.h" #include "radv_pipeline_rt.h"
typedef struct nir_parameter nir_parameter;
typedef struct glsl_type glsl_type;
/* /*
* *
* Common Constants * Common Constants
@ -85,8 +88,8 @@ enum radv_nir_sbt_entry {
SBT_ANY_HIT_IDX = offsetof(struct radv_pipeline_group_handle, any_hit_index), SBT_ANY_HIT_IDX = offsetof(struct radv_pipeline_group_handle, any_hit_index),
}; };
struct radv_nir_sbt_data radv_nir_load_sbt_entry(nir_builder *b, nir_def *idx, enum radv_nir_sbt_type binding, struct radv_nir_sbt_data radv_nir_load_sbt_entry(nir_builder *b, nir_def *base, nir_def *idx,
enum radv_nir_sbt_entry offset); enum radv_nir_sbt_type binding, enum radv_nir_sbt_entry offset);
/* /*
* *
@ -94,10 +97,10 @@ struct radv_nir_sbt_data radv_nir_load_sbt_entry(nir_builder *b, nir_def *idx, e
* *
*/ */
bool radv_nir_lower_rt_storage(nir_shader *shader, nir_variable **hit_attribs, nir_deref_instr **payload_in,
nir_variable **payload_out, uint32_t workgroup_size);
bool radv_nir_lower_rt_derefs(nir_shader *shader); void radv_nir_param_from_type(nir_parameter *param, const glsl_type *type, bool uniform, unsigned driver_attribs);
bool radv_nir_lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workgroup_size);
/* /*
* *

View file

@ -14,7 +14,9 @@
#include "nir/radv_nir_rt_stage_cps.h" #include "nir/radv_nir_rt_stage_cps.h"
#include "ac_nir.h" #include "ac_nir.h"
#include "aco_nir_call_attribs.h"
#include "radv_device.h" #include "radv_device.h"
#include "radv_nir_rt_stage_functions.h"
#include "radv_physical_device.h" #include "radv_physical_device.h"
#include "radv_pipeline_rt.h" #include "radv_pipeline_rt.h"
#include "radv_shader.h" #include "radv_shader.h"
@ -24,12 +26,9 @@ radv_arg_def_is_unused(nir_def *def)
{ {
nir_foreach_use (use, def) { nir_foreach_use (use, def) {
nir_instr *use_instr = nir_src_parent_instr(use); nir_instr *use_instr = nir_src_parent_instr(use);
if (use_instr->type == nir_instr_type_intrinsic) { if (use_instr->type == nir_instr_type_call)
nir_intrinsic_instr *use_intr = nir_instr_as_intrinsic(use_instr); continue;
if (use_intr->intrinsic == nir_intrinsic_store_scalar_arg_amd || if (use_instr->type == nir_instr_type_phi) {
use_intr->intrinsic == nir_intrinsic_store_vector_arg_amd)
continue;
} else if (use_instr->type == nir_instr_type_phi) {
nir_cf_node *prev_node = nir_cf_node_prev(&use_instr->block->cf_node); nir_cf_node *prev_node = nir_cf_node_prev(&use_instr->block->cf_node);
if (!prev_node) if (!prev_node)
return false; return false;
@ -48,13 +47,13 @@ radv_arg_def_is_unused(nir_def *def)
static bool static bool
radv_gather_unused_args_instr(nir_builder *b, nir_intrinsic_instr *instr, void *data) radv_gather_unused_args_instr(nir_builder *b, nir_intrinsic_instr *instr, void *data)
{ {
if (instr->intrinsic != nir_intrinsic_load_scalar_arg_amd && instr->intrinsic != nir_intrinsic_load_vector_arg_amd) if (instr->intrinsic != nir_intrinsic_load_param)
return false; return false;
if (!radv_arg_def_is_unused(&instr->def)) { if (!radv_arg_def_is_unused(&instr->def)) {
/* This arg is used for more than passing data to the next stage. */ /* This arg is used for more than passing data to the next stage. */
struct radv_ray_tracing_stage_info *info = data; struct radv_ray_tracing_stage_info *info = data;
BITSET_CLEAR(info->unused_args, nir_intrinsic_base(instr)); BITSET_CLEAR(info->unused_args, nir_intrinsic_param_idx(instr));
} }
return false; return false;
@ -193,14 +192,13 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
case nir_intrinsic_rt_execute_callable: { case nir_intrinsic_rt_execute_callable: {
uint32_t size = align(nir_intrinsic_stack_size(intr), 16); uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr)); nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1); nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16); nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1); nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
struct radv_nir_sbt_data sbt_data = struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS),
radv_nir_load_sbt_entry(b, intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR); intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR);
nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1); nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1);
nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1); nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
@ -212,7 +210,6 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
case nir_intrinsic_rt_trace_ray: { case nir_intrinsic_rt_trace_ray: {
uint32_t size = align(nir_intrinsic_stack_size(intr), 16); uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr)); nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1); nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16); nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
@ -361,6 +358,10 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
ret = nir_ushr_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 24); ret = nir_ushr_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 24);
break; break;
} }
case nir_intrinsic_load_sbt_base_amd: {
ret = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS);
break;
}
case nir_intrinsic_load_sbt_offset_amd: { case nir_intrinsic_load_sbt_offset_amd: {
ret = nir_load_var(b, vars->sbt_offset); ret = nir_load_var(b, vars->sbt_offset);
break; break;
@ -385,8 +386,8 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
nir_store_var(b, vars->geometry_id_and_flags, intr->src[5].ssa, 0x1); nir_store_var(b, vars->geometry_id_and_flags, intr->src[5].ssa, 0x1);
nir_store_var(b, vars->hit_kind, intr->src[6].ssa, 0x1); nir_store_var(b, vars->hit_kind, intr->src[6].ssa, 0x1);
struct radv_nir_sbt_data sbt_data = struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS),
radv_nir_load_sbt_entry(b, intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR); intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR);
nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1); nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1);
nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1); nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
@ -414,7 +415,7 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
nir_def *miss_index = nir_load_var(b, vars->miss_index); nir_def *miss_index = nir_load_var(b, vars->miss_index);
struct radv_nir_sbt_data sbt_data = struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, miss_index, SBT_MISS, SBT_RECURSIVE_PTR); radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), miss_index, SBT_MISS, SBT_RECURSIVE_PTR);
nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1); nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1);
nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1); nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
@ -455,226 +456,203 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, struct radv
return nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data); return nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data);
} }
static bool
lower_rt_derefs_cps(nir_shader *shader)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
bool progress = false;
nir_builder b;
nir_def *arg_offset = NULL;
nir_foreach_block (block, impl) {
nir_foreach_instr_safe (instr, block) {
if (instr->type != nir_instr_type_deref)
continue;
nir_deref_instr *deref = nir_instr_as_deref(instr);
if (!nir_deref_mode_is(deref, nir_var_shader_call_data))
continue;
deref->modes = nir_var_function_temp;
progress = true;
if (deref->deref_type == nir_deref_type_var) {
if (!arg_offset) {
b = nir_builder_at(nir_before_impl(impl));
arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
}
b.cursor = nir_before_instr(&deref->instr);
nir_deref_instr *replacement =
nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
nir_def_replace(&deref->def, &replacement->def);
}
}
}
return nir_progress(progress, impl, nir_metadata_control_flow);
}
void void
radv_nir_lower_rt_io_cps(nir_shader *nir) radv_nir_lower_rt_io_cps(nir_shader *nir)
{ {
NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data, NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
glsl_get_natural_size_align_bytes); glsl_get_natural_size_align_bytes);
NIR_PASS(_, nir, radv_nir_lower_rt_derefs); NIR_PASS(_, nir, lower_rt_derefs_cps);
NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset); NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
} }
/** Select the next shader based on priorities:
*
* Detect the priority of the shader stage by the lowest bits in the address (low to high):
* - Raygen - idx 0
* - Traversal - idx 1
* - Closest Hit / Miss - idx 2
* - Callable - idx 3
*
*
* This gives us the following priorities:
* Raygen : Callable > > Traversal > Raygen
* Traversal : > Chit / Miss > > Raygen
* CHit / Miss : Callable > Chit / Miss > Traversal > Raygen
* Callable : Callable > Chit / Miss > > Raygen
*/
static nir_def *
select_next_shader(nir_builder *b, nir_def *shader_addr, unsigned wave_size)
{
mesa_shader_stage stage = b->shader->info.stage;
nir_def *prio = nir_iand_imm(b, shader_addr, radv_rt_priority_mask);
nir_def *ballot = nir_ballot(b, 1, wave_size, nir_imm_bool(b, true));
nir_def *ballot_traversal = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_traversal));
nir_def *ballot_hit_miss = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_hit_miss));
nir_def *ballot_callable = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_callable));
if (stage != MESA_SHADER_CALLABLE && stage != MESA_SHADER_INTERSECTION)
ballot = nir_bcsel(b, nir_ine_imm(b, ballot_traversal, 0), ballot_traversal, ballot);
if (stage != MESA_SHADER_RAYGEN)
ballot = nir_bcsel(b, nir_ine_imm(b, ballot_hit_miss, 0), ballot_hit_miss, ballot);
if (stage != MESA_SHADER_INTERSECTION)
ballot = nir_bcsel(b, nir_ine_imm(b, ballot_callable, 0), ballot_callable, ballot);
nir_def *lsb = nir_find_lsb(b, ballot);
nir_def *next = nir_read_invocation(b, shader_addr, lsb);
return nir_iand_imm(b, next, ~radv_rt_priority_mask);
}
static void static void
radv_store_arg(nir_builder *b, const struct radv_shader_args *args, const struct radv_ray_tracing_stage_info *info, init_cps_function(nir_function *function, bool has_position_fetch)
struct ac_arg arg, nir_def *value)
{ {
/* Do not pass unused data to the next stage. */ function->num_params = has_position_fetch ? CPS_ARG_COUNT : CPS_ARG_COUNT - 1;
if (!info || !BITSET_TEST(info->unused_args, arg.arg_index)) function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params);
ac_nir_store_arg(b, &args->ac, arg, value);
radv_nir_param_from_type(function->params + RT_ARG_LAUNCH_ID, glsl_vector_type(GLSL_TYPE_UINT, 3), false, 0);
radv_nir_param_from_type(function->params + RT_ARG_LAUNCH_SIZE, glsl_vector_type(GLSL_TYPE_UINT, 3), true, 0);
radv_nir_param_from_type(function->params + RT_ARG_DESCRIPTORS, glsl_uint_type(), true, 0);
radv_nir_param_from_type(function->params + RT_ARG_DYNAMIC_DESCRIPTORS, glsl_uint_type(), true, 0);
radv_nir_param_from_type(function->params + RT_ARG_PUSH_CONSTANTS, glsl_uint_type(), true, 0);
radv_nir_param_from_type(function->params + RT_ARG_SBT_DESCRIPTORS, glsl_uint64_t_type(), true, 0);
radv_nir_param_from_type(function->params + RAYGEN_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0);
radv_nir_param_from_type(function->params + RAYGEN_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_PAYLOAD_SCRATCH_OFFSET, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_STACK_PTR, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_ACCEL_STRUCT, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_CULL_MASK_AND_FLAGS, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_SBT_OFFSET, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_SBT_STRIDE, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_MISS_INDEX, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_RAY_ORIGIN, glsl_vector_type(GLSL_TYPE_UINT, 3), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_RAY_TMIN, glsl_float_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_RAY_DIRECTION, glsl_vector_type(GLSL_TYPE_UINT, 3), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_RAY_TMAX, glsl_float_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_PRIMITIVE_ID, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_INSTANCE_ADDR, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_GEOMETRY_ID_AND_FLAGS, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CPS_ARG_HIT_KIND, glsl_uint_type(), false, 0);
if (has_position_fetch)
radv_nir_param_from_type(function->params + CPS_ARG_PRIMITIVE_ADDR, glsl_uint64_t_type(), false, 0);
function->driver_attributes =
(uint32_t)ACO_NIR_CALL_ABI_RT_RECURSIVE | ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL | ACO_NIR_FUNCTION_ATTRIB_NORETURN;
/* Entrypoints can't have parameters. Consider RT stages as callable functions */
function->is_exported = true;
function->is_entrypoint = false;
} }
void void
radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_args *args, const struct radv_shader_info *info, radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_info *info, bool resume_shader,
uint32_t *stack_size, bool resume_shader, struct radv_device *device, struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
struct radv_ray_tracing_pipeline *pipeline, bool has_position_fetch, bool has_position_fetch, const struct radv_ray_tracing_stage_info *traversal_info)
const struct radv_ray_tracing_stage_info *traversal_info)
{ {
const struct radv_physical_device *pdev = radv_device_physical(device);
nir_function_impl *impl = nir_shader_get_entrypoint(shader); nir_function_impl *impl = nir_shader_get_entrypoint(shader);
/* The first raygen shader gets called by the prolog with the standard raygen signature. Only shaders called by the
* first shader can use the CPS function signature.
*/
if (shader->info.stage != MESA_SHADER_RAYGEN || resume_shader)
init_cps_function(impl->function, has_position_fetch);
else
radv_nir_init_rt_function_params(impl->function, MESA_SHADER_RAYGEN, 0);
if (traversal_info) {
unsigned idx;
BITSET_FOREACH_SET (idx, traversal_info->unused_args, impl->function->num_params)
impl->function->params[idx].driver_attributes |= ACO_NIR_PARAM_ATTRIB_DISCARDABLE;
}
struct rt_variables vars = create_rt_variables(shader, device, pipeline->base.base.create_flags); struct rt_variables vars = create_rt_variables(shader, device, pipeline->base.base.create_flags);
struct radv_rt_shader_info rt_info = {0}; struct radv_rt_shader_info rt_info = {0};
lower_rt_instructions(shader, &vars, &rt_info); lower_rt_instructions(shader, &vars, &rt_info);
if (stack_size) { shader->scratch_size = MAX2(shader->scratch_size, vars.stack_size);
vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);
*stack_size = MAX2(*stack_size, vars.stack_size);
}
shader->scratch_size = 0;
/* This can't use NIR_PASS because NIR_DEBUG=serialize,clone invalidates pointers. */ /* This can't use NIR_PASS because NIR_DEBUG=serialize,clone invalidates pointers. */
nir_lower_returns(shader); nir_lower_returns(shader);
nir_cf_list list;
nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
/* initialize variables */ /* initialize variables */
nir_builder b = nir_builder_at(nir_before_impl(impl)); nir_builder b = nir_builder_at(nir_before_impl(impl));
nir_def *descriptors = ac_nir_load_arg(&b, &args->ac, args->descriptors[0]); nir_def *launch_size_vec = nir_load_param(&b, RT_ARG_LAUNCH_SIZE);
nir_def *push_constants = ac_nir_load_arg(&b, &args->ac, args->ac.push_constants); nir_def *launch_id_vec = nir_load_param(&b, RT_ARG_LAUNCH_ID);
nir_def *dynamic_descriptors = ac_nir_load_arg(&b, &args->ac, args->ac.dynamic_descriptors); for (unsigned i = 0; i < 3; ++i) {
nir_def *sbt_descriptors = ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_descriptors); nir_store_var(&b, vars.launch_sizes[i], nir_channel(&b, launch_size_vec, i), 0x1);
nir_store_var(&b, vars.launch_ids[i], nir_channel(&b, launch_id_vec, i), 0x1);
nir_def *launch_sizes[3];
for (uint32_t i = 0; i < ARRAY_SIZE(launch_sizes); i++) {
launch_sizes[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_sizes[i]);
nir_store_var(&b, vars.launch_sizes[i], launch_sizes[i], 1);
} }
nir_store_var(&b, vars.traversal_addr, nir_load_param(&b, RAYGEN_ARG_TRAVERSAL_ADDR), 0x1);
nir_store_var(&b, vars.shader_record_ptr, nir_load_param(&b, RAYGEN_ARG_SHADER_RECORD_PTR), 0x1);
nir_store_var(&b, vars.shader_addr, nir_imm_int64(&b, 0), 0x1);
nir_def *scratch_offset = NULL; if (shader->info.stage == MESA_SHADER_RAYGEN && !resume_shader) {
if (args->ac.scratch_offset.used) impl->function->driver_attributes &= ~ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL;
scratch_offset = ac_nir_load_arg(&b, &args->ac, args->ac.scratch_offset); nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
nir_def *ring_offsets = NULL; } else {
if (args->ac.ring_offsets.used) nir_store_var(&b, vars.stack_ptr, nir_load_param(&b, CPS_ARG_STACK_PTR), 0x1);
ring_offsets = ac_nir_load_arg(&b, &args->ac, args->ac.ring_offsets); nir_store_var(&b, vars.arg, nir_load_param(&b, CPS_ARG_PAYLOAD_SCRATCH_OFFSET), 0x1);
nir_store_var(&b, vars.origin, nir_load_param(&b, CPS_ARG_RAY_ORIGIN), 0x7);
nir_store_var(&b, vars.tmin, nir_load_param(&b, CPS_ARG_RAY_TMIN), 0x1);
nir_store_var(&b, vars.direction, nir_load_param(&b, CPS_ARG_RAY_DIRECTION), 0x7);
nir_store_var(&b, vars.tmax, nir_load_param(&b, CPS_ARG_RAY_TMAX), 0x1);
nir_store_var(&b, vars.cull_mask_and_flags, nir_load_param(&b, CPS_ARG_CULL_MASK_AND_FLAGS), 0x1);
nir_store_var(&b, vars.sbt_offset, nir_load_param(&b, CPS_ARG_SBT_OFFSET), 0x1);
nir_store_var(&b, vars.sbt_stride, nir_load_param(&b, CPS_ARG_SBT_STRIDE), 0x1);
nir_store_var(&b, vars.accel_struct, nir_load_param(&b, CPS_ARG_ACCEL_STRUCT), 0x1);
nir_store_var(&b, vars.primitive_id, nir_load_param(&b, CPS_ARG_PRIMITIVE_ID), 0x1);
nir_store_var(&b, vars.instance_addr, nir_load_param(&b, CPS_ARG_INSTANCE_ADDR), 0x1);
if (has_position_fetch)
nir_store_var(&b, vars.primitive_addr, nir_load_param(&b, CPS_ARG_PRIMITIVE_ADDR), 0x1);
nir_store_var(&b, vars.geometry_id_and_flags, nir_load_param(&b, CPS_ARG_GEOMETRY_ID_AND_FLAGS), 0x1);
nir_store_var(&b, vars.hit_kind, nir_load_param(&b, CPS_ARG_HIT_KIND), 0x1);
nir_def *launch_ids[3]; if (traversal_info && traversal_info->miss_index.state == RADV_RT_CONST_ARG_STATE_VALID)
for (uint32_t i = 0; i < ARRAY_SIZE(launch_ids); i++) { nir_store_var(&b, vars.miss_index, nir_imm_int(&b, traversal_info->miss_index.value), 0x1);
launch_ids[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_ids[i]); else
nir_store_var(&b, vars.launch_ids[i], launch_ids[i], 1); nir_store_var(&b, vars.miss_index, nir_load_param(&b, CPS_ARG_MISS_INDEX), 0x1);
} }
nir_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr);
nir_store_var(&b, vars.traversal_addr,
nir_pack_64_2x32_split(&b, traversal_addr, nir_imm_int(&b, pdev->info.address32_hi)), 1);
nir_def *shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_addr);
shader_addr = nir_pack_64_2x32(&b, shader_addr);
nir_store_var(&b, vars.shader_addr, shader_addr, 1);
nir_store_var(&b, vars.stack_ptr, ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
nir_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
nir_store_var(&b, vars.cull_mask_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1);
nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1);
nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1);
nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7);
nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1);
nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction), 0x7);
nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1);
if (traversal_info && traversal_info->miss_index.state == RADV_RT_CONST_ARG_STATE_VALID)
nir_store_var(&b, vars.miss_index, nir_imm_int(&b, traversal_info->miss_index.value), 0x1);
else
nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 0x1);
nir_def *primitive_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_addr);
nir_store_var(&b, vars.primitive_addr, nir_pack_64_2x32(&b, primitive_addr), 1);
nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id), 1);
nir_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
nir_store_var(&b, vars.geometry_id_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1);
nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1);
/* guard the shader, so that only the correct invocations execute it */
nir_if *shader_guard = NULL;
if (shader->info.stage != MESA_SHADER_RAYGEN || resume_shader) {
nir_def *uniform_shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr);
uniform_shader_addr = nir_pack_64_2x32(&b, uniform_shader_addr);
uniform_shader_addr = nir_ior_imm(&b, uniform_shader_addr, radv_get_rt_priority(shader->info.stage));
shader_guard = nir_push_if(&b, nir_ieq(&b, uniform_shader_addr, shader_addr));
shader_guard->control = nir_selection_control_divergent_always_taken;
}
nir_cf_reinsert(&list, b.cursor);
if (shader_guard)
nir_pop_if(&b, shader_guard);
b.cursor = nir_after_impl(impl); b.cursor = nir_after_impl(impl);
/* select next shader */ /* tail-call next shader */
shader_addr = nir_load_var(&b, vars.shader_addr); nir_def *shader_addr = nir_load_var(&b, vars.shader_addr);
nir_def *next = select_next_shader(&b, shader_addr, info->wave_size); nir_function *continuation_func = nir_function_create(shader, "continuation_func");
ac_nir_store_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr, next); init_cps_function(continuation_func, has_position_fetch);
ac_nir_store_arg(&b, &args->ac, args->descriptors[0], descriptors); unsigned param_count = continuation_func->num_params;
ac_nir_store_arg(&b, &args->ac, args->ac.push_constants, push_constants); nir_def **next_args = rzalloc_array_size(b.shader, sizeof(nir_def *), param_count);
ac_nir_store_arg(&b, &args->ac, args->ac.dynamic_descriptors, dynamic_descriptors); next_args[RT_ARG_LAUNCH_ID] = nir_load_param(&b, RT_ARG_LAUNCH_ID);
ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_descriptors, sbt_descriptors); next_args[RT_ARG_LAUNCH_SIZE] = nir_load_param(&b, RT_ARG_LAUNCH_SIZE);
ac_nir_store_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr, traversal_addr); next_args[RT_ARG_DESCRIPTORS] = nir_load_param(&b, RT_ARG_DESCRIPTORS);
next_args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(&b, RT_ARG_DYNAMIC_DESCRIPTORS);
for (uint32_t i = 0; i < ARRAY_SIZE(launch_sizes); i++) { next_args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(&b, RT_ARG_PUSH_CONSTANTS);
if (rt_info.uses_launch_size) next_args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(&b, RT_ARG_SBT_DESCRIPTORS);
ac_nir_store_arg(&b, &args->ac, args->ac.rt.launch_sizes[i], launch_sizes[i]); next_args[RAYGEN_ARG_TRAVERSAL_ADDR] = nir_load_var(&b, vars.traversal_addr);
else next_args[RAYGEN_ARG_SHADER_RECORD_PTR] = nir_load_var(&b, vars.shader_record_ptr);
radv_store_arg(&b, args, traversal_info, args->ac.rt.launch_sizes[i], launch_sizes[i]); next_args[CPS_ARG_PAYLOAD_SCRATCH_OFFSET] = nir_load_var(&b, vars.arg);
} next_args[CPS_ARG_STACK_PTR] = nir_load_var(&b, vars.stack_ptr);
next_args[CPS_ARG_RAY_ORIGIN] = nir_load_var(&b, vars.origin);
if (scratch_offset) next_args[CPS_ARG_RAY_TMIN] = nir_load_var(&b, vars.tmin);
ac_nir_store_arg(&b, &args->ac, args->ac.scratch_offset, scratch_offset); next_args[CPS_ARG_RAY_DIRECTION] = nir_load_var(&b, vars.direction);
if (ring_offsets) next_args[CPS_ARG_RAY_TMAX] = nir_load_var(&b, vars.tmax);
ac_nir_store_arg(&b, &args->ac, args->ac.ring_offsets, ring_offsets); next_args[CPS_ARG_CULL_MASK_AND_FLAGS] = nir_load_var(&b, vars.cull_mask_and_flags);
next_args[CPS_ARG_SBT_OFFSET] = nir_load_var(&b, vars.sbt_offset);
for (uint32_t i = 0; i < ARRAY_SIZE(launch_ids); i++) { next_args[CPS_ARG_SBT_STRIDE] = nir_load_var(&b, vars.sbt_stride);
if (rt_info.uses_launch_id) next_args[CPS_ARG_MISS_INDEX] = nir_load_var(&b, vars.miss_index);
ac_nir_store_arg(&b, &args->ac, args->ac.rt.launch_ids[i], launch_ids[i]); next_args[CPS_ARG_ACCEL_STRUCT] = nir_load_var(&b, vars.accel_struct);
else next_args[CPS_ARG_PRIMITIVE_ID] = nir_load_var(&b, vars.primitive_id);
radv_store_arg(&b, args, traversal_info, args->ac.rt.launch_ids[i], launch_ids[i]); next_args[CPS_ARG_INSTANCE_ADDR] = nir_load_var(&b, vars.instance_addr);
} next_args[CPS_ARG_PRIMITIVE_ADDR] = nir_load_var(&b, vars.primitive_addr);
next_args[CPS_ARG_GEOMETRY_ID_AND_FLAGS] = nir_load_var(&b, vars.geometry_id_and_flags);
/* store back all variables to registers */ next_args[CPS_ARG_HIT_KIND] = nir_load_var(&b, vars.hit_kind);
ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base, nir_load_var(&b, vars.stack_ptr)); nir_build_indirect_call(&b, continuation_func, shader_addr, param_count, next_args);
ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_addr, shader_addr);
radv_store_arg(&b, args, traversal_info, args->ac.rt.shader_record, nir_load_var(&b, vars.shader_record_ptr));
radv_store_arg(&b, args, traversal_info, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg));
radv_store_arg(&b, args, traversal_info, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct));
radv_store_arg(&b, args, traversal_info, args->ac.rt.cull_mask_and_flags,
nir_load_var(&b, vars.cull_mask_and_flags));
radv_store_arg(&b, args, traversal_info, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset));
radv_store_arg(&b, args, traversal_info, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride));
radv_store_arg(&b, args, traversal_info, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index));
radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin));
radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin));
radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
if (has_position_fetch)
radv_store_arg(&b, args, traversal_info, args->ac.rt.primitive_addr, nir_load_var(&b, vars.primitive_addr));
radv_store_arg(&b, args, traversal_info, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
radv_store_arg(&b, args, traversal_info, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
radv_store_arg(&b, args, traversal_info, args->ac.rt.geometry_id_and_flags,
nir_load_var(&b, vars.geometry_id_and_flags));
radv_store_arg(&b, args, traversal_info, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind));
nir_progress(true, impl, nir_metadata_none); nir_progress(true, impl, nir_metadata_none);
@ -683,5 +661,5 @@ radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_args *arg
NIR_PASS(_, shader, nir_lower_vars_to_ssa); NIR_PASS(_, shader, nir_lower_vars_to_ssa);
if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_INTERSECTION) if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_INTERSECTION)
NIR_PASS(_, shader, radv_nir_lower_hit_attribs, NULL, info->wave_size); NIR_PASS(_, shader, radv_nir_lower_rt_storage, NULL, NULL, NULL, info->wave_size);
} }

View file

@ -13,8 +13,7 @@
void radv_gather_unused_args(struct radv_ray_tracing_stage_info *info, nir_shader *nir); void radv_gather_unused_args(struct radv_ray_tracing_stage_info *info, nir_shader *nir);
void radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_args *args, void radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_info *info, bool resume_shader,
const struct radv_shader_info *info, uint32_t *stack_size, bool resume_shader,
struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
bool has_position_fetch, const struct radv_ray_tracing_stage_info *traversal_info); bool has_position_fetch, const struct radv_ray_tracing_stage_info *traversal_info);
void radv_nir_lower_rt_io_cps(nir_shader *shader); void radv_nir_lower_rt_io_cps(nir_shader *shader);

View file

@ -0,0 +1,726 @@
/*
* Copyright © 2025 Valve Corporation
* Copyright © 2021 Google
*
* SPDX-License-Identifier: MIT
*/
#include "nir/nir.h"
#include "nir/nir_builder.h"
#include "nir/radv_nir.h"
#include "nir/radv_nir_rt_common.h"
#include "nir/radv_nir_rt_stage_common.h"
#include "nir/radv_nir_rt_stage_functions.h"
#include "radv_device.h"
#include "radv_physical_device.h"
#include "radv_shader.h"
#include "aco_nir_call_attribs.h"
#include "vk_pipeline.h"
static void
radv_nir_init_common_rt_params(nir_function *function)
{
radv_nir_param_from_type(function->params + RT_ARG_LAUNCH_ID, glsl_vector_type(GLSL_TYPE_UINT, 3), false, 0);
radv_nir_param_from_type(function->params + RT_ARG_LAUNCH_SIZE, glsl_vector_type(GLSL_TYPE_UINT, 3), true, 0);
radv_nir_param_from_type(function->params + RT_ARG_DESCRIPTORS, glsl_uint_type(), true, 0);
radv_nir_param_from_type(function->params + RT_ARG_DYNAMIC_DESCRIPTORS, glsl_uint_type(), true, 0);
radv_nir_param_from_type(function->params + RT_ARG_PUSH_CONSTANTS, glsl_uint_type(), true, 0);
radv_nir_param_from_type(function->params + RT_ARG_SBT_DESCRIPTORS, glsl_uint64_t_type(), true, 0);
}
void
radv_nir_init_rt_function_params(nir_function *function, mesa_shader_stage stage, unsigned payload_size)
{
unsigned payload_base = -1u;
switch (stage) {
case MESA_SHADER_RAYGEN:
function->num_params = RAYGEN_ARG_COUNT;
function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params);
radv_nir_init_common_rt_params(function);
radv_nir_param_from_type(function->params + RAYGEN_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0);
radv_nir_param_from_type(function->params + RAYGEN_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0);
function->driver_attributes = (uint32_t)ACO_NIR_CALL_ABI_RT_RECURSIVE | ACO_NIR_FUNCTION_ATTRIB_NORETURN;
break;
case MESA_SHADER_CALLABLE:
function->num_params = RAYGEN_ARG_COUNT + DIV_ROUND_UP(payload_size, 4);
function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params);
radv_nir_init_common_rt_params(function);
radv_nir_param_from_type(function->params + RAYGEN_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0);
radv_nir_param_from_type(function->params + RAYGEN_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0);
function->driver_attributes = (uint32_t)ACO_NIR_CALL_ABI_RT_RECURSIVE | ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL;
payload_base = RAYGEN_ARG_COUNT;
break;
case MESA_SHADER_INTERSECTION:
function->num_params = TRAVERSAL_ARG_PAYLOAD_BASE + DIV_ROUND_UP(payload_size, 4);
function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params);
radv_nir_init_common_rt_params(function);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_ACCEL_STRUCT, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_CULL_MASK_AND_FLAGS, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_SBT_OFFSET, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_SBT_STRIDE, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_MISS_INDEX, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_RAY_ORIGIN, glsl_vector_type(GLSL_TYPE_UINT, 3), false,
0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_RAY_TMIN, glsl_float_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_RAY_DIRECTION, glsl_vector_type(GLSL_TYPE_UINT, 3),
false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_RAY_TMAX, glsl_float_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_PRIMITIVE_ADDR, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_PRIMITIVE_ID, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_INSTANCE_ADDR, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_GEOMETRY_ID_AND_FLAGS, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + TRAVERSAL_ARG_HIT_KIND, glsl_uint_type(), false, 0);
function->driver_attributes = ACO_NIR_CALL_ABI_TRAVERSAL;
payload_base = TRAVERSAL_ARG_PAYLOAD_BASE;
break;
case MESA_SHADER_CLOSEST_HIT:
case MESA_SHADER_MISS:
function->num_params = CHIT_MISS_ARG_PAYLOAD_BASE + DIV_ROUND_UP(payload_size, 4);
function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params);
radv_nir_init_common_rt_params(function);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_ACCEL_STRUCT, glsl_uint64_t_type(), false,
ACO_NIR_PARAM_ATTRIB_DISCARDABLE);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_CULL_MASK_AND_FLAGS, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_SBT_OFFSET, glsl_uint_type(), false,
ACO_NIR_PARAM_ATTRIB_DISCARDABLE);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_SBT_STRIDE, glsl_uint_type(), false,
ACO_NIR_PARAM_ATTRIB_DISCARDABLE);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_MISS_INDEX, glsl_uint_type(), false,
ACO_NIR_PARAM_ATTRIB_DISCARDABLE);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_RAY_ORIGIN, glsl_vector_type(GLSL_TYPE_UINT, 3), false,
0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_RAY_TMIN, glsl_float_type(), false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_RAY_DIRECTION, glsl_vector_type(GLSL_TYPE_UINT, 3),
false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_RAY_TMAX, glsl_float_type(), false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_PRIMITIVE_ADDR, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_PRIMITIVE_ID, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_INSTANCE_ADDR, glsl_uint64_t_type(), false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS, glsl_uint_type(), false, 0);
radv_nir_param_from_type(function->params + CHIT_MISS_ARG_HIT_KIND, glsl_uint_type(), false, 0);
function->driver_attributes = (uint32_t)ACO_NIR_CALL_ABI_RT_RECURSIVE | ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL;
payload_base = CHIT_MISS_ARG_PAYLOAD_BASE;
break;
default:
UNREACHABLE("invalid RT stage");
}
if (payload_base != -1u) {
for (unsigned i = 0; i < DIV_ROUND_UP(payload_size, 4); ++i) {
function->params[payload_base + i].num_components = 1;
function->params[payload_base + i].bit_size = 32;
function->params[payload_base + i].is_return = true;
function->params[payload_base + i].type = glsl_uint_type();
}
}
/* Entrypoints can't have parameters. Consider RT stages as callable functions */
function->is_exported = true;
function->is_entrypoint = false;
}
/*
* Global variables for an RT pipeline
*/
struct rt_variables {
struct radv_device *device;
const VkPipelineCreateFlags2 flags;
/* Stage-dependent parameter indices */
unsigned shader_record_ptr_param;
unsigned traversal_addr_param;
unsigned accel_struct_param;
unsigned cull_mask_and_flags_param;
unsigned sbt_offset_param;
unsigned sbt_stride_param;
unsigned miss_index_param;
unsigned ray_origin_param;
unsigned ray_tmin_param;
unsigned ray_direction_param;
unsigned ray_tmax_param;
unsigned primitive_id_param;
unsigned instance_addr_param;
unsigned primitive_addr_param;
unsigned geometry_id_and_flags_param;
unsigned hit_kind_param;
unsigned in_payload_base_param;
nir_variable **out_payload_storage;
unsigned payload_size;
nir_function *trace_ray_func;
nir_function *chit_miss_func;
nir_function *callable_func;
unsigned stack_size;
};
static struct rt_variables
create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2 flags,
unsigned max_payload_size)
{
struct rt_variables vars = {
.device = device,
.flags = flags,
};
if (max_payload_size)
vars.out_payload_storage = rzalloc_array_size(shader, DIV_ROUND_UP(max_payload_size, 4), sizeof(nir_variable *));
vars.payload_size = max_payload_size;
for (unsigned i = 0; i < DIV_ROUND_UP(max_payload_size, 4); ++i) {
vars.out_payload_storage[i] =
nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "out_payload_storage");
}
nir_function *trace_ray_func = nir_function_create(shader, "trace_ray_func");
radv_nir_init_rt_function_params(trace_ray_func, MESA_SHADER_INTERSECTION, max_payload_size);
vars.trace_ray_func = trace_ray_func;
nir_function *chit_miss_func = nir_function_create(shader, "chit_miss_func");
radv_nir_init_rt_function_params(chit_miss_func, MESA_SHADER_CLOSEST_HIT, max_payload_size);
vars.chit_miss_func = chit_miss_func;
nir_function *callable_func = nir_function_create(shader, "callable_func");
radv_nir_init_rt_function_params(callable_func, MESA_SHADER_CALLABLE, max_payload_size);
vars.callable_func = callable_func;
vars.shader_record_ptr_param = -1u;
vars.traversal_addr_param = -1u;
vars.accel_struct_param = -1u;
vars.cull_mask_and_flags_param = -1u;
vars.sbt_offset_param = -1u;
vars.sbt_stride_param = -1u;
vars.miss_index_param = -1u;
vars.ray_origin_param = -1u;
vars.ray_tmin_param = -1u;
vars.ray_direction_param = -1u;
vars.ray_tmax_param = -1u;
vars.primitive_id_param = -1u;
vars.instance_addr_param = -1u;
vars.primitive_addr_param = -1u;
vars.geometry_id_and_flags_param = -1u;
vars.hit_kind_param = -1u;
vars.in_payload_base_param = -1u;
switch (shader->info.stage) {
case MESA_SHADER_CALLABLE:
vars.in_payload_base_param = RAYGEN_ARG_COUNT;
vars.shader_record_ptr_param = RAYGEN_ARG_SHADER_RECORD_PTR;
vars.traversal_addr_param = RAYGEN_ARG_TRAVERSAL_ADDR;
break;
case MESA_SHADER_RAYGEN:
vars.shader_record_ptr_param = RAYGEN_ARG_SHADER_RECORD_PTR;
vars.traversal_addr_param = RAYGEN_ARG_TRAVERSAL_ADDR;
break;
case MESA_SHADER_INTERSECTION:
vars.traversal_addr_param = TRAVERSAL_ARG_TRAVERSAL_ADDR;
vars.shader_record_ptr_param = TRAVERSAL_ARG_SHADER_RECORD_PTR;
vars.accel_struct_param = TRAVERSAL_ARG_ACCEL_STRUCT;
vars.cull_mask_and_flags_param = TRAVERSAL_ARG_CULL_MASK_AND_FLAGS;
vars.sbt_offset_param = TRAVERSAL_ARG_SBT_OFFSET;
vars.sbt_stride_param = TRAVERSAL_ARG_SBT_STRIDE;
vars.miss_index_param = TRAVERSAL_ARG_MISS_INDEX;
vars.ray_origin_param = TRAVERSAL_ARG_RAY_ORIGIN;
vars.ray_tmin_param = TRAVERSAL_ARG_RAY_TMIN;
vars.ray_direction_param = TRAVERSAL_ARG_RAY_DIRECTION;
vars.ray_tmax_param = TRAVERSAL_ARG_RAY_TMAX;
vars.in_payload_base_param = TRAVERSAL_ARG_PAYLOAD_BASE;
break;
case MESA_SHADER_CLOSEST_HIT:
case MESA_SHADER_MISS:
vars.traversal_addr_param = CHIT_MISS_ARG_TRAVERSAL_ADDR;
vars.shader_record_ptr_param = CHIT_MISS_ARG_SHADER_RECORD_PTR;
vars.accel_struct_param = CHIT_MISS_ARG_ACCEL_STRUCT;
vars.cull_mask_and_flags_param = CHIT_MISS_ARG_CULL_MASK_AND_FLAGS;
vars.sbt_offset_param = CHIT_MISS_ARG_SBT_OFFSET;
vars.sbt_stride_param = CHIT_MISS_ARG_SBT_STRIDE;
vars.miss_index_param = CHIT_MISS_ARG_MISS_INDEX;
vars.ray_origin_param = CHIT_MISS_ARG_RAY_ORIGIN;
vars.ray_tmin_param = CHIT_MISS_ARG_RAY_TMIN;
vars.ray_direction_param = CHIT_MISS_ARG_RAY_DIRECTION;
vars.ray_tmax_param = CHIT_MISS_ARG_RAY_TMAX;
vars.primitive_id_param = CHIT_MISS_ARG_PRIMITIVE_ID;
vars.instance_addr_param = CHIT_MISS_ARG_INSTANCE_ADDR;
vars.primitive_addr_param = CHIT_MISS_ARG_PRIMITIVE_ADDR;
vars.geometry_id_and_flags_param = CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS;
vars.hit_kind_param = CHIT_MISS_ARG_HIT_KIND;
vars.in_payload_base_param = CHIT_MISS_ARG_PAYLOAD_BASE;
break;
default:
break;
}
return vars;
}
static bool
lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_vars)
{
if (instr->type == nir_instr_type_jump) {
nir_jump_instr *jump = nir_instr_as_jump(instr);
if (jump->type == nir_jump_halt) {
jump->type = nir_jump_return;
return true;
}
return false;
} else if (instr->type != nir_instr_type_intrinsic) {
return false;
}
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
struct rt_variables *vars = _vars;
b->cursor = nir_before_instr(&intr->instr);
nir_def *ret = NULL;
switch (intr->intrinsic) {
case nir_intrinsic_execute_callable: {
struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS),
intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR);
unsigned param_count = RAYGEN_ARG_COUNT + DIV_ROUND_UP(vars->payload_size, 4);
nir_def **args = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count);
args[RT_ARG_LAUNCH_ID] = nir_load_param(b, RT_ARG_LAUNCH_ID);
args[RT_ARG_LAUNCH_SIZE] = nir_load_param(b, RT_ARG_LAUNCH_SIZE);
args[RT_ARG_DESCRIPTORS] = nir_load_param(b, RT_ARG_DESCRIPTORS);
args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS);
args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(b, RT_ARG_PUSH_CONSTANTS);
args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS);
args[RAYGEN_ARG_TRAVERSAL_ADDR] = nir_undef(b, 1, 64);
args[RAYGEN_ARG_SHADER_RECORD_PTR] = sbt_data.shader_record_ptr;
for (unsigned i = 0; i < DIV_ROUND_UP(vars->payload_size, 4); ++i) {
args[RAYGEN_ARG_COUNT + i] = nir_instr_def(&nir_build_deref_var(b, vars->out_payload_storage[i])->instr);
}
nir_build_indirect_call(b, vars->callable_func, sbt_data.shader_addr, param_count, args);
break;
}
case nir_intrinsic_trace_ray: {
nir_def *undef = nir_undef(b, 1, 32);
/* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
nir_def *cull_mask_and_flags = nir_ior(b, nir_ishl_imm(b, intr->src[2].ssa, 24), intr->src[1].ssa);
nir_def *traversal_addr = nir_load_param(b, vars->traversal_addr_param);
unsigned param_count = TRAVERSAL_ARG_PAYLOAD_BASE + DIV_ROUND_UP(vars->payload_size, 4);
nir_def **args = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count);
args[RT_ARG_LAUNCH_ID] = nir_load_param(b, RT_ARG_LAUNCH_ID);
args[RT_ARG_LAUNCH_SIZE] = nir_load_param(b, RT_ARG_LAUNCH_SIZE);
args[RT_ARG_DESCRIPTORS] = nir_load_param(b, RT_ARG_DESCRIPTORS);
args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS);
args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(b, RT_ARG_PUSH_CONSTANTS);
args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS);
args[TRAVERSAL_ARG_TRAVERSAL_ADDR] = traversal_addr;
/* Traversal does not have a shader record. */
args[TRAVERSAL_ARG_SHADER_RECORD_PTR] = nir_undef(b, 1, 64);
args[TRAVERSAL_ARG_ACCEL_STRUCT] = intr->src[0].ssa;
args[TRAVERSAL_ARG_CULL_MASK_AND_FLAGS] = cull_mask_and_flags;
args[TRAVERSAL_ARG_SBT_OFFSET] = nir_iand_imm(b, intr->src[3].ssa, 0xf);
args[TRAVERSAL_ARG_SBT_STRIDE] = nir_iand_imm(b, intr->src[4].ssa, 0xf);
args[TRAVERSAL_ARG_MISS_INDEX] = nir_iand_imm(b, intr->src[5].ssa, 0xffff);
args[TRAVERSAL_ARG_RAY_ORIGIN] = intr->src[6].ssa;
args[TRAVERSAL_ARG_RAY_TMIN] = intr->src[7].ssa;
args[TRAVERSAL_ARG_RAY_DIRECTION] = intr->src[8].ssa;
args[TRAVERSAL_ARG_RAY_TMAX] = intr->src[9].ssa;
args[TRAVERSAL_ARG_PRIMITIVE_ADDR] = nir_undef(b, 1, 64);
args[TRAVERSAL_ARG_PRIMITIVE_ID] = undef;
args[TRAVERSAL_ARG_INSTANCE_ADDR] = nir_undef(b, 1, 64);
args[TRAVERSAL_ARG_GEOMETRY_ID_AND_FLAGS] = undef;
args[TRAVERSAL_ARG_HIT_KIND] = undef;
for (unsigned i = 0; i < DIV_ROUND_UP(vars->payload_size, 4); ++i) {
args[TRAVERSAL_ARG_PAYLOAD_BASE + i] =
nir_instr_def(&nir_build_deref_var(b, vars->out_payload_storage[i])->instr);
}
nir_build_indirect_call(b, vars->trace_ray_func, traversal_addr, param_count, args);
break;
}
case nir_intrinsic_load_shader_record_ptr: {
ret = nir_load_param(b, vars->shader_record_ptr_param);
break;
}
case nir_intrinsic_load_ray_launch_size: {
ret = nir_load_param(b, RT_ARG_LAUNCH_SIZE);
break;
};
case nir_intrinsic_load_ray_launch_id: {
ret = nir_load_param(b, RT_ARG_LAUNCH_ID);
break;
}
case nir_intrinsic_load_ray_t_min: {
ret = nir_load_param(b, vars->ray_tmin_param);
break;
}
case nir_intrinsic_load_ray_t_max: {
ret = nir_load_param(b, vars->ray_tmax_param);
break;
}
case nir_intrinsic_load_ray_world_origin: {
ret = nir_load_param(b, vars->ray_origin_param);
break;
}
case nir_intrinsic_load_ray_world_direction: {
ret = nir_load_param(b, vars->ray_direction_param);
break;
}
case nir_intrinsic_load_ray_instance_custom_index: {
ret = radv_load_custom_instance(vars->device, b, nir_load_param(b, vars->instance_addr_param));
break;
}
case nir_intrinsic_load_primitive_id: {
ret = nir_load_param(b, vars->primitive_id_param);
break;
}
case nir_intrinsic_load_ray_geometry_index: {
ret = nir_load_param(b, vars->geometry_id_and_flags_param);
ret = nir_iand_imm(b, ret, 0xFFFFFFF);
break;
}
case nir_intrinsic_load_instance_id: {
ret = radv_load_instance_id(vars->device, b, nir_load_param(b, vars->instance_addr_param));
break;
}
case nir_intrinsic_load_ray_flags: {
ret = nir_iand_imm(b, nir_load_param(b, vars->cull_mask_and_flags_param), 0xFFFFFF);
break;
}
case nir_intrinsic_load_ray_hit_kind: {
ret = nir_load_param(b, vars->hit_kind_param);
break;
}
case nir_intrinsic_load_ray_world_to_object: {
unsigned c = nir_intrinsic_column(intr);
nir_def *instance_node_addr = nir_load_param(b, vars->instance_addr_param);
nir_def *wto_matrix[3];
radv_load_wto_matrix(vars->device, b, instance_node_addr, wto_matrix);
nir_def *vals[3];
for (unsigned i = 0; i < 3; ++i)
vals[i] = nir_channel(b, wto_matrix[i], c);
ret = nir_vec(b, vals, 3);
break;
}
case nir_intrinsic_load_ray_object_to_world: {
unsigned c = nir_intrinsic_column(intr);
nir_def *otw_matrix[3];
radv_load_otw_matrix(vars->device, b, nir_load_param(b, vars->instance_addr_param), otw_matrix);
ret = nir_vec3(b, nir_channel(b, otw_matrix[0], c), nir_channel(b, otw_matrix[1], c),
nir_channel(b, otw_matrix[2], c));
break;
}
case nir_intrinsic_load_ray_object_origin: {
nir_def *wto_matrix[3];
radv_load_wto_matrix(vars->device, b, nir_load_param(b, vars->instance_addr_param), wto_matrix);
ret = nir_build_vec3_mat_mult(b, nir_load_param(b, vars->ray_origin_param), wto_matrix, true);
break;
}
case nir_intrinsic_load_ray_object_direction: {
nir_def *wto_matrix[3];
radv_load_wto_matrix(vars->device, b, nir_load_param(b, vars->instance_addr_param), wto_matrix);
ret = nir_build_vec3_mat_mult(b, nir_load_param(b, vars->ray_direction_param), wto_matrix, false);
break;
}
case nir_intrinsic_load_cull_mask: {
ret = nir_ushr_imm(b, nir_load_param(b, vars->cull_mask_and_flags_param), 24);
break;
}
case nir_intrinsic_load_sbt_base_amd: {
ret = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS);
break;
}
case nir_intrinsic_load_sbt_offset_amd: {
ret = nir_load_param(b, vars->sbt_offset_param);
break;
}
case nir_intrinsic_load_sbt_stride_amd: {
ret = nir_load_param(b, vars->sbt_stride_param);
break;
}
case nir_intrinsic_load_accel_struct_amd: {
ret = nir_load_param(b, vars->accel_struct_param);
break;
}
case nir_intrinsic_load_cull_mask_and_flags_amd: {
ret = nir_load_param(b, vars->cull_mask_and_flags_param);
break;
}
case nir_intrinsic_execute_closest_hit_amd: {
struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS),
intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR);
nir_def *should_return =
nir_test_mask(b, nir_load_param(b, vars->cull_mask_and_flags_param), SpvRayFlagsSkipClosestHitShaderKHRMask);
if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) {
should_return = nir_ior(b, should_return, nir_ieq_imm(b, sbt_data.shader_addr, 0));
}
/* should_return is set if we had a hit but we won't be calling the closest hit
* shader and hence need to return immediately to the calling shader. */
nir_push_if(b, nir_inot(b, should_return));
unsigned param_count = CHIT_MISS_ARG_PAYLOAD_BASE + DIV_ROUND_UP(vars->payload_size, 4);
nir_def **args = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count);
args[RT_ARG_LAUNCH_ID] = nir_load_param(b, RT_ARG_LAUNCH_ID);
args[RT_ARG_LAUNCH_SIZE] = nir_load_param(b, RT_ARG_LAUNCH_SIZE);
args[RT_ARG_DESCRIPTORS] = nir_load_param(b, RT_ARG_DESCRIPTORS);
args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS);
args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(b, RT_ARG_PUSH_CONSTANTS);
args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS);
args[CHIT_MISS_ARG_TRAVERSAL_ADDR] = nir_load_param(b, vars->traversal_addr_param);
args[CHIT_MISS_ARG_SHADER_RECORD_PTR] = sbt_data.shader_record_ptr;
args[CHIT_MISS_ARG_ACCEL_STRUCT] = nir_load_param(b, vars->accel_struct_param);
args[CHIT_MISS_ARG_CULL_MASK_AND_FLAGS] = nir_load_param(b, vars->cull_mask_and_flags_param);
args[CHIT_MISS_ARG_SBT_OFFSET] = nir_load_param(b, vars->sbt_offset_param);
args[CHIT_MISS_ARG_SBT_STRIDE] = nir_load_param(b, vars->sbt_stride_param);
args[CHIT_MISS_ARG_MISS_INDEX] = nir_load_param(b, vars->miss_index_param);
args[CHIT_MISS_ARG_RAY_ORIGIN] = nir_load_param(b, vars->ray_origin_param);
args[CHIT_MISS_ARG_RAY_TMIN] = nir_load_param(b, vars->ray_tmin_param);
args[CHIT_MISS_ARG_RAY_DIRECTION] = nir_load_param(b, vars->ray_direction_param);
args[CHIT_MISS_ARG_RAY_TMAX] = intr->src[1].ssa;
args[CHIT_MISS_ARG_PRIMITIVE_ADDR] = intr->src[2].ssa;
args[CHIT_MISS_ARG_PRIMITIVE_ID] = intr->src[3].ssa;
args[CHIT_MISS_ARG_INSTANCE_ADDR] = intr->src[4].ssa;
args[CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS] = intr->src[5].ssa;
args[CHIT_MISS_ARG_HIT_KIND] = intr->src[6].ssa;
for (unsigned i = 0; i < DIV_ROUND_UP(vars->payload_size, 4); ++i) {
args[CHIT_MISS_ARG_PAYLOAD_BASE + i] =
nir_instr_def(&nir_build_deref_cast(b, nir_load_param(b, TRAVERSAL_ARG_PAYLOAD_BASE + i),
nir_var_shader_call_data, glsl_uint_type(), 4)
->instr);
}
nir_build_indirect_call(b, vars->chit_miss_func, sbt_data.shader_addr, param_count, args);
nir_pop_if(b, NULL);
break;
}
case nir_intrinsic_execute_miss_amd: {
nir_def *undef = nir_undef(b, 1, 32);
nir_def *miss_index = nir_load_param(b, vars->miss_index_param);
struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), miss_index, SBT_MISS, SBT_RECURSIVE_PTR);
if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) {
/* In case of a NULL miss shader, do nothing and just return. */
nir_push_if(b, nir_ine_imm(b, sbt_data.shader_addr, 0));
}
unsigned param_count = CHIT_MISS_ARG_PAYLOAD_BASE + DIV_ROUND_UP(vars->payload_size, 4);
nir_def **args = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count);
args[RT_ARG_LAUNCH_ID] = nir_load_param(b, RT_ARG_LAUNCH_ID);
args[RT_ARG_LAUNCH_SIZE] = nir_load_param(b, RT_ARG_LAUNCH_SIZE);
args[RT_ARG_DESCRIPTORS] = nir_load_param(b, RT_ARG_DESCRIPTORS);
args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS);
args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(b, RT_ARG_PUSH_CONSTANTS);
args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS);
args[CHIT_MISS_ARG_TRAVERSAL_ADDR] = nir_load_param(b, vars->traversal_addr_param);
args[CHIT_MISS_ARG_SHADER_RECORD_PTR] = sbt_data.shader_record_ptr;
args[CHIT_MISS_ARG_ACCEL_STRUCT] = nir_load_param(b, vars->accel_struct_param);
args[CHIT_MISS_ARG_CULL_MASK_AND_FLAGS] = nir_load_param(b, vars->cull_mask_and_flags_param);
args[CHIT_MISS_ARG_SBT_OFFSET] = nir_load_param(b, vars->sbt_offset_param);
args[CHIT_MISS_ARG_SBT_STRIDE] = nir_load_param(b, vars->sbt_stride_param);
args[CHIT_MISS_ARG_MISS_INDEX] = nir_load_param(b, vars->miss_index_param);
args[CHIT_MISS_ARG_RAY_ORIGIN] = nir_load_param(b, vars->ray_origin_param);
args[CHIT_MISS_ARG_RAY_TMIN] = nir_load_param(b, vars->ray_tmin_param);
args[CHIT_MISS_ARG_RAY_DIRECTION] = nir_load_param(b, vars->ray_direction_param);
args[CHIT_MISS_ARG_RAY_TMAX] = intr->src[0].ssa;
args[CHIT_MISS_ARG_PRIMITIVE_ADDR] = nir_undef(b, 1, 64);
args[CHIT_MISS_ARG_PRIMITIVE_ID] = undef;
args[CHIT_MISS_ARG_INSTANCE_ADDR] = nir_undef(b, 1, 64);
args[CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS] = undef;
args[CHIT_MISS_ARG_HIT_KIND] = undef;
for (unsigned i = 0; i < DIV_ROUND_UP(vars->payload_size, 4); ++i) {
args[CHIT_MISS_ARG_PAYLOAD_BASE + i] =
nir_instr_def(&nir_build_deref_cast(b, nir_load_param(b, TRAVERSAL_ARG_PAYLOAD_BASE + i),
nir_var_shader_call_data, glsl_uint_type(), 4)
->instr);
}
nir_build_indirect_call(b, vars->chit_miss_func, sbt_data.shader_addr, param_count, args);
if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR))
nir_pop_if(b, NULL);
break;
}
case nir_intrinsic_load_ray_triangle_vertex_positions: {
nir_def *primitive_addr = nir_load_param(b, vars->primitive_addr_param);
ret = radv_load_vertex_position(vars->device, b, primitive_addr, nir_intrinsic_column(intr));
break;
}
default:
return false;
}
if (ret)
nir_def_rewrite_uses(&intr->def, ret);
nir_instr_remove(&intr->instr);
return true;
}
/* Lower aliased ray payload variables. See radv_nir_lower_rt_io for a more detailed
* explanation on ray payload storage.
*/
static void
lower_rt_deref_var(nir_shader *shader, nir_function_impl *impl, nir_instr *instr, struct hash_table *cloned_vars)
{
nir_deref_instr *deref = nir_instr_as_deref(instr);
nir_variable *var = deref->var;
struct hash_entry *entry = _mesa_hash_table_search(cloned_vars, var);
if (!(var->data.mode & nir_var_function_temp) && !entry)
return;
hash_table_foreach (cloned_vars, cloned_entry) {
if (var == cloned_entry->data)
return;
}
nir_variable *new_var;
if (entry) {
new_var = entry->data;
} else {
new_var = nir_variable_clone(var, shader);
_mesa_hash_table_insert(cloned_vars, var, new_var);
exec_node_remove(&var->node);
var->data.mode = nir_var_shader_temp;
exec_list_push_tail(&shader->variables, &var->node);
exec_list_push_tail(&impl->locals, &new_var->node);
}
deref->modes = nir_var_shader_temp;
nir_foreach_use_safe (use, nir_instr_def(instr)) {
if (nir_src_is_if(use))
continue;
nir_instr *parent = nir_src_parent_instr(use);
if (parent->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(parent);
if (intrin->intrinsic != nir_intrinsic_trace_ray && intrin->intrinsic != nir_intrinsic_execute_callable &&
intrin->intrinsic != nir_intrinsic_execute_closest_hit_amd &&
intrin->intrinsic != nir_intrinsic_execute_miss_amd)
continue;
nir_builder b = nir_builder_at(nir_before_instr(parent));
nir_deref_instr *old_deref = nir_build_deref_var(&b, var);
nir_deref_instr *new_deref = nir_build_deref_var(&b, new_var);
nir_copy_deref(&b, new_deref, old_deref);
b.cursor = nir_after_instr(parent);
nir_copy_deref(&b, old_deref, new_deref);
nir_src_rewrite(use, nir_instr_def(&new_deref->instr));
}
}
static bool
lower_rt_derefs_functions(nir_shader *shader)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
bool progress = false;
struct hash_table *cloned_vars = _mesa_pointer_hash_table_create(shader);
nir_foreach_block (block, impl) {
nir_foreach_instr_safe (instr, block) {
if (instr->type != nir_instr_type_deref)
continue;
nir_deref_instr *deref = nir_instr_as_deref(instr);
if (!nir_deref_mode_is(deref, nir_var_function_temp))
continue;
if (deref->deref_type == nir_deref_type_var) {
lower_rt_deref_var(shader, impl, instr, cloned_vars);
progress = true;
} else {
assert(deref->deref_type != nir_deref_type_cast);
/* Parent modes might have changed, propagate change */
nir_deref_instr *parent = nir_src_as_deref(deref->parent);
if (parent->modes != deref->modes)
deref->modes = parent->modes;
}
}
}
return nir_progress(progress, impl, nir_metadata_control_flow);
}
void
radv_nir_lower_rt_io_functions(nir_shader *nir)
{
/* When compiling separately and using function calls, function parameters for ray payloads store the currently
* active ray payload. The same parameters are reused for different ray payload types - the function signature
* only allows for passing one ray payload (the active one) at a time. We model this in NIR by designating all ray
* payload variables as aliased (every ray payload variable's driver location is 0).
*
* This doesn't quite match the SPIR-V semantics of different ray payload variables - each payload variable is in
* a different location and can be written/read independently. It's lower_rt_derefs's job to accomodate this.
* lower_rt_derefs duplicates all ray payload variables and marks the original one as a shader_temp variable,
* in order to make the shader's payload read/writes operate on temporary copies that do not alias.
* radv_nir_lower_ray_payload_derefs will then convert the aliased variables to proper payload loads/stores, which
* later get lowered to function call parameters by `lower_rt_storage`.
*/
NIR_PASS(_, nir, lower_rt_derefs_functions);
NIR_PASS(_, nir, nir_split_var_copies);
NIR_PASS(_, nir, nir_lower_var_copies);
NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, 0);
}
nir_function_impl *
radv_get_rt_shader_entrypoint(nir_shader *shader)
{
nir_foreach_function_impl (impl, shader)
if (impl->function->is_entrypoint || impl->function->is_exported)
return impl;
return NULL;
}
void
radv_nir_lower_rt_abi_functions(nir_shader *shader, const struct radv_shader_info *info, uint32_t payload_size,
struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
nir_function *entrypoint_function = impl->function;
radv_nir_init_rt_function_params(entrypoint_function, shader->info.stage, payload_size);
struct rt_variables vars = create_rt_variables(shader, device, pipeline->base.base.create_flags, payload_size);
nir_shader_instructions_pass(shader, lower_rt_instruction, nir_metadata_none, &vars);
/* This can't use NIR_PASS because NIR_DEBUG=serialize,clone invalidates pointers. */
nir_lower_returns(shader);
/* initialize variables */
nir_progress(true, impl, nir_metadata_none);
/* cleanup passes */
nir_builder b = nir_builder_at(nir_before_impl(impl));
nir_deref_instr **payload_in_storage =
rzalloc_array_size(shader, sizeof(nir_deref_instr *), DIV_ROUND_UP(payload_size, 4));
if (vars.in_payload_base_param != -1u) {
for (unsigned i = 0; i < DIV_ROUND_UP(payload_size, 4); ++i) {
payload_in_storage[i] = nir_build_deref_cast(&b, nir_load_param(&b, vars.in_payload_base_param + i),
nir_var_shader_call_data, glsl_uint_type(), 4);
}
}
NIR_PASS(_, shader, radv_nir_lower_rt_storage, NULL, payload_in_storage, vars.out_payload_storage, info->wave_size);
NIR_PASS(_, shader, nir_remove_dead_derefs);
NIR_PASS(_, shader, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_call_data, NULL);
NIR_PASS(_, shader, nir_lower_global_vars_to_local);
NIR_PASS(_, shader, nir_lower_vars_to_ssa);
}

View file

@ -0,0 +1,22 @@
/*
* Copyright © 2025 Valve Corporation
*
* SPDX-License-Identifier: MIT
*/
/* This file contains the public interface for all RT pipeline stage lowering. */
#ifndef RADV_NIR_RT_STAGE_FUNCTIONS_H
#define RADV_NIR_RT_STAGE_FUNCTIONS_H
#include "radv_pipeline_rt.h"
nir_function_impl *radv_get_rt_shader_entrypoint(nir_shader *shader);
void radv_nir_init_rt_function_params(nir_function *function, mesa_shader_stage stage, unsigned payload_size);
void radv_nir_lower_rt_abi_functions(nir_shader *shader, const struct radv_shader_info *info, uint32_t payload_size,
struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline);
void radv_nir_lower_rt_io_functions(nir_shader *shader);
#endif // RADV_NIR_RT_STAGE_FUNCTIONS_H

View file

@ -8,8 +8,11 @@
#include "nir/radv_nir_rt_common.h" #include "nir/radv_nir_rt_common.h"
#include "nir/radv_nir_rt_stage_common.h" #include "nir/radv_nir_rt_stage_common.h"
#include "nir/radv_nir_rt_stage_monolithic.h" #include "nir/radv_nir_rt_stage_monolithic.h"
#include "aco_nir_call_attribs.h"
#include "nir_builder.h" #include "nir_builder.h"
#include "radv_device.h" #include "radv_device.h"
#include "radv_nir_rt_stage_functions.h"
#include "radv_physical_device.h" #include "radv_physical_device.h"
struct chit_miss_inlining_params { struct chit_miss_inlining_params {
@ -20,6 +23,7 @@ struct chit_miss_inlining_params {
struct radv_nir_sbt_data *sbt; struct radv_nir_sbt_data *sbt;
unsigned payload_offset; unsigned payload_offset;
unsigned stack_base;
}; };
struct chit_miss_inlining_vars { struct chit_miss_inlining_vars {
@ -217,7 +221,13 @@ preprocess_shader_cb_monolithic(nir_shader *nir, void *_data)
{ {
uint32_t *payload_offset = _data; uint32_t *payload_offset = _data;
NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, *payload_offset); /* When compiling in monolithic mode, ray payloads are lowered to registers. Each ray payload gets a separate
* space in the register storage. Call nir_lower_vars_to_explicit_types to assign separate locations to payload
* variables, then lower to load/store instructions which will later be lowered to variable loads/stores.
*/
nir_lower_vars_to_explicit_types(nir, nir_var_function_temp, glsl_get_natural_size_align_bytes);
nir_lower_vars_to_explicit_types(nir, nir_var_shader_temp, glsl_get_natural_size_align_bytes);
radv_nir_lower_ray_payload_derefs(nir, *payload_offset);
} }
void void
@ -234,10 +244,6 @@ struct rt_variables {
uint32_t payload_offset; uint32_t payload_offset;
unsigned stack_size; unsigned stack_size;
nir_def *launch_sizes[3];
nir_def *launch_ids[3];
nir_def *shader_record_ptr;
nir_variable *stack_ptr; nir_variable *stack_ptr;
}; };
@ -255,19 +261,20 @@ radv_build_recursive_case(nir_builder *b, nir_def *idx, struct radv_ray_tracing_
.device = params->device, .device = params->device,
}; };
struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
setup_chit_miss_inlining(&vars, params, b, shader, var_remap);
nir_opt_dead_cf(shader); nir_opt_dead_cf(shader);
preprocess_shader_cb_monolithic(shader, &params->payload_offset); preprocess_shader_cb_monolithic(shader, &params->payload_offset);
struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
setup_chit_miss_inlining(&vars, params, b, shader, var_remap);
nir_shader_intrinsics_pass(shader, lower_rt_instruction_chit_miss, nir_metadata_control_flow, &vars); nir_shader_intrinsics_pass(shader, lower_rt_instruction_chit_miss, nir_metadata_control_flow, &vars);
nir_lower_returns(shader); nir_lower_returns(shader);
nir_opt_dce(shader); nir_opt_dce(shader);
radv_nir_inline_constants(b->shader, shader); radv_nir_inline_constants(b->shader, shader);
b->shader->scratch_size = MAX2(b->shader->scratch_size, shader->scratch_size + params->stack_base);
nir_push_if(b, nir_ieq_imm(b, idx, group->handle.general_index)); nir_push_if(b, nir_ieq_imm(b, idx, group->handle.general_index));
nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
@ -323,7 +330,7 @@ lower_rt_call_monolithic(nir_builder *b, nir_intrinsic_instr *intr, void *data)
}; };
nir_def *stack_ptr = nir_load_var(b, vars->stack_ptr); nir_def *stack_ptr = nir_load_var(b, vars->stack_ptr);
nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, stack_ptr, b->shader->scratch_size), 0x1); nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, stack_ptr, vars->stack_size), 0x1);
struct radv_nir_rt_traversal_result result = struct radv_nir_rt_traversal_result result =
radv_build_traversal(state->device, state->pipeline, b, &params, NULL); radv_build_traversal(state->device, state->pipeline, b, &params, NULL);
@ -346,7 +353,8 @@ lower_rt_call_monolithic(nir_builder *b, nir_intrinsic_instr *intr, void *data)
nir_push_if(b, nir_load_var(b, result.hit)); nir_push_if(b, nir_load_var(b, result.hit));
{ {
struct radv_nir_sbt_data hit_sbt = struct radv_nir_sbt_data hit_sbt =
radv_nir_load_sbt_entry(b, nir_load_var(b, result.sbt_index), SBT_HIT, SBT_CLOSEST_HIT_IDX); radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), nir_load_var(b, result.sbt_index),
SBT_HIT, SBT_CLOSEST_HIT_IDX);
inline_params.sbt = &hit_sbt; inline_params.sbt = &hit_sbt;
nir_def *should_return = nir_test_mask(b, params.cull_mask_and_flags, SpvRayFlagsSkipClosestHitShaderKHRMask); nir_def *should_return = nir_test_mask(b, params.cull_mask_and_flags, SpvRayFlagsSkipClosestHitShaderKHRMask);
@ -364,7 +372,8 @@ lower_rt_call_monolithic(nir_builder *b, nir_intrinsic_instr *intr, void *data)
} }
nir_push_else(b, NULL); nir_push_else(b, NULL);
{ {
struct radv_nir_sbt_data miss_sbt = radv_nir_load_sbt_entry(b, params.miss_index, SBT_MISS, SBT_GENERAL_IDX); struct radv_nir_sbt_data miss_sbt = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS),
params.miss_index, SBT_MISS, SBT_GENERAL_IDX);
inline_params.sbt = &miss_sbt; inline_params.sbt = &miss_sbt;
radv_visit_inlined_shaders(b, miss_sbt.shader_addr, radv_visit_inlined_shaders(b, miss_sbt.shader_addr,
@ -394,15 +403,19 @@ lower_rt_instruction_monolithic(nir_builder *b, nir_intrinsic_instr *intr, void
switch (intr->intrinsic) { switch (intr->intrinsic) {
case nir_intrinsic_load_shader_record_ptr: { case nir_intrinsic_load_shader_record_ptr: {
nir_def_replace(&intr->def, vars->shader_record_ptr); nir_def_replace(&intr->def, nir_load_param(b, RAYGEN_ARG_SHADER_RECORD_PTR));
return true; return true;
} }
case nir_intrinsic_load_ray_launch_size: { case nir_intrinsic_load_ray_launch_size: {
nir_def_replace(&intr->def, nir_vec(b, vars->launch_sizes, 3)); nir_def_replace(&intr->def, nir_load_param(b, RT_ARG_LAUNCH_SIZE));
return true; return true;
}; };
case nir_intrinsic_load_ray_launch_id: { case nir_intrinsic_load_ray_launch_id: {
nir_def_replace(&intr->def, nir_vec(b, vars->launch_ids, 3)); nir_def_replace(&intr->def, nir_load_param(b, RT_ARG_LAUNCH_ID));
return true;
}
case nir_intrinsic_load_sbt_base_amd: {
nir_def_replace(&intr->def, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS));
return true; return true;
} }
case nir_intrinsic_load_scratch: { case nir_intrinsic_load_scratch: {
@ -428,30 +441,36 @@ radv_count_hit_attrib_slots(nir_builder *b, nir_intrinsic_instr *instr, void *da
return false; return false;
} }
static bool
radv_count_ray_payload_size(nir_builder *b, nir_intrinsic_instr *instr, void *data)
{
uint32_t *count = data;
if (instr->intrinsic == nir_intrinsic_load_incoming_ray_payload_amd ||
instr->intrinsic == nir_intrinsic_load_outgoing_ray_payload_amd ||
instr->intrinsic == nir_intrinsic_store_incoming_ray_payload_amd ||
instr->intrinsic == nir_intrinsic_store_outgoing_ray_payload_amd)
*count = MAX2(*count, (nir_intrinsic_base(instr) + 1) * 4);
return false;
}
void void
radv_nir_lower_rt_abi_monolithic(nir_shader *shader, const struct radv_shader_args *args, uint32_t *stack_size, radv_nir_lower_rt_abi_monolithic(nir_shader *shader, struct radv_device *device,
struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline) struct radv_ray_tracing_pipeline *pipeline)
{ {
nir_function_impl *impl = nir_shader_get_entrypoint(shader); nir_function_impl *impl = nir_shader_get_entrypoint(shader);
radv_nir_init_rt_function_params(impl->function, MESA_SHADER_RAYGEN, 0);
nir_builder b = nir_builder_at(nir_before_impl(impl)); nir_builder b = nir_builder_at(nir_before_impl(impl));
struct rt_variables vars = { struct rt_variables vars = {
.device = device, .device = device,
.flags = pipeline->base.base.create_flags, .flags = pipeline->base.base.create_flags,
.stack_size = b.shader->scratch_size,
}; };
vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr"); vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
for (uint32_t i = 0; i < ARRAY_SIZE(vars.launch_sizes); i++)
vars.launch_sizes[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_sizes[i]);
for (uint32_t i = 0; i < ARRAY_SIZE(vars.launch_sizes); i++) {
vars.launch_ids[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_ids[i]);
}
nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
vars.shader_record_ptr = nir_pack_64_2x32(&b, record_ptr);
nir_def *stack_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base);
nir_store_var(&b, vars.stack_ptr, stack_ptr, 0x1);
struct lower_rt_instruction_monolithic_state state = { struct lower_rt_instruction_monolithic_state state = {
.device = device, .device = device,
@ -464,17 +483,23 @@ radv_nir_lower_rt_abi_monolithic(nir_shader *shader, const struct radv_shader_ar
nir_index_ssa_defs(impl); nir_index_ssa_defs(impl);
uint32_t hit_attrib_count = 0; uint32_t hit_attrib_count = 0;
uint32_t payload_size = 0;
nir_shader_intrinsics_pass(shader, radv_count_hit_attrib_slots, nir_metadata_all, &hit_attrib_count); nir_shader_intrinsics_pass(shader, radv_count_hit_attrib_slots, nir_metadata_all, &hit_attrib_count);
nir_shader_intrinsics_pass(shader, radv_count_ray_payload_size, nir_metadata_all, &payload_size);
/* Register storage for hit attributes */ /* Register storage for hit attributes */
STACK_ARRAY(nir_variable *, hit_attribs, hit_attrib_count); STACK_ARRAY(nir_variable *, hit_attribs, hit_attrib_count);
STACK_ARRAY(nir_variable *, payload_vars, DIV_ROUND_UP(payload_size, 4));
STACK_ARRAY(nir_deref_instr *, payload_derefs, DIV_ROUND_UP(payload_size, 4));
for (uint32_t i = 0; i < hit_attrib_count; i++) for (uint32_t i = 0; i < hit_attrib_count; i++)
hit_attribs[i] = nir_local_variable_create(impl, glsl_uint_type(), "ahit_attrib"); hit_attribs[i] = nir_local_variable_create(impl, glsl_uint_type(), "ahit_attrib");
radv_nir_lower_hit_attribs(shader, hit_attribs, 0); b.cursor = nir_before_impl(impl);
for (uint32_t i = 0; i < DIV_ROUND_UP(payload_size, 4); i++) {
payload_vars[i] = nir_local_variable_create(impl, glsl_uint_type(), "payload_storage");
payload_derefs[i] = nir_build_deref_var(&b, payload_vars[i]);
}
vars.stack_size = MAX2(vars.stack_size, shader->scratch_size); radv_nir_lower_rt_storage(shader, hit_attribs, payload_derefs, payload_vars, 0);
*stack_size = MAX2(*stack_size, vars.stack_size);
shader->scratch_size = 0;
nir_progress(true, impl, nir_metadata_none); nir_progress(true, impl, nir_metadata_none);

View file

@ -11,8 +11,8 @@
#include "radv_pipeline_rt.h" #include "radv_pipeline_rt.h"
void radv_nir_lower_rt_abi_monolithic(nir_shader *shader, const struct radv_shader_args *args, uint32_t *stack_size, void radv_nir_lower_rt_abi_monolithic(nir_shader *shader, struct radv_device *device,
struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline); struct radv_ray_tracing_pipeline *pipeline);
void radv_nir_lower_rt_io_monolithic(nir_shader *shader); void radv_nir_lower_rt_io_monolithic(nir_shader *shader);
#endif // RADV_NIR_RT_STAGE_MONOLITHIC_H #endif // RADV_NIR_RT_STAGE_MONOLITHIC_H

View file

@ -385,6 +385,7 @@ insert_inlined_shader(nir_builder *b, struct traversal_inlining_params *params,
nir_opt_dce(shader); nir_opt_dce(shader);
radv_nir_inline_constants(b->shader, shader); radv_nir_inline_constants(b->shader, shader);
b->shader->scratch_size = MAX2(b->shader->scratch_size, shader->scratch_size);
nir_push_if(b, nir_ieq_imm(b, idx, call_idx)); nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
@ -607,6 +608,9 @@ nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
/* Eliminate the casts introduced for the commit return of the any-hit shader. */ /* Eliminate the casts introduced for the commit return of the any-hit shader. */
NIR_PASS(_, intersection, nir_opt_deref); NIR_PASS(_, intersection, nir_opt_deref);
/* Reflect the scratch memory required by the inlined any-hit shader */
intersection->scratch_size += any_hit->scratch_size;
ralloc_free(dead_ctx); ralloc_free(dead_ctx);
} }
@ -666,9 +670,6 @@ radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_g
params->preprocess(any_hit_stage, params->preprocess_data); params->preprocess(any_hit_stage, params->preprocess_data);
/* reserve stack size for any_hit before it is inlined */
data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size;
nir_lower_intersection_shader(nir_stage, any_hit_stage); nir_lower_intersection_shader(nir_stage, any_hit_stage);
ralloc_free(any_hit_stage); ralloc_free(any_hit_stage);
} }
@ -857,7 +858,8 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
nir_push_if(b, nir_inot(b, intersection->base.opaque)); nir_push_if(b, nir_inot(b, intersection->base.opaque));
{ {
struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX); struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1); nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
struct traversal_inlining_params inlining_params = { struct traversal_inlining_params inlining_params = {
@ -942,7 +944,8 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
nir_store_var(b, data->trav_vars.ahit_isec_count, nir_store_var(b, data->trav_vars.ahit_isec_count,
nir_iadd_imm(b, nir_load_var(b, data->trav_vars.ahit_isec_count), 1 << 16), 0x1); nir_iadd_imm(b, nir_load_var(b, data->trav_vars.ahit_isec_count), 1 << 16), 0x1);
struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX); struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX);
nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1); nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
struct traversal_inlining_params inlining_params = { struct traversal_inlining_params inlining_params = {
@ -1120,22 +1123,22 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin
if (device->rra_trace.ray_history_addr) if (device->rra_trace.ray_history_addr)
radv_build_end_trace_token(b, &data, nir_load_var(b, iteration_instance_count)); radv_build_end_trace_token(b, &data, nir_load_var(b, iteration_instance_count));
nir_progress(true, nir_shader_get_entrypoint(b->shader), nir_metadata_none); nir_progress(true, b->impl, nir_metadata_none);
radv_nir_lower_hit_attrib_derefs(b->shader); radv_nir_lower_hit_attrib_derefs(b->shader);
return data.trav_vars.result; return data.trav_vars.result;
} }
static void static void
preprocess_traversal_shader_ahit_isec(nir_shader *nir, void *_) preprocess_traversal_shader_ahit_isec(nir_shader *nir, void *cb)
{ {
/* Compiling a separate traversal shader is always done in CPS mode. */ radv_nir_traversal_preprocess_cb preprocess_cb = cb;
radv_nir_lower_rt_io_cps(nir); preprocess_cb(nir);
} }
nir_shader * nir_shader *
radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
struct radv_ray_tracing_stage_info *info) struct radv_ray_tracing_stage_info *info, radv_nir_traversal_preprocess_cb preprocess)
{ {
const struct radv_physical_device *pdev = radv_device_physical(device); const struct radv_physical_device *pdev = radv_device_physical(device);
@ -1184,11 +1187,12 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_
params.direction = nir_load_ray_world_direction(&b); params.direction = nir_load_ray_world_direction(&b);
params.preprocess_ahit_isec = preprocess_traversal_shader_ahit_isec; params.preprocess_ahit_isec = preprocess_traversal_shader_ahit_isec;
params.cb_data = preprocess;
params.ignore_cull_mask = false; params.ignore_cull_mask = false;
struct radv_nir_rt_traversal_result result = radv_build_traversal(device, pipeline, &b, &params, info); struct radv_nir_rt_traversal_result result = radv_build_traversal(device, pipeline, &b, &params, info);
radv_nir_lower_hit_attribs(b.shader, hit_attribs, pdev->rt_wave_size); radv_nir_lower_rt_storage(b.shader, hit_attribs, NULL, NULL, pdev->rt_wave_size);
nir_push_if(&b, nir_load_var(&b, result.hit)); nir_push_if(&b, nir_load_var(&b, result.hit));
{ {

View file

@ -9,7 +9,10 @@
#include "radv_pipeline_rt.h" #include "radv_pipeline_rt.h"
typedef void (*radv_nir_traversal_preprocess_cb)(nir_shader *nir);
nir_shader *radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, nir_shader *radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
struct radv_ray_tracing_stage_info *info); struct radv_ray_tracing_stage_info *info,
radv_nir_traversal_preprocess_cb preprocess);
#endif // RADV_NIR_RT_TRAVERSAL_SHADER_H #endif // RADV_NIR_RT_TRAVERSAL_SHADER_H

View file

@ -8186,8 +8186,7 @@ radv_emit_ray_tracing_pipeline(struct radv_cmd_buffer *cmd_buffer, struct radv_r
const uint32_t traversal_shader_addr_offset = radv_get_user_sgpr_loc(rt_prolog, AC_UD_CS_TRAVERSAL_SHADER_ADDR); const uint32_t traversal_shader_addr_offset = radv_get_user_sgpr_loc(rt_prolog, AC_UD_CS_TRAVERSAL_SHADER_ADDR);
struct radv_shader *traversal_shader = cmd_buffer->state.shaders[MESA_SHADER_INTERSECTION]; struct radv_shader *traversal_shader = cmd_buffer->state.shaders[MESA_SHADER_INTERSECTION];
if (traversal_shader_addr_offset && traversal_shader) { if (traversal_shader_addr_offset && traversal_shader) {
uint64_t traversal_va = traversal_shader->va | radv_rt_priority_traversal; uint64_t traversal_va = traversal_shader->va;
radeon_begin(cs); radeon_begin(cs);
if (pdev->info.gfx_level >= GFX12) { if (pdev->info.gfx_level >= GFX12) {
gfx12_push_32bit_pointer(traversal_shader_addr_offset, traversal_va, &pdev->info); gfx12_push_32bit_pointer(traversal_shader_addr_offset, traversal_va, &pdev->info);
@ -13740,10 +13739,9 @@ radv_emit_rt_stack_size(struct radv_cmd_buffer *cmd_buffer)
unsigned rsrc2 = rt_prolog->config.rsrc2; unsigned rsrc2 = rt_prolog->config.rsrc2;
/* Reserve scratch for stacks manually since it is not handled by the compute path. */ /* Reserve scratch for stacks manually since it is not handled by the compute path. */
uint32_t scratch_bytes_per_wave = rt_prolog->config.scratch_bytes_per_wave;
const uint32_t wave_size = rt_prolog->info.wave_size; const uint32_t wave_size = rt_prolog->info.wave_size;
scratch_bytes_per_wave += uint32_t scratch_bytes_per_wave =
align(cmd_buffer->state.rt_stack_size * wave_size, pdev->info.scratch_wavesize_granularity); align(cmd_buffer->state.rt_stack_size * wave_size, pdev->info.scratch_wavesize_granularity);
cmd_buffer->compute_scratch_size_per_wave_needed = cmd_buffer->compute_scratch_size_per_wave_needed =

View file

@ -13,6 +13,7 @@
#include "nir/radv_nir.h" #include "nir/radv_nir.h"
#include "nir/radv_nir_rt_stage_cps.h" #include "nir/radv_nir_rt_stage_cps.h"
#include "nir/radv_nir_rt_stage_functions.h"
#include "nir/radv_nir_rt_stage_monolithic.h" #include "nir/radv_nir_rt_stage_monolithic.h"
#include "nir/radv_nir_rt_traversal_shader.h" #include "nir/radv_nir_rt_traversal_shader.h"
#include "ac_nir.h" #include "ac_nir.h"
@ -323,7 +324,6 @@ should_move_rt_instruction(nir_intrinsic_instr *instr)
switch (instr->intrinsic) { switch (instr->intrinsic) {
case nir_intrinsic_load_hit_attrib_amd: case nir_intrinsic_load_hit_attrib_amd:
return nir_intrinsic_base(instr) < RADV_MAX_HIT_ATTRIB_DWORDS; return nir_intrinsic_base(instr) < RADV_MAX_HIT_ATTRIB_DWORDS;
case nir_intrinsic_load_rt_arg_scratch_offset_amd:
case nir_intrinsic_load_ray_flags: case nir_intrinsic_load_ray_flags:
case nir_intrinsic_load_ray_object_origin: case nir_intrinsic_load_ray_object_origin:
case nir_intrinsic_load_ray_world_origin: case nir_intrinsic_load_ray_world_origin:
@ -364,7 +364,7 @@ move_rt_instructions(nir_shader *shader)
static VkResult static VkResult
radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
struct radv_ray_tracing_pipeline *pipeline, enum radv_rt_lowering_mode mode, struct radv_ray_tracing_pipeline *pipeline, enum radv_rt_lowering_mode mode,
struct radv_shader_stage *stage, uint32_t *stack_size, struct radv_shader_stage *stage, uint32_t *payload_size, uint32_t *stack_size,
struct radv_ray_tracing_stage_info *stage_info, struct radv_ray_tracing_stage_info *stage_info,
const struct radv_ray_tracing_stage_info *traversal_stage_info, const struct radv_ray_tracing_stage_info *traversal_stage_info,
struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache, struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache,
@ -384,6 +384,9 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
case RADV_RT_LOWERING_MODE_CPS: case RADV_RT_LOWERING_MODE_CPS:
radv_nir_lower_rt_io_cps(stage->nir); radv_nir_lower_rt_io_cps(stage->nir);
break; break;
case RADV_RT_LOWERING_MODE_FUNCTION_CALLS:
radv_nir_lower_rt_io_functions(stage->nir);
break;
} }
/* Gather shader info. */ /* Gather shader info. */
@ -440,29 +443,39 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
switch (mode) { switch (mode) {
case RADV_RT_LOWERING_MODE_MONOLITHIC: case RADV_RT_LOWERING_MODE_MONOLITHIC:
assert(num_shaders == 1); assert(num_shaders == 1);
radv_nir_lower_rt_abi_monolithic(temp_stage.nir, &temp_stage.args, stack_size, device, pipeline); radv_nir_lower_rt_abi_monolithic(temp_stage.nir, device, pipeline);
break; break;
case RADV_RT_LOWERING_MODE_CPS: case RADV_RT_LOWERING_MODE_CPS:
radv_nir_lower_rt_abi_cps(temp_stage.nir, &temp_stage.args, &stage->info, stack_size, i > 0, device, pipeline, radv_nir_lower_rt_abi_cps(temp_stage.nir, &stage->info, i > 0, device, pipeline, has_position_fetch,
has_position_fetch, traversal_stage_info); traversal_stage_info);
break;
case RADV_RT_LOWERING_MODE_FUNCTION_CALLS:
assert(num_shaders == 1);
radv_nir_lower_rt_abi_functions(temp_stage.nir, &temp_stage.info, *payload_size, device, pipeline);
break; break;
} }
/* Info might be out-of-date after inlining in radv_nir_lower_rt_abi(). */ /* Info might be out-of-date after inlining in radv_nir_lower_rt_abi(). */
nir_shader_gather_info(temp_stage.nir, nir_shader_get_entrypoint(temp_stage.nir)); nir_shader_gather_info(temp_stage.nir, radv_get_rt_shader_entrypoint(temp_stage.nir));
radv_nir_shader_info_pass(device, temp_stage.nir, &stage->layout, &stage->key, NULL, RADV_PIPELINE_RAY_TRACING, radv_nir_shader_info_pass(device, temp_stage.nir, &stage->layout, &stage->key, NULL, RADV_PIPELINE_RAY_TRACING,
false, &stage->info); false, &stage->info);
radv_optimize_nir(temp_stage.nir, stage->key.optimisations_disabled); radv_optimize_nir(temp_stage.nir, temp_stage.key.optimisations_disabled);
radv_postprocess_nir(device, NULL, &temp_stage); radv_postprocess_nir(device, NULL, &temp_stage);
stage->info.nir_shared_size = MAX2(stage->info.nir_shared_size, temp_stage.info.nir_shared_size);
if (stage_info) NIR_PASS(_, stage->nir, radv_nir_lower_call_abi, stage->info.wave_size);
radv_gather_unused_args(stage_info, shaders[i]); NIR_PASS(_, stage->nir, nir_lower_global_vars_to_local);
NIR_PASS(_, stage->nir, nir_lower_vars_to_ssa);
if (!stage->key.optimisations_disabled)
NIR_PASS(_, stage->nir, nir_minimize_call_live_states);
stage->info.nir_shared_size = MAX2(stage->info.nir_shared_size, temp_stage.info.nir_shared_size);
if (stage_info && mode == RADV_RT_LOWERING_MODE_CPS)
radv_gather_unused_args(stage_info, temp_stage.nir);
} }
bool dump_shader = radv_can_dump_shader(device, shaders[0]); bool dump_shader = radv_can_dump_shader(device, stage->nir);
bool dump_nir = dump_shader && (instance->debug_flags & RADV_DEBUG_DUMP_NIR); bool dump_nir = dump_shader && (instance->debug_flags & RADV_DEBUG_DUMP_NIR);
bool replayable = (pipeline->base.base.create_flags & bool replayable = (pipeline->base.base.create_flags &
VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) && VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) &&
@ -493,7 +506,6 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
if (dump_shader) if (dump_shader)
simple_mtx_unlock(&instance->shader_dump_mtx); simple_mtx_unlock(&instance->shader_dump_mtx);
ralloc_free(mem_ctx);
free(binary); free(binary);
return result; return result;
} }
@ -503,6 +515,9 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
if (shader) { if (shader) {
shader->nir_string = nir_string; shader->nir_string = nir_string;
if (stack_size)
*stack_size = DIV_ROUND_UP(shader->config.scratch_bytes_per_wave, shader->info.wave_size);
radv_shader_dump_debug_info(device, dump_shader, binary, shader, shaders, num_shaders, &stage->info); radv_shader_dump_debug_info(device, dump_shader, binary, shader, shaders, num_shaders, &stage->info);
if (shader && keep_executable_info && stage->spirv.size) { if (shader && keep_executable_info && stage->spirv.size) {
@ -515,7 +530,6 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
if (dump_shader) if (dump_shader)
simple_mtx_unlock(&instance->shader_dump_mtx); simple_mtx_unlock(&instance->shader_dump_mtx);
ralloc_free(mem_ctx);
free(binary); free(binary);
*out_shader = shader; *out_shader = shader;
@ -633,6 +647,10 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
if (!stages) if (!stages)
return VK_ERROR_OUT_OF_HOST_MEMORY; return VK_ERROR_OUT_OF_HOST_MEMORY;
uint32_t payload_size = 0;
if (pCreateInfo->pLibraryInterface)
payload_size = pCreateInfo->pLibraryInterface->maxPipelineRayPayloadSize;
bool library = pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR; bool library = pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR;
/* Beyond 50 shader stages, inlining everything bloats the shader a ton, increasing compile times and /* Beyond 50 shader stages, inlining everything bloats the shader a ton, increasing compile times and
@ -657,6 +675,19 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
NIR_PASS(_, stage->nir, radv_nir_lower_hit_attrib_derefs); NIR_PASS(_, stage->nir, radv_nir_lower_hit_attrib_derefs);
nir_foreach_variable_with_modes (var, stage->nir, nir_var_shader_call_data) {
unsigned size, alignment;
glsl_get_natural_size_align_bytes(var->type, &size, &alignment);
payload_size = MAX2(payload_size, size);
}
nir_foreach_function_impl (impl, stage->nir) {
nir_foreach_variable_in_list (var, &impl->locals) {
unsigned size, alignment;
glsl_get_natural_size_align_bytes(var->type, &size, &alignment);
payload_size = MAX2(payload_size, size);
}
}
rt_stages[i].info = radv_gather_ray_tracing_stage_info(stage->nir); rt_stages[i].info = radv_gather_ray_tracing_stage_info(stage->nir);
stage->feedback.duration = os_time_get_nano() - stage_start; stage->feedback.duration = os_time_get_nano() - stage_start;
@ -736,8 +767,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
enum radv_rt_lowering_mode mode = enum radv_rt_lowering_mode mode =
stage->stage == MESA_SHADER_RAYGEN ? raygen_lowering_mode : recursive_lowering_mode; stage->stage == MESA_SHADER_RAYGEN ? raygen_lowering_mode : recursive_lowering_mode;
result = radv_rt_nir_to_asm(device, cache, pipeline, mode, stage, &stack_size, &rt_stages[idx].info, NULL, result =
replay_block, skip_shaders_cache, has_position_fetch, &rt_stages[idx].shader); radv_rt_nir_to_asm(device, cache, pipeline, mode, stage, &payload_size, &stack_size, &rt_stages[idx].info,
NULL, replay_block, skip_shaders_cache, has_position_fetch, &rt_stages[idx].shader);
if (result != VK_SUCCESS) if (result != VK_SUCCESS)
goto cleanup; goto cleanup;
@ -787,17 +819,20 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
traversal_info.unset_flags &= info->unset_flags; traversal_info.unset_flags &= info->unset_flags;
} }
radv_nir_traversal_preprocess_cb preprocess =
recursive_lowering_mode == RADV_RT_LOWERING_MODE_CPS ? radv_nir_lower_rt_io_cps : radv_nir_lower_rt_io_functions;
/* create traversal shader */ /* create traversal shader */
nir_shader *traversal_nir = radv_build_traversal_shader(device, pipeline, &traversal_info); nir_shader *traversal_nir = radv_build_traversal_shader(device, pipeline, &traversal_info, preprocess);
struct radv_shader_stage traversal_stage = { struct radv_shader_stage traversal_stage = {
.stage = MESA_SHADER_INTERSECTION, .stage = MESA_SHADER_INTERSECTION,
.nir = traversal_nir, .nir = traversal_nir,
.key = stage_keys[MESA_SHADER_INTERSECTION], .key = stage_keys[MESA_SHADER_INTERSECTION],
}; };
radv_shader_layout_init(pipeline_layout, MESA_SHADER_INTERSECTION, &traversal_stage.layout); radv_shader_layout_init(pipeline_layout, MESA_SHADER_INTERSECTION, &traversal_stage.layout);
result = radv_rt_nir_to_asm(device, cache, pipeline, recursive_lowering_mode, &traversal_stage, NULL, NULL, result = radv_rt_nir_to_asm(device, cache, pipeline, recursive_lowering_mode, &traversal_stage, &payload_size,
&traversal_info, NULL, skip_shaders_cache, has_position_fetch, &pipeline->traversal_stack_size, NULL, &traversal_info, NULL, skip_shaders_cache,
&pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]); has_position_fetch, &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]);
ralloc_free(traversal_nir); ralloc_free(traversal_nir);
cleanup: cleanup:
@ -858,10 +893,11 @@ compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, stru
UNREACHABLE("Invalid stage type in RT shader"); UNREACHABLE("Invalid stage type in RT shader");
} }
} }
pipeline->stack_size = pipeline->stack_size = raygen_size +
raygen_size + MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) *
MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) * MAX2(chit_miss_size, intersection_size + any_hit_size) + (chit_miss_size + intersection_size + any_hit_size + pipeline->traversal_stack_size) +
MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) * chit_miss_size + 2 * callable_size; MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) * chit_miss_size +
2 * callable_size;
} }
static void static void
@ -1216,7 +1252,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, const VkRayTra
if (pipeline->groups[i].recursive_shader != VK_SHADER_UNUSED_KHR) { if (pipeline->groups[i].recursive_shader != VK_SHADER_UNUSED_KHR) {
struct radv_shader *shader = pipeline->stages[pipeline->groups[i].recursive_shader].shader; struct radv_shader *shader = pipeline->stages[pipeline->groups[i].recursive_shader].shader;
if (shader) if (shader)
pipeline->groups[i].handle.recursive_shader_ptr = shader->va | radv_get_rt_priority(shader->info.stage); pipeline->groups[i].handle.recursive_shader_ptr = shader->va;
} }
} }

View file

@ -12,6 +12,7 @@
#define RADV_PIPELINE_RT_H #define RADV_PIPELINE_RT_H
#include "util/bitset.h" #include "util/bitset.h"
#include "aco_nir_call_attribs.h"
#include "radv_pipeline_compute.h" #include "radv_pipeline_compute.h"
#include "radv_shader.h" #include "radv_shader.h"
@ -27,6 +28,7 @@ struct radv_ray_tracing_pipeline {
unsigned group_count; unsigned group_count;
uint32_t stack_size; uint32_t stack_size;
uint32_t traversal_stack_size;
/* set if any shaders from this pipeline require robustness2 in the merged traversal shader */ /* set if any shaders from this pipeline require robustness2 in the merged traversal shader */
bool traversal_storage_robustness2 : 1; bool traversal_storage_robustness2 : 1;
@ -76,7 +78,7 @@ struct radv_ray_tracing_stage_info {
bool can_inline; bool can_inline;
bool has_position_fetch; bool has_position_fetch;
BITSET_DECLARE(unused_args, AC_MAX_ARGS); BITSET_DECLARE(unused_args, CPS_ARG_COUNT);
struct radv_rt_const_arg_info tmin; struct radv_rt_const_arg_info tmin;
struct radv_rt_const_arg_info tmax; struct radv_rt_const_arg_info tmax;

View file

@ -45,6 +45,7 @@
#include "vk_sync.h" #include "vk_sync.h"
#include "vk_ycbcr_conversion.h" #include "vk_ycbcr_conversion.h"
#include "nir/radv_nir_rt_stage_functions.h"
#include "aco_shader_info.h" #include "aco_shader_info.h"
#include "radv_aco_shader_info.h" #include "radv_aco_shader_info.h"
#if AMD_LLVM_AVAILABLE #if AMD_LLVM_AVAILABLE
@ -220,7 +221,9 @@ radv_optimize_nir(struct nir_shader *shader, bool optimize_conservatively)
NIR_PASS(progress, shader, nir_opt_move, nir_move_load_ubo); NIR_PASS(progress, shader, nir_opt_move, nir_move_load_ubo);
nir_shader_gather_info(shader, nir_shader_get_entrypoint(shader)); /* radv_get_rt_shader_entrypoint returns the entrypoint for non-RT shaders too. */
nir_function_impl *entrypoint = radv_get_rt_shader_entrypoint(shader);
nir_shader_gather_info(shader, entrypoint);
} }
void void

View file

@ -654,35 +654,9 @@ void radv_get_nir_options(struct radv_physical_device *pdev);
enum radv_rt_lowering_mode { enum radv_rt_lowering_mode {
RADV_RT_LOWERING_MODE_MONOLITHIC, RADV_RT_LOWERING_MODE_MONOLITHIC,
RADV_RT_LOWERING_MODE_CPS, RADV_RT_LOWERING_MODE_CPS,
RADV_RT_LOWERING_MODE_FUNCTION_CALLS,
}; };
enum radv_rt_priority {
radv_rt_priority_raygen = 0,
radv_rt_priority_traversal = 1,
radv_rt_priority_hit_miss = 2,
radv_rt_priority_callable = 3,
radv_rt_priority_mask = 0x3,
};
static inline enum radv_rt_priority
radv_get_rt_priority(mesa_shader_stage stage)
{
switch (stage) {
case MESA_SHADER_RAYGEN:
return radv_rt_priority_raygen;
case MESA_SHADER_INTERSECTION:
case MESA_SHADER_ANY_HIT:
return radv_rt_priority_traversal;
case MESA_SHADER_CLOSEST_HIT:
case MESA_SHADER_MISS:
return radv_rt_priority_hit_miss;
case MESA_SHADER_CALLABLE:
return radv_rt_priority_callable;
default:
UNREACHABLE("Unimplemented RT shader stage.");
}
}
struct radv_shader_layout; struct radv_shader_layout;
enum radv_pipeline_type; enum radv_pipeline_type;

View file

@ -320,50 +320,6 @@ radv_init_shader_args(const struct radv_device *device, mesa_shader_stage stage,
args->user_sgprs_locs.shader_data[i].sgpr_idx = -1; args->user_sgprs_locs.shader_data[i].sgpr_idx = -1;
} }
void
radv_declare_rt_shader_args(enum amd_gfx_level gfx_level, struct radv_shader_args *args)
{
add_ud_arg(args, 2, AC_ARG_CONST_ADDR, &args->ac.rt.uniform_shader_addr, AC_UD_SCRATCH_RING_OFFSETS);
add_ud_arg(args, 1, AC_ARG_CONST_ADDR, &args->descriptors[0], AC_UD_INDIRECT_DESCRIPTORS);
ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_ADDR, &args->ac.push_constants);
ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_ADDR, &args->ac.dynamic_descriptors);
ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_ADDR, &args->ac.rt.traversal_shader_addr);
ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.sbt_descriptors);
for (uint32_t i = 0; i < ARRAY_SIZE(args->ac.rt.launch_sizes); i++)
ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_VALUE, &args->ac.rt.launch_sizes[i]);
if (gfx_level < GFX9) {
ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_VALUE, &args->ac.scratch_offset);
ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_ADDR, &args->ac.ring_offsets);
}
for (uint32_t i = 0; i < ARRAY_SIZE(args->ac.rt.launch_ids); i++)
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.launch_ids[i]);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.dynamic_callable_stack_base);
ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.shader_addr);
ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.shader_record);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.payload_offset);
ac_add_arg(&args->ac, AC_ARG_VGPR, 3, AC_ARG_VALUE, &args->ac.rt.ray_origin);
ac_add_arg(&args->ac, AC_ARG_VGPR, 3, AC_ARG_VALUE, &args->ac.rt.ray_direction);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.ray_tmin);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.ray_tmax);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.cull_mask_and_flags);
ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.accel_struct);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.sbt_offset);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.sbt_stride);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.miss_index);
ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.instance_addr);
ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.primitive_addr);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.primitive_id);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.geometry_id_and_flags);
ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.hit_kind);
}
static bool static bool
radv_tcs_needs_state_sgpr(const struct radv_shader_info *info, const struct radv_graphics_state_key *gfx_state) radv_tcs_needs_state_sgpr(const struct radv_shader_info *info, const struct radv_graphics_state_key *gfx_state)
{ {
@ -563,7 +519,6 @@ declare_shader_args(const struct radv_device *device, const struct radv_graphics
radv_init_shader_args(device, stage, args); radv_init_shader_args(device, stage, args);
if (mesa_shader_stage_is_rt(stage)) { if (mesa_shader_stage_is_rt(stage)) {
radv_declare_rt_shader_args(gfx_level, args);
return; return;
} }

View file

@ -950,7 +950,6 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_load_packed_passthrough_primitive_amd: case nir_intrinsic_load_packed_passthrough_primitive_amd:
case nir_intrinsic_load_initial_edgeflags_amd: case nir_intrinsic_load_initial_edgeflags_amd:
case nir_intrinsic_gds_atomic_add_amd: case nir_intrinsic_gds_atomic_add_amd:
case nir_intrinsic_load_rt_arg_scratch_offset_amd:
case nir_intrinsic_load_intersection_opaque_amd: case nir_intrinsic_load_intersection_opaque_amd:
case nir_intrinsic_load_vector_arg_amd: case nir_intrinsic_load_vector_arg_amd:
case nir_intrinsic_load_btd_stack_id_intel: case nir_intrinsic_load_btd_stack_id_intel:
@ -1024,6 +1023,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_load_blend_input_pan: case nir_intrinsic_load_blend_input_pan:
case nir_intrinsic_atest_pan: case nir_intrinsic_atest_pan:
case nir_intrinsic_zs_emit_pan: case nir_intrinsic_zs_emit_pan:
case nir_intrinsic_load_return_param_amd:
is_divergent = true; is_divergent = true;
break; break;

View file

@ -2152,6 +2152,13 @@ intrinsic("sleep_amd", indices=[BASE])
# s_nop BASE (sleep for BASE+1 cycles, BASE must be in [0, 15]). # s_nop BASE (sleep for BASE+1 cycles, BASE must be in [0, 15]).
intrinsic("nop_amd", indices=[BASE]) intrinsic("nop_amd", indices=[BASE])
intrinsic("store_param_amd", src_comp=[-1], indices=[PARAM_IDX])
intrinsic("load_return_param_amd", dest_comp=0, indices=[CALL_IDX, PARAM_IDX])
system_value("call_return_address_amd", 1, bit_sizes=[64])
# src[0] is the divergent call target for each lane, src[1] is the (uniform) address to jump to next
intrinsic("set_next_call_pc_amd", src_comp=[1, 1], bit_sizes=[64])
# Return the FMASK descriptor of color buffer 0. # Return the FMASK descriptor of color buffer 0.
system_value("fbfetch_image_fmask_desc_amd", 8) system_value("fbfetch_image_fmask_desc_amd", 8)
# Return the image descriptor of color buffer 0. # Return the image descriptor of color buffer 0.