diff --git a/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c b/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c index 67ebd2ee6b7..bc80645cae1 100644 --- a/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c +++ b/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c @@ -770,11 +770,6 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation &output_width, &output_height, &output_channels); } - if (input_height > input_width) { - SWAP(input_width, input_height); - SWAP(output_width, output_height); - } - if (operation->fully_connected) { unsigned original_input_width = input_width; input_width = 15; @@ -787,6 +782,9 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation input_channels = original_input_height / input_height; weight_width = input_width; weight_height = input_height; + } else { + SWAP(input_width, input_height); + SWAP(output_width, output_height); } etna_bo_cpu_prep(bo, DRM_ETNA_PREP_WRITE); diff --git a/src/gallium/drivers/etnaviv/etnaviv_ml_tp.c b/src/gallium/drivers/etnaviv/etnaviv_ml_tp.c index c02f63a727e..6bbfc0495f0 100644 --- a/src/gallium/drivers/etnaviv/etnaviv_ml_tp.c +++ b/src/gallium/drivers/etnaviv/etnaviv_ml_tp.c @@ -349,11 +349,13 @@ split_reshuffle(struct etna_ml_subgraph *subgraph, const struct etna_operation * unsigned remaining_out_size, remaining_in_size; unsigned dim_to_split = 0; - if (out_dims[1] >= out_dims[dim_to_split]) - dim_to_split = 1; + if (operation->input_channels >= out_dims[dim_to_split]) { + if (out_dims[1] >= out_dims[dim_to_split]) + dim_to_split = 1; - if (out_dims[2] >= out_dims[dim_to_split]) - dim_to_split = 2; + if (out_dims[2] >= out_dims[dim_to_split]) + dim_to_split = 2; + } remaining_in_size = in_dims[dim_to_split]; remaining_out_size = out_dims[dim_to_split]; @@ -429,10 +431,8 @@ create_reshuffle_config(struct etna_ml_subgraph *subgraph, const struct etna_ope set_default_tp_config(map); - if (input_height > input_width) { - SWAP(input_width, input_height); - SWAP(output_width, output_height); - } + SWAP(input_width, input_height); + SWAP(output_width, output_height); in_dims[0] = input_width; in_dims[1] = input_height;