diff --git a/src/gallium/frontends/teflon/tfl_device.c b/src/gallium/frontends/teflon/tfl_device.c index 4a082d0c76e..1fc2238f77d 100644 --- a/src/gallium/frontends/teflon/tfl_device.c +++ b/src/gallium/frontends/teflon/tfl_device.c @@ -334,14 +334,28 @@ fill_tensor(struct teflon_delegate *delegate, TfLiteContext *tf_context, struct tensor->scale = quant->scale->data[0]; tensor->zero_point = quant->zero_point->data[0]; - assert(quant->scale->size == quant->zero_point->size); - if (quant->scale->size > 1 && - (!all_scales_equal(quant) || !all_zero_points_equal(quant))) { + /* Handle per-channel quantization */ + if (quant->scale->size > 1 && !all_scales_equal(quant)) { tensor->scales = calloc(quant->scale->size, sizeof(*tensor->scales)); memcpy(tensor->scales, quant->scale->data, quant->scale->size * sizeof(*tensor->scales)); - tensor->zero_points = calloc(quant->zero_point->size, sizeof(*tensor->zero_points)); - memcpy(tensor->zero_points, quant->zero_point->data, quant->zero_point->size * sizeof(*tensor->zero_points)); + tensor->zero_points = calloc(quant->scale->size, sizeof(*tensor->zero_points)); + if (quant->zero_point->size == quant->scale->size) { + /* Same number of zero_points as scales - copy directly */ + memcpy(tensor->zero_points, quant->zero_point->data, quant->scale->size * sizeof(*tensor->zero_points)); + } else if (quant->zero_point->size == 1) { + /* Single zero_point for all channels (common for symmetric quantization) - replicate it */ + for (int i = 0; i < quant->scale->size; i++) { + tensor->zero_points[i] = quant->zero_point->data[0]; + } + } else { + /* Unexpected case - use first zero_point for all */ + fprintf(stderr, "teflon: WARNING: tensor %d has %d scales but %d zero_points, using first zero_point for all\n", + index, quant->scale->size, quant->zero_point->size); + for (int i = 0; i < quant->scale->size; i++) { + tensor->zero_points[i] = quant->zero_point->data[0]; + } + } } } @@ -720,9 +734,10 @@ fused_relu6_supported(TfLiteTensor *tensor) assert(tensor->quantization.type == kTfLiteAffineQuantization); affine = (TfLiteAffineQuantization *)tensor->quantization.params; - assert(affine->scale->size == affine->zero_point->size); - for (int i = 0; i < affine->zero_point->size; i++) { - if ((quantized_max - affine->zero_point->data[i]) * affine->scale->data[i] > 6.0f) + /* Handle per-channel quantization where zero_point->size may be 1 */ + for (int i = 0; i < affine->scale->size; i++) { + int zp_idx = (affine->zero_point->size == 1) ? 0 : i; + if ((quantized_max - affine->zero_point->data[zp_idx]) * affine->scale->data[i] > 6.0f) return false; } return true;