mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-03-16 06:10:45 +01:00
spirv: add initial support for cooperative matrix per-element ops
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36992>
This commit is contained in:
parent
a57753dca2
commit
1ba49c3594
2 changed files with 37 additions and 0 deletions
|
|
@ -74,6 +74,7 @@ static const struct spirv_capabilities implemented_capabilities = {
|
|||
.CooperativeMatrixKHR = true,
|
||||
.CooperativeMatrixConversionsNV = true,
|
||||
.CooperativeMatrixReductionsNV = true,
|
||||
.CooperativeMatrixPerElementOperationsNV = true,
|
||||
.CoreBuiltinsARM = true,
|
||||
.CullDistance = true,
|
||||
.DemoteToHelperInvocation = true,
|
||||
|
|
@ -7011,6 +7012,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpCooperativeMatrixConvertNV:
|
||||
case SpvOpCooperativeMatrixTransposeNV:
|
||||
case SpvOpCooperativeMatrixReduceNV:
|
||||
case SpvOpCooperativeMatrixPerElementOpNV:
|
||||
vtn_handle_cooperative_instruction(b, opcode, w, count);
|
||||
break;
|
||||
|
||||
|
|
|
|||
|
|
@ -238,6 +238,41 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||
break;
|
||||
}
|
||||
|
||||
case SpvOpCooperativeMatrixPerElementOpNV: {
|
||||
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
|
||||
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
|
||||
|
||||
struct vtn_function *per_element_fn = vtn_value(b, w[4], vtn_value_type_function)->func;
|
||||
|
||||
per_element_fn->referenced = true;
|
||||
per_element_fn->nir_func->cmat_call = true;
|
||||
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_per_element_nv");
|
||||
|
||||
nir_cmat_call_instr *call = nir_cmat_call_instr_create(b->nb.shader, nir_cmat_call_op_per_element_op, per_element_fn->nir_func);
|
||||
call->params[0] = nir_src_for_ssa(&dst->def);
|
||||
call->params[1] = nir_src_for_ssa(nir_imm_zero(&b->nb, 1, 32));
|
||||
call->params[2] = nir_src_for_ssa(nir_imm_zero(&b->nb, 1, 32));
|
||||
call->params[3] = nir_src_for_ssa(&src->def);
|
||||
|
||||
for (unsigned i = 0; i < count - 5; i++) {
|
||||
struct vtn_ssa_value *ssa = vtn_ssa_value(b, w[5 + i]);
|
||||
nir_def *def;
|
||||
nir_deref_instr *deref = NULL;
|
||||
|
||||
if (ssa->is_variable) {
|
||||
deref = nir_build_deref_var(&b->nb, ssa->var);
|
||||
def = &deref->def;
|
||||
} else if (glsl_type_is_vector_or_scalar(ssa->type)) {
|
||||
def = ssa->def;
|
||||
} else
|
||||
def = ssa->elems[0]->def;
|
||||
|
||||
call->params[4 + i] = nir_src_for_ssa(def);
|
||||
}
|
||||
nir_builder_instr_insert(&b->nb, &call->instr);
|
||||
vtn_push_var_ssa(b, w[2], dst->var);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
UNREACHABLE("Unexpected opcode for cooperative matrix instruction");
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue