mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-09 08:58:02 +02:00
radv/nir: add a struct for parameters to cooperative matrix lowering
Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33378>
This commit is contained in:
parent
baa09cb94a
commit
ad611adeb7
1 changed files with 39 additions and 30 deletions
|
|
@ -7,12 +7,16 @@
|
|||
#include "nir_builder.h"
|
||||
#include "radv_nir.h"
|
||||
|
||||
typedef struct {
|
||||
unsigned wave_size;
|
||||
} lower_cmat_params;
|
||||
|
||||
static unsigned
|
||||
radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size)
|
||||
radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params)
|
||||
{
|
||||
return desc.use != GLSL_CMAT_USE_ACCUMULATOR
|
||||
? 16
|
||||
: (desc.cols * desc.rows / wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
|
||||
: (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
|
||||
}
|
||||
|
||||
/* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement
|
||||
|
|
@ -32,28 +36,29 @@ radv_nir_cmat_bits(struct glsl_cmat_description desc)
|
|||
}
|
||||
|
||||
static nir_def *
|
||||
radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src)
|
||||
radv_nir_load_cmat(nir_builder *b, const lower_cmat_params *params, nir_def *src)
|
||||
{
|
||||
nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
|
||||
return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), glsl_base_type_bit_size(desc.element_type),
|
||||
src, 0);
|
||||
return nir_build_load_deref(b, radv_nir_cmat_length(desc, params), glsl_base_type_bit_size(desc.element_type), src,
|
||||
0);
|
||||
}
|
||||
|
||||
static const struct glsl_type *
|
||||
radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, unsigned wave_size)
|
||||
radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map,
|
||||
const lower_cmat_params *params)
|
||||
{
|
||||
struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type);
|
||||
if (entry) {
|
||||
return entry->data;
|
||||
} else if (glsl_type_is_cmat(orig_type)) {
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type);
|
||||
unsigned length = radv_nir_cmat_length(desc, wave_size);
|
||||
unsigned length = radv_nir_cmat_length(desc, params);
|
||||
|
||||
return glsl_vector_type(desc.element_type, length);
|
||||
} else if (glsl_type_is_array(orig_type)) {
|
||||
const struct glsl_type *elem_type = glsl_get_array_element(orig_type);
|
||||
const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, wave_size);
|
||||
const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, params);
|
||||
|
||||
if (elem_type == new_elem_type)
|
||||
return orig_type;
|
||||
|
|
@ -65,7 +70,7 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta
|
|||
bool change = false;
|
||||
for (unsigned i = 0; i < num_fields; ++i) {
|
||||
const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i);
|
||||
const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, wave_size);
|
||||
const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, params);
|
||||
|
||||
if (field_type != new_field_type) {
|
||||
change = true;
|
||||
|
|
@ -81,7 +86,7 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta
|
|||
for (unsigned i = 0; i < num_fields; ++i) {
|
||||
fields[i] = *glsl_get_struct_field_data(orig_type, i);
|
||||
|
||||
fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, wave_size);
|
||||
fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, params);
|
||||
}
|
||||
|
||||
const struct glsl_type *ret =
|
||||
|
|
@ -102,11 +107,15 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
if (!shader->info.cs.has_cooperative_matrix)
|
||||
return false;
|
||||
|
||||
const lower_cmat_params params = {
|
||||
.wave_size = wave_size,
|
||||
};
|
||||
|
||||
struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
|
||||
struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL);
|
||||
|
||||
nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) {
|
||||
const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
|
||||
const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, ¶ms);
|
||||
if (new_type != var->type) {
|
||||
var->type = new_type;
|
||||
progress = true;
|
||||
|
|
@ -114,7 +123,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
}
|
||||
|
||||
nir_foreach_function_temp_variable (var, func->impl) {
|
||||
const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
|
||||
const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, ¶ms);
|
||||
if (new_type != var->type) {
|
||||
var->type = new_type;
|
||||
progress = true;
|
||||
|
|
@ -134,7 +143,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
switch (intr->intrinsic) {
|
||||
case nir_intrinsic_cmat_length: {
|
||||
struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
|
||||
unsigned len = radv_nir_cmat_length(desc, wave_size) / radv_nir_cmat_length_mul(desc);
|
||||
unsigned len = radv_nir_cmat_length(desc, ¶ms) / radv_nir_cmat_length_mul(desc);
|
||||
nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
|
|
@ -143,7 +152,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
case nir_intrinsic_cmat_extract: {
|
||||
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
|
||||
nir_def *src0 = radv_nir_load_cmat(&b, wave_size, intr->src[0].ssa);
|
||||
nir_def *src0 = radv_nir_load_cmat(&b, ¶ms, intr->src[0].ssa);
|
||||
|
||||
nir_def *index = intr->src[1].ssa;
|
||||
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
|
||||
|
|
@ -156,7 +165,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_insert: {
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa);
|
||||
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
nir_def *index = intr->src[3].ssa;
|
||||
|
|
@ -174,7 +183,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
nir_def *elem = intr->src[1].ssa;
|
||||
|
||||
nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
|
||||
nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, ¶ms));
|
||||
|
||||
nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
|
||||
nir_instr_remove(instr);
|
||||
|
|
@ -197,9 +206,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
|
||||
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||
|
||||
unsigned length = radv_nir_cmat_length(desc, wave_size);
|
||||
unsigned length = radv_nir_cmat_length(desc, ¶ms);
|
||||
unsigned mul = radv_nir_cmat_length_mul(desc);
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
|
||||
nir_def *vars[16];
|
||||
if (mul > 1) {
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
|
|
@ -250,7 +259,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
|
||||
nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
|
||||
src = radv_nir_load_cmat(&b, wave_size, src);
|
||||
src = radv_nir_load_cmat(&b, ¶ms, src);
|
||||
|
||||
nir_def *local_idx = nir_load_subgroup_invocation(&b);
|
||||
|
||||
|
|
@ -264,9 +273,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
|
||||
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||
|
||||
unsigned length = radv_nir_cmat_length(desc, wave_size);
|
||||
unsigned length = radv_nir_cmat_length(desc, ¶ms);
|
||||
unsigned mul = radv_nir_cmat_length_mul(desc);
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
|
||||
nir_def *vars[16];
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
vars[i] = nir_channel(&b, src, i);
|
||||
|
|
@ -307,9 +316,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_muladd: {
|
||||
nir_def *A = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
|
||||
nir_def *B = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
|
||||
nir_def *C = radv_nir_load_cmat(&b, wave_size, intr->src[3].ssa);
|
||||
nir_def *A = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_def *B = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa);
|
||||
nir_def *C = radv_nir_load_cmat(&b, ¶ms, intr->src[3].ssa);
|
||||
nir_def *ret;
|
||||
|
||||
ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
|
||||
|
|
@ -326,7 +335,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
|
||||
nir_def *src = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
|
||||
nir_def *src = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
|
||||
if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
|
||||
|
|
@ -356,7 +365,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_scalar_op: {
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
|
||||
|
|
@ -366,8 +375,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_binary_op: {
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
|
||||
nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_def *src2 = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
nir_def *ret = nir_build_alu2(&b, op, src1, src2);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
|
||||
|
|
@ -377,7 +386,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_bitcast: {
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
|
||||
nir_component_mask(src1->num_components));
|
||||
nir_instr_remove(instr);
|
||||
|
|
@ -397,7 +406,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
}
|
||||
case nir_instr_type_deref: {
|
||||
nir_deref_instr *deref = nir_instr_as_deref(instr);
|
||||
const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, wave_size);
|
||||
const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, ¶ms);
|
||||
if (new_type != deref->type) {
|
||||
deref->type = new_type;
|
||||
progress = true;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue