diff --git a/src/gallium/drivers/ethosu/ethosu_cmd.c b/src/gallium/drivers/ethosu/ethosu_cmd.c index d092603cd96..32719c05cb0 100644 --- a/src/gallium/drivers/ethosu/ethosu_cmd.c +++ b/src/gallium/drivers/ethosu/ethosu_cmd.c @@ -137,23 +137,9 @@ emit_strides( struct ethosu_feature_map *feature_map, uint32_t cmd_stride_c, uint32_t cmd_stride_y, uint32_t cmd_stride_x) { - unsigned elem_size = 1; - unsigned tensor_x, tensor_y, tensor_c; - struct ethosu_tensor *tensor = feature_map->tensor; - - if (tensor->layout == ETHOSU_LAYOUT_NHCWB16) { - tensor_x = 16 * elem_size; - tensor_c = tensor_x * tensor->shape.width; - tensor_y = elem_size * tensor->shape.width * align(tensor->shape.depth, 16); - } else { - tensor_c = elem_size; - tensor_x = tensor->shape.depth * tensor_c; - tensor_y = tensor->shape.width * tensor_x; - } - - EMIT1(cmd_stride_y, 0x0, tensor_y); - EMIT1(cmd_stride_x, 0x0, tensor_x); - EMIT1(cmd_stride_c, 0x0, tensor_c); + EMIT1(cmd_stride_y, 0x0, feature_map->stride.y); + EMIT1(cmd_stride_x, 0x0, feature_map->stride.x); + EMIT1(cmd_stride_c, 0x0, feature_map->stride.c); } static void diff --git a/src/gallium/drivers/ethosu/ethosu_lower.c b/src/gallium/drivers/ethosu/ethosu_lower.c index 7ab313aadab..3342a53aa39 100644 --- a/src/gallium/drivers/ethosu/ethosu_lower.c +++ b/src/gallium/drivers/ethosu/ethosu_lower.c @@ -30,6 +30,22 @@ needed_total_padding(int input_size, int stride, int filter_size) return MAX2(filter_size - (input_size % stride), 0); } +static void +set_feature_map_strides(struct ethosu_feature_map *fm, bool is_nhcwb16) +{ + unsigned elem_size = 1; + + if (is_nhcwb16) { + fm->stride.x = 16 * elem_size; + fm->stride.c = fm->stride.x * fm->shape.width; + fm->stride.y = elem_size * fm->shape.width * align(fm->shape.depth, 16); + } else { + fm->stride.c = elem_size; + fm->stride.x = fm->shape.depth * fm->stride.c; + fm->stride.y = fm->shape.width * fm->stride.x; + } +} + static void set_feature_map(struct ethosu_subgraph *subgraph, struct pipe_tensor *tensor, @@ -43,6 +59,8 @@ set_feature_map(struct ethosu_subgraph *subgraph, fm->scale = tensor->scale; fm->is_signed = tensor->is_signed; fm->precision = log2(tensor->type_size); + + set_feature_map_strides(fm, fm->tensor->layout == ETHOSU_LAYOUT_NHCWB16); } static void diff --git a/src/gallium/drivers/ethosu/ethosu_ml.h b/src/gallium/drivers/ethosu/ethosu_ml.h index e4dca781211..bb4b7be8ab4 100644 --- a/src/gallium/drivers/ethosu/ethosu_ml.h +++ b/src/gallium/drivers/ethosu/ethosu_ml.h @@ -66,11 +66,18 @@ enum ethosu_upscale_mode { ETHOSU_UPSCALE_ZEROS = 2, }; +struct ethosu_stride { + unsigned x; + unsigned y; + unsigned c; +}; + struct ethosu_tensor; struct ethosu_feature_map { struct ethosu_tensor *tensor; struct ethosu_block shape; + struct ethosu_stride stride; bool is_signed; uint8_t precision; struct ethosu_tile_box tiles;