diff --git a/src/gallium/frontends/teflon/tfl_device.c b/src/gallium/frontends/teflon/tfl_device.c index e6b26197ec7..5dcc47ce697 100644 --- a/src/gallium/frontends/teflon/tfl_device.c +++ b/src/gallium/frontends/teflon/tfl_device.c @@ -155,6 +155,12 @@ fill_operation(struct teflon_delegate *delegate, TfLiteContext *tf_context, TfLi case kTfLiteBuiltinAdd: operation->type = PIPE_ML_OPERATION_TYPE_ADD; break; + case kTfLiteBuiltinConcatenation: + operation->type = PIPE_ML_OPERATION_TYPE_CONCATENATION; + break; + case kTfLiteBuiltinSplit: + operation->type = PIPE_ML_OPERATION_TYPE_SPLIT; + break; default: unreachable("Unsupported ML operation type"); } @@ -224,9 +230,15 @@ dump_graph(struct pipe_tensor *tensors, unsigned tensor_count, struct pipe_ml_op case PIPE_ML_OPERATION_TYPE_CONVOLUTION: teflon_debug("%-6s ", operations[i].conv.depthwise ? "DWCONV" : "CONV"); break; + case PIPE_ML_OPERATION_TYPE_CONCATENATION: + teflon_debug("%-6s ", "CONCAT"); + break; case PIPE_ML_OPERATION_TYPE_POOLING: teflon_debug("%-6s ", "POOL"); break; + case PIPE_ML_OPERATION_TYPE_SPLIT: + teflon_debug("%-6s ", "SPLIT"); + break; } for (unsigned j = 0; j < operations[i].input_count; j++) { @@ -516,6 +528,36 @@ PrepareDelegate(TfLiteContext *context, TfLiteDelegate *delegate) case kTfLiteBuiltinAdd: supported = true; break; + case kTfLiteBuiltinConcatenation: { + TfLiteConcatenationParams *params = node->builtin_data; + supported = true; + + if (params->axis != 3 && + params->axis != -1) + supported = false; + + unsigned input_channels = context->tensors[node->inputs->data[0]].dims->data[3]; + for (unsigned i = 1; i < node->inputs->size; i++) + if (input_channels != context->tensors[node->inputs->data[i]].dims->data[3]) + supported = false; + + break; + } + case kTfLiteBuiltinSplit: { + int32_t axis = context->tensors[node->inputs->data[0]].data.i32[0]; + supported = true; + + if (axis != 3 && + axis != -1) + supported = false; + + unsigned output_channels = context->tensors[node->outputs->data[0]].dims->data[3]; + for (unsigned i = 1; i < node->outputs->size; i++) + if (output_channels != context->tensors[node->outputs->data[i]].dims->data[3]) + supported = false; + + break; + } } if (supported) diff --git a/src/gallium/include/pipe/p_state.h b/src/gallium/include/pipe/p_state.h index e1819e8c396..a497198708c 100644 --- a/src/gallium/include/pipe/p_state.h +++ b/src/gallium/include/pipe/p_state.h @@ -1059,6 +1059,8 @@ enum pipe_ml_operation_type { PIPE_ML_OPERATION_TYPE_ADD, PIPE_ML_OPERATION_TYPE_CONVOLUTION, PIPE_ML_OPERATION_TYPE_POOLING, + PIPE_ML_OPERATION_TYPE_CONCATENATION, + PIPE_ML_OPERATION_TYPE_SPLIT, }; /**