diff --git a/src/gallium/frontends/teflon/tfl_device.c b/src/gallium/frontends/teflon/tfl_device.c index a21f95ba68d..6394bb3285a 100644 --- a/src/gallium/frontends/teflon/tfl_device.c +++ b/src/gallium/frontends/teflon/tfl_device.c @@ -166,9 +166,22 @@ fill_operation(struct teflon_delegate *delegate, TfLiteContext *tf_context, TfLi operation->conv.weight_tensor->dims[2] == 1; break; } - case kTfLiteBuiltinAveragePool2d: + case kTfLiteBuiltinMaxPool2d: + operation->pooling.type = PIPE_ML_POOLING_TYPE_MAX; + __attribute__((fallthrough)); + case kTfLiteBuiltinAveragePool2d: { + TfLitePoolParams *params = node->builtin_data; + operation->type = PIPE_ML_OPERATION_TYPE_POOLING; + + /* Skip setting operation->pooling.type for PIPE_ML_POOLING_TYPE_AVG==0 */ + operation->pooling.filter_height = params->filter_height; + operation->pooling.filter_width = params->filter_width; + operation->pooling.stride_x = params->stride_width; + operation->pooling.stride_y = params->stride_height; + operation->pooling.padding_same = params->padding == kTfLitePaddingSame; break; + } case kTfLiteBuiltinAdd: operation->type = PIPE_ML_OPERATION_TYPE_ADD; break; @@ -573,6 +586,8 @@ tflite_builtin_op_name(TfLiteBuiltinOperator op) return "ADD"; case kTfLiteBuiltinAveragePool2d: return "AVGPOOL"; + case kTfLiteBuiltinMaxPool2d: + return "MAXPOOL"; case kTfLiteBuiltinConv2d: return "CONV"; case kTfLiteBuiltinDepthwiseConv2d: diff --git a/src/gallium/include/pipe/p_state.h b/src/gallium/include/pipe/p_state.h index 208ac317400..d9c887bab9f 100644 --- a/src/gallium/include/pipe/p_state.h +++ b/src/gallium/include/pipe/p_state.h @@ -1056,6 +1056,11 @@ enum pipe_ml_operation_type { PIPE_ML_OPERATION_TYPE_TRANSPOSE, }; +enum pipe_ml_pooling_type { + PIPE_ML_POOLING_TYPE_AVG, + PIPE_ML_POOLING_TYPE_MAX, +}; + /** * Information about a single operation inside a ML subgraph. */ @@ -1124,6 +1129,12 @@ struct pipe_ml_operation unsigned dilation_height_factor; } conv; struct { + + /** + * Type of pooling operation. + */ + enum pipe_ml_pooling_type type; + /** * Stride used to access the input tensor on the x axis. */