diff --git a/.pick_status.json b/.pick_status.json index ceaef0e523e..8f04f1575dc 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -544,7 +544,7 @@ "description": "vtn: Remove transpose(m0)*m1 fast path", "nominated": true, "nomination_type": 0, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 04d71ae6eb2..a6b327f6a02 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -94,38 +94,16 @@ matrix_multiply(struct vtn_builder *b, transpose_result = true; } - if (src0_transpose && !src1_transpose && - glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) { - /* We already have the rows of src0 and the columns of src1 available, - * so we can just take the dot product of each row with each column to - * get the result. - */ - - for (unsigned i = 0; i < src1_columns; i++) { - nir_def *vec_src[4]; - for (unsigned j = 0; j < src0_rows; j++) { - vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def, - src1->elems[i]->def); - } - dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows); - } - } else { - /* We don't handle the case where src1 is transposed but not src0, since - * the general case only uses individual components of src1 so the - * optimizer should chew through the transpose we emitted for src1. - */ - - for (unsigned i = 0; i < src1_columns; i++) { - /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ + for (unsigned i = 0; i < src1_columns; i++) { + /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ + dest->elems[i]->def = + nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def, + nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1)); + for (int j = src0_columns - 2; j >= 0; j--) { dest->elems[i]->def = - nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def, - nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1)); - for (int j = src0_columns - 2; j >= 0; j--) { - dest->elems[i]->def = - nir_ffma(&b->nb, src0->elems[j]->def, - nir_channel(&b->nb, src1->elems[i]->def, j), - dest->elems[i]->def); - } + nir_ffma(&b->nb, src0->elems[j]->def, + nir_channel(&b->nb, src1->elems[i]->def, j), + dest->elems[i]->def); } }