diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.c b/src/gallium/drivers/d3d12/d3d12_nir_passes.c index c2cc4d70473..85a79d14575 100644 --- a/src/gallium/drivers/d3d12/d3d12_nir_passes.c +++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.c @@ -36,12 +36,12 @@ * so we need to lower the flip into the NIR shader. */ -static nir_ssa_def * -get_state_var(nir_builder *b, - enum d3d12_state_var var_enum, - const char *var_name, - const struct glsl_type *var_type, - nir_variable **out_var) +nir_ssa_def * +d3d12_get_state_var(nir_builder *b, + enum d3d12_state_var var_enum, + const char *var_name, + const struct glsl_type *var_type, + nir_variable **out_var) { const gl_state_index16 tokens[STATE_LENGTH] = { STATE_INTERNAL_DRIVER, var_enum }; if (*out_var == NULL) { @@ -79,8 +79,8 @@ lower_pos_write(nir_builder *b, struct nir_instr *instr, nir_variable **flip) b->cursor = nir_before_instr(&intr->instr); nir_ssa_def *pos = nir_ssa_for_src(b, intr->src[1], 4); - nir_ssa_def *flip_y = get_state_var(b, D3D12_STATE_VAR_Y_FLIP, "d3d12_FlipY", - glsl_float_type(), flip); + nir_ssa_def *flip_y = d3d12_get_state_var(b, D3D12_STATE_VAR_Y_FLIP, "d3d12_FlipY", + glsl_float_type(), flip); nir_ssa_def *def = nir_vec4(b, nir_channel(b, pos, 0), nir_fmul(b, nir_channel(b, pos, 1), flip_y), @@ -184,10 +184,10 @@ lower_pos_read(nir_builder *b, struct nir_instr *instr, nir_ssa_def *depth = nir_channel(b, pos, 2); assert(depth_transform_var); - nir_ssa_def *depth_transform = get_state_var(b, D3D12_STATE_VAR_DEPTH_TRANSFORM, - "d3d12_DepthTransform", - glsl_vec_type(2), - depth_transform_var); + nir_ssa_def *depth_transform = d3d12_get_state_var(b, D3D12_STATE_VAR_DEPTH_TRANSFORM, + "d3d12_DepthTransform", + glsl_vec_type(2), + depth_transform_var); depth = nir_fmad(b, depth, nir_channel(b, depth_transform, 0), nir_channel(b, depth_transform, 1)); @@ -236,7 +236,7 @@ lower_compute_state_vars(nir_builder *b, nir_instr *instr, void *_state) nir_ssa_def *result = NULL; switch (intr->intrinsic) { case nir_intrinsic_load_num_workgroups: - result = get_state_var(b, D3D12_STATE_VAR_NUM_WORKGROUPS, "d3d12_NumWorkgroups", + result = d3d12_get_state_var(b, D3D12_STATE_VAR_NUM_WORKGROUPS, "d3d12_NumWorkgroups", glsl_vec_type(3), &vars->num_workgroups); break; default: @@ -330,8 +330,8 @@ lower_load_first_vertex(nir_builder *b, nir_instr *instr, nir_variable **first_v b->cursor = nir_before_instr(&intr->instr); - nir_ssa_def *load = get_state_var(b, D3D12_STATE_VAR_FIRST_VERTEX, "d3d12_FirstVertex", - glsl_uint_type(), first_vertex); + nir_ssa_def *load = d3d12_get_state_var(b, D3D12_STATE_VAR_FIRST_VERTEX, "d3d12_FirstVertex", + glsl_uint_type(), first_vertex); nir_ssa_def_rewrite_uses(&intr->dest.ssa, load); nir_instr_remove(instr); diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.h b/src/gallium/drivers/d3d12/d3d12_nir_passes.h index 03f7454572f..0c81d2e262b 100644 --- a/src/gallium/drivers/d3d12/d3d12_nir_passes.h +++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.h @@ -25,6 +25,7 @@ #define D3D12_NIR_PASSES_H #include "nir.h" +#include "nir_builder.h" #ifdef __cplusplus extern "C" { @@ -32,6 +33,14 @@ extern "C" { struct d3d12_shader; struct d3d12_image_format_conversion_info; +enum d3d12_state_var; + +nir_ssa_def * +d3d12_get_state_var(nir_builder *b, + enum d3d12_state_var var_enum, + const char *var_name, + const struct glsl_type *var_type, + nir_variable **out_var); bool d3d12_lower_point_sprite(nir_shader *shader,