radv/nir: add a struct for parameters to cooperative matrix lowering

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33378>
This commit is contained in:
Samuel Pitoiset 2025-01-22 03:46:01 -08:00 committed by Marge Bot
parent baa09cb94a
commit ad611adeb7

View file

@ -7,12 +7,16 @@
#include "nir_builder.h"
#include "radv_nir.h"
typedef struct {
unsigned wave_size;
} lower_cmat_params;
static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size)
radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params)
{
return desc.use != GLSL_CMAT_USE_ACCUMULATOR
? 16
: (desc.cols * desc.rows / wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
: (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
}
/* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement
@ -32,28 +36,29 @@ radv_nir_cmat_bits(struct glsl_cmat_description desc)
}
static nir_def *
radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src)
radv_nir_load_cmat(nir_builder *b, const lower_cmat_params *params, nir_def *src)
{
nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), glsl_base_type_bit_size(desc.element_type),
src, 0);
return nir_build_load_deref(b, radv_nir_cmat_length(desc, params), glsl_base_type_bit_size(desc.element_type), src,
0);
}
static const struct glsl_type *
radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, unsigned wave_size)
radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map,
const lower_cmat_params *params)
{
struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type);
if (entry) {
return entry->data;
} else if (glsl_type_is_cmat(orig_type)) {
struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type);
unsigned length = radv_nir_cmat_length(desc, wave_size);
unsigned length = radv_nir_cmat_length(desc, params);
return glsl_vector_type(desc.element_type, length);
} else if (glsl_type_is_array(orig_type)) {
const struct glsl_type *elem_type = glsl_get_array_element(orig_type);
const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, wave_size);
const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, params);
if (elem_type == new_elem_type)
return orig_type;
@ -65,7 +70,7 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta
bool change = false;
for (unsigned i = 0; i < num_fields; ++i) {
const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i);
const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, wave_size);
const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, params);
if (field_type != new_field_type) {
change = true;
@ -81,7 +86,7 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta
for (unsigned i = 0; i < num_fields; ++i) {
fields[i] = *glsl_get_struct_field_data(orig_type, i);
fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, wave_size);
fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, params);
}
const struct glsl_type *ret =
@ -102,11 +107,15 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
if (!shader->info.cs.has_cooperative_matrix)
return false;
const lower_cmat_params params = {
.wave_size = wave_size,
};
struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL);
nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) {
const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, &params);
if (new_type != var->type) {
var->type = new_type;
progress = true;
@ -114,7 +123,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
}
nir_foreach_function_temp_variable (var, func->impl) {
const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, &params);
if (new_type != var->type) {
var->type = new_type;
progress = true;
@ -134,7 +143,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
switch (intr->intrinsic) {
case nir_intrinsic_cmat_length: {
struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
unsigned len = radv_nir_cmat_length(desc, wave_size) / radv_nir_cmat_length_mul(desc);
unsigned len = radv_nir_cmat_length(desc, &params) / radv_nir_cmat_length_mul(desc);
nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
nir_instr_remove(instr);
progress = true;
@ -143,7 +152,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
case nir_intrinsic_cmat_extract: {
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
nir_def *src0 = radv_nir_load_cmat(&b, wave_size, intr->src[0].ssa);
nir_def *src0 = radv_nir_load_cmat(&b, &params, intr->src[0].ssa);
nir_def *index = intr->src[1].ssa;
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
@ -156,7 +165,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
break;
}
case nir_intrinsic_cmat_insert: {
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
nir_def *index = intr->src[3].ssa;
@ -174,7 +183,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
nir_def *elem = intr->src[1].ssa;
nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, &params));
nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
nir_instr_remove(instr);
@ -197,9 +206,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
unsigned length = radv_nir_cmat_length(desc, wave_size);
unsigned length = radv_nir_cmat_length(desc, &params);
unsigned mul = radv_nir_cmat_length_mul(desc);
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
if (mul > 1) {
for (unsigned i = 0; i < length; ++i)
@ -250,7 +259,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
src = radv_nir_load_cmat(&b, wave_size, src);
src = radv_nir_load_cmat(&b, &params, src);
nir_def *local_idx = nir_load_subgroup_invocation(&b);
@ -264,9 +273,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
unsigned length = radv_nir_cmat_length(desc, wave_size);
unsigned length = radv_nir_cmat_length(desc, &params);
unsigned mul = radv_nir_cmat_length_mul(desc);
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
for (unsigned i = 0; i < length; ++i)
vars[i] = nir_channel(&b, src, i);
@ -307,9 +316,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
break;
}
case nir_intrinsic_cmat_muladd: {
nir_def *A = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_def *B = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
nir_def *C = radv_nir_load_cmat(&b, wave_size, intr->src[3].ssa);
nir_def *A = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_def *B = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
nir_def *C = radv_nir_load_cmat(&b, &params, intr->src[3].ssa);
nir_def *ret;
ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
@ -326,7 +335,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
nir_def *src = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_def *src = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
@ -356,7 +365,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
break;
}
case nir_intrinsic_cmat_scalar_op: {
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
@ -366,8 +375,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
break;
}
case nir_intrinsic_cmat_binary_op: {
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_def *src2 = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu2(&b, op, src1, src2);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
@ -377,7 +386,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
break;
}
case nir_intrinsic_cmat_bitcast: {
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
nir_component_mask(src1->num_components));
nir_instr_remove(instr);
@ -397,7 +406,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
}
case nir_instr_type_deref: {
nir_deref_instr *deref = nir_instr_as_deref(instr);
const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, wave_size);
const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, &params);
if (new_type != deref->type) {
deref->type = new_type;
progress = true;