diff --git a/src/gallium/drivers/ethosu/ethosu_lower.c b/src/gallium/drivers/ethosu/ethosu_lower.c index 2cd7de5047a..0e185b5d652 100644 --- a/src/gallium/drivers/ethosu/ethosu_lower.c +++ b/src/gallium/drivers/ethosu/ethosu_lower.c @@ -306,6 +306,115 @@ ethos_create_lut(struct ethosu_operation *operation, uint8_t *lut, double (*func } } +// Implementation from Vela and TensorFlow Lite Micro kernel +static int16_t +saturating_left_shift_16(int16_t value, int amount) +{ + int32_t result = value << amount; + return CLAMP(result, INT16_MIN, INT16_MAX); +} + +// Implementation from Vela and TensorFlow Lite Micro kernel +// Similar to ARM instruction SQDMULH. +// Similar to gemmlowp::SaturatingRoundingDoublingHighMul except +// rounding to zero instead of to nearest (SQRDMULH). +static int16_t +saturating_doubling_high_mul_16(int16_t a, int16_t b) +{ + bool overflow = a == b && a == INT16_MIN; + int32_t a_32 = a; + int32_t b_32 = b; + int32_t ab_32 = a_32 * b_32; + int16_t ab_x2_high16 = (int16_t)(ab_32 / (1 << 15)); + return overflow ? INT16_MAX : ab_x2_high16; +} + +static int16_t +saturating_rounding_doubling_high_mul_16(int16_t a, int16_t b) +{ + bool overflow = a == b && a == INT16_MIN; + int32_t a_32 = a; + int32_t b_32 = b; + int32_t ab_32 = a_32 * b_32; + int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); + int16_t ab_x2_high16 = ((ab_32 + nudge) / (1 << 15)); + return overflow ? INT16_MAX : ab_x2_high16; +} + +static int16_t +rounding_divide_by_pow2_16(int16_t x, int exponent) +{ + const int16_t mask = (1 << exponent) - 1; + const int16_t remainder = x & mask; + const int16_t threshold = (mask >> 1) + ((x < 0) ? 1 : 0); + return (x >> exponent) + ((remainder > threshold) ? 1 : 0); +} + +static int16_t +downscale_int32_to_int16_multiplier(int32_t multiplier) +{ + return CLAMP(((multiplier / 32768) + 1) / 2, INT16_MIN, INT16_MAX); +} + +static void +ethos_create_hswish_lut(struct ethosu_operation *operation, uint8_t *lut) +{ + const double ifm_scale = operation->ifm.scale; + const double ofm_scale = operation->ofm.scale; + const unsigned zpIn = operation->ifm.zero_point; + const unsigned zpOut = operation->ofm.zero_point; + + const int qMin = operation->ifm.is_signed ? -128 : 0; + const int qMax = operation->ifm.is_signed ? 127 : 255; + + const double ifmScaleHires = (1.0 / 128.0) * ifm_scale; + const double reluMultiplier = 3.0 / 32768.0; + + int32_t out_shift; + int32_t relu_shift; + int32_t out_scale = ethosu_quantize_scale(ifmScaleHires / ofm_scale, &out_shift, false); + int32_t relu_scale = ethosu_quantize_scale(ifmScaleHires / reluMultiplier, &relu_shift, false); + int16_t outScale16 = downscale_int32_to_int16_multiplier(out_scale); + int16_t reluScale16 = downscale_int32_to_int16_multiplier(relu_scale); + // convert to left shift-positive notation + int outShift = 31 - out_shift; + int reluShift = 31 - relu_shift; + + for (int x = qMin; x <= qMax; ++x, lut++) { + const int16_t inputValue = (int16_t)(x - zpIn); + const int16_t inputValueOnHiresInputScale = (int16_t)(inputValue << 7); + const int16_t inputValueOnPreshiftOutputScale = saturating_rounding_doubling_high_mul_16(inputValueOnHiresInputScale, outScale16); + int16_t reluValue = inputValueOnHiresInputScale; + + if (reluShift > 0) + reluValue = saturating_left_shift_16(reluValue, reluShift - 1); + + reluValue = saturating_rounding_doubling_high_mul_16(reluValue, reluScale16); + + if (reluShift > 0) + reluValue = saturating_left_shift_16(reluValue, 1); + + // Try to get reluShift into the [-31, 0] range + if (reluShift < -31) { + reluValue = reluValue >> (-31 - reluShift); + reluShift = -31; + } + + if (reluShift < 0) + reluValue = rounding_divide_by_pow2_16(reluValue, -reluShift); + + reluValue = (int16_t)((reluValue + (1 << 15)) >> 1); + + const int16_t preshiftOutputValue = saturating_doubling_high_mul_16(reluValue, inputValueOnPreshiftOutputScale); + + int16_t outputValue = rounding_divide_by_pow2_16(preshiftOutputValue, -outShift); + + int lutVal = outputValue + zpOut; + lutVal = MIN2(qMax, MAX2(qMin, lutVal)); + *lut = lutVal; + } +} + static void ethosu_lower_lut_dma(struct ethosu_subgraph *subgraph, const struct pipe_ml_operation *poperation, @@ -344,6 +453,31 @@ ethosu_lower_lut(struct ethosu_subgraph *subgraph, ethosu_sched_operation(subgraph, operation); } +static void +ethosu_lower_hswish(struct ethosu_subgraph *subgraph, + const struct pipe_ml_operation *poperation, + struct ethosu_operation *operation) +{ + uint8_t lut[LUT8_SIZE]; + + operation->type = ETHOSU_OPERATION_TYPE_POOLING; + operation->round_mode = ETHOSU_ROUNDING_NATURAL; + operation->pooling.type = ETHOSU_POOLING_TYPE_AVG; + operation->pooling.activation = ETHOSU_POOLING_ACTIVATION_LUT(0); + + set_feature_maps(subgraph, poperation->input_tensors[0], poperation->output_tensors[0], operation); + + ethos_create_hswish_lut(operation, lut); + fill_lut(subgraph, operation, lut); + + /* The LUT handles 0 point and scale, so make them equal */ + operation->ofm.zero_point = operation->ifm.zero_point; + operation->ofm.scale = operation->ifm.scale; + + allocate_feature_maps(subgraph, operation); + ethosu_sched_operation(subgraph, operation); +} + static void ethosu_lower_concatenation(struct ethosu_subgraph *subgraph, const struct pipe_ml_operation *poperation, @@ -629,6 +763,17 @@ ethosu_lower_graph(struct ethosu_subgraph *subgraph, break; } + case PIPE_ML_OPERATION_TYPE_HSWISH: { + ethosu_lower_hswish(subgraph, &poperations[i], &operation); + + struct ethosu_operation dma_operation = {0}; + ethosu_lower_lut_dma(subgraph, &poperations[i], &operation, &dma_operation); + util_dynarray_append(&subgraph->operations, dma_operation); + + util_dynarray_append(&subgraph->operations, operation); + break; + } + case PIPE_ML_OPERATION_TYPE_STRIDED_SLICE: { ethosu_lower_strided_slice(subgraph, &poperations[i], &operation); util_dynarray_append(&subgraph->operations, operation); diff --git a/src/gallium/drivers/ethosu/ethosu_ml.c b/src/gallium/drivers/ethosu/ethosu_ml.c index 212b55664d1..a2fbed926bb 100644 --- a/src/gallium/drivers/ethosu/ethosu_ml.c +++ b/src/gallium/drivers/ethosu/ethosu_ml.c @@ -147,6 +147,7 @@ ethosu_ml_operation_supported(struct pipe_ml_device *pdevice, case PIPE_ML_OPERATION_TYPE_PAD: case PIPE_ML_OPERATION_TYPE_LOGISTIC: case PIPE_ML_OPERATION_TYPE_TANH: + case PIPE_ML_OPERATION_TYPE_HSWISH: supported = true; break; case PIPE_ML_OPERATION_TYPE_RESIZE: {