mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-18 07:18:06 +02:00
This suppresses below compile warnings: - warning: variable 'idx' is used uninitialized whenever 'if' condition is false [-Wsometimes-uninitialized] Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38835>
956 lines
35 KiB
C
956 lines
35 KiB
C
/*
|
|
* Copyright © 2025 Red Hat Inc.
|
|
* SPDX-License-Identifier: MIT
|
|
*/
|
|
|
|
#include "nir.h"
|
|
#include "nir_deref.h"
|
|
#include "nir_builder.h"
|
|
|
|
/*
|
|
* Lower flexible size cooperative matrix operations down to operations at the supported granularity.
|
|
*/
|
|
struct split_mat {
|
|
unsigned num_row_splits;
|
|
unsigned num_col_splits;
|
|
nir_variable **split_vars;
|
|
};
|
|
|
|
struct split_info {
|
|
struct hash_table *split_mats;
|
|
unsigned m_gran;
|
|
unsigned n_gran;
|
|
unsigned k_gran;
|
|
};
|
|
|
|
static struct split_mat *find_split(struct hash_table *split_mats,
|
|
nir_intrinsic_instr *intr, int idx)
|
|
{
|
|
nir_variable *var = nir_deref_instr_get_variable(nir_src_as_deref(intr->src[idx]));
|
|
struct hash_entry *entry = _mesa_hash_table_search(split_mats, var);
|
|
return entry ? entry->data : NULL;
|
|
}
|
|
|
|
static struct split_mat *find_call_split(struct hash_table *split_mats,
|
|
nir_cmat_call_instr *call, int idx)
|
|
{
|
|
nir_deref_instr *deref = nir_src_as_deref(call->params[idx]);
|
|
if (!deref)
|
|
return NULL;
|
|
nir_variable *var = nir_deref_instr_get_variable(deref);
|
|
struct hash_entry *entry = _mesa_hash_table_search(split_mats, var);
|
|
return entry ? entry->data : NULL;
|
|
}
|
|
|
|
static struct nir_deref_instr *recreate_derefs(nir_builder *b, nir_src *src,
|
|
nir_variable *var)
|
|
{
|
|
nir_deref_instr *deref = nir_src_as_deref(*src);
|
|
nir_deref_path path;
|
|
nir_deref_path_init(&path, deref, NULL);
|
|
|
|
nir_deref_instr *old_head = path.path[0];
|
|
b->cursor = nir_after_instr(&old_head->instr);
|
|
nir_deref_instr *head = nir_build_deref_var(b, var);
|
|
for (int i = 1; path.path[i]; i++) {
|
|
nir_deref_instr *old = path.path[i];
|
|
b->cursor = nir_after_instr(&old->instr);
|
|
head = nir_build_deref_follower(b, head, old);
|
|
}
|
|
|
|
nir_deref_path_finish(&path);
|
|
return head;
|
|
}
|
|
|
|
static void
|
|
get_rowcol_gran(struct glsl_cmat_description desc, unsigned m_gran,
|
|
unsigned n_gran, unsigned k_gran,
|
|
unsigned *row_gran, unsigned *col_gran)
|
|
{
|
|
switch (desc.use) {
|
|
case GLSL_CMAT_USE_A:
|
|
default:
|
|
*row_gran = m_gran;
|
|
*col_gran = k_gran;
|
|
break;
|
|
case GLSL_CMAT_USE_B:
|
|
*row_gran = k_gran;
|
|
*col_gran = n_gran;
|
|
break;
|
|
case GLSL_CMAT_USE_ACCUMULATOR:
|
|
*row_gran = m_gran;
|
|
*col_gran = n_gran;
|
|
break;
|
|
}
|
|
}
|
|
|
|
static void
|
|
get_lower_sizes(struct glsl_cmat_description desc, unsigned m_gran,
|
|
unsigned n_gran, unsigned k_gran,
|
|
unsigned *split_rows_out, unsigned *split_cols_out)
|
|
{
|
|
unsigned split_rows = 0, split_cols = 0;
|
|
|
|
unsigned row_gran, col_gran;
|
|
|
|
get_rowcol_gran(desc, m_gran, n_gran, k_gran, &row_gran, &col_gran);
|
|
|
|
if (desc.rows && desc.rows != row_gran) {
|
|
split_rows = row_gran;
|
|
assert(desc.rows % split_rows == 0);
|
|
}
|
|
if (desc.cols && desc.cols != col_gran) {
|
|
split_cols = col_gran;
|
|
assert(desc.cols % split_cols == 0);
|
|
}
|
|
|
|
*split_rows_out = split_rows;
|
|
*split_cols_out = split_cols;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_construct(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
if (!dst_split)
|
|
return false;
|
|
|
|
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_construct(b, &dst_deref->def, intr->src[1].ssa);
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_copy(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *src_split = find_split(info->split_mats, intr, 1);
|
|
|
|
if (!dst_split)
|
|
return false;
|
|
|
|
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
assert(src_split);
|
|
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_copy(b, &dst_deref->def, &src_deref->def);
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_length(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
|
|
unsigned row_gran, col_gran;
|
|
unsigned split_rows = 0, split_cols = 0;
|
|
unsigned splits = 1;
|
|
|
|
get_rowcol_gran(desc, info->m_gran, info->n_gran, info->k_gran, &row_gran, &col_gran);
|
|
|
|
if (desc.rows == row_gran &&
|
|
desc.cols == col_gran)
|
|
return false;
|
|
|
|
get_lower_sizes(desc, info->m_gran, info->n_gran, info->k_gran, &split_rows, &split_cols);
|
|
|
|
if (split_rows) {
|
|
splits = desc.rows / split_rows;
|
|
desc.rows = split_rows;
|
|
}
|
|
|
|
if (split_cols) {
|
|
splits *= desc.cols / split_cols;
|
|
desc.cols = split_cols;
|
|
}
|
|
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_def *def = nir_cmat_length(b, .cmat_desc = desc);
|
|
def = nir_imul_imm(b, def, splits);
|
|
nir_def_replace(&intr->def, def);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_insert(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *src_split = find_split(info->split_mats, intr, 2);
|
|
|
|
if (!dst_split)
|
|
return false;
|
|
|
|
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
assert(src_split);
|
|
b->cursor = nir_before_instr(instr);
|
|
|
|
nir_def *len = nir_cmat_length(b, .cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type));
|
|
nir_def *arr_idx = nir_udiv(b, intr->src[3].ssa, len);
|
|
nir_def *base_idx = nir_umod(b, intr->src[3].ssa, len);
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[2], src_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
|
|
nir_def *new_def = nir_cmat_extract(b, nir_src_bit_size(intr->src[1]), &src_deref->def, base_idx);
|
|
nir_def *cond = nir_ieq_imm(b, arr_idx, i);
|
|
|
|
new_def = nir_bcsel(b, cond, intr->src[1].ssa, new_def);
|
|
nir_cmat_insert(b, &dst_deref->def, new_def, &src_deref->def, base_idx);
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_extract(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *src_split = find_split(info->split_mats, intr, 0);
|
|
|
|
if (!src_split)
|
|
return false;
|
|
|
|
unsigned splits = src_split->num_col_splits * src_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
b->cursor = nir_before_instr(instr);
|
|
|
|
nir_def *len = nir_cmat_length(b, .cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type));
|
|
nir_def *arr_idx = nir_udiv(b, intr->src[1].ssa, len);
|
|
nir_def *base_idx = nir_umod(b, intr->src[1].ssa, len);
|
|
nir_def *last_def = nir_undef(b, 1, intr->def.bit_size);
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[0], src_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_def *cond = nir_ieq_imm(b, arr_idx, i);
|
|
nir_def *new_def = nir_cmat_extract(b, intr->def.bit_size, &src_deref->def, base_idx);
|
|
|
|
last_def = nir_bcsel(b, cond, new_def, last_def);
|
|
}
|
|
nir_def_replace(&intr->def, last_def);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_convert(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *src_split = find_split(info->split_mats, intr, 1);
|
|
|
|
if (!dst_split && !src_split)
|
|
return false;
|
|
|
|
assert(dst_split && src_split);
|
|
|
|
unsigned splits = src_split->num_col_splits * src_split->num_row_splits;
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_convert(b, &dst_deref->def, &src_deref->def, .saturate = nir_intrinsic_saturate(intr),
|
|
.cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
|
|
static bool
|
|
split_cmat_transpose(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *src_split = find_split(info->split_mats, intr, 1);
|
|
|
|
if (!dst_split && !src_split)
|
|
return false;
|
|
|
|
assert(dst_split && src_split);
|
|
|
|
for (unsigned r = 0; r < src_split->num_row_splits; r++) {
|
|
for (unsigned c = 0; c < src_split->num_col_splits; c++) {
|
|
int in_idx = r * src_split->num_col_splits + c;
|
|
int out_idx = c * dst_split->num_col_splits + r;
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[out_idx]);
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[in_idx]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_transpose(b, &dst_deref->def, &src_deref->def);
|
|
}
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_bitcast(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *src_split = find_split(info->split_mats, intr, 1);
|
|
|
|
if (!dst_split)
|
|
return false;
|
|
|
|
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
assert(src_split);
|
|
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_bitcast(b, &dst_deref->def, &src_deref->def);
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_binary_op(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *src0_split = find_split(info->split_mats, intr, 1);
|
|
struct split_mat *src1_split = find_split(info->split_mats, intr, 2);
|
|
|
|
if (!dst_split)
|
|
return false;
|
|
|
|
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
assert(src0_split);
|
|
assert(src1_split);
|
|
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
|
|
nir_deref_instr *src0_deref = recreate_derefs(b, &intr->src[1], src0_split->split_vars[i]);
|
|
nir_deref_instr *src1_deref = recreate_derefs(b, &intr->src[2], src1_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_binary_op(b, &dst_deref->def, &src0_deref->def, &src1_deref->def,
|
|
.alu_op = nir_intrinsic_alu_op(intr));
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_unary_op(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *src_split = find_split(info->split_mats, intr, 1);
|
|
|
|
if (!dst_split)
|
|
return false;
|
|
|
|
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
assert(src_split);
|
|
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_unary_op(b, &dst_deref->def, &src_deref->def, .alu_op = nir_intrinsic_alu_op(intr));
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_scalar_op(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *dst_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *src_split = find_split(info->split_mats, intr, 1);
|
|
|
|
if (!dst_split)
|
|
return false;
|
|
|
|
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
assert(src_split);
|
|
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_scalar_op(b, &dst_deref->def, &src_deref->def, intr->src[2].ssa,
|
|
.alu_op = nir_intrinsic_alu_op(intr));
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_muladd(nir_builder *b,
|
|
nir_function_impl *impl,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
struct split_mat *result_split = find_split(info->split_mats, intr, 0);
|
|
struct split_mat *a_split = find_split(info->split_mats, intr, 1);
|
|
struct split_mat *b_split = find_split(info->split_mats, intr, 2);
|
|
struct split_mat *c_split = find_split(info->split_mats, intr, 3);
|
|
|
|
unsigned m_splits = 1;
|
|
unsigned n_splits = 1;
|
|
unsigned k_splits = 1;
|
|
|
|
if (!result_split && !a_split && !b_split && !c_split)
|
|
return false;
|
|
|
|
if (result_split) {
|
|
assert(c_split);
|
|
m_splits = result_split->num_row_splits;
|
|
n_splits = result_split->num_col_splits;
|
|
|
|
assert(c_split->num_row_splits == m_splits);
|
|
assert(c_split->num_col_splits == n_splits);
|
|
if (a_split)
|
|
assert(a_split->num_row_splits == m_splits);
|
|
|
|
if (b_split)
|
|
assert(b_split->num_col_splits == n_splits);
|
|
}
|
|
|
|
if (a_split && a_split->num_col_splits > 1) {
|
|
assert(b_split);
|
|
assert(b_split->num_row_splits == a_split->num_col_splits);
|
|
k_splits = a_split->num_col_splits;
|
|
}
|
|
|
|
for (unsigned m = 0; m < m_splits; m++) {
|
|
for (unsigned n = 0; n < n_splits; n++) {
|
|
unsigned idx = m * n_splits + n;
|
|
nir_deref_instr *dst_deref = result_split ? recreate_derefs(b, &intr->src[0], result_split->split_vars[idx]) : nir_src_as_deref(intr->src[0]);
|
|
nir_deref_instr *c_deref = c_split ? recreate_derefs(b, &intr->src[3], c_split->split_vars[idx]) : nir_src_as_deref(intr->src[3]);
|
|
|
|
for (unsigned k = 0; k < k_splits; k++) {
|
|
unsigned a_idx = m * k_splits + k;
|
|
unsigned b_idx = k * n_splits + n;
|
|
nir_deref_instr *a_deref = a_split ? recreate_derefs(b, &intr->src[1], a_split->split_vars[a_idx]) : nir_src_as_deref(intr->src[1]);
|
|
nir_deref_instr *b_deref = b_split ? recreate_derefs(b, &intr->src[2], b_split->split_vars[b_idx]) : nir_src_as_deref(intr->src[2]);
|
|
nir_deref_instr *k_dst_deref = k == k_splits - 1 ? dst_deref : c_deref;
|
|
b->cursor = nir_before_instr(instr);
|
|
|
|
nir_cmat_muladd(b, &k_dst_deref->def, &a_deref->def, &b_deref->def, &c_deref->def,
|
|
.saturate = nir_intrinsic_saturate(intr),
|
|
.cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
|
|
}
|
|
}
|
|
}
|
|
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static void
|
|
call_reduce(nir_builder *b,
|
|
nir_cmat_call_instr *call,
|
|
nir_cmat_reduce reduce,
|
|
nir_def *dst, nir_def *src0)
|
|
{
|
|
nir_cmat_call_instr *ncall = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_reduce, call->callee);
|
|
ncall->params[0] = nir_src_for_ssa(dst);
|
|
ncall->params[1] = nir_src_for_ssa(src0);
|
|
ncall->const_index[0] = reduce;
|
|
nir_builder_instr_insert(b, &ncall->instr);
|
|
}
|
|
|
|
static void
|
|
call_reduce_finish(nir_builder *b,
|
|
nir_cmat_call_instr *call,
|
|
nir_cmat_reduce reduce,
|
|
nir_def *dst, nir_def *src0, nir_def *src1)
|
|
{
|
|
nir_cmat_call_instr *ncall = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_reduce_finish, call->callee);
|
|
ncall->params[0] = nir_src_for_ssa(dst);
|
|
ncall->params[1] = nir_src_for_ssa(src0);
|
|
ncall->params[2] = nir_src_for_ssa(src1);
|
|
ncall->const_index[0] = reduce;
|
|
nir_builder_instr_insert(b, &ncall->instr);
|
|
}
|
|
|
|
static void
|
|
call_reduce_2x2(nir_builder *b,
|
|
nir_cmat_call_instr *call,
|
|
nir_def *dst,
|
|
nir_def *src0, nir_def *src1,
|
|
nir_def *src2, nir_def *src3)
|
|
{
|
|
nir_cmat_call_instr *ncall = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_reduce_2x2, call->callee);
|
|
ncall->params[0] = nir_src_for_ssa(dst);
|
|
ncall->params[1] = nir_src_for_ssa(src0);
|
|
ncall->params[2] = nir_src_for_ssa(src1);
|
|
ncall->params[3] = nir_src_for_ssa(src2);
|
|
ncall->params[4] = nir_src_for_ssa(src3);
|
|
nir_builder_instr_insert(b, &ncall->instr);
|
|
}
|
|
|
|
static bool
|
|
split_cmat_call_reduce(nir_builder *b,
|
|
nir_function_impl *impl,
|
|
nir_cmat_call_instr *call,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &call->instr;
|
|
nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call);
|
|
struct split_mat *dst_split = find_call_split(info->split_mats, call, 0);
|
|
struct split_mat *src_split = find_call_split(info->split_mats, call, 1);
|
|
|
|
if (reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) {
|
|
assert(!(reduce & ~(NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)));
|
|
|
|
/* for each source split - reduce it by itself. */
|
|
int src_splits = 1;
|
|
if (src_split)
|
|
src_splits = src_split->num_col_splits * src_split->num_row_splits;
|
|
nir_deref_instr **temp_derefs = ralloc_array(NULL, nir_deref_instr *, src_splits);
|
|
|
|
const struct glsl_type *temp_type = nir_deref_instr_get_variable(nir_src_as_deref(call->params[1]))->type;
|
|
if (src_splits > 1)
|
|
temp_type = src_split->split_vars[0]->type;
|
|
for (unsigned i = 0; i < src_splits; i++) {
|
|
nir_variable *temp_var = nir_local_variable_create(impl, temp_type,
|
|
"reduce_split_srcs");
|
|
temp_derefs[i] = nir_build_deref_var(b, temp_var);
|
|
}
|
|
|
|
if (src_splits > 1) {
|
|
/* reduce each individual src matrix */
|
|
for (unsigned i = 0; i < src_splits; i++) {
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[i]);
|
|
b->cursor = nir_before_instr(instr);
|
|
call_reduce(b, call, reduce, &temp_derefs[i]->def, &src_deref->def);
|
|
}
|
|
|
|
if ((reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) == (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) {
|
|
for (unsigned i = 1; i < src_splits; i++) {
|
|
nir_deref_instr *second_deref = temp_derefs[i];
|
|
b->cursor = nir_before_instr(instr);
|
|
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
|
|
}
|
|
} else if (reduce & NIR_CMAT_REDUCE_ROW) {
|
|
for (unsigned i = 1; i < src_split->num_col_splits; i++) {
|
|
nir_deref_instr *second_deref = temp_derefs[i];
|
|
b->cursor = nir_before_instr(instr);
|
|
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
|
|
}
|
|
} else if (reduce & NIR_CMAT_REDUCE_COLUMN) {
|
|
for (unsigned i = 1; i < src_split->num_row_splits; i++) {
|
|
nir_deref_instr *second_deref = temp_derefs[i * src_split->num_col_splits];
|
|
b->cursor = nir_before_instr(instr);
|
|
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
|
|
}
|
|
}
|
|
} else {
|
|
call_reduce(b, call, reduce, &temp_derefs[0]->def, &nir_src_as_deref(call->params[1])->def);
|
|
}
|
|
|
|
/* at this point temp_derefs should contain all the split reduced src matrices
|
|
now to store them */
|
|
if (dst_split) {
|
|
for (unsigned r = 0; r < dst_split->num_row_splits; r++) {
|
|
for (unsigned c = 0; c < dst_split->num_col_splits; c++) {
|
|
int didx = r * dst_split->num_col_splits + c;
|
|
int idx;
|
|
if ((reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) == (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN))
|
|
idx = 0;
|
|
else if (reduce & NIR_CMAT_REDUCE_ROW)
|
|
idx = r % (src_split ? src_split->num_row_splits : 1);
|
|
else if (reduce & NIR_CMAT_REDUCE_COLUMN)
|
|
idx = c % (src_split ? src_split->num_col_splits : 1);
|
|
else
|
|
UNREACHABLE("Unknown NIR_CMAT_REDUCE_*");
|
|
|
|
nir_deref_instr *deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[didx]);
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_cmat_copy(b, &deref->def, &temp_derefs[idx]->def);
|
|
}
|
|
}
|
|
} else {
|
|
nir_cmat_copy(b, call->params[0].ssa, &temp_derefs[0]->def);
|
|
}
|
|
|
|
ralloc_free(temp_derefs);
|
|
} else if (reduce & NIR_CMAT_REDUCE_2X2) {
|
|
assert(reduce == NIR_CMAT_REDUCE_2X2);
|
|
|
|
/* dst can have target dimensions, but src but be at least twice as large */
|
|
assert (src_split);
|
|
|
|
int rows = 1, cols = 1;
|
|
if (dst_split) {
|
|
rows = dst_split->num_row_splits;
|
|
cols = dst_split->num_col_splits;
|
|
}
|
|
|
|
for (unsigned r = 0; r < rows; r++) {
|
|
for (unsigned c = 0; c < cols; c++) {
|
|
int d_idx = c + r * cols;
|
|
int src_top_left_col = c * 2;
|
|
int src_top_left_row = r * 2;
|
|
int src_top_idx = src_top_left_col + src_top_left_row * src_split->num_col_splits;
|
|
int src_bottom_idx = src_top_left_col + (src_top_left_row + 1) * src_split->num_col_splits;
|
|
nir_deref_instr *src0_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx]);
|
|
nir_deref_instr *src1_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx + 1]);
|
|
nir_deref_instr *src2_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_bottom_idx]);
|
|
nir_deref_instr *src3_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_bottom_idx + 1]);
|
|
nir_deref_instr *dst_deref = dst_split ? recreate_derefs(b, &call->params[0], dst_split->split_vars[d_idx]) : nir_src_as_deref(call->params[0]);
|
|
b->cursor = nir_before_instr(instr);
|
|
call_reduce_2x2(b, call, &dst_deref->def, &src0_deref->def, &src1_deref->def, &src2_deref->def, &src3_deref->def);
|
|
}
|
|
}
|
|
}
|
|
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_load_store(nir_builder *b,
|
|
nir_intrinsic_instr *intr,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &intr->instr;
|
|
const bool is_load = intr->intrinsic == nir_intrinsic_cmat_load;
|
|
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
|
|
nir_variable *var = nir_deref_instr_get_variable(nir_src_as_deref(intr->src[!is_load]));
|
|
struct hash_entry *entry = _mesa_hash_table_search(info->split_mats, var);
|
|
if (!entry)
|
|
return false;
|
|
|
|
struct split_mat *split = entry->data;
|
|
unsigned splits = split->num_row_splits * split->num_col_splits;
|
|
for (unsigned i = 0; i < splits; i++) {
|
|
nir_deref_instr *new_deref = recreate_derefs(b, &intr->src[!is_load], split->split_vars[i]);
|
|
nir_deref_instr *ptr_deref;
|
|
nir_def *stride = intr->src[2].ssa;
|
|
nir_def *ptr = intr->src[is_load].ssa;
|
|
|
|
b->cursor = nir_before_instr(instr);
|
|
if (i > 0) {
|
|
nir_deref_instr *addr_deref = nir_src_as_deref(intr->src[is_load]);
|
|
unsigned dst_bit_size = addr_deref->def.bit_size;
|
|
nir_def *this_index = nir_imm_zero(b, 1, dst_bit_size);
|
|
unsigned deref_bytes_size = glsl_get_explicit_size(addr_deref->type, false);
|
|
const struct glsl_type *scalar_type = glsl_get_scalar_type(glsl_get_cmat_element(var->type));
|
|
unsigned elem_size = glsl_get_explicit_size(scalar_type, false);
|
|
struct glsl_cmat_description desc = *glsl_get_cmat_description(split->split_vars[i]->type);
|
|
unsigned row_offset, col_offset;
|
|
|
|
row_offset = (i % split->num_col_splits) * desc.cols;
|
|
col_offset = (i / split->num_col_splits) * desc.rows;
|
|
|
|
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR)
|
|
SWAP(row_offset, col_offset);
|
|
|
|
ptr_deref = nir_build_deref_cast(b, &addr_deref->def, addr_deref->modes, scalar_type, elem_size);
|
|
stride = nir_udiv_imm(b, nir_imul_imm(b, stride, deref_bytes_size), elem_size);
|
|
|
|
if (col_offset)
|
|
this_index = nir_imm_intN_t(b, col_offset, dst_bit_size);
|
|
if (row_offset)
|
|
this_index = nir_iadd(b, this_index, nir_u2uN(b, nir_imul_imm(b, stride, row_offset), dst_bit_size));
|
|
ptr_deref = nir_build_deref_ptr_as_array(b, ptr_deref, this_index);
|
|
ptr = &ptr_deref->def;
|
|
}
|
|
if (is_load)
|
|
nir_cmat_load(b, &new_deref->def, ptr, stride,
|
|
.matrix_layout = layout);
|
|
else
|
|
nir_cmat_store(b, ptr, &new_deref->def, stride,
|
|
.matrix_layout = layout);
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_cmat_call_per_element_op(nir_builder *b,
|
|
nir_cmat_call_instr *call,
|
|
struct split_info *info)
|
|
{
|
|
nir_instr *instr = &call->instr;
|
|
struct split_mat *dst_split = find_call_split(info->split_mats, call, 0);
|
|
struct split_mat *src_split = find_call_split(info->split_mats, call, 3);
|
|
if (!dst_split)
|
|
return false;
|
|
|
|
assert(src_split);
|
|
int splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
|
if (splits <= 1)
|
|
return false;
|
|
|
|
for (unsigned r = 0; r < dst_split->num_row_splits; r++) {
|
|
for (unsigned c = 0; c < dst_split->num_col_splits; c++) {
|
|
int idx = r * dst_split->num_col_splits + c;
|
|
nir_deref_instr *dst_deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[idx]);
|
|
nir_deref_instr *src_deref = recreate_derefs(b, &call->params[3], src_split->split_vars[idx]);
|
|
struct glsl_cmat_description cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type);
|
|
nir_cmat_call_instr *new_call = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_per_element_op, call->callee);
|
|
new_call->params[0] = nir_src_for_ssa(&dst_deref->def);
|
|
new_call->params[1] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.rows * r));
|
|
new_call->params[2] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.cols * c));
|
|
new_call->params[3] = nir_src_for_ssa(&src_deref->def);
|
|
|
|
for (unsigned i = 4; i < call->num_params; i++) {
|
|
if (nir_src_as_deref(call->params[i])) {
|
|
struct split_mat *src1_split = find_call_split(info->split_mats, call, i);
|
|
nir_deref_instr *src1_deref = src1_split ? recreate_derefs(b, &call->params[i], src1_split->split_vars[idx]) : nir_src_as_deref(call->params[i]);
|
|
new_call->params[i] = src1_deref ? nir_src_for_ssa(&src1_deref->def) : call->params[i];
|
|
} else
|
|
new_call->params[i] = call->params[i];
|
|
}
|
|
b->cursor = nir_before_instr(instr);
|
|
nir_builder_instr_insert(b, &new_call->instr);
|
|
}
|
|
}
|
|
nir_instr_remove(instr);
|
|
return true;
|
|
}
|
|
|
|
static bool
|
|
split_matrix_impl(nir_function_impl *impl, struct split_info *info)
|
|
{
|
|
bool progress = false;
|
|
nir_builder b = nir_builder_create(impl);
|
|
nir_foreach_block_reverse (block, impl) {
|
|
nir_foreach_instr_reverse_safe (instr, block) {
|
|
b.cursor = nir_before_instr(instr);
|
|
switch (instr->type) {
|
|
case nir_instr_type_intrinsic: {
|
|
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
|
|
switch (intr->intrinsic) {
|
|
case nir_intrinsic_cmat_construct:
|
|
progress |= split_cmat_construct(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_copy:
|
|
progress |= split_cmat_copy(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_length:
|
|
progress |= split_cmat_length(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_insert:
|
|
progress |= split_cmat_insert(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_extract:
|
|
progress |= split_cmat_extract(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_convert:
|
|
progress |= split_cmat_convert(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_transpose:
|
|
progress |= split_cmat_transpose(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_bitcast:
|
|
progress |= split_cmat_bitcast(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_binary_op:
|
|
progress |= split_cmat_binary_op(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_unary_op:
|
|
progress |= split_cmat_unary_op(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_scalar_op:
|
|
progress |= split_cmat_scalar_op(&b, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_muladd:
|
|
progress |= split_cmat_muladd(&b, impl, intr, info);
|
|
break;
|
|
case nir_intrinsic_cmat_load:
|
|
case nir_intrinsic_cmat_store:
|
|
progress |= split_cmat_load_store(&b, intr, info);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
case nir_instr_type_cmat_call: {
|
|
nir_cmat_call_instr *cmat_call = nir_instr_as_cmat_call(instr);
|
|
switch (cmat_call->op) {
|
|
case nir_cmat_call_op_reduce:
|
|
progress |= split_cmat_call_reduce(&b, impl, cmat_call, info);
|
|
break;
|
|
case nir_cmat_call_op_per_element_op:
|
|
progress |= split_cmat_call_per_element_op(&b, cmat_call, info);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return progress;
|
|
}
|
|
|
|
static struct split_mat *
|
|
split_var(nir_shader *shader,
|
|
nir_function_impl *impl,
|
|
void *mem_ctx,
|
|
nir_variable *var,
|
|
unsigned m_gran,
|
|
unsigned n_gran,
|
|
unsigned k_gran)
|
|
{
|
|
if (!glsl_type_is_cmat(glsl_without_array(var->type)))
|
|
return NULL;
|
|
|
|
const struct glsl_type *type = var->type;
|
|
if (glsl_type_is_array(type)) {
|
|
type = glsl_without_array(var->type);
|
|
}
|
|
|
|
struct glsl_cmat_description desc = *glsl_get_cmat_description(type);
|
|
unsigned split_rows = 0, split_cols = 0;
|
|
|
|
get_lower_sizes(desc, m_gran, n_gran, k_gran, &split_rows, &split_cols);
|
|
|
|
unsigned num_row_split = 1, num_col_split = 1;
|
|
|
|
if (split_rows) {
|
|
num_row_split = desc.rows / split_rows;
|
|
desc.rows = split_rows;
|
|
}
|
|
if (split_cols) {
|
|
num_col_split = desc.cols / split_cols;
|
|
desc.cols = split_cols;
|
|
}
|
|
|
|
if (num_row_split == 1 && num_col_split == 1)
|
|
return NULL;
|
|
|
|
const struct glsl_type *new_type = glsl_type_wrap_in_arrays(glsl_cmat_type(&desc), var->type);
|
|
|
|
struct split_mat *split_mat = ralloc(mem_ctx, struct split_mat);
|
|
if (!split_mat)
|
|
return NULL;
|
|
|
|
unsigned num_split = num_row_split * num_col_split;
|
|
split_mat->num_row_splits = num_row_split;
|
|
split_mat->num_col_splits = num_col_split;
|
|
split_mat->split_vars = ralloc_array(split_mat, struct nir_variable *, num_split);
|
|
for (unsigned i = 0; i < num_split; i++) {
|
|
if (!nir_variable_is_global(var)) {
|
|
split_mat->split_vars[i] = nir_local_variable_create(impl,
|
|
new_type, var->name);
|
|
} else {
|
|
split_mat->split_vars[i] = nir_variable_create(shader, var->data.mode,
|
|
new_type, var->name);
|
|
}
|
|
}
|
|
return split_mat;
|
|
}
|
|
|
|
static bool
|
|
lower_dimensions(nir_shader *shader, nir_function_impl *impl,
|
|
unsigned m_gran, unsigned n_gran, unsigned k_gran)
|
|
{
|
|
struct hash_table *split_mats = _mesa_pointer_hash_table_create(NULL);
|
|
void *mem_ctx = ralloc_context(NULL);
|
|
bool progress = false;
|
|
|
|
nir_foreach_variable_in_shader(var, shader) {
|
|
struct split_mat *split_mat = split_var(shader, NULL, mem_ctx, var, m_gran, n_gran, k_gran);
|
|
if (split_mat)
|
|
_mesa_hash_table_insert(split_mats, var, split_mat);
|
|
}
|
|
nir_foreach_function_temp_variable (var, impl) {
|
|
struct split_mat *split_mat = split_var(shader, impl, mem_ctx, var, m_gran, n_gran, k_gran);
|
|
if (split_mat)
|
|
_mesa_hash_table_insert(split_mats, var, split_mat);
|
|
}
|
|
|
|
struct split_info split_info = {
|
|
.split_mats = split_mats,
|
|
.m_gran = m_gran,
|
|
.n_gran = n_gran,
|
|
.k_gran = k_gran,
|
|
};
|
|
progress = split_matrix_impl(impl, &split_info);
|
|
_mesa_hash_table_destroy(split_mats, NULL);
|
|
ralloc_free(mem_ctx);
|
|
return progress;
|
|
}
|
|
|
|
bool
|
|
nir_lower_cooperative_matrix_flexible_dimensions(nir_shader *shader, unsigned m_gran, unsigned n_gran, unsigned k_gran)
|
|
{
|
|
bool progress = false;
|
|
|
|
if (!shader->info.cs.has_cooperative_matrix)
|
|
return false;
|
|
|
|
struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
|
|
|
|
progress |= lower_dimensions(shader, func->impl, m_gran, n_gran, k_gran);
|
|
|
|
nir_foreach_function_impl(fnim, shader)
|
|
nir_progress(progress, fnim, 0);
|
|
return progress;
|
|
}
|