diff --git a/src/amd/common/ac_shader_args.h b/src/amd/common/ac_shader_args.h index a483a3dcdc9..32a73ad5618 100644 --- a/src/amd/common/ac_shader_args.h +++ b/src/amd/common/ac_shader_args.h @@ -198,20 +198,6 @@ struct ac_shader_args { struct ac_arg shader_addr; struct ac_arg shader_record; struct ac_arg payload_offset; - struct ac_arg ray_origin; - struct ac_arg ray_tmin; - struct ac_arg ray_direction; - struct ac_arg ray_tmax; - struct ac_arg cull_mask_and_flags; - struct ac_arg sbt_offset; - struct ac_arg sbt_stride; - struct ac_arg miss_index; - struct ac_arg accel_struct; - struct ac_arg primitive_id; - struct ac_arg instance_addr; - struct ac_arg primitive_addr; - struct ac_arg geometry_id_and_flags; - struct ac_arg hit_kind; } rt; }; diff --git a/src/amd/compiler/aco_nir_call_attribs.h b/src/amd/compiler/aco_nir_call_attribs.h index 33dc011914c..3d46ab88671 100644 --- a/src/amd/compiler/aco_nir_call_attribs.h +++ b/src/amd/compiler/aco_nir_call_attribs.h @@ -26,4 +26,87 @@ enum aco_nir_parameter_attribs { ACO_NIR_PARAM_ATTRIB_DISCARDABLE = 0x1, }; +enum aco_nir_call_system_args { + ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC, + ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC, + ACO_NIR_CALL_SYSTEM_ARG_COUNT, +}; + +enum aco_nir_rt_function_arg { + RT_ARG_LAUNCH_ID = 0, + RT_ARG_LAUNCH_SIZE, + RT_ARG_DESCRIPTORS, + RT_ARG_DYNAMIC_DESCRIPTORS, + RT_ARG_PUSH_CONSTANTS, + RT_ARG_SBT_DESCRIPTORS, + RT_ARG_COUNT, +}; + +enum aco_nir_raygen_function_arg { + RAYGEN_ARG_TRAVERSAL_ADDR = RT_ARG_COUNT, + RAYGEN_ARG_SHADER_RECORD_PTR, + RAYGEN_ARG_COUNT, +}; + +enum aco_nir_traversal_function_arg { + TRAVERSAL_ARG_TRAVERSAL_ADDR = RT_ARG_COUNT, + TRAVERSAL_ARG_SHADER_RECORD_PTR, + TRAVERSAL_ARG_ACCEL_STRUCT, + TRAVERSAL_ARG_CULL_MASK_AND_FLAGS, + TRAVERSAL_ARG_SBT_OFFSET, + TRAVERSAL_ARG_SBT_STRIDE, + TRAVERSAL_ARG_MISS_INDEX, + TRAVERSAL_ARG_RAY_ORIGIN, + TRAVERSAL_ARG_RAY_TMIN, + TRAVERSAL_ARG_RAY_DIRECTION, + TRAVERSAL_ARG_RAY_TMAX, + TRAVERSAL_ARG_PRIMITIVE_ADDR, + TRAVERSAL_ARG_PRIMITIVE_ID, + TRAVERSAL_ARG_INSTANCE_ADDR, + TRAVERSAL_ARG_GEOMETRY_ID_AND_FLAGS, + TRAVERSAL_ARG_HIT_KIND, + TRAVERSAL_ARG_PAYLOAD_BASE, +}; + +enum aco_nir_chit_miss_function_arg { + CHIT_MISS_ARG_TRAVERSAL_ADDR = RT_ARG_COUNT, + CHIT_MISS_ARG_SHADER_RECORD_PTR, + CHIT_MISS_ARG_ACCEL_STRUCT, + CHIT_MISS_ARG_CULL_MASK_AND_FLAGS, + CHIT_MISS_ARG_SBT_OFFSET, + CHIT_MISS_ARG_SBT_STRIDE, + CHIT_MISS_ARG_MISS_INDEX, + CHIT_MISS_ARG_RAY_ORIGIN, + CHIT_MISS_ARG_RAY_TMIN, + CHIT_MISS_ARG_RAY_DIRECTION, + CHIT_MISS_ARG_RAY_TMAX, + CHIT_MISS_ARG_PRIMITIVE_ADDR, + CHIT_MISS_ARG_PRIMITIVE_ID, + CHIT_MISS_ARG_INSTANCE_ADDR, + CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS, + CHIT_MISS_ARG_HIT_KIND, + CHIT_MISS_ARG_PAYLOAD_BASE, +}; + +/* aco_nir_cps_function_arg extends aco_nir_raygen_function_arg */ +enum aco_nir_cps_function_arg { + CPS_ARG_PAYLOAD_SCRATCH_OFFSET = RAYGEN_ARG_COUNT, + CPS_ARG_STACK_PTR, + CPS_ARG_ACCEL_STRUCT, + CPS_ARG_CULL_MASK_AND_FLAGS, + CPS_ARG_SBT_OFFSET, + CPS_ARG_SBT_STRIDE, + CPS_ARG_MISS_INDEX, + CPS_ARG_RAY_ORIGIN, + CPS_ARG_RAY_TMIN, + CPS_ARG_RAY_DIRECTION, + CPS_ARG_RAY_TMAX, + CPS_ARG_PRIMITIVE_ID, + CPS_ARG_INSTANCE_ADDR, + CPS_ARG_GEOMETRY_ID_AND_FLAGS, + CPS_ARG_HIT_KIND, + CPS_ARG_PRIMITIVE_ADDR, + CPS_ARG_COUNT, +}; + #endif /* ACO_NIR_CALL_ATTRIBS_H */ diff --git a/src/amd/vulkan/meson.build b/src/amd/vulkan/meson.build index ddb9b56d3d6..3a20baca003 100644 --- a/src/amd/vulkan/meson.build +++ b/src/amd/vulkan/meson.build @@ -71,6 +71,7 @@ libradv_files = files( 'nir/radv_nir_apply_pipeline_layout.c', 'nir/radv_nir_export_multiview.c', 'nir/radv_nir_lower_abi.c', + 'nir/radv_nir_lower_call_abi.c', 'nir/radv_nir_lower_cooperative_matrix.c', 'nir/radv_nir_lower_fs_barycentric.c', 'nir/radv_nir_lower_fs_intrinsics.c', @@ -90,6 +91,7 @@ libradv_files = files( 'nir/radv_nir_rt_common.c', 'nir/radv_nir_rt_stage_common.c', 'nir/radv_nir_rt_stage_cps.c', + 'nir/radv_nir_rt_stage_functions.c', 'nir/radv_nir_rt_stage_monolithic.c', 'nir/radv_nir_rt_traversal_shader.c', 'nir/radv_nir_trim_fs_color_exports.c', diff --git a/src/amd/vulkan/nir/radv_nir.h b/src/amd/vulkan/nir/radv_nir.h index 20248781b47..e6fc3f32b32 100644 --- a/src/amd/vulkan/nir/radv_nir.h +++ b/src/amd/vulkan/nir/radv_nir.h @@ -103,6 +103,10 @@ bool radv_nir_opt_fs_builtins(nir_shader *shader, const struct radv_graphics_sta bool radv_nir_lower_immediate_samplers(nir_shader *shader, struct radv_device *device, const struct radv_shader_stage *stage); +void radv_nir_lower_callee_signature(nir_function *function); + +bool radv_nir_lower_call_abi(nir_shader *shader, unsigned wave_size); + #ifdef __cplusplus } #endif diff --git a/src/amd/vulkan/nir/radv_nir_apply_pipeline_layout.c b/src/amd/vulkan/nir/radv_nir_apply_pipeline_layout.c index a8d00566876..1395c38e5dd 100644 --- a/src/amd/vulkan/nir/radv_nir_apply_pipeline_layout.c +++ b/src/amd/vulkan/nir/radv_nir_apply_pipeline_layout.c @@ -5,6 +5,7 @@ */ #include "ac_descriptors.h" #include "ac_shader_util.h" +#include "aco_nir_call_attribs.h" #include "nir.h" #include "nir_builder.h" #include "radv_descriptor_set.h" @@ -37,6 +38,30 @@ get_scalar_arg(nir_builder *b, unsigned size, struct ac_arg arg) return nir_load_scalar_arg_amd(b, size, .base = arg.arg_index); } +static nir_def * +get_indirect_descriptors_addr(nir_builder *b, apply_layout_state *state) +{ + if (mesa_shader_stage_is_rt(b->shader->info.stage)) + return nir_load_param(b, RT_ARG_DESCRIPTORS); + return get_scalar_arg(b, 1, state->args->descriptors[0]); +} + +static nir_def * +get_indirect_push_constants_addr(nir_builder *b, apply_layout_state *state) +{ + if (mesa_shader_stage_is_rt(b->shader->info.stage)) + return nir_load_param(b, RT_ARG_PUSH_CONSTANTS); + return get_scalar_arg(b, 1, state->args->ac.push_constants); +} + +static nir_def * +get_dynamic_descriptors_addr(nir_builder *b, apply_layout_state *state) +{ + if (mesa_shader_stage_is_rt(b->shader->info.stage)) + return nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS); + return get_scalar_arg(b, 1, state->args->ac.dynamic_descriptors); +} + static nir_def * convert_pointer_to_64_bit(nir_builder *b, apply_layout_state *state, nir_def *ptr) { @@ -47,8 +72,9 @@ static nir_def * load_desc_ptr(nir_builder *b, apply_layout_state *state, unsigned set) { const struct radv_userdata_locations *user_sgprs_locs = &state->info->user_sgprs_locs; - if (user_sgprs_locs->shader_data[AC_UD_INDIRECT_DESCRIPTORS].sgpr_idx != -1) { - nir_def *addr = get_scalar_arg(b, 1, state->args->descriptors[0]); + if (user_sgprs_locs->shader_data[AC_UD_INDIRECT_DESCRIPTORS].sgpr_idx != -1 || + mesa_shader_stage_is_rt(b->shader->info.stage)) { + nir_def *addr = get_indirect_descriptors_addr(b, state); addr = convert_pointer_to_64_bit(b, state, addr); return ac_nir_load_smem(b, 1, addr, nir_imm_int(b, set * 4), 4, 0); } @@ -69,7 +95,7 @@ visit_vulkan_resource_index(nir_builder *b, apply_layout_state *state, nir_intri nir_def *set_ptr; if (vk_descriptor_type_is_dynamic(layout->binding[binding].type)) { unsigned idx = state->layout->set[desc_set].dynamic_offset_start + layout->binding[binding].dynamic_offset_offset; - set_ptr = get_scalar_arg(b, 1, state->args->ac.dynamic_descriptors); + set_ptr = get_dynamic_descriptors_addr(b, state); offset = idx * 16; stride = 16; } else { @@ -341,8 +367,10 @@ load_push_constant(nir_builder *b, apply_layout_state *state, nir_intrinsic_inst continue; } - if (!state->args->ac.push_constants.used) { - /* Assume this is an inlined push constant load which was expanded to include dwords which are not inlined. */ + if (!state->args->ac.push_constants.used && !mesa_shader_stage_is_rt(b->shader->info.stage)) { + /* Assume this is an inlined push constant load which was expanded to include dwords which are not inlined. + * RT stages use neither shader args nor inlined push constants, so skip this for RT shaders. + */ assert(const_offset != -1); data[num_loads++] = nir_undef(b, 1, 32); start += 1; @@ -350,7 +378,7 @@ load_push_constant(nir_builder *b, apply_layout_state *state, nir_intrinsic_inst } if (!offset) { - addr = get_scalar_arg(b, 1, state->args->ac.push_constants); + addr = get_indirect_push_constants_addr(b, state); addr = convert_pointer_to_64_bit(b, state, addr); offset = nir_iadd_imm_nuw(b, intrin->src[0].ssa, base); } diff --git a/src/amd/vulkan/nir/radv_nir_lower_call_abi.c b/src/amd/vulkan/nir/radv_nir_lower_call_abi.c new file mode 100644 index 00000000000..e943b357659 --- /dev/null +++ b/src/amd/vulkan/nir/radv_nir_lower_call_abi.c @@ -0,0 +1,439 @@ +/* + * Copyright © 2023 Valve Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice (including the next + * paragraph) shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#include "aco_nir_call_attribs.h" +#include "nir_builder.h" +#include "radv_nir.h" + +void +radv_nir_lower_callee_signature(nir_function *function) +{ + nir_parameter *old_params = function->params; + unsigned old_num_params = function->num_params; + + function->num_params += ACO_NIR_CALL_SYSTEM_ARG_COUNT; + function->params = rzalloc_array_size(function->shader, function->num_params, sizeof(nir_parameter)); + + memcpy(function->params + ACO_NIR_CALL_SYSTEM_ARG_COUNT, old_params, old_num_params * sizeof(nir_parameter)); + + /* These are not return params, but each callee will modify these registers + * as part of the next callee selection. Make sure modification is allowed by + * marking the parameters as DISCARDABLE. Unlike other discardable parameters, + * ACO makes sure correct values are always written to them. + */ + function->params[ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC].num_components = 1; + function->params[ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC].bit_size = 64; + function->params[ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC].driver_attributes = ACO_NIR_PARAM_ATTRIB_DISCARDABLE; + function->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC].num_components = 1; + function->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC].bit_size = 64; + function->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC].is_uniform = true; + function->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC].driver_attributes = ACO_NIR_PARAM_ATTRIB_DISCARDABLE; + + for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < function->num_params; ++i) { + if (!function->params[i].is_return) + continue; + + function->params[i].bit_size = glsl_get_bit_size(function->params[i].type); + function->params[i].num_components = glsl_get_vector_elements(function->params[i].type); + } + + nir_function_impl *impl = function->impl; + + if (!impl) + return; + + nir_foreach_block (block, impl) { + nir_foreach_instr_safe (instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + + if (intr->intrinsic == nir_intrinsic_load_param) + nir_intrinsic_set_param_idx(intr, nir_intrinsic_param_idx(intr) + ACO_NIR_CALL_SYSTEM_ARG_COUNT); + } + } +} + +/* Checks if caller can call callee using tail calls. + * + * If the ABIs mismatch, we might need to insert move instructions to move return values from callee return registers to + * caller return registers after the call. In that case, tail-calls are impossible to do correctly. + */ +static bool +is_tail_call_compatible(nir_function *caller, nir_function *callee) +{ + /* If the caller doesn't return at all, we don't need to care if return params are compatible */ + if (caller->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_NORETURN) + return true; + /* The same ABI can't mismatch */ + if ((caller->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK) == + (callee->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK)) + return true; + /* The recursive shader ABI and the traversal shader ABI are built so that return parameters occupy exactly + * the same registers, to allow tail calls from the traversal shader. */ + if ((caller->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK) == ACO_NIR_CALL_ABI_TRAVERSAL && + (callee->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK) == ACO_NIR_CALL_ABI_RT_RECURSIVE) + return true; + return false; +} + +static void +gather_tail_call_instrs_block(nir_function *caller, const struct nir_block *block, struct set *tail_calls) +{ + nir_foreach_instr_reverse (instr, block) { + /* Making an instruction a tail call effectively moves it beyond the last block. If there are any instructions + * in the way, this reordering may be incorrect. + */ + if (instr->type != nir_instr_type_call) + return; + nir_call_instr *call = nir_instr_as_call(instr); + + if (!is_tail_call_compatible(caller, call->callee)) + return; + if (call->callee->num_params != caller->num_params) + return; + + for (unsigned i = 0; i < call->num_params; ++i) { + if (call->callee->params[i].is_return != caller->params[i].is_return) + return; + /* We can only do tail calls if the caller returns exactly the callee return values */ + if (caller->params[i].is_return) { + assert(nir_def_as_deref_or_null(call->params[i].ssa)); + nir_deref_instr *deref_root = nir_def_as_deref(call->params[i].ssa); + while (nir_deref_instr_parent(deref_root)) + deref_root = nir_deref_instr_parent(deref_root); + + if (!deref_root->parent.ssa) + return; + nir_intrinsic_instr *intrin = nir_def_as_intrinsic_or_null(deref_root->parent.ssa); + if (!intrin || intrin->intrinsic != nir_intrinsic_load_param) + return; + /* The call parameters aren't lowered at this point, we need to add the call arg count here */ + if (nir_intrinsic_param_idx(intrin) != i + ACO_NIR_CALL_SYSTEM_ARG_COUNT) + return; + } + if (call->callee->params[i].is_uniform != caller->params[i].is_uniform) + return; + if (call->callee->params[i].bit_size != caller->params[i].bit_size) + return; + if (call->callee->params[i].num_components != caller->params[i].num_components) + return; + } + + _mesa_set_add(tail_calls, instr); + } + + set_foreach (&block->predecessors, pred) { + gather_tail_call_instrs_block(caller, pred->key, tail_calls); + } +} + +struct lower_param_info { + nir_def *return_deref; + nir_variable *param_var; +}; + +static void +rewrite_return_param_uses(nir_def *def, unsigned param_idx, struct lower_param_info *param_defs) +{ + nir_foreach_use_safe (use, def) { + nir_instr *use_instr = nir_src_parent_instr(use); + + if (use_instr->type == nir_instr_type_deref) { + assert(nir_instr_as_deref(use_instr)->deref_type == nir_deref_type_cast); + rewrite_return_param_uses(&nir_instr_as_deref(use_instr)->def, param_idx, param_defs); + nir_instr_remove(use_instr); + } + } + nir_def_rewrite_uses(def, param_defs[param_idx].return_deref); +} + +static void +lower_call_abi_for_callee(nir_function *function, unsigned wave_size) +{ + nir_function_impl *impl = function->impl; + + nir_builder b = nir_builder_create(impl); + b.cursor = nir_before_impl(impl); + + nir_variable *tail_call_pc = + nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "_tail_call_pc"); + + struct set *tail_call_instrs = _mesa_set_create(b.shader, _mesa_hash_pointer, _mesa_key_pointer_equal); + gather_tail_call_instrs_block(function, nir_impl_last_block(impl), tail_call_instrs); + + /* guard the shader, so that only the correct invocations execute it */ + + nir_def *guard_condition = NULL; + nir_def *shader_addr; + nir_def *uniform_shader_addr; + if (function->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL) { + nir_cf_list list; + nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl)); + + b.cursor = nir_before_impl(impl); + + shader_addr = nir_load_param(&b, ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC); + uniform_shader_addr = nir_load_param(&b, ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC); + + nir_store_var(&b, tail_call_pc, shader_addr, 0x1); + + guard_condition = nir_ieq(&b, uniform_shader_addr, shader_addr); + nir_if *shader_guard = nir_push_if(&b, guard_condition); + shader_guard->control = nir_selection_control_divergent_always_taken; + nir_store_var(&b, tail_call_pc, nir_imm_int64(&b, 0), 0x1); + nir_cf_reinsert(&list, b.cursor); + nir_pop_if(&b, shader_guard); + } else { + nir_store_var(&b, tail_call_pc, nir_imm_int64(&b, 0), 0x1); + } + + b.cursor = nir_before_impl(impl); + struct lower_param_info *param_infos = ralloc_size(b.shader, function->num_params * sizeof(struct lower_param_info)); + + for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < function->num_params; ++i) { + param_infos[i].param_var = nir_local_variable_create(impl, function->params[i].type, "_param"); + + if (function->params[i].is_return) { + assert(!glsl_type_is_array(function->params[i].type) && !glsl_type_is_struct(function->params[i].type)); + + param_infos[i].return_deref = &nir_build_deref_var(&b, param_infos[i].param_var)->def; + } else { + param_infos[i].return_deref = NULL; + } + } + + /* Lower everything related to call parameters and dispatch, particularly return parameters and tail calls. + * + * Return parameters in NIR are represented by having the parameter value actually be a deref. Callers pass + * a deref value to the call, and the callee can cast the parameter value back to a deref. Replace these deref_casts + * with a deref of a variable we declared further above, so the shader can be lowered to SSA. + * One simple example: + * + * %1 = load_param 0 + * %2 = deref_cast %1 + * %3 = load_deref %2 + * %4 = inot %3 + * store_deref %2, %4 + * + * becomes + * + * decl_var _param; + * + * %1 = load_param 0 + * store_var _param, %1 + * %3 = load_var _param + * %4 = inot %3 + * store_var _param, %4 + * %5 = load_var _param + * store_param_amd %5 + * + * If tail calls are detected, the call instruction is replaced with a sequence of writing the parameters to the new + * values they should have at callee entry, and updating the tail_call_pc value so that the callee is jumped to next. + */ + bool has_tail_call = false; + nir_foreach_block (block, impl) { + bool progress; + /* rewrite_return_param_uses may remove multiple instructions (not just the current one), which not even + * nir_foreach_instr_safe can safely iterate over. Therefore, if we made progress, we need to restart iteration. + */ + do { + progress = false; + nir_foreach_instr (instr, block) { + if (instr->type == nir_instr_type_call && _mesa_set_search(tail_call_instrs, instr)) { + nir_call_instr *call = nir_instr_as_call(instr); + b.cursor = nir_before_instr(instr); + + for (unsigned i = 0; i < call->num_params; ++i) { + if (call->callee->params[i + ACO_NIR_CALL_SYSTEM_ARG_COUNT].is_return) + nir_store_var(&b, param_infos[i + ACO_NIR_CALL_SYSTEM_ARG_COUNT].param_var, + nir_load_deref(&b, nir_def_as_deref(call->params[i].ssa)), + (0x1 << glsl_get_vector_elements( + call->callee->params[i + ACO_NIR_CALL_SYSTEM_ARG_COUNT].type)) - + 1); + else + nir_store_var(&b, param_infos[i + ACO_NIR_CALL_SYSTEM_ARG_COUNT].param_var, call->params[i].ssa, + (0x1 << call->params[i].ssa->num_components) - 1); + } + + nir_store_var(&b, tail_call_pc, call->indirect_callee.ssa, 0x1); + + nir_instr_remove(instr); + + has_tail_call = true; + progress = true; + break; + } + + if (instr->type != nir_instr_type_intrinsic) + continue; + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_param) { + unsigned param_idx = nir_intrinsic_param_idx(intr); + + if (param_idx >= ACO_NIR_CALL_SYSTEM_ARG_COUNT && function->params[param_idx].is_return) { + rewrite_return_param_uses(&intr->def, param_idx, param_infos); + nir_instr_remove(instr); + progress = true; + break; + } + } + } + } while (progress); + } + + b.cursor = nir_before_impl(impl); + + for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < function->num_params; ++i) { + unsigned num_components = glsl_get_vector_elements(function->params[i].type); + nir_store_var(&b, param_infos[i].param_var, nir_load_param(&b, i), (0x1 << num_components) - 1); + } + + /* Setup a jump to a different shader in the cases where there is a next shader to be called. */ + if (!(function->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_NORETURN) || + (function->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL) || has_tail_call) { + b.cursor = nir_after_impl(impl); + + for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < function->num_params; ++i) + nir_store_param_amd(&b, nir_load_var(&b, param_infos[i].param_var), .param_idx = i); + shader_addr = nir_load_var(&b, tail_call_pc); + nir_def *ballot = nir_ballot(&b, 1, wave_size, nir_ine_imm(&b, shader_addr, 0)); + nir_def *ballot_addr = nir_read_invocation(&b, shader_addr, nir_find_lsb(&b, ballot)); + + nir_def *no_next_shader = nir_ieq_imm(&b, ballot, 0); + nir_def *terminate_cond; + + /* In functions marked noreturn, we don't need to bother checking the call return address. */ + if (function->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_NORETURN) { + uniform_shader_addr = ballot_addr; + terminate_cond = no_next_shader; + } else { + nir_def *return_address = nir_load_call_return_address_amd(&b); + uniform_shader_addr = nir_bcsel(&b, no_next_shader, return_address, ballot_addr); + /* If the next shader address is zero for every invocation, return. */ + terminate_cond = nir_ieq_imm(&b, uniform_shader_addr, 0); + } + + nir_push_if(&b, terminate_cond); + nir_terminate(&b); + nir_pop_if(&b, NULL); + + nir_set_next_call_pc_amd(&b, shader_addr, uniform_shader_addr); + } +} + +static void +lower_call_abi_for_call(nir_builder *b, nir_call_instr *call, unsigned *cur_call_idx) +{ + unsigned call_idx = (*cur_call_idx)++; + + for (unsigned i = 0; i < call->num_params; ++i) { + unsigned callee_param_idx = i + ACO_NIR_CALL_SYSTEM_ARG_COUNT; + + if (!call->callee->params[callee_param_idx].is_return) + continue; + + b->cursor = nir_before_instr(&call->instr); + + nir_src *old_src = &call->params[i]; + + assert(nir_def_as_deref_or_null(old_src->ssa)); + nir_deref_instr *param_deref = nir_def_as_deref(old_src->ssa); + assert(param_deref->deref_type == nir_deref_type_var); + + nir_src_rewrite(old_src, nir_load_deref(b, param_deref)); + + b->cursor = nir_after_instr(&call->instr); + + unsigned num_components = glsl_get_vector_elements(param_deref->type); + + nir_store_deref( + b, param_deref, + nir_load_return_param_amd(b, num_components, glsl_base_type_get_bit_size(param_deref->type->base_type), + .call_idx = call_idx, .param_idx = callee_param_idx), + (1u << num_components) - 1); + } + + b->cursor = nir_before_instr(&call->instr); + + nir_call_instr *new_call = nir_call_instr_create(b->shader, call->callee); + new_call->indirect_callee = nir_src_for_ssa(call->indirect_callee.ssa); + new_call->params[ACO_NIR_CALL_SYSTEM_ARG_DIVERGENT_PC] = nir_src_for_ssa(call->indirect_callee.ssa); + new_call->params[ACO_NIR_CALL_SYSTEM_ARG_UNIFORM_PC] = + nir_src_for_ssa(nir_read_first_invocation(b, call->indirect_callee.ssa)); + for (unsigned i = ACO_NIR_CALL_SYSTEM_ARG_COUNT; i < new_call->num_params; ++i) + new_call->params[i] = nir_src_for_ssa(call->params[i - ACO_NIR_CALL_SYSTEM_ARG_COUNT].ssa); + + nir_builder_instr_insert(b, &new_call->instr); + nir_instr_remove(&call->instr); +} + +static bool +lower_call_abi_for_caller(nir_function_impl *impl) +{ + bool progress = false; + unsigned cur_call_idx = 0; + + nir_foreach_block (block, impl) { + nir_foreach_instr_safe (instr, block) { + if (instr->type != nir_instr_type_call) + continue; + nir_call_instr *call = nir_instr_as_call(instr); + if (call->callee->impl) + continue; + + nir_builder b = nir_builder_create(impl); + lower_call_abi_for_call(&b, call, &cur_call_idx); + progress = true; + } + } + + return progress; +} + +bool +radv_nir_lower_call_abi(nir_shader *shader, unsigned wave_size) +{ + bool progress = false; + nir_foreach_function (function, shader) { + if (function->is_exported) { + radv_nir_lower_callee_signature(function); + if (function->impl) + progress |= nir_progress(true, function->impl, nir_metadata_none); + } + } + + nir_foreach_function_with_impl (function, impl, shader) { + bool func_progress = false; + if (function->is_exported) { + lower_call_abi_for_callee(function, wave_size); + func_progress = true; + } + func_progress |= lower_call_abi_for_caller(impl); + + progress |= nir_progress(func_progress, impl, nir_metadata_none); + } + + return progress; +} diff --git a/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c b/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c index bb3d9e3dec4..526b55ac85e 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c +++ b/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c @@ -10,13 +10,19 @@ #include "radv_constants.h" #include "radv_nir.h" +typedef nir_def *(*load_intrin_cb)(nir_builder *b, unsigned base); +typedef void (*store_intrin_cb)(nir_builder *b, nir_def *val, unsigned base); + struct lower_hit_attrib_deref_args { nir_variable_mode mode; uint32_t base_offset; + + load_intrin_cb load_cb; + store_intrin_cb store_cb; }; static bool -lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) +lower_rt_var_deref(nir_builder *b, nir_instr *instr, void *data) { if (instr->type != nir_instr_type_intrinsic) return false; @@ -29,6 +35,8 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); if (!nir_deref_mode_is(deref, args->mode)) return false; + if (deref->deref_type == nir_deref_type_cast) + return false; b->cursor = nir_after_instr(instr); @@ -48,19 +56,16 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) uint32_t comp_offset = offset % 4; if (bit_size == 64) { - components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base), - nir_load_hit_attrib_amd(b, .base = base + 1)); + components[comp] = nir_pack_64_2x32_split(b, args->load_cb(b, base), args->load_cb(b, base + 1)); } else if (bit_size == 32) { - components[comp] = nir_load_hit_attrib_amd(b, .base = base); + components[comp] = args->load_cb(b, base); } else if (bit_size == 16) { - components[comp] = - nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2); + components[comp] = nir_channel(b, nir_unpack_32_2x16(b, args->load_cb(b, base)), comp_offset / 2); } else if (bit_size == 8) { - components[comp] = - nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset); + components[comp] = nir_channel(b, nir_unpack_bits(b, args->load_cb(b, base), 8), comp_offset); } else { assert(bit_size == 1); - components[comp] = nir_i2b(b, nir_load_hit_attrib_amd(b, .base = base)); + components[comp] = nir_i2b(b, args->load_cb(b, base)); } } @@ -78,25 +83,25 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) nir_def *component = nir_channel(b, value, comp); if (bit_size == 64) { - nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base); - nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1); + args->store_cb(b, nir_unpack_64_2x32_split_x(b, component), base); + args->store_cb(b, nir_unpack_64_2x32_split_y(b, component), base + 1); } else if (bit_size == 32) { - nir_store_hit_attrib_amd(b, component, .base = base); + args->store_cb(b, component, base); } else if (bit_size == 16) { - nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)); + nir_def *prev = nir_unpack_32_2x16(b, args->load_cb(b, base)); nir_def *components[2]; for (uint32_t word = 0; word < 2; word++) components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word); - nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base); + args->store_cb(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), base); } else if (bit_size == 8) { - nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8); + nir_def *prev = nir_unpack_bits(b, args->load_cb(b, base), 8); nir_def *components[4]; for (uint32_t byte = 0; byte < 4; byte++) components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte); - nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base); + args->store_cb(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), base); } else { assert(bit_size == 1); - nir_store_hit_attrib_amd(b, nir_b2i32(b, component), .base = base); + args->store_cb(b, nir_b2i32(b, component), base); } } } @@ -123,23 +128,24 @@ radv_lower_payload_arg_to_offset(nir_builder *b, nir_intrinsic_instr *instr, voi } static bool -radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base_offset) +radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, load_intrin_cb load_cb, store_intrin_cb store_cb, + uint32_t base_offset) { bool progress = false; progress |= nir_lower_indirect_derefs_to_if_else_trees(shader, mode, UINT32_MAX); - progress |= nir_lower_vars_to_explicit_types(shader, mode, glsl_get_natural_size_align_bytes); - if (shader->info.stage == MESA_SHADER_RAYGEN && mode == nir_var_function_temp) progress |= nir_shader_intrinsics_pass(shader, radv_lower_payload_arg_to_offset, nir_metadata_control_flow, NULL); struct lower_hit_attrib_deref_args args = { .mode = mode, .base_offset = base_offset, + .load_cb = load_cb, + .store_cb = store_cb, }; - progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref, nir_metadata_control_flow, &args); + progress |= nir_shader_instructions_pass(shader, lower_rt_var_deref, nir_metadata_control_flow, &args); if (progress) { nir_remove_dead_derefs(shader); @@ -149,16 +155,57 @@ radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base return progress; } +static nir_def * +load_hit_attrib_cb(nir_builder *b, unsigned base) +{ + return nir_load_hit_attrib_amd(b, .base = base); +} + +static void +store_hit_attrib_cb(nir_builder *b, nir_def *val, unsigned base) +{ + nir_store_hit_attrib_amd(b, val, .base = base); +} + bool radv_nir_lower_hit_attrib_derefs(nir_shader *shader) { - return radv_nir_lower_rt_vars(shader, nir_var_ray_hit_attrib, 0); + bool progress = false; + progress |= nir_lower_vars_to_explicit_types(shader, nir_var_ray_hit_attrib, glsl_get_natural_size_align_bytes); + progress |= radv_nir_lower_rt_vars(shader, nir_var_ray_hit_attrib, load_hit_attrib_cb, store_hit_attrib_cb, 0); + return progress; +} + +static nir_def * +load_incoming_payload_cb(nir_builder *b, unsigned base) +{ + return nir_load_incoming_ray_payload_amd(b, .base = base); +} + +static void +store_incoming_payload_cb(nir_builder *b, nir_def *val, unsigned base) +{ + nir_store_incoming_ray_payload_amd(b, val, .base = base); +} + +static nir_def * +load_outgoing_payload_cb(nir_builder *b, unsigned base) +{ + return nir_load_outgoing_ray_payload_amd(b, .base = base); +} + +static void +store_outgoing_payload_cb(nir_builder *b, nir_def *val, unsigned base) +{ + nir_store_outgoing_ray_payload_amd(b, val, .base = base); } bool radv_nir_lower_ray_payload_derefs(nir_shader *shader, uint32_t offset) { - bool progress = radv_nir_lower_rt_vars(shader, nir_var_function_temp, RADV_MAX_HIT_ATTRIB_SIZE + offset); - progress |= radv_nir_lower_rt_vars(shader, nir_var_shader_call_data, RADV_MAX_HIT_ATTRIB_SIZE + offset); + bool progress = radv_nir_lower_rt_vars(shader, nir_var_function_temp, load_outgoing_payload_cb, + store_outgoing_payload_cb, offset); + progress |= radv_nir_lower_rt_vars(shader, nir_var_shader_call_data, load_incoming_payload_cb, + store_incoming_payload_cb, offset); return progress; } diff --git a/src/amd/vulkan/nir/radv_nir_rt_stage_common.c b/src/amd/vulkan/nir/radv_nir_rt_stage_common.c index af13bc7e930..5ce95852d08 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_stage_common.c +++ b/src/amd/vulkan/nir/radv_nir_rt_stage_common.c @@ -6,19 +6,21 @@ */ #include "nir/radv_nir_rt_stage_common.h" + +#include "nir/radv_nir_rt_stage_functions.h" +#include "aco_nir_call_attribs.h" #include "nir_builder.h" struct radv_nir_sbt_data -radv_nir_load_sbt_entry(nir_builder *b, nir_def *idx, enum radv_nir_sbt_type binding, enum radv_nir_sbt_entry offset) +radv_nir_load_sbt_entry(nir_builder *b, nir_def *base, nir_def *idx, enum radv_nir_sbt_type binding, + enum radv_nir_sbt_entry offset) { struct radv_nir_sbt_data data; - nir_def *desc_base_addr = nir_load_sbt_base_amd(b); - - nir_def *desc = nir_pack_64_2x32(b, ac_nir_load_smem(b, 2, desc_base_addr, nir_imm_int(b, binding), 4, 0)); + nir_def *desc = nir_pack_64_2x32(b, ac_nir_load_smem(b, 2, base, nir_imm_int(b, binding), 4, 0)); nir_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16)); - nir_def *stride = ac_nir_load_smem(b, 1, desc_base_addr, stride_offset, 4, 0); + nir_def *stride = ac_nir_load_smem(b, 1, base, stride_offset, 4, 0); nir_def *addr = nir_iadd(b, desc, nir_u2u64(b, nir_iadd_imm(b, nir_imul(b, idx, stride), offset))); @@ -149,52 +151,14 @@ radv_visit_inlined_shaders(nir_builder *b, nir_def *sbt_idx, bool can_have_null_ free(cases); } -bool -radv_nir_lower_rt_derefs(nir_shader *shader) -{ - nir_function_impl *impl = nir_shader_get_entrypoint(shader); - - bool progress = false; - - nir_builder b; - nir_def *arg_offset = NULL; - - nir_foreach_block (block, impl) { - nir_foreach_instr_safe (instr, block) { - if (instr->type != nir_instr_type_deref) - continue; - - nir_deref_instr *deref = nir_instr_as_deref(instr); - if (!nir_deref_mode_is(deref, nir_var_shader_call_data)) - continue; - - deref->modes = nir_var_function_temp; - progress = true; - - if (deref->deref_type == nir_deref_type_var) { - if (!arg_offset) { - b = nir_builder_at(nir_before_impl(impl)); - arg_offset = nir_load_rt_arg_scratch_offset_amd(&b); - } - - b.cursor = nir_before_instr(&deref->instr); - nir_deref_instr *replacement = - nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0); - nir_def_replace(&deref->def, &replacement->def); - } - } - } - - return nir_progress(progress, impl, nir_metadata_control_flow); -} - -/* Lowers hit attributes to registers or shared memory. If hit_attribs is NULL, attributes are +/* Lowers RT I/O vars to registers or shared memory. If hit_attribs is NULL, attributes are * lowered to shared memory. */ bool -radv_nir_lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workgroup_size) +radv_nir_lower_rt_storage(nir_shader *shader, nir_variable **hit_attribs, nir_deref_instr **payload_in, + nir_variable **payload_out, uint32_t workgroup_size) { bool progress = false; - nir_function_impl *impl = nir_shader_get_entrypoint(shader); + nir_function_impl *impl = radv_get_rt_shader_entrypoint(shader); nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib) { attrib->data.mode = nir_var_shader_temp; @@ -210,30 +174,52 @@ radv_nir_lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint3 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); if (intrin->intrinsic != nir_intrinsic_load_hit_attrib_amd && - intrin->intrinsic != nir_intrinsic_store_hit_attrib_amd) + intrin->intrinsic != nir_intrinsic_store_hit_attrib_amd && + intrin->intrinsic != nir_intrinsic_load_incoming_ray_payload_amd && + intrin->intrinsic != nir_intrinsic_store_incoming_ray_payload_amd && + intrin->intrinsic != nir_intrinsic_load_outgoing_ray_payload_amd && + intrin->intrinsic != nir_intrinsic_store_outgoing_ray_payload_amd) continue; progress = true; b.cursor = nir_after_instr(instr); - nir_def *offset; - if (!hit_attribs) - offset = nir_imul_imm( - &b, nir_iadd_imm(&b, nir_load_subgroup_invocation(&b), nir_intrinsic_base(intrin) * workgroup_size), - sizeof(uint32_t)); + if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd || + intrin->intrinsic == nir_intrinsic_store_hit_attrib_amd) { + nir_def *offset; + if (!hit_attribs) + offset = nir_imul_imm( + &b, nir_iadd_imm(&b, nir_load_subgroup_invocation(&b), nir_intrinsic_base(intrin) * workgroup_size), + sizeof(uint32_t)); - if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) { - nir_def *ret; - if (hit_attribs) - ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]); + if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) { + nir_def *ret; + if (hit_attribs) + ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]); + else + ret = nir_load_shared(&b, 1, 32, offset, .base = 0, .align_mul = 4); + nir_def_rewrite_uses(nir_instr_def(instr), ret); + } else { + if (hit_attribs) + nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1); + else + nir_store_shared(&b, intrin->src->ssa, offset, .base = 0, .align_mul = 4); + } + } else if (intrin->intrinsic == nir_intrinsic_load_incoming_ray_payload_amd || + intrin->intrinsic == nir_intrinsic_store_incoming_ray_payload_amd) { + if (!payload_in) + continue; + if (intrin->intrinsic == nir_intrinsic_load_incoming_ray_payload_amd) + nir_def_rewrite_uses(nir_instr_def(instr), nir_load_deref(&b, payload_in[nir_intrinsic_base(intrin)])); else - ret = nir_load_shared(&b, 1, 32, offset, .base = 0, .align_mul = 4); - nir_def_rewrite_uses(nir_instr_def(instr), ret); + nir_store_deref(&b, payload_in[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1); } else { - if (hit_attribs) - nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1); + if (!payload_out) + continue; + if (intrin->intrinsic == nir_intrinsic_load_outgoing_ray_payload_amd) + nir_def_rewrite_uses(nir_instr_def(instr), nir_load_var(&b, payload_out[nir_intrinsic_base(intrin)])); else - nir_store_shared(&b, intrin->src->ssa, offset, .base = 0, .align_mul = 4); + nir_store_var(&b, payload_out[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1); } nir_instr_remove(instr); } @@ -244,3 +230,13 @@ radv_nir_lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint3 return nir_progress(progress, impl, nir_metadata_control_flow); } + +void +radv_nir_param_from_type(nir_parameter *param, const glsl_type *type, bool uniform, unsigned driver_attribs) +{ + param->num_components = glsl_get_vector_elements(type); + param->bit_size = glsl_get_bit_size(type); + param->type = type; + param->is_uniform = uniform; + param->driver_attributes = driver_attribs; +} diff --git a/src/amd/vulkan/nir/radv_nir_rt_stage_common.h b/src/amd/vulkan/nir/radv_nir_rt_stage_common.h index 251daead8ac..fb46fa90321 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_stage_common.h +++ b/src/amd/vulkan/nir/radv_nir_rt_stage_common.h @@ -15,6 +15,9 @@ #include "radv_pipeline_cache.h" #include "radv_pipeline_rt.h" +typedef struct nir_parameter nir_parameter; +typedef struct glsl_type glsl_type; + /* * * Common Constants @@ -85,8 +88,8 @@ enum radv_nir_sbt_entry { SBT_ANY_HIT_IDX = offsetof(struct radv_pipeline_group_handle, any_hit_index), }; -struct radv_nir_sbt_data radv_nir_load_sbt_entry(nir_builder *b, nir_def *idx, enum radv_nir_sbt_type binding, - enum radv_nir_sbt_entry offset); +struct radv_nir_sbt_data radv_nir_load_sbt_entry(nir_builder *b, nir_def *base, nir_def *idx, + enum radv_nir_sbt_type binding, enum radv_nir_sbt_entry offset); /* * @@ -94,10 +97,10 @@ struct radv_nir_sbt_data radv_nir_load_sbt_entry(nir_builder *b, nir_def *idx, e * */ +bool radv_nir_lower_rt_storage(nir_shader *shader, nir_variable **hit_attribs, nir_deref_instr **payload_in, + nir_variable **payload_out, uint32_t workgroup_size); -bool radv_nir_lower_rt_derefs(nir_shader *shader); -bool radv_nir_lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workgroup_size); - +void radv_nir_param_from_type(nir_parameter *param, const glsl_type *type, bool uniform, unsigned driver_attribs); /* * diff --git a/src/amd/vulkan/nir/radv_nir_rt_stage_cps.c b/src/amd/vulkan/nir/radv_nir_rt_stage_cps.c index 073fae3de8b..ae13ef28b3d 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_stage_cps.c +++ b/src/amd/vulkan/nir/radv_nir_rt_stage_cps.c @@ -14,7 +14,9 @@ #include "nir/radv_nir_rt_stage_cps.h" #include "ac_nir.h" +#include "aco_nir_call_attribs.h" #include "radv_device.h" +#include "radv_nir_rt_stage_functions.h" #include "radv_physical_device.h" #include "radv_pipeline_rt.h" #include "radv_shader.h" @@ -24,12 +26,9 @@ radv_arg_def_is_unused(nir_def *def) { nir_foreach_use (use, def) { nir_instr *use_instr = nir_src_parent_instr(use); - if (use_instr->type == nir_instr_type_intrinsic) { - nir_intrinsic_instr *use_intr = nir_instr_as_intrinsic(use_instr); - if (use_intr->intrinsic == nir_intrinsic_store_scalar_arg_amd || - use_intr->intrinsic == nir_intrinsic_store_vector_arg_amd) - continue; - } else if (use_instr->type == nir_instr_type_phi) { + if (use_instr->type == nir_instr_type_call) + continue; + if (use_instr->type == nir_instr_type_phi) { nir_cf_node *prev_node = nir_cf_node_prev(&use_instr->block->cf_node); if (!prev_node) return false; @@ -48,13 +47,13 @@ radv_arg_def_is_unused(nir_def *def) static bool radv_gather_unused_args_instr(nir_builder *b, nir_intrinsic_instr *instr, void *data) { - if (instr->intrinsic != nir_intrinsic_load_scalar_arg_amd && instr->intrinsic != nir_intrinsic_load_vector_arg_amd) + if (instr->intrinsic != nir_intrinsic_load_param) return false; if (!radv_arg_def_is_unused(&instr->def)) { /* This arg is used for more than passing data to the next stage. */ struct radv_ray_tracing_stage_info *info = data; - BITSET_CLEAR(info->unused_args, nir_intrinsic_base(instr)); + BITSET_CLEAR(info->unused_args, nir_intrinsic_param_idx(instr)); } return false; @@ -193,14 +192,13 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data) case nir_intrinsic_rt_execute_callable: { uint32_t size = align(nir_intrinsic_stack_size(intr), 16); nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr)); - ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage)); nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1); nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16); nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1); - struct radv_nir_sbt_data sbt_data = - radv_nir_load_sbt_entry(b, intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR); + struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), + intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR); nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1); nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1); @@ -212,7 +210,6 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data) case nir_intrinsic_rt_trace_ray: { uint32_t size = align(nir_intrinsic_stack_size(intr), 16); nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr)); - ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage)); nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1); nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16); @@ -361,6 +358,10 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data) ret = nir_ushr_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 24); break; } + case nir_intrinsic_load_sbt_base_amd: { + ret = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS); + break; + } case nir_intrinsic_load_sbt_offset_amd: { ret = nir_load_var(b, vars->sbt_offset); break; @@ -385,8 +386,8 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data) nir_store_var(b, vars->geometry_id_and_flags, intr->src[5].ssa, 0x1); nir_store_var(b, vars->hit_kind, intr->src[6].ssa, 0x1); - struct radv_nir_sbt_data sbt_data = - radv_nir_load_sbt_entry(b, intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR); + struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), + intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR); nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1); nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1); @@ -414,7 +415,7 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data) nir_def *miss_index = nir_load_var(b, vars->miss_index); struct radv_nir_sbt_data sbt_data = - radv_nir_load_sbt_entry(b, miss_index, SBT_MISS, SBT_RECURSIVE_PTR); + radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), miss_index, SBT_MISS, SBT_RECURSIVE_PTR); nir_store_var(b, vars->shader_addr, sbt_data.shader_addr, 0x1); nir_store_var(b, vars->shader_record_ptr, sbt_data.shader_record_ptr, 0x1); @@ -455,226 +456,203 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, struct radv return nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data); } +static bool +lower_rt_derefs_cps(nir_shader *shader) +{ + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + + bool progress = false; + + nir_builder b; + nir_def *arg_offset = NULL; + + nir_foreach_block (block, impl) { + nir_foreach_instr_safe (instr, block) { + if (instr->type != nir_instr_type_deref) + continue; + + nir_deref_instr *deref = nir_instr_as_deref(instr); + if (!nir_deref_mode_is(deref, nir_var_shader_call_data)) + continue; + + deref->modes = nir_var_function_temp; + progress = true; + + if (deref->deref_type == nir_deref_type_var) { + if (!arg_offset) { + b = nir_builder_at(nir_before_impl(impl)); + arg_offset = nir_load_rt_arg_scratch_offset_amd(&b); + } + + b.cursor = nir_before_instr(&deref->instr); + nir_deref_instr *replacement = + nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0); + nir_def_replace(&deref->def, &replacement->def); + } + } + } + + return nir_progress(progress, impl, nir_metadata_control_flow); +} + void radv_nir_lower_rt_io_cps(nir_shader *nir) { NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data, glsl_get_natural_size_align_bytes); - NIR_PASS(_, nir, radv_nir_lower_rt_derefs); + NIR_PASS(_, nir, lower_rt_derefs_cps); NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset); } -/** Select the next shader based on priorities: - * - * Detect the priority of the shader stage by the lowest bits in the address (low to high): - * - Raygen - idx 0 - * - Traversal - idx 1 - * - Closest Hit / Miss - idx 2 - * - Callable - idx 3 - * - * - * This gives us the following priorities: - * Raygen : Callable > > Traversal > Raygen - * Traversal : > Chit / Miss > > Raygen - * CHit / Miss : Callable > Chit / Miss > Traversal > Raygen - * Callable : Callable > Chit / Miss > > Raygen - */ -static nir_def * -select_next_shader(nir_builder *b, nir_def *shader_addr, unsigned wave_size) -{ - mesa_shader_stage stage = b->shader->info.stage; - nir_def *prio = nir_iand_imm(b, shader_addr, radv_rt_priority_mask); - nir_def *ballot = nir_ballot(b, 1, wave_size, nir_imm_bool(b, true)); - nir_def *ballot_traversal = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_traversal)); - nir_def *ballot_hit_miss = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_hit_miss)); - nir_def *ballot_callable = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_callable)); - - if (stage != MESA_SHADER_CALLABLE && stage != MESA_SHADER_INTERSECTION) - ballot = nir_bcsel(b, nir_ine_imm(b, ballot_traversal, 0), ballot_traversal, ballot); - if (stage != MESA_SHADER_RAYGEN) - ballot = nir_bcsel(b, nir_ine_imm(b, ballot_hit_miss, 0), ballot_hit_miss, ballot); - if (stage != MESA_SHADER_INTERSECTION) - ballot = nir_bcsel(b, nir_ine_imm(b, ballot_callable, 0), ballot_callable, ballot); - - nir_def *lsb = nir_find_lsb(b, ballot); - nir_def *next = nir_read_invocation(b, shader_addr, lsb); - return nir_iand_imm(b, next, ~radv_rt_priority_mask); -} - static void -radv_store_arg(nir_builder *b, const struct radv_shader_args *args, const struct radv_ray_tracing_stage_info *info, - struct ac_arg arg, nir_def *value) +init_cps_function(nir_function *function, bool has_position_fetch) { - /* Do not pass unused data to the next stage. */ - if (!info || !BITSET_TEST(info->unused_args, arg.arg_index)) - ac_nir_store_arg(b, &args->ac, arg, value); + function->num_params = has_position_fetch ? CPS_ARG_COUNT : CPS_ARG_COUNT - 1; + function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params); + + radv_nir_param_from_type(function->params + RT_ARG_LAUNCH_ID, glsl_vector_type(GLSL_TYPE_UINT, 3), false, 0); + radv_nir_param_from_type(function->params + RT_ARG_LAUNCH_SIZE, glsl_vector_type(GLSL_TYPE_UINT, 3), true, 0); + radv_nir_param_from_type(function->params + RT_ARG_DESCRIPTORS, glsl_uint_type(), true, 0); + radv_nir_param_from_type(function->params + RT_ARG_DYNAMIC_DESCRIPTORS, glsl_uint_type(), true, 0); + radv_nir_param_from_type(function->params + RT_ARG_PUSH_CONSTANTS, glsl_uint_type(), true, 0); + radv_nir_param_from_type(function->params + RT_ARG_SBT_DESCRIPTORS, glsl_uint64_t_type(), true, 0); + radv_nir_param_from_type(function->params + RAYGEN_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0); + radv_nir_param_from_type(function->params + RAYGEN_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_PAYLOAD_SCRATCH_OFFSET, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_STACK_PTR, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_ACCEL_STRUCT, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_CULL_MASK_AND_FLAGS, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_SBT_OFFSET, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_SBT_STRIDE, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_MISS_INDEX, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_RAY_ORIGIN, glsl_vector_type(GLSL_TYPE_UINT, 3), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_RAY_TMIN, glsl_float_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_RAY_DIRECTION, glsl_vector_type(GLSL_TYPE_UINT, 3), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_RAY_TMAX, glsl_float_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_PRIMITIVE_ID, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_INSTANCE_ADDR, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_GEOMETRY_ID_AND_FLAGS, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CPS_ARG_HIT_KIND, glsl_uint_type(), false, 0); + + if (has_position_fetch) + radv_nir_param_from_type(function->params + CPS_ARG_PRIMITIVE_ADDR, glsl_uint64_t_type(), false, 0); + + function->driver_attributes = + (uint32_t)ACO_NIR_CALL_ABI_RT_RECURSIVE | ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL | ACO_NIR_FUNCTION_ATTRIB_NORETURN; + + /* Entrypoints can't have parameters. Consider RT stages as callable functions */ + function->is_exported = true; + function->is_entrypoint = false; } void -radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_args *args, const struct radv_shader_info *info, - uint32_t *stack_size, bool resume_shader, struct radv_device *device, - struct radv_ray_tracing_pipeline *pipeline, bool has_position_fetch, - const struct radv_ray_tracing_stage_info *traversal_info) +radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_info *info, bool resume_shader, + struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, + bool has_position_fetch, const struct radv_ray_tracing_stage_info *traversal_info) { - const struct radv_physical_device *pdev = radv_device_physical(device); - nir_function_impl *impl = nir_shader_get_entrypoint(shader); + /* The first raygen shader gets called by the prolog with the standard raygen signature. Only shaders called by the + * first shader can use the CPS function signature. + */ + if (shader->info.stage != MESA_SHADER_RAYGEN || resume_shader) + init_cps_function(impl->function, has_position_fetch); + else + radv_nir_init_rt_function_params(impl->function, MESA_SHADER_RAYGEN, 0); + + if (traversal_info) { + unsigned idx; + BITSET_FOREACH_SET (idx, traversal_info->unused_args, impl->function->num_params) + impl->function->params[idx].driver_attributes |= ACO_NIR_PARAM_ATTRIB_DISCARDABLE; + } + struct rt_variables vars = create_rt_variables(shader, device, pipeline->base.base.create_flags); struct radv_rt_shader_info rt_info = {0}; lower_rt_instructions(shader, &vars, &rt_info); - if (stack_size) { - vars.stack_size = MAX2(vars.stack_size, shader->scratch_size); - *stack_size = MAX2(*stack_size, vars.stack_size); - } - shader->scratch_size = 0; + shader->scratch_size = MAX2(shader->scratch_size, vars.stack_size); /* This can't use NIR_PASS because NIR_DEBUG=serialize,clone invalidates pointers. */ nir_lower_returns(shader); - nir_cf_list list; - nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl)); - /* initialize variables */ nir_builder b = nir_builder_at(nir_before_impl(impl)); - nir_def *descriptors = ac_nir_load_arg(&b, &args->ac, args->descriptors[0]); - nir_def *push_constants = ac_nir_load_arg(&b, &args->ac, args->ac.push_constants); - nir_def *dynamic_descriptors = ac_nir_load_arg(&b, &args->ac, args->ac.dynamic_descriptors); - nir_def *sbt_descriptors = ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_descriptors); - - nir_def *launch_sizes[3]; - for (uint32_t i = 0; i < ARRAY_SIZE(launch_sizes); i++) { - launch_sizes[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_sizes[i]); - nir_store_var(&b, vars.launch_sizes[i], launch_sizes[i], 1); + nir_def *launch_size_vec = nir_load_param(&b, RT_ARG_LAUNCH_SIZE); + nir_def *launch_id_vec = nir_load_param(&b, RT_ARG_LAUNCH_ID); + for (unsigned i = 0; i < 3; ++i) { + nir_store_var(&b, vars.launch_sizes[i], nir_channel(&b, launch_size_vec, i), 0x1); + nir_store_var(&b, vars.launch_ids[i], nir_channel(&b, launch_id_vec, i), 0x1); } + nir_store_var(&b, vars.traversal_addr, nir_load_param(&b, RAYGEN_ARG_TRAVERSAL_ADDR), 0x1); + nir_store_var(&b, vars.shader_record_ptr, nir_load_param(&b, RAYGEN_ARG_SHADER_RECORD_PTR), 0x1); + nir_store_var(&b, vars.shader_addr, nir_imm_int64(&b, 0), 0x1); - nir_def *scratch_offset = NULL; - if (args->ac.scratch_offset.used) - scratch_offset = ac_nir_load_arg(&b, &args->ac, args->ac.scratch_offset); - nir_def *ring_offsets = NULL; - if (args->ac.ring_offsets.used) - ring_offsets = ac_nir_load_arg(&b, &args->ac, args->ac.ring_offsets); + if (shader->info.stage == MESA_SHADER_RAYGEN && !resume_shader) { + impl->function->driver_attributes &= ~ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL; + nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1); + } else { + nir_store_var(&b, vars.stack_ptr, nir_load_param(&b, CPS_ARG_STACK_PTR), 0x1); + nir_store_var(&b, vars.arg, nir_load_param(&b, CPS_ARG_PAYLOAD_SCRATCH_OFFSET), 0x1); + nir_store_var(&b, vars.origin, nir_load_param(&b, CPS_ARG_RAY_ORIGIN), 0x7); + nir_store_var(&b, vars.tmin, nir_load_param(&b, CPS_ARG_RAY_TMIN), 0x1); + nir_store_var(&b, vars.direction, nir_load_param(&b, CPS_ARG_RAY_DIRECTION), 0x7); + nir_store_var(&b, vars.tmax, nir_load_param(&b, CPS_ARG_RAY_TMAX), 0x1); + nir_store_var(&b, vars.cull_mask_and_flags, nir_load_param(&b, CPS_ARG_CULL_MASK_AND_FLAGS), 0x1); + nir_store_var(&b, vars.sbt_offset, nir_load_param(&b, CPS_ARG_SBT_OFFSET), 0x1); + nir_store_var(&b, vars.sbt_stride, nir_load_param(&b, CPS_ARG_SBT_STRIDE), 0x1); + nir_store_var(&b, vars.accel_struct, nir_load_param(&b, CPS_ARG_ACCEL_STRUCT), 0x1); + nir_store_var(&b, vars.primitive_id, nir_load_param(&b, CPS_ARG_PRIMITIVE_ID), 0x1); + nir_store_var(&b, vars.instance_addr, nir_load_param(&b, CPS_ARG_INSTANCE_ADDR), 0x1); + if (has_position_fetch) + nir_store_var(&b, vars.primitive_addr, nir_load_param(&b, CPS_ARG_PRIMITIVE_ADDR), 0x1); + nir_store_var(&b, vars.geometry_id_and_flags, nir_load_param(&b, CPS_ARG_GEOMETRY_ID_AND_FLAGS), 0x1); + nir_store_var(&b, vars.hit_kind, nir_load_param(&b, CPS_ARG_HIT_KIND), 0x1); - nir_def *launch_ids[3]; - for (uint32_t i = 0; i < ARRAY_SIZE(launch_ids); i++) { - launch_ids[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_ids[i]); - nir_store_var(&b, vars.launch_ids[i], launch_ids[i], 1); + if (traversal_info && traversal_info->miss_index.state == RADV_RT_CONST_ARG_STATE_VALID) + nir_store_var(&b, vars.miss_index, nir_imm_int(&b, traversal_info->miss_index.value), 0x1); + else + nir_store_var(&b, vars.miss_index, nir_load_param(&b, CPS_ARG_MISS_INDEX), 0x1); } - nir_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr); - nir_store_var(&b, vars.traversal_addr, - nir_pack_64_2x32_split(&b, traversal_addr, nir_imm_int(&b, pdev->info.address32_hi)), 1); - - nir_def *shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_addr); - shader_addr = nir_pack_64_2x32(&b, shader_addr); - nir_store_var(&b, vars.shader_addr, shader_addr, 1); - - nir_store_var(&b, vars.stack_ptr, ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1); - nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record); - nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1); - nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1); - - nir_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct); - nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1); - nir_store_var(&b, vars.cull_mask_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1); - nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1); - nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1); - nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7); - nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1); - nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction), 0x7); - nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1); - - if (traversal_info && traversal_info->miss_index.state == RADV_RT_CONST_ARG_STATE_VALID) - nir_store_var(&b, vars.miss_index, nir_imm_int(&b, traversal_info->miss_index.value), 0x1); - else - nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 0x1); - - nir_def *primitive_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_addr); - nir_store_var(&b, vars.primitive_addr, nir_pack_64_2x32(&b, primitive_addr), 1); - nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id), 1); - nir_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr); - nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1); - nir_store_var(&b, vars.geometry_id_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1); - nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1); - - /* guard the shader, so that only the correct invocations execute it */ - nir_if *shader_guard = NULL; - if (shader->info.stage != MESA_SHADER_RAYGEN || resume_shader) { - nir_def *uniform_shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr); - uniform_shader_addr = nir_pack_64_2x32(&b, uniform_shader_addr); - uniform_shader_addr = nir_ior_imm(&b, uniform_shader_addr, radv_get_rt_priority(shader->info.stage)); - - shader_guard = nir_push_if(&b, nir_ieq(&b, uniform_shader_addr, shader_addr)); - shader_guard->control = nir_selection_control_divergent_always_taken; - } - - nir_cf_reinsert(&list, b.cursor); - - if (shader_guard) - nir_pop_if(&b, shader_guard); - b.cursor = nir_after_impl(impl); - /* select next shader */ - shader_addr = nir_load_var(&b, vars.shader_addr); - nir_def *next = select_next_shader(&b, shader_addr, info->wave_size); - ac_nir_store_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr, next); + /* tail-call next shader */ + nir_def *shader_addr = nir_load_var(&b, vars.shader_addr); + nir_function *continuation_func = nir_function_create(shader, "continuation_func"); + init_cps_function(continuation_func, has_position_fetch); - ac_nir_store_arg(&b, &args->ac, args->descriptors[0], descriptors); - ac_nir_store_arg(&b, &args->ac, args->ac.push_constants, push_constants); - ac_nir_store_arg(&b, &args->ac, args->ac.dynamic_descriptors, dynamic_descriptors); - ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_descriptors, sbt_descriptors); - ac_nir_store_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr, traversal_addr); - - for (uint32_t i = 0; i < ARRAY_SIZE(launch_sizes); i++) { - if (rt_info.uses_launch_size) - ac_nir_store_arg(&b, &args->ac, args->ac.rt.launch_sizes[i], launch_sizes[i]); - else - radv_store_arg(&b, args, traversal_info, args->ac.rt.launch_sizes[i], launch_sizes[i]); - } - - if (scratch_offset) - ac_nir_store_arg(&b, &args->ac, args->ac.scratch_offset, scratch_offset); - if (ring_offsets) - ac_nir_store_arg(&b, &args->ac, args->ac.ring_offsets, ring_offsets); - - for (uint32_t i = 0; i < ARRAY_SIZE(launch_ids); i++) { - if (rt_info.uses_launch_id) - ac_nir_store_arg(&b, &args->ac, args->ac.rt.launch_ids[i], launch_ids[i]); - else - radv_store_arg(&b, args, traversal_info, args->ac.rt.launch_ids[i], launch_ids[i]); - } - - /* store back all variables to registers */ - ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base, nir_load_var(&b, vars.stack_ptr)); - ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_addr, shader_addr); - radv_store_arg(&b, args, traversal_info, args->ac.rt.shader_record, nir_load_var(&b, vars.shader_record_ptr)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.cull_mask_and_flags, - nir_load_var(&b, vars.cull_mask_and_flags)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax)); - - if (has_position_fetch) - radv_store_arg(&b, args, traversal_info, args->ac.rt.primitive_addr, nir_load_var(&b, vars.primitive_addr)); - - radv_store_arg(&b, args, traversal_info, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.geometry_id_and_flags, - nir_load_var(&b, vars.geometry_id_and_flags)); - radv_store_arg(&b, args, traversal_info, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind)); + unsigned param_count = continuation_func->num_params; + nir_def **next_args = rzalloc_array_size(b.shader, sizeof(nir_def *), param_count); + next_args[RT_ARG_LAUNCH_ID] = nir_load_param(&b, RT_ARG_LAUNCH_ID); + next_args[RT_ARG_LAUNCH_SIZE] = nir_load_param(&b, RT_ARG_LAUNCH_SIZE); + next_args[RT_ARG_DESCRIPTORS] = nir_load_param(&b, RT_ARG_DESCRIPTORS); + next_args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(&b, RT_ARG_DYNAMIC_DESCRIPTORS); + next_args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(&b, RT_ARG_PUSH_CONSTANTS); + next_args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(&b, RT_ARG_SBT_DESCRIPTORS); + next_args[RAYGEN_ARG_TRAVERSAL_ADDR] = nir_load_var(&b, vars.traversal_addr); + next_args[RAYGEN_ARG_SHADER_RECORD_PTR] = nir_load_var(&b, vars.shader_record_ptr); + next_args[CPS_ARG_PAYLOAD_SCRATCH_OFFSET] = nir_load_var(&b, vars.arg); + next_args[CPS_ARG_STACK_PTR] = nir_load_var(&b, vars.stack_ptr); + next_args[CPS_ARG_RAY_ORIGIN] = nir_load_var(&b, vars.origin); + next_args[CPS_ARG_RAY_TMIN] = nir_load_var(&b, vars.tmin); + next_args[CPS_ARG_RAY_DIRECTION] = nir_load_var(&b, vars.direction); + next_args[CPS_ARG_RAY_TMAX] = nir_load_var(&b, vars.tmax); + next_args[CPS_ARG_CULL_MASK_AND_FLAGS] = nir_load_var(&b, vars.cull_mask_and_flags); + next_args[CPS_ARG_SBT_OFFSET] = nir_load_var(&b, vars.sbt_offset); + next_args[CPS_ARG_SBT_STRIDE] = nir_load_var(&b, vars.sbt_stride); + next_args[CPS_ARG_MISS_INDEX] = nir_load_var(&b, vars.miss_index); + next_args[CPS_ARG_ACCEL_STRUCT] = nir_load_var(&b, vars.accel_struct); + next_args[CPS_ARG_PRIMITIVE_ID] = nir_load_var(&b, vars.primitive_id); + next_args[CPS_ARG_INSTANCE_ADDR] = nir_load_var(&b, vars.instance_addr); + next_args[CPS_ARG_PRIMITIVE_ADDR] = nir_load_var(&b, vars.primitive_addr); + next_args[CPS_ARG_GEOMETRY_ID_AND_FLAGS] = nir_load_var(&b, vars.geometry_id_and_flags); + next_args[CPS_ARG_HIT_KIND] = nir_load_var(&b, vars.hit_kind); + nir_build_indirect_call(&b, continuation_func, shader_addr, param_count, next_args); nir_progress(true, impl, nir_metadata_none); @@ -683,5 +661,5 @@ radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_args *arg NIR_PASS(_, shader, nir_lower_vars_to_ssa); if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_INTERSECTION) - NIR_PASS(_, shader, radv_nir_lower_hit_attribs, NULL, info->wave_size); + NIR_PASS(_, shader, radv_nir_lower_rt_storage, NULL, NULL, NULL, info->wave_size); } diff --git a/src/amd/vulkan/nir/radv_nir_rt_stage_cps.h b/src/amd/vulkan/nir/radv_nir_rt_stage_cps.h index a797ef084cb..5c776137b0d 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_stage_cps.h +++ b/src/amd/vulkan/nir/radv_nir_rt_stage_cps.h @@ -13,8 +13,7 @@ void radv_gather_unused_args(struct radv_ray_tracing_stage_info *info, nir_shader *nir); -void radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_args *args, - const struct radv_shader_info *info, uint32_t *stack_size, bool resume_shader, +void radv_nir_lower_rt_abi_cps(nir_shader *shader, const struct radv_shader_info *info, bool resume_shader, struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, bool has_position_fetch, const struct radv_ray_tracing_stage_info *traversal_info); void radv_nir_lower_rt_io_cps(nir_shader *shader); diff --git a/src/amd/vulkan/nir/radv_nir_rt_stage_functions.c b/src/amd/vulkan/nir/radv_nir_rt_stage_functions.c new file mode 100644 index 00000000000..454a7121cff --- /dev/null +++ b/src/amd/vulkan/nir/radv_nir_rt_stage_functions.c @@ -0,0 +1,726 @@ +/* + * Copyright © 2025 Valve Corporation + * Copyright © 2021 Google + * + * SPDX-License-Identifier: MIT + */ + +#include "nir/nir.h" +#include "nir/nir_builder.h" + +#include "nir/radv_nir.h" +#include "nir/radv_nir_rt_common.h" +#include "nir/radv_nir_rt_stage_common.h" +#include "nir/radv_nir_rt_stage_functions.h" + +#include "radv_device.h" +#include "radv_physical_device.h" +#include "radv_shader.h" + +#include "aco_nir_call_attribs.h" + +#include "vk_pipeline.h" + +static void +radv_nir_init_common_rt_params(nir_function *function) +{ + radv_nir_param_from_type(function->params + RT_ARG_LAUNCH_ID, glsl_vector_type(GLSL_TYPE_UINT, 3), false, 0); + radv_nir_param_from_type(function->params + RT_ARG_LAUNCH_SIZE, glsl_vector_type(GLSL_TYPE_UINT, 3), true, 0); + radv_nir_param_from_type(function->params + RT_ARG_DESCRIPTORS, glsl_uint_type(), true, 0); + radv_nir_param_from_type(function->params + RT_ARG_DYNAMIC_DESCRIPTORS, glsl_uint_type(), true, 0); + radv_nir_param_from_type(function->params + RT_ARG_PUSH_CONSTANTS, glsl_uint_type(), true, 0); + radv_nir_param_from_type(function->params + RT_ARG_SBT_DESCRIPTORS, glsl_uint64_t_type(), true, 0); +} + +void +radv_nir_init_rt_function_params(nir_function *function, mesa_shader_stage stage, unsigned payload_size) +{ + unsigned payload_base = -1u; + + switch (stage) { + case MESA_SHADER_RAYGEN: + function->num_params = RAYGEN_ARG_COUNT; + function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params); + radv_nir_init_common_rt_params(function); + radv_nir_param_from_type(function->params + RAYGEN_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0); + radv_nir_param_from_type(function->params + RAYGEN_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0); + function->driver_attributes = (uint32_t)ACO_NIR_CALL_ABI_RT_RECURSIVE | ACO_NIR_FUNCTION_ATTRIB_NORETURN; + break; + case MESA_SHADER_CALLABLE: + function->num_params = RAYGEN_ARG_COUNT + DIV_ROUND_UP(payload_size, 4); + function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params); + radv_nir_init_common_rt_params(function); + radv_nir_param_from_type(function->params + RAYGEN_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0); + radv_nir_param_from_type(function->params + RAYGEN_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0); + + function->driver_attributes = (uint32_t)ACO_NIR_CALL_ABI_RT_RECURSIVE | ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL; + payload_base = RAYGEN_ARG_COUNT; + break; + case MESA_SHADER_INTERSECTION: + function->num_params = TRAVERSAL_ARG_PAYLOAD_BASE + DIV_ROUND_UP(payload_size, 4); + function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params); + radv_nir_init_common_rt_params(function); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_ACCEL_STRUCT, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_CULL_MASK_AND_FLAGS, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_SBT_OFFSET, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_SBT_STRIDE, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_MISS_INDEX, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_RAY_ORIGIN, glsl_vector_type(GLSL_TYPE_UINT, 3), false, + 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_RAY_TMIN, glsl_float_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_RAY_DIRECTION, glsl_vector_type(GLSL_TYPE_UINT, 3), + false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_RAY_TMAX, glsl_float_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_PRIMITIVE_ADDR, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_PRIMITIVE_ID, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_INSTANCE_ADDR, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_GEOMETRY_ID_AND_FLAGS, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + TRAVERSAL_ARG_HIT_KIND, glsl_uint_type(), false, 0); + + function->driver_attributes = ACO_NIR_CALL_ABI_TRAVERSAL; + payload_base = TRAVERSAL_ARG_PAYLOAD_BASE; + break; + case MESA_SHADER_CLOSEST_HIT: + case MESA_SHADER_MISS: + function->num_params = CHIT_MISS_ARG_PAYLOAD_BASE + DIV_ROUND_UP(payload_size, 4); + function->params = rzalloc_array_size(function->shader, sizeof(nir_parameter), function->num_params); + radv_nir_init_common_rt_params(function); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_TRAVERSAL_ADDR, glsl_uint64_t_type(), true, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_SHADER_RECORD_PTR, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_ACCEL_STRUCT, glsl_uint64_t_type(), false, + ACO_NIR_PARAM_ATTRIB_DISCARDABLE); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_CULL_MASK_AND_FLAGS, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_SBT_OFFSET, glsl_uint_type(), false, + ACO_NIR_PARAM_ATTRIB_DISCARDABLE); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_SBT_STRIDE, glsl_uint_type(), false, + ACO_NIR_PARAM_ATTRIB_DISCARDABLE); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_MISS_INDEX, glsl_uint_type(), false, + ACO_NIR_PARAM_ATTRIB_DISCARDABLE); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_RAY_ORIGIN, glsl_vector_type(GLSL_TYPE_UINT, 3), false, + 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_RAY_TMIN, glsl_float_type(), false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_RAY_DIRECTION, glsl_vector_type(GLSL_TYPE_UINT, 3), + false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_RAY_TMAX, glsl_float_type(), false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_PRIMITIVE_ADDR, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_PRIMITIVE_ID, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_INSTANCE_ADDR, glsl_uint64_t_type(), false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS, glsl_uint_type(), false, 0); + radv_nir_param_from_type(function->params + CHIT_MISS_ARG_HIT_KIND, glsl_uint_type(), false, 0); + + function->driver_attributes = (uint32_t)ACO_NIR_CALL_ABI_RT_RECURSIVE | ACO_NIR_FUNCTION_ATTRIB_DIVERGENT_CALL; + payload_base = CHIT_MISS_ARG_PAYLOAD_BASE; + break; + default: + UNREACHABLE("invalid RT stage"); + } + + if (payload_base != -1u) { + for (unsigned i = 0; i < DIV_ROUND_UP(payload_size, 4); ++i) { + function->params[payload_base + i].num_components = 1; + function->params[payload_base + i].bit_size = 32; + function->params[payload_base + i].is_return = true; + function->params[payload_base + i].type = glsl_uint_type(); + } + } + + /* Entrypoints can't have parameters. Consider RT stages as callable functions */ + function->is_exported = true; + function->is_entrypoint = false; +} + +/* + * Global variables for an RT pipeline + */ +struct rt_variables { + struct radv_device *device; + const VkPipelineCreateFlags2 flags; + + /* Stage-dependent parameter indices */ + unsigned shader_record_ptr_param; + unsigned traversal_addr_param; + unsigned accel_struct_param; + unsigned cull_mask_and_flags_param; + unsigned sbt_offset_param; + unsigned sbt_stride_param; + unsigned miss_index_param; + unsigned ray_origin_param; + unsigned ray_tmin_param; + unsigned ray_direction_param; + unsigned ray_tmax_param; + unsigned primitive_id_param; + unsigned instance_addr_param; + unsigned primitive_addr_param; + unsigned geometry_id_and_flags_param; + unsigned hit_kind_param; + unsigned in_payload_base_param; + + nir_variable **out_payload_storage; + unsigned payload_size; + + nir_function *trace_ray_func; + nir_function *chit_miss_func; + nir_function *callable_func; + + unsigned stack_size; +}; + +static struct rt_variables +create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2 flags, + unsigned max_payload_size) +{ + struct rt_variables vars = { + .device = device, + .flags = flags, + }; + + if (max_payload_size) + vars.out_payload_storage = rzalloc_array_size(shader, DIV_ROUND_UP(max_payload_size, 4), sizeof(nir_variable *)); + vars.payload_size = max_payload_size; + for (unsigned i = 0; i < DIV_ROUND_UP(max_payload_size, 4); ++i) { + vars.out_payload_storage[i] = + nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "out_payload_storage"); + } + + nir_function *trace_ray_func = nir_function_create(shader, "trace_ray_func"); + radv_nir_init_rt_function_params(trace_ray_func, MESA_SHADER_INTERSECTION, max_payload_size); + vars.trace_ray_func = trace_ray_func; + nir_function *chit_miss_func = nir_function_create(shader, "chit_miss_func"); + radv_nir_init_rt_function_params(chit_miss_func, MESA_SHADER_CLOSEST_HIT, max_payload_size); + vars.chit_miss_func = chit_miss_func; + nir_function *callable_func = nir_function_create(shader, "callable_func"); + radv_nir_init_rt_function_params(callable_func, MESA_SHADER_CALLABLE, max_payload_size); + vars.callable_func = callable_func; + + vars.shader_record_ptr_param = -1u; + vars.traversal_addr_param = -1u; + vars.accel_struct_param = -1u; + vars.cull_mask_and_flags_param = -1u; + vars.sbt_offset_param = -1u; + vars.sbt_stride_param = -1u; + vars.miss_index_param = -1u; + vars.ray_origin_param = -1u; + vars.ray_tmin_param = -1u; + vars.ray_direction_param = -1u; + vars.ray_tmax_param = -1u; + vars.primitive_id_param = -1u; + vars.instance_addr_param = -1u; + vars.primitive_addr_param = -1u; + vars.geometry_id_and_flags_param = -1u; + vars.hit_kind_param = -1u; + vars.in_payload_base_param = -1u; + + switch (shader->info.stage) { + case MESA_SHADER_CALLABLE: + vars.in_payload_base_param = RAYGEN_ARG_COUNT; + vars.shader_record_ptr_param = RAYGEN_ARG_SHADER_RECORD_PTR; + vars.traversal_addr_param = RAYGEN_ARG_TRAVERSAL_ADDR; + break; + case MESA_SHADER_RAYGEN: + vars.shader_record_ptr_param = RAYGEN_ARG_SHADER_RECORD_PTR; + vars.traversal_addr_param = RAYGEN_ARG_TRAVERSAL_ADDR; + break; + case MESA_SHADER_INTERSECTION: + vars.traversal_addr_param = TRAVERSAL_ARG_TRAVERSAL_ADDR; + vars.shader_record_ptr_param = TRAVERSAL_ARG_SHADER_RECORD_PTR; + vars.accel_struct_param = TRAVERSAL_ARG_ACCEL_STRUCT; + vars.cull_mask_and_flags_param = TRAVERSAL_ARG_CULL_MASK_AND_FLAGS; + vars.sbt_offset_param = TRAVERSAL_ARG_SBT_OFFSET; + vars.sbt_stride_param = TRAVERSAL_ARG_SBT_STRIDE; + vars.miss_index_param = TRAVERSAL_ARG_MISS_INDEX; + vars.ray_origin_param = TRAVERSAL_ARG_RAY_ORIGIN; + vars.ray_tmin_param = TRAVERSAL_ARG_RAY_TMIN; + vars.ray_direction_param = TRAVERSAL_ARG_RAY_DIRECTION; + vars.ray_tmax_param = TRAVERSAL_ARG_RAY_TMAX; + vars.in_payload_base_param = TRAVERSAL_ARG_PAYLOAD_BASE; + break; + case MESA_SHADER_CLOSEST_HIT: + case MESA_SHADER_MISS: + vars.traversal_addr_param = CHIT_MISS_ARG_TRAVERSAL_ADDR; + vars.shader_record_ptr_param = CHIT_MISS_ARG_SHADER_RECORD_PTR; + vars.accel_struct_param = CHIT_MISS_ARG_ACCEL_STRUCT; + vars.cull_mask_and_flags_param = CHIT_MISS_ARG_CULL_MASK_AND_FLAGS; + vars.sbt_offset_param = CHIT_MISS_ARG_SBT_OFFSET; + vars.sbt_stride_param = CHIT_MISS_ARG_SBT_STRIDE; + vars.miss_index_param = CHIT_MISS_ARG_MISS_INDEX; + vars.ray_origin_param = CHIT_MISS_ARG_RAY_ORIGIN; + vars.ray_tmin_param = CHIT_MISS_ARG_RAY_TMIN; + vars.ray_direction_param = CHIT_MISS_ARG_RAY_DIRECTION; + vars.ray_tmax_param = CHIT_MISS_ARG_RAY_TMAX; + vars.primitive_id_param = CHIT_MISS_ARG_PRIMITIVE_ID; + vars.instance_addr_param = CHIT_MISS_ARG_INSTANCE_ADDR; + vars.primitive_addr_param = CHIT_MISS_ARG_PRIMITIVE_ADDR; + vars.geometry_id_and_flags_param = CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS; + vars.hit_kind_param = CHIT_MISS_ARG_HIT_KIND; + vars.in_payload_base_param = CHIT_MISS_ARG_PAYLOAD_BASE; + break; + default: + break; + } + return vars; +} + +static bool +lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_vars) +{ + if (instr->type == nir_instr_type_jump) { + nir_jump_instr *jump = nir_instr_as_jump(instr); + if (jump->type == nir_jump_halt) { + jump->type = nir_jump_return; + return true; + } + return false; + } else if (instr->type != nir_instr_type_intrinsic) { + return false; + } + + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + + struct rt_variables *vars = _vars; + + b->cursor = nir_before_instr(&intr->instr); + + nir_def *ret = NULL; + switch (intr->intrinsic) { + case nir_intrinsic_execute_callable: { + struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), + intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR); + + unsigned param_count = RAYGEN_ARG_COUNT + DIV_ROUND_UP(vars->payload_size, 4); + nir_def **args = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count); + args[RT_ARG_LAUNCH_ID] = nir_load_param(b, RT_ARG_LAUNCH_ID); + args[RT_ARG_LAUNCH_SIZE] = nir_load_param(b, RT_ARG_LAUNCH_SIZE); + args[RT_ARG_DESCRIPTORS] = nir_load_param(b, RT_ARG_DESCRIPTORS); + args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS); + args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(b, RT_ARG_PUSH_CONSTANTS); + args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS); + args[RAYGEN_ARG_TRAVERSAL_ADDR] = nir_undef(b, 1, 64); + args[RAYGEN_ARG_SHADER_RECORD_PTR] = sbt_data.shader_record_ptr; + for (unsigned i = 0; i < DIV_ROUND_UP(vars->payload_size, 4); ++i) { + args[RAYGEN_ARG_COUNT + i] = nir_instr_def(&nir_build_deref_var(b, vars->out_payload_storage[i])->instr); + } + nir_build_indirect_call(b, vars->callable_func, sbt_data.shader_addr, param_count, args); + break; + } + case nir_intrinsic_trace_ray: { + nir_def *undef = nir_undef(b, 1, 32); + /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */ + nir_def *cull_mask_and_flags = nir_ior(b, nir_ishl_imm(b, intr->src[2].ssa, 24), intr->src[1].ssa); + nir_def *traversal_addr = nir_load_param(b, vars->traversal_addr_param); + + unsigned param_count = TRAVERSAL_ARG_PAYLOAD_BASE + DIV_ROUND_UP(vars->payload_size, 4); + nir_def **args = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count); + args[RT_ARG_LAUNCH_ID] = nir_load_param(b, RT_ARG_LAUNCH_ID); + args[RT_ARG_LAUNCH_SIZE] = nir_load_param(b, RT_ARG_LAUNCH_SIZE); + args[RT_ARG_DESCRIPTORS] = nir_load_param(b, RT_ARG_DESCRIPTORS); + args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS); + args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(b, RT_ARG_PUSH_CONSTANTS); + args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS); + args[TRAVERSAL_ARG_TRAVERSAL_ADDR] = traversal_addr; + /* Traversal does not have a shader record. */ + args[TRAVERSAL_ARG_SHADER_RECORD_PTR] = nir_undef(b, 1, 64); + args[TRAVERSAL_ARG_ACCEL_STRUCT] = intr->src[0].ssa; + args[TRAVERSAL_ARG_CULL_MASK_AND_FLAGS] = cull_mask_and_flags; + args[TRAVERSAL_ARG_SBT_OFFSET] = nir_iand_imm(b, intr->src[3].ssa, 0xf); + args[TRAVERSAL_ARG_SBT_STRIDE] = nir_iand_imm(b, intr->src[4].ssa, 0xf); + args[TRAVERSAL_ARG_MISS_INDEX] = nir_iand_imm(b, intr->src[5].ssa, 0xffff); + args[TRAVERSAL_ARG_RAY_ORIGIN] = intr->src[6].ssa; + args[TRAVERSAL_ARG_RAY_TMIN] = intr->src[7].ssa; + args[TRAVERSAL_ARG_RAY_DIRECTION] = intr->src[8].ssa; + args[TRAVERSAL_ARG_RAY_TMAX] = intr->src[9].ssa; + args[TRAVERSAL_ARG_PRIMITIVE_ADDR] = nir_undef(b, 1, 64); + args[TRAVERSAL_ARG_PRIMITIVE_ID] = undef; + args[TRAVERSAL_ARG_INSTANCE_ADDR] = nir_undef(b, 1, 64); + args[TRAVERSAL_ARG_GEOMETRY_ID_AND_FLAGS] = undef; + args[TRAVERSAL_ARG_HIT_KIND] = undef; + for (unsigned i = 0; i < DIV_ROUND_UP(vars->payload_size, 4); ++i) { + args[TRAVERSAL_ARG_PAYLOAD_BASE + i] = + nir_instr_def(&nir_build_deref_var(b, vars->out_payload_storage[i])->instr); + } + nir_build_indirect_call(b, vars->trace_ray_func, traversal_addr, param_count, args); + break; + } + case nir_intrinsic_load_shader_record_ptr: { + ret = nir_load_param(b, vars->shader_record_ptr_param); + break; + } + case nir_intrinsic_load_ray_launch_size: { + ret = nir_load_param(b, RT_ARG_LAUNCH_SIZE); + break; + }; + case nir_intrinsic_load_ray_launch_id: { + ret = nir_load_param(b, RT_ARG_LAUNCH_ID); + break; + } + case nir_intrinsic_load_ray_t_min: { + ret = nir_load_param(b, vars->ray_tmin_param); + break; + } + case nir_intrinsic_load_ray_t_max: { + ret = nir_load_param(b, vars->ray_tmax_param); + break; + } + case nir_intrinsic_load_ray_world_origin: { + ret = nir_load_param(b, vars->ray_origin_param); + break; + } + case nir_intrinsic_load_ray_world_direction: { + ret = nir_load_param(b, vars->ray_direction_param); + break; + } + case nir_intrinsic_load_ray_instance_custom_index: { + ret = radv_load_custom_instance(vars->device, b, nir_load_param(b, vars->instance_addr_param)); + break; + } + case nir_intrinsic_load_primitive_id: { + ret = nir_load_param(b, vars->primitive_id_param); + break; + } + case nir_intrinsic_load_ray_geometry_index: { + ret = nir_load_param(b, vars->geometry_id_and_flags_param); + ret = nir_iand_imm(b, ret, 0xFFFFFFF); + break; + } + case nir_intrinsic_load_instance_id: { + ret = radv_load_instance_id(vars->device, b, nir_load_param(b, vars->instance_addr_param)); + break; + } + case nir_intrinsic_load_ray_flags: { + ret = nir_iand_imm(b, nir_load_param(b, vars->cull_mask_and_flags_param), 0xFFFFFF); + break; + } + case nir_intrinsic_load_ray_hit_kind: { + ret = nir_load_param(b, vars->hit_kind_param); + break; + } + case nir_intrinsic_load_ray_world_to_object: { + unsigned c = nir_intrinsic_column(intr); + nir_def *instance_node_addr = nir_load_param(b, vars->instance_addr_param); + nir_def *wto_matrix[3]; + radv_load_wto_matrix(vars->device, b, instance_node_addr, wto_matrix); + + nir_def *vals[3]; + for (unsigned i = 0; i < 3; ++i) + vals[i] = nir_channel(b, wto_matrix[i], c); + + ret = nir_vec(b, vals, 3); + break; + } + case nir_intrinsic_load_ray_object_to_world: { + unsigned c = nir_intrinsic_column(intr); + nir_def *otw_matrix[3]; + radv_load_otw_matrix(vars->device, b, nir_load_param(b, vars->instance_addr_param), otw_matrix); + ret = nir_vec3(b, nir_channel(b, otw_matrix[0], c), nir_channel(b, otw_matrix[1], c), + nir_channel(b, otw_matrix[2], c)); + break; + } + case nir_intrinsic_load_ray_object_origin: { + nir_def *wto_matrix[3]; + radv_load_wto_matrix(vars->device, b, nir_load_param(b, vars->instance_addr_param), wto_matrix); + ret = nir_build_vec3_mat_mult(b, nir_load_param(b, vars->ray_origin_param), wto_matrix, true); + break; + } + case nir_intrinsic_load_ray_object_direction: { + nir_def *wto_matrix[3]; + radv_load_wto_matrix(vars->device, b, nir_load_param(b, vars->instance_addr_param), wto_matrix); + ret = nir_build_vec3_mat_mult(b, nir_load_param(b, vars->ray_direction_param), wto_matrix, false); + break; + } + case nir_intrinsic_load_cull_mask: { + ret = nir_ushr_imm(b, nir_load_param(b, vars->cull_mask_and_flags_param), 24); + break; + } + case nir_intrinsic_load_sbt_base_amd: { + ret = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS); + break; + } + case nir_intrinsic_load_sbt_offset_amd: { + ret = nir_load_param(b, vars->sbt_offset_param); + break; + } + case nir_intrinsic_load_sbt_stride_amd: { + ret = nir_load_param(b, vars->sbt_stride_param); + break; + } + case nir_intrinsic_load_accel_struct_amd: { + ret = nir_load_param(b, vars->accel_struct_param); + break; + } + case nir_intrinsic_load_cull_mask_and_flags_amd: { + ret = nir_load_param(b, vars->cull_mask_and_flags_param); + break; + } + case nir_intrinsic_execute_closest_hit_amd: { + struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), + intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR); + + nir_def *should_return = + nir_test_mask(b, nir_load_param(b, vars->cull_mask_and_flags_param), SpvRayFlagsSkipClosestHitShaderKHRMask); + + if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) { + should_return = nir_ior(b, should_return, nir_ieq_imm(b, sbt_data.shader_addr, 0)); + } + + /* should_return is set if we had a hit but we won't be calling the closest hit + * shader and hence need to return immediately to the calling shader. */ + nir_push_if(b, nir_inot(b, should_return)); + unsigned param_count = CHIT_MISS_ARG_PAYLOAD_BASE + DIV_ROUND_UP(vars->payload_size, 4); + nir_def **args = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count); + args[RT_ARG_LAUNCH_ID] = nir_load_param(b, RT_ARG_LAUNCH_ID); + args[RT_ARG_LAUNCH_SIZE] = nir_load_param(b, RT_ARG_LAUNCH_SIZE); + args[RT_ARG_DESCRIPTORS] = nir_load_param(b, RT_ARG_DESCRIPTORS); + args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS); + args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(b, RT_ARG_PUSH_CONSTANTS); + args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS); + args[CHIT_MISS_ARG_TRAVERSAL_ADDR] = nir_load_param(b, vars->traversal_addr_param); + args[CHIT_MISS_ARG_SHADER_RECORD_PTR] = sbt_data.shader_record_ptr; + args[CHIT_MISS_ARG_ACCEL_STRUCT] = nir_load_param(b, vars->accel_struct_param); + args[CHIT_MISS_ARG_CULL_MASK_AND_FLAGS] = nir_load_param(b, vars->cull_mask_and_flags_param); + args[CHIT_MISS_ARG_SBT_OFFSET] = nir_load_param(b, vars->sbt_offset_param); + args[CHIT_MISS_ARG_SBT_STRIDE] = nir_load_param(b, vars->sbt_stride_param); + args[CHIT_MISS_ARG_MISS_INDEX] = nir_load_param(b, vars->miss_index_param); + args[CHIT_MISS_ARG_RAY_ORIGIN] = nir_load_param(b, vars->ray_origin_param); + args[CHIT_MISS_ARG_RAY_TMIN] = nir_load_param(b, vars->ray_tmin_param); + args[CHIT_MISS_ARG_RAY_DIRECTION] = nir_load_param(b, vars->ray_direction_param); + args[CHIT_MISS_ARG_RAY_TMAX] = intr->src[1].ssa; + args[CHIT_MISS_ARG_PRIMITIVE_ADDR] = intr->src[2].ssa; + args[CHIT_MISS_ARG_PRIMITIVE_ID] = intr->src[3].ssa; + args[CHIT_MISS_ARG_INSTANCE_ADDR] = intr->src[4].ssa; + args[CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS] = intr->src[5].ssa; + args[CHIT_MISS_ARG_HIT_KIND] = intr->src[6].ssa; + for (unsigned i = 0; i < DIV_ROUND_UP(vars->payload_size, 4); ++i) { + args[CHIT_MISS_ARG_PAYLOAD_BASE + i] = + nir_instr_def(&nir_build_deref_cast(b, nir_load_param(b, TRAVERSAL_ARG_PAYLOAD_BASE + i), + nir_var_shader_call_data, glsl_uint_type(), 4) + ->instr); + } + nir_build_indirect_call(b, vars->chit_miss_func, sbt_data.shader_addr, param_count, args); + nir_pop_if(b, NULL); + break; + } + case nir_intrinsic_execute_miss_amd: { + nir_def *undef = nir_undef(b, 1, 32); + nir_def *miss_index = nir_load_param(b, vars->miss_index_param); + struct radv_nir_sbt_data sbt_data = + radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), miss_index, SBT_MISS, SBT_RECURSIVE_PTR); + + if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) { + /* In case of a NULL miss shader, do nothing and just return. */ + nir_push_if(b, nir_ine_imm(b, sbt_data.shader_addr, 0)); + } + + unsigned param_count = CHIT_MISS_ARG_PAYLOAD_BASE + DIV_ROUND_UP(vars->payload_size, 4); + nir_def **args = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count); + args[RT_ARG_LAUNCH_ID] = nir_load_param(b, RT_ARG_LAUNCH_ID); + args[RT_ARG_LAUNCH_SIZE] = nir_load_param(b, RT_ARG_LAUNCH_SIZE); + args[RT_ARG_DESCRIPTORS] = nir_load_param(b, RT_ARG_DESCRIPTORS); + args[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS); + args[RT_ARG_PUSH_CONSTANTS] = nir_load_param(b, RT_ARG_PUSH_CONSTANTS); + args[RT_ARG_SBT_DESCRIPTORS] = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS); + args[CHIT_MISS_ARG_TRAVERSAL_ADDR] = nir_load_param(b, vars->traversal_addr_param); + args[CHIT_MISS_ARG_SHADER_RECORD_PTR] = sbt_data.shader_record_ptr; + args[CHIT_MISS_ARG_ACCEL_STRUCT] = nir_load_param(b, vars->accel_struct_param); + args[CHIT_MISS_ARG_CULL_MASK_AND_FLAGS] = nir_load_param(b, vars->cull_mask_and_flags_param); + args[CHIT_MISS_ARG_SBT_OFFSET] = nir_load_param(b, vars->sbt_offset_param); + args[CHIT_MISS_ARG_SBT_STRIDE] = nir_load_param(b, vars->sbt_stride_param); + args[CHIT_MISS_ARG_MISS_INDEX] = nir_load_param(b, vars->miss_index_param); + args[CHIT_MISS_ARG_RAY_ORIGIN] = nir_load_param(b, vars->ray_origin_param); + args[CHIT_MISS_ARG_RAY_TMIN] = nir_load_param(b, vars->ray_tmin_param); + args[CHIT_MISS_ARG_RAY_DIRECTION] = nir_load_param(b, vars->ray_direction_param); + args[CHIT_MISS_ARG_RAY_TMAX] = intr->src[0].ssa; + args[CHIT_MISS_ARG_PRIMITIVE_ADDR] = nir_undef(b, 1, 64); + args[CHIT_MISS_ARG_PRIMITIVE_ID] = undef; + args[CHIT_MISS_ARG_INSTANCE_ADDR] = nir_undef(b, 1, 64); + args[CHIT_MISS_ARG_GEOMETRY_ID_AND_FLAGS] = undef; + args[CHIT_MISS_ARG_HIT_KIND] = undef; + for (unsigned i = 0; i < DIV_ROUND_UP(vars->payload_size, 4); ++i) { + args[CHIT_MISS_ARG_PAYLOAD_BASE + i] = + nir_instr_def(&nir_build_deref_cast(b, nir_load_param(b, TRAVERSAL_ARG_PAYLOAD_BASE + i), + nir_var_shader_call_data, glsl_uint_type(), 4) + ->instr); + } + nir_build_indirect_call(b, vars->chit_miss_func, sbt_data.shader_addr, param_count, args); + + if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) + nir_pop_if(b, NULL); + + break; + } + case nir_intrinsic_load_ray_triangle_vertex_positions: { + nir_def *primitive_addr = nir_load_param(b, vars->primitive_addr_param); + ret = radv_load_vertex_position(vars->device, b, primitive_addr, nir_intrinsic_column(intr)); + break; + } + default: + return false; + } + + if (ret) + nir_def_rewrite_uses(&intr->def, ret); + nir_instr_remove(&intr->instr); + + return true; +} + +/* Lower aliased ray payload variables. See radv_nir_lower_rt_io for a more detailed + * explanation on ray payload storage. + */ +static void +lower_rt_deref_var(nir_shader *shader, nir_function_impl *impl, nir_instr *instr, struct hash_table *cloned_vars) +{ + nir_deref_instr *deref = nir_instr_as_deref(instr); + nir_variable *var = deref->var; + struct hash_entry *entry = _mesa_hash_table_search(cloned_vars, var); + if (!(var->data.mode & nir_var_function_temp) && !entry) + return; + + hash_table_foreach (cloned_vars, cloned_entry) { + if (var == cloned_entry->data) + return; + } + + nir_variable *new_var; + if (entry) { + new_var = entry->data; + } else { + new_var = nir_variable_clone(var, shader); + _mesa_hash_table_insert(cloned_vars, var, new_var); + + exec_node_remove(&var->node); + var->data.mode = nir_var_shader_temp; + exec_list_push_tail(&shader->variables, &var->node); + + exec_list_push_tail(&impl->locals, &new_var->node); + } + + deref->modes = nir_var_shader_temp; + + nir_foreach_use_safe (use, nir_instr_def(instr)) { + if (nir_src_is_if(use)) + continue; + + nir_instr *parent = nir_src_parent_instr(use); + if (parent->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(parent); + if (intrin->intrinsic != nir_intrinsic_trace_ray && intrin->intrinsic != nir_intrinsic_execute_callable && + intrin->intrinsic != nir_intrinsic_execute_closest_hit_amd && + intrin->intrinsic != nir_intrinsic_execute_miss_amd) + continue; + + nir_builder b = nir_builder_at(nir_before_instr(parent)); + nir_deref_instr *old_deref = nir_build_deref_var(&b, var); + nir_deref_instr *new_deref = nir_build_deref_var(&b, new_var); + + nir_copy_deref(&b, new_deref, old_deref); + b.cursor = nir_after_instr(parent); + nir_copy_deref(&b, old_deref, new_deref); + + nir_src_rewrite(use, nir_instr_def(&new_deref->instr)); + } +} + +static bool +lower_rt_derefs_functions(nir_shader *shader) +{ + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + + bool progress = false; + + struct hash_table *cloned_vars = _mesa_pointer_hash_table_create(shader); + + nir_foreach_block (block, impl) { + nir_foreach_instr_safe (instr, block) { + if (instr->type != nir_instr_type_deref) + continue; + + nir_deref_instr *deref = nir_instr_as_deref(instr); + if (!nir_deref_mode_is(deref, nir_var_function_temp)) + continue; + + if (deref->deref_type == nir_deref_type_var) { + lower_rt_deref_var(shader, impl, instr, cloned_vars); + progress = true; + } else { + assert(deref->deref_type != nir_deref_type_cast); + /* Parent modes might have changed, propagate change */ + nir_deref_instr *parent = nir_src_as_deref(deref->parent); + if (parent->modes != deref->modes) + deref->modes = parent->modes; + } + } + } + + return nir_progress(progress, impl, nir_metadata_control_flow); +} + +void +radv_nir_lower_rt_io_functions(nir_shader *nir) +{ + /* When compiling separately and using function calls, function parameters for ray payloads store the currently + * active ray payload. The same parameters are reused for different ray payload types - the function signature + * only allows for passing one ray payload (the active one) at a time. We model this in NIR by designating all ray + * payload variables as aliased (every ray payload variable's driver location is 0). + * + * This doesn't quite match the SPIR-V semantics of different ray payload variables - each payload variable is in + * a different location and can be written/read independently. It's lower_rt_derefs's job to accomodate this. + * lower_rt_derefs duplicates all ray payload variables and marks the original one as a shader_temp variable, + * in order to make the shader's payload read/writes operate on temporary copies that do not alias. + * radv_nir_lower_ray_payload_derefs will then convert the aliased variables to proper payload loads/stores, which + * later get lowered to function call parameters by `lower_rt_storage`. + */ + NIR_PASS(_, nir, lower_rt_derefs_functions); + NIR_PASS(_, nir, nir_split_var_copies); + NIR_PASS(_, nir, nir_lower_var_copies); + NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, 0); +} + +nir_function_impl * +radv_get_rt_shader_entrypoint(nir_shader *shader) +{ + nir_foreach_function_impl (impl, shader) + if (impl->function->is_entrypoint || impl->function->is_exported) + return impl; + return NULL; +} + +void +radv_nir_lower_rt_abi_functions(nir_shader *shader, const struct radv_shader_info *info, uint32_t payload_size, + struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline) +{ + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + nir_function *entrypoint_function = impl->function; + + radv_nir_init_rt_function_params(entrypoint_function, shader->info.stage, payload_size); + + struct rt_variables vars = create_rt_variables(shader, device, pipeline->base.base.create_flags, payload_size); + + nir_shader_instructions_pass(shader, lower_rt_instruction, nir_metadata_none, &vars); + + /* This can't use NIR_PASS because NIR_DEBUG=serialize,clone invalidates pointers. */ + nir_lower_returns(shader); + + /* initialize variables */ + + nir_progress(true, impl, nir_metadata_none); + + /* cleanup passes */ + nir_builder b = nir_builder_at(nir_before_impl(impl)); + nir_deref_instr **payload_in_storage = + rzalloc_array_size(shader, sizeof(nir_deref_instr *), DIV_ROUND_UP(payload_size, 4)); + if (vars.in_payload_base_param != -1u) { + for (unsigned i = 0; i < DIV_ROUND_UP(payload_size, 4); ++i) { + payload_in_storage[i] = nir_build_deref_cast(&b, nir_load_param(&b, vars.in_payload_base_param + i), + nir_var_shader_call_data, glsl_uint_type(), 4); + } + } + + NIR_PASS(_, shader, radv_nir_lower_rt_storage, NULL, payload_in_storage, vars.out_payload_storage, info->wave_size); + NIR_PASS(_, shader, nir_remove_dead_derefs); + NIR_PASS(_, shader, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_call_data, NULL); + NIR_PASS(_, shader, nir_lower_global_vars_to_local); + NIR_PASS(_, shader, nir_lower_vars_to_ssa); +} diff --git a/src/amd/vulkan/nir/radv_nir_rt_stage_functions.h b/src/amd/vulkan/nir/radv_nir_rt_stage_functions.h new file mode 100644 index 00000000000..dbeb164a8aa --- /dev/null +++ b/src/amd/vulkan/nir/radv_nir_rt_stage_functions.h @@ -0,0 +1,22 @@ +/* + * Copyright © 2025 Valve Corporation + * + * SPDX-License-Identifier: MIT + */ + +/* This file contains the public interface for all RT pipeline stage lowering. */ + +#ifndef RADV_NIR_RT_STAGE_FUNCTIONS_H +#define RADV_NIR_RT_STAGE_FUNCTIONS_H + +#include "radv_pipeline_rt.h" + +nir_function_impl *radv_get_rt_shader_entrypoint(nir_shader *shader); + +void radv_nir_init_rt_function_params(nir_function *function, mesa_shader_stage stage, unsigned payload_size); + +void radv_nir_lower_rt_abi_functions(nir_shader *shader, const struct radv_shader_info *info, uint32_t payload_size, + struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline); +void radv_nir_lower_rt_io_functions(nir_shader *shader); + +#endif // RADV_NIR_RT_STAGE_FUNCTIONS_H diff --git a/src/amd/vulkan/nir/radv_nir_rt_stage_monolithic.c b/src/amd/vulkan/nir/radv_nir_rt_stage_monolithic.c index cc449ecf775..f54a155be44 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_stage_monolithic.c +++ b/src/amd/vulkan/nir/radv_nir_rt_stage_monolithic.c @@ -8,8 +8,11 @@ #include "nir/radv_nir_rt_common.h" #include "nir/radv_nir_rt_stage_common.h" #include "nir/radv_nir_rt_stage_monolithic.h" + +#include "aco_nir_call_attribs.h" #include "nir_builder.h" #include "radv_device.h" +#include "radv_nir_rt_stage_functions.h" #include "radv_physical_device.h" struct chit_miss_inlining_params { @@ -20,6 +23,7 @@ struct chit_miss_inlining_params { struct radv_nir_sbt_data *sbt; unsigned payload_offset; + unsigned stack_base; }; struct chit_miss_inlining_vars { @@ -217,7 +221,13 @@ preprocess_shader_cb_monolithic(nir_shader *nir, void *_data) { uint32_t *payload_offset = _data; - NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, *payload_offset); + /* When compiling in monolithic mode, ray payloads are lowered to registers. Each ray payload gets a separate + * space in the register storage. Call nir_lower_vars_to_explicit_types to assign separate locations to payload + * variables, then lower to load/store instructions which will later be lowered to variable loads/stores. + */ + nir_lower_vars_to_explicit_types(nir, nir_var_function_temp, glsl_get_natural_size_align_bytes); + nir_lower_vars_to_explicit_types(nir, nir_var_shader_temp, glsl_get_natural_size_align_bytes); + radv_nir_lower_ray_payload_derefs(nir, *payload_offset); } void @@ -234,10 +244,6 @@ struct rt_variables { uint32_t payload_offset; unsigned stack_size; - nir_def *launch_sizes[3]; - nir_def *launch_ids[3]; - nir_def *shader_record_ptr; - nir_variable *stack_ptr; }; @@ -255,19 +261,20 @@ radv_build_recursive_case(nir_builder *b, nir_def *idx, struct radv_ray_tracing_ .device = params->device, }; - struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL); - setup_chit_miss_inlining(&vars, params, b, shader, var_remap); - nir_opt_dead_cf(shader); preprocess_shader_cb_monolithic(shader, ¶ms->payload_offset); + struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL); + setup_chit_miss_inlining(&vars, params, b, shader, var_remap); + nir_shader_intrinsics_pass(shader, lower_rt_instruction_chit_miss, nir_metadata_control_flow, &vars); nir_lower_returns(shader); nir_opt_dce(shader); radv_nir_inline_constants(b->shader, shader); + b->shader->scratch_size = MAX2(b->shader->scratch_size, shader->scratch_size + params->stack_base); nir_push_if(b, nir_ieq_imm(b, idx, group->handle.general_index)); nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); @@ -323,7 +330,7 @@ lower_rt_call_monolithic(nir_builder *b, nir_intrinsic_instr *intr, void *data) }; nir_def *stack_ptr = nir_load_var(b, vars->stack_ptr); - nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, stack_ptr, b->shader->scratch_size), 0x1); + nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, stack_ptr, vars->stack_size), 0x1); struct radv_nir_rt_traversal_result result = radv_build_traversal(state->device, state->pipeline, b, ¶ms, NULL); @@ -346,7 +353,8 @@ lower_rt_call_monolithic(nir_builder *b, nir_intrinsic_instr *intr, void *data) nir_push_if(b, nir_load_var(b, result.hit)); { struct radv_nir_sbt_data hit_sbt = - radv_nir_load_sbt_entry(b, nir_load_var(b, result.sbt_index), SBT_HIT, SBT_CLOSEST_HIT_IDX); + radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), nir_load_var(b, result.sbt_index), + SBT_HIT, SBT_CLOSEST_HIT_IDX); inline_params.sbt = &hit_sbt; nir_def *should_return = nir_test_mask(b, params.cull_mask_and_flags, SpvRayFlagsSkipClosestHitShaderKHRMask); @@ -364,7 +372,8 @@ lower_rt_call_monolithic(nir_builder *b, nir_intrinsic_instr *intr, void *data) } nir_push_else(b, NULL); { - struct radv_nir_sbt_data miss_sbt = radv_nir_load_sbt_entry(b, params.miss_index, SBT_MISS, SBT_GENERAL_IDX); + struct radv_nir_sbt_data miss_sbt = radv_nir_load_sbt_entry(b, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS), + params.miss_index, SBT_MISS, SBT_GENERAL_IDX); inline_params.sbt = &miss_sbt; radv_visit_inlined_shaders(b, miss_sbt.shader_addr, @@ -394,15 +403,19 @@ lower_rt_instruction_monolithic(nir_builder *b, nir_intrinsic_instr *intr, void switch (intr->intrinsic) { case nir_intrinsic_load_shader_record_ptr: { - nir_def_replace(&intr->def, vars->shader_record_ptr); + nir_def_replace(&intr->def, nir_load_param(b, RAYGEN_ARG_SHADER_RECORD_PTR)); return true; } case nir_intrinsic_load_ray_launch_size: { - nir_def_replace(&intr->def, nir_vec(b, vars->launch_sizes, 3)); + nir_def_replace(&intr->def, nir_load_param(b, RT_ARG_LAUNCH_SIZE)); return true; }; case nir_intrinsic_load_ray_launch_id: { - nir_def_replace(&intr->def, nir_vec(b, vars->launch_ids, 3)); + nir_def_replace(&intr->def, nir_load_param(b, RT_ARG_LAUNCH_ID)); + return true; + } + case nir_intrinsic_load_sbt_base_amd: { + nir_def_replace(&intr->def, nir_load_param(b, RT_ARG_SBT_DESCRIPTORS)); return true; } case nir_intrinsic_load_scratch: { @@ -428,30 +441,36 @@ radv_count_hit_attrib_slots(nir_builder *b, nir_intrinsic_instr *instr, void *da return false; } +static bool +radv_count_ray_payload_size(nir_builder *b, nir_intrinsic_instr *instr, void *data) +{ + uint32_t *count = data; + if (instr->intrinsic == nir_intrinsic_load_incoming_ray_payload_amd || + instr->intrinsic == nir_intrinsic_load_outgoing_ray_payload_amd || + instr->intrinsic == nir_intrinsic_store_incoming_ray_payload_amd || + instr->intrinsic == nir_intrinsic_store_outgoing_ray_payload_amd) + *count = MAX2(*count, (nir_intrinsic_base(instr) + 1) * 4); + + return false; +} + void -radv_nir_lower_rt_abi_monolithic(nir_shader *shader, const struct radv_shader_args *args, uint32_t *stack_size, - struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline) +radv_nir_lower_rt_abi_monolithic(nir_shader *shader, struct radv_device *device, + struct radv_ray_tracing_pipeline *pipeline) { nir_function_impl *impl = nir_shader_get_entrypoint(shader); + radv_nir_init_rt_function_params(impl->function, MESA_SHADER_RAYGEN, 0); nir_builder b = nir_builder_at(nir_before_impl(impl)); struct rt_variables vars = { .device = device, .flags = pipeline->base.base.create_flags, + .stack_size = b.shader->scratch_size, }; vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr"); - - for (uint32_t i = 0; i < ARRAY_SIZE(vars.launch_sizes); i++) - vars.launch_sizes[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_sizes[i]); - for (uint32_t i = 0; i < ARRAY_SIZE(vars.launch_sizes); i++) { - vars.launch_ids[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_ids[i]); - } - nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record); - vars.shader_record_ptr = nir_pack_64_2x32(&b, record_ptr); - nir_def *stack_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base); - nir_store_var(&b, vars.stack_ptr, stack_ptr, 0x1); + nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1); struct lower_rt_instruction_monolithic_state state = { .device = device, @@ -464,17 +483,23 @@ radv_nir_lower_rt_abi_monolithic(nir_shader *shader, const struct radv_shader_ar nir_index_ssa_defs(impl); uint32_t hit_attrib_count = 0; + uint32_t payload_size = 0; nir_shader_intrinsics_pass(shader, radv_count_hit_attrib_slots, nir_metadata_all, &hit_attrib_count); + nir_shader_intrinsics_pass(shader, radv_count_ray_payload_size, nir_metadata_all, &payload_size); /* Register storage for hit attributes */ STACK_ARRAY(nir_variable *, hit_attribs, hit_attrib_count); + STACK_ARRAY(nir_variable *, payload_vars, DIV_ROUND_UP(payload_size, 4)); + STACK_ARRAY(nir_deref_instr *, payload_derefs, DIV_ROUND_UP(payload_size, 4)); for (uint32_t i = 0; i < hit_attrib_count; i++) hit_attribs[i] = nir_local_variable_create(impl, glsl_uint_type(), "ahit_attrib"); - radv_nir_lower_hit_attribs(shader, hit_attribs, 0); + b.cursor = nir_before_impl(impl); + for (uint32_t i = 0; i < DIV_ROUND_UP(payload_size, 4); i++) { + payload_vars[i] = nir_local_variable_create(impl, glsl_uint_type(), "payload_storage"); + payload_derefs[i] = nir_build_deref_var(&b, payload_vars[i]); + } - vars.stack_size = MAX2(vars.stack_size, shader->scratch_size); - *stack_size = MAX2(*stack_size, vars.stack_size); - shader->scratch_size = 0; + radv_nir_lower_rt_storage(shader, hit_attribs, payload_derefs, payload_vars, 0); nir_progress(true, impl, nir_metadata_none); diff --git a/src/amd/vulkan/nir/radv_nir_rt_stage_monolithic.h b/src/amd/vulkan/nir/radv_nir_rt_stage_monolithic.h index 8a981ad540a..24d8ed2c97f 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_stage_monolithic.h +++ b/src/amd/vulkan/nir/radv_nir_rt_stage_monolithic.h @@ -11,8 +11,8 @@ #include "radv_pipeline_rt.h" -void radv_nir_lower_rt_abi_monolithic(nir_shader *shader, const struct radv_shader_args *args, uint32_t *stack_size, - struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline); +void radv_nir_lower_rt_abi_monolithic(nir_shader *shader, struct radv_device *device, + struct radv_ray_tracing_pipeline *pipeline); void radv_nir_lower_rt_io_monolithic(nir_shader *shader); #endif // RADV_NIR_RT_STAGE_MONOLITHIC_H diff --git a/src/amd/vulkan/nir/radv_nir_rt_traversal_shader.c b/src/amd/vulkan/nir/radv_nir_rt_traversal_shader.c index cd2775ce6cb..c68540224c3 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_traversal_shader.c +++ b/src/amd/vulkan/nir/radv_nir_rt_traversal_shader.c @@ -385,6 +385,7 @@ insert_inlined_shader(nir_builder *b, struct traversal_inlining_params *params, nir_opt_dce(shader); radv_nir_inline_constants(b->shader, shader); + b->shader->scratch_size = MAX2(b->shader->scratch_size, shader->scratch_size); nir_push_if(b, nir_ieq_imm(b, idx, call_idx)); nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); @@ -607,6 +608,9 @@ nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit) /* Eliminate the casts introduced for the commit return of the any-hit shader. */ NIR_PASS(_, intersection, nir_opt_deref); + /* Reflect the scratch memory required by the inlined any-hit shader */ + intersection->scratch_size += any_hit->scratch_size; + ralloc_free(dead_ctx); } @@ -666,9 +670,6 @@ radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_g params->preprocess(any_hit_stage, params->preprocess_data); - /* reserve stack size for any_hit before it is inlined */ - data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size; - nir_lower_intersection_shader(nir_stage, any_hit_stage); ralloc_free(any_hit_stage); } @@ -857,7 +858,8 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int nir_push_if(b, nir_inot(b, intersection->base.opaque)); { - struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX); + struct radv_nir_sbt_data sbt_data = + radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX); nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1); struct traversal_inlining_params inlining_params = { @@ -942,7 +944,8 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio nir_store_var(b, data->trav_vars.ahit_isec_count, nir_iadd_imm(b, nir_load_var(b, data->trav_vars.ahit_isec_count), 1 << 16), 0x1); - struct radv_nir_sbt_data sbt_data = radv_nir_load_sbt_entry(b, sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX); + struct radv_nir_sbt_data sbt_data = + radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX); nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1); struct traversal_inlining_params inlining_params = { @@ -1120,22 +1123,22 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin if (device->rra_trace.ray_history_addr) radv_build_end_trace_token(b, &data, nir_load_var(b, iteration_instance_count)); - nir_progress(true, nir_shader_get_entrypoint(b->shader), nir_metadata_none); + nir_progress(true, b->impl, nir_metadata_none); radv_nir_lower_hit_attrib_derefs(b->shader); return data.trav_vars.result; } static void -preprocess_traversal_shader_ahit_isec(nir_shader *nir, void *_) +preprocess_traversal_shader_ahit_isec(nir_shader *nir, void *cb) { - /* Compiling a separate traversal shader is always done in CPS mode. */ - radv_nir_lower_rt_io_cps(nir); + radv_nir_traversal_preprocess_cb preprocess_cb = cb; + preprocess_cb(nir); } nir_shader * radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, - struct radv_ray_tracing_stage_info *info) + struct radv_ray_tracing_stage_info *info, radv_nir_traversal_preprocess_cb preprocess) { const struct radv_physical_device *pdev = radv_device_physical(device); @@ -1184,11 +1187,12 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_ params.direction = nir_load_ray_world_direction(&b); params.preprocess_ahit_isec = preprocess_traversal_shader_ahit_isec; + params.cb_data = preprocess; params.ignore_cull_mask = false; struct radv_nir_rt_traversal_result result = radv_build_traversal(device, pipeline, &b, ¶ms, info); - radv_nir_lower_hit_attribs(b.shader, hit_attribs, pdev->rt_wave_size); + radv_nir_lower_rt_storage(b.shader, hit_attribs, NULL, NULL, pdev->rt_wave_size); nir_push_if(&b, nir_load_var(&b, result.hit)); { diff --git a/src/amd/vulkan/nir/radv_nir_rt_traversal_shader.h b/src/amd/vulkan/nir/radv_nir_rt_traversal_shader.h index a2218ffa0fd..362975eb866 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_traversal_shader.h +++ b/src/amd/vulkan/nir/radv_nir_rt_traversal_shader.h @@ -9,7 +9,10 @@ #include "radv_pipeline_rt.h" +typedef void (*radv_nir_traversal_preprocess_cb)(nir_shader *nir); + nir_shader *radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, - struct radv_ray_tracing_stage_info *info); + struct radv_ray_tracing_stage_info *info, + radv_nir_traversal_preprocess_cb preprocess); #endif // RADV_NIR_RT_TRAVERSAL_SHADER_H diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c index da129f007e6..35ad8912e12 100644 --- a/src/amd/vulkan/radv_cmd_buffer.c +++ b/src/amd/vulkan/radv_cmd_buffer.c @@ -8186,8 +8186,7 @@ radv_emit_ray_tracing_pipeline(struct radv_cmd_buffer *cmd_buffer, struct radv_r const uint32_t traversal_shader_addr_offset = radv_get_user_sgpr_loc(rt_prolog, AC_UD_CS_TRAVERSAL_SHADER_ADDR); struct radv_shader *traversal_shader = cmd_buffer->state.shaders[MESA_SHADER_INTERSECTION]; if (traversal_shader_addr_offset && traversal_shader) { - uint64_t traversal_va = traversal_shader->va | radv_rt_priority_traversal; - + uint64_t traversal_va = traversal_shader->va; radeon_begin(cs); if (pdev->info.gfx_level >= GFX12) { gfx12_push_32bit_pointer(traversal_shader_addr_offset, traversal_va, &pdev->info); @@ -13740,10 +13739,9 @@ radv_emit_rt_stack_size(struct radv_cmd_buffer *cmd_buffer) unsigned rsrc2 = rt_prolog->config.rsrc2; /* Reserve scratch for stacks manually since it is not handled by the compute path. */ - uint32_t scratch_bytes_per_wave = rt_prolog->config.scratch_bytes_per_wave; const uint32_t wave_size = rt_prolog->info.wave_size; - scratch_bytes_per_wave += + uint32_t scratch_bytes_per_wave = align(cmd_buffer->state.rt_stack_size * wave_size, pdev->info.scratch_wavesize_granularity); cmd_buffer->compute_scratch_size_per_wave_needed = diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 8bd02f07757..373813f025e 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -13,6 +13,7 @@ #include "nir/radv_nir.h" #include "nir/radv_nir_rt_stage_cps.h" +#include "nir/radv_nir_rt_stage_functions.h" #include "nir/radv_nir_rt_stage_monolithic.h" #include "nir/radv_nir_rt_traversal_shader.h" #include "ac_nir.h" @@ -323,7 +324,6 @@ should_move_rt_instruction(nir_intrinsic_instr *instr) switch (instr->intrinsic) { case nir_intrinsic_load_hit_attrib_amd: return nir_intrinsic_base(instr) < RADV_MAX_HIT_ATTRIB_DWORDS; - case nir_intrinsic_load_rt_arg_scratch_offset_amd: case nir_intrinsic_load_ray_flags: case nir_intrinsic_load_ray_object_origin: case nir_intrinsic_load_ray_world_origin: @@ -364,7 +364,7 @@ move_rt_instructions(nir_shader *shader) static VkResult radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, struct radv_ray_tracing_pipeline *pipeline, enum radv_rt_lowering_mode mode, - struct radv_shader_stage *stage, uint32_t *stack_size, + struct radv_shader_stage *stage, uint32_t *payload_size, uint32_t *stack_size, struct radv_ray_tracing_stage_info *stage_info, const struct radv_ray_tracing_stage_info *traversal_stage_info, struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache, @@ -384,6 +384,9 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, case RADV_RT_LOWERING_MODE_CPS: radv_nir_lower_rt_io_cps(stage->nir); break; + case RADV_RT_LOWERING_MODE_FUNCTION_CALLS: + radv_nir_lower_rt_io_functions(stage->nir); + break; } /* Gather shader info. */ @@ -440,29 +443,39 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, switch (mode) { case RADV_RT_LOWERING_MODE_MONOLITHIC: assert(num_shaders == 1); - radv_nir_lower_rt_abi_monolithic(temp_stage.nir, &temp_stage.args, stack_size, device, pipeline); + radv_nir_lower_rt_abi_monolithic(temp_stage.nir, device, pipeline); break; case RADV_RT_LOWERING_MODE_CPS: - radv_nir_lower_rt_abi_cps(temp_stage.nir, &temp_stage.args, &stage->info, stack_size, i > 0, device, pipeline, - has_position_fetch, traversal_stage_info); + radv_nir_lower_rt_abi_cps(temp_stage.nir, &stage->info, i > 0, device, pipeline, has_position_fetch, + traversal_stage_info); + break; + case RADV_RT_LOWERING_MODE_FUNCTION_CALLS: + assert(num_shaders == 1); + radv_nir_lower_rt_abi_functions(temp_stage.nir, &temp_stage.info, *payload_size, device, pipeline); break; } /* Info might be out-of-date after inlining in radv_nir_lower_rt_abi(). */ - nir_shader_gather_info(temp_stage.nir, nir_shader_get_entrypoint(temp_stage.nir)); + nir_shader_gather_info(temp_stage.nir, radv_get_rt_shader_entrypoint(temp_stage.nir)); radv_nir_shader_info_pass(device, temp_stage.nir, &stage->layout, &stage->key, NULL, RADV_PIPELINE_RAY_TRACING, false, &stage->info); - radv_optimize_nir(temp_stage.nir, stage->key.optimisations_disabled); + radv_optimize_nir(temp_stage.nir, temp_stage.key.optimisations_disabled); radv_postprocess_nir(device, NULL, &temp_stage); - stage->info.nir_shared_size = MAX2(stage->info.nir_shared_size, temp_stage.info.nir_shared_size); - if (stage_info) - radv_gather_unused_args(stage_info, shaders[i]); + NIR_PASS(_, stage->nir, radv_nir_lower_call_abi, stage->info.wave_size); + NIR_PASS(_, stage->nir, nir_lower_global_vars_to_local); + NIR_PASS(_, stage->nir, nir_lower_vars_to_ssa); + if (!stage->key.optimisations_disabled) + NIR_PASS(_, stage->nir, nir_minimize_call_live_states); + + stage->info.nir_shared_size = MAX2(stage->info.nir_shared_size, temp_stage.info.nir_shared_size); + if (stage_info && mode == RADV_RT_LOWERING_MODE_CPS) + radv_gather_unused_args(stage_info, temp_stage.nir); } - bool dump_shader = radv_can_dump_shader(device, shaders[0]); + bool dump_shader = radv_can_dump_shader(device, stage->nir); bool dump_nir = dump_shader && (instance->debug_flags & RADV_DEBUG_DUMP_NIR); bool replayable = (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) && @@ -493,7 +506,6 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, if (dump_shader) simple_mtx_unlock(&instance->shader_dump_mtx); - ralloc_free(mem_ctx); free(binary); return result; } @@ -503,6 +515,9 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, if (shader) { shader->nir_string = nir_string; + if (stack_size) + *stack_size = DIV_ROUND_UP(shader->config.scratch_bytes_per_wave, shader->info.wave_size); + radv_shader_dump_debug_info(device, dump_shader, binary, shader, shaders, num_shaders, &stage->info); if (shader && keep_executable_info && stage->spirv.size) { @@ -515,7 +530,6 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, if (dump_shader) simple_mtx_unlock(&instance->shader_dump_mtx); - ralloc_free(mem_ctx); free(binary); *out_shader = shader; @@ -633,6 +647,10 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca if (!stages) return VK_ERROR_OUT_OF_HOST_MEMORY; + uint32_t payload_size = 0; + if (pCreateInfo->pLibraryInterface) + payload_size = pCreateInfo->pLibraryInterface->maxPipelineRayPayloadSize; + bool library = pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR; /* Beyond 50 shader stages, inlining everything bloats the shader a ton, increasing compile times and @@ -657,6 +675,19 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca NIR_PASS(_, stage->nir, radv_nir_lower_hit_attrib_derefs); + nir_foreach_variable_with_modes (var, stage->nir, nir_var_shader_call_data) { + unsigned size, alignment; + glsl_get_natural_size_align_bytes(var->type, &size, &alignment); + payload_size = MAX2(payload_size, size); + } + nir_foreach_function_impl (impl, stage->nir) { + nir_foreach_variable_in_list (var, &impl->locals) { + unsigned size, alignment; + glsl_get_natural_size_align_bytes(var->type, &size, &alignment); + payload_size = MAX2(payload_size, size); + } + } + rt_stages[i].info = radv_gather_ray_tracing_stage_info(stage->nir); stage->feedback.duration = os_time_get_nano() - stage_start; @@ -736,8 +767,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca enum radv_rt_lowering_mode mode = stage->stage == MESA_SHADER_RAYGEN ? raygen_lowering_mode : recursive_lowering_mode; - result = radv_rt_nir_to_asm(device, cache, pipeline, mode, stage, &stack_size, &rt_stages[idx].info, NULL, - replay_block, skip_shaders_cache, has_position_fetch, &rt_stages[idx].shader); + result = + radv_rt_nir_to_asm(device, cache, pipeline, mode, stage, &payload_size, &stack_size, &rt_stages[idx].info, + NULL, replay_block, skip_shaders_cache, has_position_fetch, &rt_stages[idx].shader); if (result != VK_SUCCESS) goto cleanup; @@ -787,17 +819,20 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca traversal_info.unset_flags &= info->unset_flags; } + radv_nir_traversal_preprocess_cb preprocess = + recursive_lowering_mode == RADV_RT_LOWERING_MODE_CPS ? radv_nir_lower_rt_io_cps : radv_nir_lower_rt_io_functions; + /* create traversal shader */ - nir_shader *traversal_nir = radv_build_traversal_shader(device, pipeline, &traversal_info); + nir_shader *traversal_nir = radv_build_traversal_shader(device, pipeline, &traversal_info, preprocess); struct radv_shader_stage traversal_stage = { .stage = MESA_SHADER_INTERSECTION, .nir = traversal_nir, .key = stage_keys[MESA_SHADER_INTERSECTION], }; radv_shader_layout_init(pipeline_layout, MESA_SHADER_INTERSECTION, &traversal_stage.layout); - result = radv_rt_nir_to_asm(device, cache, pipeline, recursive_lowering_mode, &traversal_stage, NULL, NULL, - &traversal_info, NULL, skip_shaders_cache, has_position_fetch, - &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]); + result = radv_rt_nir_to_asm(device, cache, pipeline, recursive_lowering_mode, &traversal_stage, &payload_size, + &pipeline->traversal_stack_size, NULL, &traversal_info, NULL, skip_shaders_cache, + has_position_fetch, &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]); ralloc_free(traversal_nir); cleanup: @@ -858,10 +893,11 @@ compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, stru UNREACHABLE("Invalid stage type in RT shader"); } } - pipeline->stack_size = - raygen_size + - MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) * MAX2(chit_miss_size, intersection_size + any_hit_size) + - MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) * chit_miss_size + 2 * callable_size; + pipeline->stack_size = raygen_size + + MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) * + (chit_miss_size + intersection_size + any_hit_size + pipeline->traversal_stack_size) + + MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) * chit_miss_size + + 2 * callable_size; } static void @@ -1216,7 +1252,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, const VkRayTra if (pipeline->groups[i].recursive_shader != VK_SHADER_UNUSED_KHR) { struct radv_shader *shader = pipeline->stages[pipeline->groups[i].recursive_shader].shader; if (shader) - pipeline->groups[i].handle.recursive_shader_ptr = shader->va | radv_get_rt_priority(shader->info.stage); + pipeline->groups[i].handle.recursive_shader_ptr = shader->va; } } diff --git a/src/amd/vulkan/radv_pipeline_rt.h b/src/amd/vulkan/radv_pipeline_rt.h index 787ead3eb3f..00785efb4fd 100644 --- a/src/amd/vulkan/radv_pipeline_rt.h +++ b/src/amd/vulkan/radv_pipeline_rt.h @@ -12,6 +12,7 @@ #define RADV_PIPELINE_RT_H #include "util/bitset.h" +#include "aco_nir_call_attribs.h" #include "radv_pipeline_compute.h" #include "radv_shader.h" @@ -27,6 +28,7 @@ struct radv_ray_tracing_pipeline { unsigned group_count; uint32_t stack_size; + uint32_t traversal_stack_size; /* set if any shaders from this pipeline require robustness2 in the merged traversal shader */ bool traversal_storage_robustness2 : 1; @@ -76,7 +78,7 @@ struct radv_ray_tracing_stage_info { bool can_inline; bool has_position_fetch; - BITSET_DECLARE(unused_args, AC_MAX_ARGS); + BITSET_DECLARE(unused_args, CPS_ARG_COUNT); struct radv_rt_const_arg_info tmin; struct radv_rt_const_arg_info tmax; diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 9d4bb3422a6..57244adc58b 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -45,6 +45,7 @@ #include "vk_sync.h" #include "vk_ycbcr_conversion.h" +#include "nir/radv_nir_rt_stage_functions.h" #include "aco_shader_info.h" #include "radv_aco_shader_info.h" #if AMD_LLVM_AVAILABLE @@ -220,7 +221,9 @@ radv_optimize_nir(struct nir_shader *shader, bool optimize_conservatively) NIR_PASS(progress, shader, nir_opt_move, nir_move_load_ubo); - nir_shader_gather_info(shader, nir_shader_get_entrypoint(shader)); + /* radv_get_rt_shader_entrypoint returns the entrypoint for non-RT shaders too. */ + nir_function_impl *entrypoint = radv_get_rt_shader_entrypoint(shader); + nir_shader_gather_info(shader, entrypoint); } void diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h index ca7ef199a7b..d9bd47a4408 100644 --- a/src/amd/vulkan/radv_shader.h +++ b/src/amd/vulkan/radv_shader.h @@ -654,35 +654,9 @@ void radv_get_nir_options(struct radv_physical_device *pdev); enum radv_rt_lowering_mode { RADV_RT_LOWERING_MODE_MONOLITHIC, RADV_RT_LOWERING_MODE_CPS, + RADV_RT_LOWERING_MODE_FUNCTION_CALLS, }; -enum radv_rt_priority { - radv_rt_priority_raygen = 0, - radv_rt_priority_traversal = 1, - radv_rt_priority_hit_miss = 2, - radv_rt_priority_callable = 3, - radv_rt_priority_mask = 0x3, -}; - -static inline enum radv_rt_priority -radv_get_rt_priority(mesa_shader_stage stage) -{ - switch (stage) { - case MESA_SHADER_RAYGEN: - return radv_rt_priority_raygen; - case MESA_SHADER_INTERSECTION: - case MESA_SHADER_ANY_HIT: - return radv_rt_priority_traversal; - case MESA_SHADER_CLOSEST_HIT: - case MESA_SHADER_MISS: - return radv_rt_priority_hit_miss; - case MESA_SHADER_CALLABLE: - return radv_rt_priority_callable; - default: - UNREACHABLE("Unimplemented RT shader stage."); - } -} - struct radv_shader_layout; enum radv_pipeline_type; diff --git a/src/amd/vulkan/radv_shader_args.c b/src/amd/vulkan/radv_shader_args.c index b65cf086979..8588b59141a 100644 --- a/src/amd/vulkan/radv_shader_args.c +++ b/src/amd/vulkan/radv_shader_args.c @@ -320,50 +320,6 @@ radv_init_shader_args(const struct radv_device *device, mesa_shader_stage stage, args->user_sgprs_locs.shader_data[i].sgpr_idx = -1; } -void -radv_declare_rt_shader_args(enum amd_gfx_level gfx_level, struct radv_shader_args *args) -{ - add_ud_arg(args, 2, AC_ARG_CONST_ADDR, &args->ac.rt.uniform_shader_addr, AC_UD_SCRATCH_RING_OFFSETS); - add_ud_arg(args, 1, AC_ARG_CONST_ADDR, &args->descriptors[0], AC_UD_INDIRECT_DESCRIPTORS); - ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_ADDR, &args->ac.push_constants); - ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_ADDR, &args->ac.dynamic_descriptors); - ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_ADDR, &args->ac.rt.traversal_shader_addr); - ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.sbt_descriptors); - - for (uint32_t i = 0; i < ARRAY_SIZE(args->ac.rt.launch_sizes); i++) - ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_VALUE, &args->ac.rt.launch_sizes[i]); - - if (gfx_level < GFX9) { - ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_VALUE, &args->ac.scratch_offset); - ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_ADDR, &args->ac.ring_offsets); - } - - for (uint32_t i = 0; i < ARRAY_SIZE(args->ac.rt.launch_ids); i++) - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.launch_ids[i]); - - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.dynamic_callable_stack_base); - ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.shader_addr); - ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.shader_record); - - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.payload_offset); - ac_add_arg(&args->ac, AC_ARG_VGPR, 3, AC_ARG_VALUE, &args->ac.rt.ray_origin); - ac_add_arg(&args->ac, AC_ARG_VGPR, 3, AC_ARG_VALUE, &args->ac.rt.ray_direction); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.ray_tmin); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.ray_tmax); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.cull_mask_and_flags); - - ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.accel_struct); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.sbt_offset); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.sbt_stride); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.miss_index); - - ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.instance_addr); - ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_ADDR, &args->ac.rt.primitive_addr); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.primitive_id); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.geometry_id_and_flags); - ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_VALUE, &args->ac.rt.hit_kind); -} - static bool radv_tcs_needs_state_sgpr(const struct radv_shader_info *info, const struct radv_graphics_state_key *gfx_state) { @@ -563,7 +519,6 @@ declare_shader_args(const struct radv_device *device, const struct radv_graphics radv_init_shader_args(device, stage, args); if (mesa_shader_stage_is_rt(stage)) { - radv_declare_rt_shader_args(gfx_level, args); return; } diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c index 6075f9efd65..423bb2bf96d 100644 --- a/src/compiler/nir/nir_divergence_analysis.c +++ b/src/compiler/nir/nir_divergence_analysis.c @@ -950,7 +950,6 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state) case nir_intrinsic_load_packed_passthrough_primitive_amd: case nir_intrinsic_load_initial_edgeflags_amd: case nir_intrinsic_gds_atomic_add_amd: - case nir_intrinsic_load_rt_arg_scratch_offset_amd: case nir_intrinsic_load_intersection_opaque_amd: case nir_intrinsic_load_vector_arg_amd: case nir_intrinsic_load_btd_stack_id_intel: @@ -1024,6 +1023,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state) case nir_intrinsic_load_blend_input_pan: case nir_intrinsic_atest_pan: case nir_intrinsic_zs_emit_pan: + case nir_intrinsic_load_return_param_amd: is_divergent = true; break; diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 67f6df73abf..9911dab3c4d 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -2152,6 +2152,13 @@ intrinsic("sleep_amd", indices=[BASE]) # s_nop BASE (sleep for BASE+1 cycles, BASE must be in [0, 15]). intrinsic("nop_amd", indices=[BASE]) +intrinsic("store_param_amd", src_comp=[-1], indices=[PARAM_IDX]) +intrinsic("load_return_param_amd", dest_comp=0, indices=[CALL_IDX, PARAM_IDX]) + +system_value("call_return_address_amd", 1, bit_sizes=[64]) +# src[0] is the divergent call target for each lane, src[1] is the (uniform) address to jump to next +intrinsic("set_next_call_pc_amd", src_comp=[1, 1], bit_sizes=[64]) + # Return the FMASK descriptor of color buffer 0. system_value("fbfetch_image_fmask_desc_amd", 8) # Return the image descriptor of color buffer 0.