radv/nir/lower_cmat: use nir_src_as_deref

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35633>
This commit is contained in:
Georg Lehmann 2025-06-24 13:51:20 +02:00 committed by Marge Bot
parent 48fc8c8d1c
commit 21523dad96

View file

@ -378,7 +378,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
break;
}
case nir_intrinsic_cmat_extract: {
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
nir_deref_instr *src_deref = nir_src_as_deref(intr->src[0]);
struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
nir_def *src0 = radv_nir_load_cmat(&b, &params, intr->src[0].ssa);
@ -394,7 +394,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
}
case nir_intrinsic_cmat_insert: {
nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
nir_def *index = intr->src[3].ssa;
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, &params));
@ -407,7 +407,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
break;
}
case nir_intrinsic_cmat_construct: {
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
nir_def *elem = intr->src[1].ssa;
@ -422,11 +422,11 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
case nir_intrinsic_cmat_store: {
const bool is_load = intr->intrinsic == nir_intrinsic_cmat_load;
nir_deref_instr *cmat_deref = nir_instr_as_deref(intr->src[!is_load].ssa->parent_instr);
nir_deref_instr *cmat_deref = nir_src_as_deref(intr->src[!is_load]);
struct glsl_cmat_description desc = *glsl_get_cmat_description(cmat_deref->type);
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
nir_deref_instr *deref = nir_instr_as_deref(intr->src[is_load].ssa->parent_instr);
nir_deref_instr *deref = nir_src_as_deref(intr->src[is_load]);
nir_def *stride = intr->src[2].ssa;
nir_def *local_idx = nir_load_subgroup_invocation(&b);
@ -555,8 +555,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
}
case nir_intrinsic_cmat_transpose:
case nir_intrinsic_cmat_convert: {
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
nir_deref_instr *src_deref = nir_src_as_deref(intr->src[1]);
struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_deref->type);
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
nir_def *src = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
@ -630,8 +630,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
nir_def *src = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu1(&b, op, src);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_store_deref(&b, nir_src_as_deref(intr->src[0]), ret, nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
@ -640,8 +639,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_store_deref(&b, nir_src_as_deref(intr->src[0]), ret, nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
@ -651,16 +649,14 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
nir_def *src2 = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu2(&b, op, src1, src2);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_store_deref(&b, nir_src_as_deref(intr->src[0]), ret, nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
}
case nir_intrinsic_cmat_bitcast: {
nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
nir_component_mask(src1->num_components));
nir_store_deref(&b, nir_src_as_deref(intr->src[0]), src1, nir_component_mask(src1->num_components));
nir_instr_remove(instr);
progress = true;
break;