diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index e43062fbacc..ef77bba4d0e 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -1944,7 +1944,7 @@ typedef struct { unsigned _pad:7; } nir_io_semantics; -#define NIR_INTRINSIC_MAX_INPUTS 5 +#define NIR_INTRINSIC_MAX_INPUTS 11 typedef struct { const char *name; diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index d93b551c785..df9f22f0607 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -373,6 +373,28 @@ intrinsic("end_primitive_with_counter", src_comp=[1, 1], indices=[STREAM_ID]) # Contains the final total vertex and primitive counts in the current GS thread. intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1], indices=[STREAM_ID]) +# Trace a ray through an acceleration structure +# +# This instruction has a lot of parameters: +# 0. Acceleration Structure +# 1. Ray Flags +# 2. Cull Mask +# 3. SBT Offset +# 4. SBT Stride +# 5. Miss shader index +# 6. Ray Origin +# 7. Ray Tmin +# 8. Ray Direction +# 9. Ray Tmax +# 10. Payload +intrinsic("trace_ray", src_comp=[-1, 1, 1, 1, 1, 1, 3, 1, 3, 1, -1]) +# src[] = { hit_t, hit_kind } +intrinsic("report_ray_intersection", src_comp=[1, 1], dest_comp=1) +intrinsic("ignore_ray_intersection") +intrinsic("terminate_ray") +# src[] = { sbt_index, payload } +intrinsic("execute_callable", src_comp=[1, -1]) + # Atomic counters # # The *_var variants take an atomic_uint nir_variable, while the other, diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index e0eda4ccb3d..9c92bd8c0e2 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -5067,6 +5067,65 @@ vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode, vtn_push_nir_ssa(b, w[2], def); } +static void +vtn_handle_ray_intrinsic(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count) +{ + nir_intrinsic_instr *intrin; + + switch (opcode) { + case SpvOpTraceRayKHR: { + intrin = nir_intrinsic_instr_create(b->nb.shader, + nir_intrinsic_trace_ray); + + /* The sources are in the same order in the NIR intrinsic */ + for (unsigned i = 0; i < 10; i++) + intrin->src[i] = nir_src_for_ssa(vtn_ssa_value(b, w[i + 1])->def); + + nir_deref_instr *payload = vtn_get_call_payload_for_location(b, w[11]); + intrin->src[10] = nir_src_for_ssa(&payload->dest.ssa); + nir_builder_instr_insert(&b->nb, &intrin->instr); + break; + } + + case SpvOpReportIntersectionKHR: { + intrin = nir_intrinsic_instr_create(b->nb.shader, + nir_intrinsic_report_ray_intersection); + intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def); + intrin->src[1] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def); + nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 1, NULL); + nir_builder_instr_insert(&b->nb, &intrin->instr); + vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa); + break; + } + + case SpvOpIgnoreIntersectionKHR: + intrin = nir_intrinsic_instr_create(b->nb.shader, + nir_intrinsic_ignore_ray_intersection); + nir_builder_instr_insert(&b->nb, &intrin->instr); + break; + + case SpvOpTerminateRayKHR: + intrin = nir_intrinsic_instr_create(b->nb.shader, + nir_intrinsic_terminate_ray); + nir_builder_instr_insert(&b->nb, &intrin->instr); + break; + + case SpvOpExecuteCallableKHR: { + intrin = nir_intrinsic_instr_create(b->nb.shader, + nir_intrinsic_execute_callable); + intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[1])->def); + nir_deref_instr *payload = vtn_get_call_payload_for_location(b, w[2]); + intrin->src[1] = nir_src_for_ssa(&payload->dest.ssa); + nir_builder_instr_insert(&b->nb, &intrin->instr); + break; + } + + default: + vtn_fail_with_opcode("Unhandled opcode", opcode); + } +} + static bool vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -5476,6 +5535,14 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, break; } + case SpvOpTraceRayKHR: + case SpvOpReportIntersectionKHR: + case SpvOpIgnoreIntersectionKHR: + case SpvOpTerminateRayKHR: + case SpvOpExecuteCallableKHR: + vtn_handle_ray_intrinsic(b, opcode, w, count); + break; + case SpvOpLifetimeStart: case SpvOpLifetimeStop: break; diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 5dec9af1c72..45187a092c7 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -844,6 +844,9 @@ nir_ssa_def * vtn_pointer_to_offset(struct vtn_builder *b, struct vtn_pointer *ptr, nir_ssa_def **index_out); +nir_deref_instr * +vtn_get_call_payload_for_location(struct vtn_builder *b, uint32_t location_id); + struct vtn_ssa_value * vtn_local_load(struct vtn_builder *b, nir_deref_instr *src, enum gl_access_qualifier access); diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index 191192c727f..168d0e5156f 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -1716,6 +1716,18 @@ assign_missing_member_locations(struct vtn_variable *var) } } +nir_deref_instr * +vtn_get_call_payload_for_location(struct vtn_builder *b, uint32_t location_id) +{ + uint32_t location = vtn_constant_uint(b, location_id); + nir_foreach_variable_with_modes(var, b->nb.shader, nir_var_shader_temp) { + if (var->data.explicit_location && + var->data.location == location) + return nir_build_deref_var(&b->nb, var); + } + vtn_fail("Couldn't find variable with a storage class of CallableDataKHR " + "or RayPayloadKHR and location %d", location); +} static void vtn_create_variable(struct vtn_builder *b, struct vtn_value *val, @@ -1813,6 +1825,14 @@ vtn_create_variable(struct vtn_builder *b, struct vtn_value *val, var->var = rzalloc(b->shader, nir_variable); var->var->name = ralloc_strdup(var->var, val->name); var->var->type = vtn_type_get_nir_type(b, var->type, var->mode); + + /* This is a total hack but we need some way to flag variables which are + * going to be call payloads. See get_call_payload_deref. + */ + if (storage_class == SpvStorageClassCallableDataKHR || + storage_class == SpvStorageClassRayPayloadKHR) + var->var->data.explicit_location = true; + var->var->data.mode = nir_mode; var->var->data.location = -1; var->var->interface_type = NULL;