nir: add a cmat call instruction type.

This adds a new instruction type to handle cooperative matrix calls.

This clones the call instr, drops callee, and adds a single metadata
slot and a call operation (dummy only for now).

(Not NACKed by Alyssa)

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38389>
This commit is contained in:
Dave Airlie 2025-11-12 12:17:10 +10:00 committed by Marge Bot
parent 6ba0797a06
commit 26eaba935d
17 changed files with 168 additions and 1 deletions

View file

@ -1348,6 +1348,7 @@ v3d_instr_delay_cb(nir_instr *instr, void *data)
case nir_instr_type_deref:
case nir_instr_type_jump:
case nir_instr_type_call:
case nir_instr_type_cmat_call:
case nir_instr_type_phi:
return 1;

View file

@ -905,6 +905,32 @@ nir_call_instr_create(nir_shader *shader, nir_function *callee)
return instr;
}
int
nir_cmat_call_op_params(nir_cmat_call_op op, nir_function *callee)
{
switch (op) {
default:
return callee->num_params;
}
}
nir_cmat_call_instr *
nir_cmat_call_instr_create(nir_shader *shader, nir_cmat_call_op op, nir_function *callee)
{
const unsigned num_params = nir_cmat_call_op_params(op, callee);
nir_cmat_call_instr *instr =
nir_instr_create(shader, nir_instr_type_cmat_call,
sizeof(nir_cmat_call_instr) + sizeof(nir_src) * num_params);
instr->callee = callee;
instr->op = op;
instr->num_params = num_params;
for (unsigned i = 0; i < num_params; i++)
src_init(&instr->params[i]);
return instr;
}
static int8_t default_tg4_offsets[4][2] = {
{ 0, 1 },
{ 1, 1 },
@ -1481,6 +1507,7 @@ nir_instr_def(nir_instr *instr)
case nir_instr_type_call:
case nir_instr_type_jump:
case nir_instr_type_cmat_call:
return NULL;
}
@ -2879,6 +2906,7 @@ nir_instr_can_speculate(nir_instr *instr)
case nir_instr_type_call:
case nir_instr_type_jump:
case nir_instr_type_phi:
case nir_instr_type_cmat_call:
return false;
}

View file

@ -947,6 +947,7 @@ typedef enum ENUM_PACKED {
nir_instr_type_jump,
nir_instr_type_undef,
nir_instr_type_phi,
nir_instr_type_cmat_call,
} nir_instr_type;
typedef struct nir_instr {
@ -1833,6 +1834,24 @@ typedef struct nir_call_instr {
nir_src params[];
} nir_call_instr;
#define NIR_CMAT_CALL_MAX_CONST_INDEX 1
typedef enum {
nir_cmat_call_op_none,
} nir_cmat_call_op;
typedef struct nir_cmat_call_instr {
nir_instr instr;
nir_cmat_call_op op;
nir_function *callee;
int const_index[NIR_CMAT_CALL_MAX_CONST_INDEX];
unsigned num_params;
nir_src params[];
} nir_cmat_call_instr;
#include "nir_intrinsics.h"
#define NIR_INTRINSIC_MAX_CONST_INDEX 8
@ -2872,6 +2891,8 @@ NIR_DEFINE_CAST(nir_instr_as_undef, nir_instr, nir_undef_instr, instr,
type, nir_instr_type_undef)
NIR_DEFINE_CAST(nir_instr_as_phi, nir_instr, nir_phi_instr, instr,
type, nir_instr_type_phi)
NIR_DEFINE_CAST(nir_instr_as_cmat_call, nir_instr, nir_cmat_call_instr, instr,
type, nir_instr_type_cmat_call)
#define NIR_DEFINE_DEF_AS_INSTR(instr_type, suffix, cast) \
static inline instr_type *nir_def_as_##cast(const nir_def *def) \
@ -4129,6 +4150,11 @@ nir_intrinsic_instr *nir_intrinsic_instr_create(nir_shader *shader,
nir_call_instr *nir_call_instr_create(nir_shader *shader,
nir_function *callee);
int nir_cmat_call_op_params(nir_cmat_call_op op, nir_function *callee);
nir_cmat_call_instr *nir_cmat_call_instr_create(nir_shader *shader,
nir_cmat_call_op op,
nir_function *callee);
/** Creates a NIR texture instruction */
nir_tex_instr *nir_tex_instr_create(nir_shader *shader, unsigned num_srcs);

View file

@ -489,6 +489,19 @@ clone_call(clone_state *state, const nir_call_instr *call)
return ncall;
}
static nir_cmat_call_instr *
clone_cmat_call(clone_state *state, const nir_cmat_call_instr *call)
{
nir_function *ncallee = remap_global(state, call->callee);
nir_cmat_call_instr *ncall = nir_cmat_call_instr_create(state->ns, call->op, ncallee);
clone_debug_info(state, &ncall->instr, &call->instr);
for (unsigned i = 0; i < ncall->num_params; i++)
__clone_src(state, ncall, &ncall->params[i], &call->params[i]);
memcpy(ncall->const_index, call->const_index, sizeof(ncall->const_index));
return ncall;
}
static nir_instr *
clone_instr(clone_state *state, const nir_instr *instr)
{
@ -511,6 +524,8 @@ clone_instr(clone_state *state, const nir_instr *instr)
return &clone_jump(state, nir_instr_as_jump(instr))->instr;
case nir_instr_type_call:
return &clone_call(state, nir_instr_as_call(instr))->instr;
case nir_instr_type_cmat_call:
return &clone_cmat_call(state, nir_instr_as_cmat_call(instr))->instr;
default:
UNREACHABLE("bad instr type");
return NULL;

View file

@ -1218,6 +1218,7 @@ instr_is_loop_invariant(nir_instr *instr, struct divergence_state *state)
case nir_instr_type_tex:
return nir_foreach_src(instr, src_invariant, state->loop);
case nir_instr_type_call:
case nir_instr_type_cmat_call:
return false;
case nir_instr_type_phi:
default:

View file

@ -25,6 +25,7 @@ _nir_foreach_def(nir_instr *instr, nir_foreach_def_cb cb, void *state)
return cb(&nir_instr_as_undef(instr)->def, state);
case nir_instr_type_call:
case nir_instr_type_cmat_call:
case nir_instr_type_jump:
return true;
@ -115,7 +116,14 @@ nir_foreach_src(nir_instr *instr, nir_foreach_src_cb cb, void *state)
return false;
return true;
}
case nir_instr_type_cmat_call: {
nir_cmat_call_instr *call = nir_instr_as_cmat_call(instr);
for (unsigned i = 0; i < call->num_params; i++) {
if (!_nir_visit_src(&call->params[i], cb, state))
return false;
}
break;
}
case nir_instr_type_load_const:
case nir_instr_type_undef:
return true;

View file

@ -75,6 +75,7 @@ instr_can_rewrite(const nir_instr *instr)
}
}
case nir_instr_type_call:
case nir_instr_type_cmat_call:
case nir_instr_type_jump:
case nir_instr_type_undef:
return false;

View file

@ -68,6 +68,7 @@ is_live(BITSET_WORD *defs_live, nir_instr *instr)
{
switch (instr->type) {
case nir_instr_type_call:
case nir_instr_type_cmat_call:
case nir_instr_type_jump:
return true;
case nir_instr_type_alu: {

View file

@ -241,6 +241,7 @@ opt_move_discards_to_top_impl(nir_function_impl *impl)
continue;
case nir_instr_type_call:
case nir_instr_type_cmat_call:
instr->pass_flags = STOP_PROCESSING_INSTR_FLAG;
/* We don't know what the function will do */
goto break_all;

View file

@ -92,6 +92,7 @@ block_check_for_allowed_instrs(nir_block *block, unsigned *count,
}
case nir_instr_type_call:
case nir_instr_type_cmat_call:
case nir_instr_type_jump:
return false;
}

View file

@ -2051,6 +2051,26 @@ print_call_instr(nir_call_instr *instr, print_state *state)
}
}
static void
print_cmat_call_instr(nir_cmat_call_instr *instr, print_state *state)
{
FILE *fp = state->fp;
print_no_dest_padding(state);
fprintf(fp, "cmat_call %s ", instr->callee->name);
for (unsigned i = 0; i < instr->num_params; i++) {
if (i != 0)
fprintf(fp, ", ");
if (instr->callee->params[i].name)
fprintf(fp, "%s ", instr->callee->params[i].name);
print_src(&instr->params[i], state, nir_type_invalid);
}
}
static void
print_jump_instr(nir_jump_instr *instr, print_state *state)
{
@ -2168,6 +2188,10 @@ print_instr(const nir_instr *instr, print_state *state, unsigned tabs)
print_call_instr(nir_instr_as_call(instr), state);
break;
case nir_instr_type_cmat_call:
print_cmat_call_instr(nir_instr_as_cmat_call(instr), state);
break;
case nir_instr_type_intrinsic:
print_intrinsic_instr(nir_instr_as_intrinsic(instr), state);
break;
@ -2226,6 +2250,7 @@ block_has_instruction_with_dest(nir_block *block)
case nir_instr_type_jump:
case nir_instr_type_call:
case nir_instr_type_cmat_call:
/* Doesn't define a new value. */
break;
}

View file

@ -477,6 +477,7 @@ nir_schedule_calculate_deps(nir_deps_state *state, nir_schedule_node *n)
break;
case nir_instr_type_call:
case nir_instr_type_cmat_call:
UNREACHABLE("Calls should have been lowered");
break;
@ -1090,6 +1091,7 @@ nir_schedule_get_delay(nir_schedule_scoreboard *scoreboard, nir_instr *instr)
case nir_instr_type_deref:
case nir_instr_type_jump:
case nir_instr_type_call:
case nir_instr_type_cmat_call:
case nir_instr_type_phi:
return 1;

View file

@ -1612,6 +1612,36 @@ read_call(read_ctx *ctx)
return call;
}
static void
write_cmat_call(write_ctx *ctx, const nir_cmat_call_instr *call)
{
blob_write_uint32(ctx->blob, write_lookup_object(ctx, call->callee));
blob_write_uint32(ctx->blob, call->op);
for (unsigned i = 0; i < call->num_params; i++)
write_src(ctx, &call->params[i]);
for (unsigned i = 0; i < NIR_CMAT_CALL_MAX_CONST_INDEX; i++)
blob_write_uint32(ctx->blob, call->const_index[i]);
}
static nir_cmat_call_instr *
read_cmat_call(read_ctx *ctx)
{
nir_function *callee = read_object(ctx);
nir_cmat_call_op op = blob_read_uint32(ctx->blob);
nir_cmat_call_instr *call = nir_cmat_call_instr_create(ctx->nir, op, callee);
for (unsigned i = 0; i < call->num_params; i++)
read_src(ctx, &call->params[i]);
for (unsigned i = 0; i < NIR_CMAT_CALL_MAX_CONST_INDEX; i++)
call->const_index[i] = blob_read_uint32(ctx->blob);
return call;
}
enum nir_serialize_debug_info_flags {
NIR_SERIALIZE_FILENAME = 1 << 0,
NIR_SERIALIZE_VARIABLE_NAME = 1 << 1,
@ -1717,6 +1747,10 @@ write_instr(write_ctx *ctx, const nir_instr *instr)
blob_write_uint32(ctx->blob, instr->type);
write_call(ctx, nir_instr_as_call(instr));
break;
case nir_instr_type_cmat_call:
blob_write_uint32(ctx->blob, instr->type);
write_cmat_call(ctx, nir_instr_as_cmat_call(instr));
break;
default:
UNREACHABLE("bad instr type");
}
@ -1769,6 +1803,9 @@ read_instr(read_ctx *ctx, nir_block *block)
case nir_instr_type_call:
instr = &read_call(ctx)->instr;
break;
case nir_instr_type_cmat_call:
instr = &read_cmat_call(ctx)->instr;
break;
default:
UNREACHABLE("bad instr type");
}

View file

@ -1080,6 +1080,16 @@ validate_call_instr(nir_call_instr *instr, validate_state *state)
}
}
static void
validate_cmat_call_instr(nir_cmat_call_instr *instr, validate_state *state)
{
validate_assert(state, instr->num_params == nir_cmat_call_op_params(instr->op, instr->callee));
for (unsigned i = 0; i < instr->num_params; i++) {
validate_src(&instr->params[i], state);
}
}
static void
validate_const_value(nir_const_value *val, unsigned bit_size,
bool is_null_constant, validate_state *state)
@ -1228,6 +1238,10 @@ validate_instr(nir_instr *instr, validate_state *state)
validate_call_instr(nir_instr_as_call(instr), state);
break;
case nir_instr_type_cmat_call:
validate_cmat_call_instr(nir_instr_as_cmat_call(instr), state);
break;
case nir_instr_type_intrinsic:
validate_intrinsic_instr(nir_instr_as_intrinsic(instr), state);
break;

View file

@ -4299,6 +4299,7 @@ emit_instr(struct ir3_context *ctx, nir_instr *instr)
emit_phi(ctx, nir_instr_as_phi(instr));
break;
case nir_instr_type_call:
case nir_instr_type_cmat_call:
ir3_context_error(ctx, "Unhandled NIR instruction type: %d\n",
instr->type);
break;
@ -4470,6 +4471,7 @@ instr_can_be_predicated(nir_instr *instr)
case nir_instr_type_phi:
return true;
case nir_instr_type_call:
case nir_instr_type_cmat_call:
case nir_instr_type_jump:
return false;
case nir_instr_type_intrinsic: {

View file

@ -4522,6 +4522,9 @@ emit_block(struct ntv_context *ctx, struct nir_block *block)
case nir_instr_type_call:
UNREACHABLE("nir_instr_type_call not supported");
break;
case nir_instr_type_cmat_call:
UNREACHABLE("nir_instr_type_cmat_call not supported");
break;
case nir_instr_type_deref:
emit_deref(ctx, nir_instr_as_deref(instr));
break;

View file

@ -1729,6 +1729,7 @@ instr_to_msl(struct nir_to_msl_ctx *ctx, nir_instr *instr)
assert(!"We should have lowered derefs by now");
break;
case nir_instr_type_call:
case nir_instr_type_cmat_call:
assert(!"We should have inlined all functions by now");
break;
case nir_instr_type_tex: