ethosu: Add hard swish operation

Hard swish lowers to a pooling operation with a LUT.

Signed-off-by: Rob Herring (Arm) <robh@kernel.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39975>
This commit is contained in:
Rob Herring (Arm) 2026-03-18 17:23:08 -05:00 committed by Marge Bot
parent f2800fe13b
commit 3487b15312
2 changed files with 146 additions and 0 deletions

View file

@ -306,6 +306,115 @@ ethos_create_lut(struct ethosu_operation *operation, uint8_t *lut, double (*func
}
}
// Implementation from Vela and TensorFlow Lite Micro kernel
static int16_t
saturating_left_shift_16(int16_t value, int amount)
{
int32_t result = value << amount;
return CLAMP(result, INT16_MIN, INT16_MAX);
}
// Implementation from Vela and TensorFlow Lite Micro kernel
// Similar to ARM instruction SQDMULH.
// Similar to gemmlowp::SaturatingRoundingDoublingHighMul except
// rounding to zero instead of to nearest (SQRDMULH).
static int16_t
saturating_doubling_high_mul_16(int16_t a, int16_t b)
{
bool overflow = a == b && a == INT16_MIN;
int32_t a_32 = a;
int32_t b_32 = b;
int32_t ab_32 = a_32 * b_32;
int16_t ab_x2_high16 = (int16_t)(ab_32 / (1 << 15));
return overflow ? INT16_MAX : ab_x2_high16;
}
static int16_t
saturating_rounding_doubling_high_mul_16(int16_t a, int16_t b)
{
bool overflow = a == b && a == INT16_MIN;
int32_t a_32 = a;
int32_t b_32 = b;
int32_t ab_32 = a_32 * b_32;
int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
int16_t ab_x2_high16 = ((ab_32 + nudge) / (1 << 15));
return overflow ? INT16_MAX : ab_x2_high16;
}
static int16_t
rounding_divide_by_pow2_16(int16_t x, int exponent)
{
const int16_t mask = (1 << exponent) - 1;
const int16_t remainder = x & mask;
const int16_t threshold = (mask >> 1) + ((x < 0) ? 1 : 0);
return (x >> exponent) + ((remainder > threshold) ? 1 : 0);
}
static int16_t
downscale_int32_to_int16_multiplier(int32_t multiplier)
{
return CLAMP(((multiplier / 32768) + 1) / 2, INT16_MIN, INT16_MAX);
}
static void
ethos_create_hswish_lut(struct ethosu_operation *operation, uint8_t *lut)
{
const double ifm_scale = operation->ifm.scale;
const double ofm_scale = operation->ofm.scale;
const unsigned zpIn = operation->ifm.zero_point;
const unsigned zpOut = operation->ofm.zero_point;
const int qMin = operation->ifm.is_signed ? -128 : 0;
const int qMax = operation->ifm.is_signed ? 127 : 255;
const double ifmScaleHires = (1.0 / 128.0) * ifm_scale;
const double reluMultiplier = 3.0 / 32768.0;
int32_t out_shift;
int32_t relu_shift;
int32_t out_scale = ethosu_quantize_scale(ifmScaleHires / ofm_scale, &out_shift, false);
int32_t relu_scale = ethosu_quantize_scale(ifmScaleHires / reluMultiplier, &relu_shift, false);
int16_t outScale16 = downscale_int32_to_int16_multiplier(out_scale);
int16_t reluScale16 = downscale_int32_to_int16_multiplier(relu_scale);
// convert to left shift-positive notation
int outShift = 31 - out_shift;
int reluShift = 31 - relu_shift;
for (int x = qMin; x <= qMax; ++x, lut++) {
const int16_t inputValue = (int16_t)(x - zpIn);
const int16_t inputValueOnHiresInputScale = (int16_t)(inputValue << 7);
const int16_t inputValueOnPreshiftOutputScale = saturating_rounding_doubling_high_mul_16(inputValueOnHiresInputScale, outScale16);
int16_t reluValue = inputValueOnHiresInputScale;
if (reluShift > 0)
reluValue = saturating_left_shift_16(reluValue, reluShift - 1);
reluValue = saturating_rounding_doubling_high_mul_16(reluValue, reluScale16);
if (reluShift > 0)
reluValue = saturating_left_shift_16(reluValue, 1);
// Try to get reluShift into the [-31, 0] range
if (reluShift < -31) {
reluValue = reluValue >> (-31 - reluShift);
reluShift = -31;
}
if (reluShift < 0)
reluValue = rounding_divide_by_pow2_16(reluValue, -reluShift);
reluValue = (int16_t)((reluValue + (1 << 15)) >> 1);
const int16_t preshiftOutputValue = saturating_doubling_high_mul_16(reluValue, inputValueOnPreshiftOutputScale);
int16_t outputValue = rounding_divide_by_pow2_16(preshiftOutputValue, -outShift);
int lutVal = outputValue + zpOut;
lutVal = MIN2(qMax, MAX2(qMin, lutVal));
*lut = lutVal;
}
}
static void
ethosu_lower_lut_dma(struct ethosu_subgraph *subgraph,
const struct pipe_ml_operation *poperation,
@ -344,6 +453,31 @@ ethosu_lower_lut(struct ethosu_subgraph *subgraph,
ethosu_sched_operation(subgraph, operation);
}
static void
ethosu_lower_hswish(struct ethosu_subgraph *subgraph,
const struct pipe_ml_operation *poperation,
struct ethosu_operation *operation)
{
uint8_t lut[LUT8_SIZE];
operation->type = ETHOSU_OPERATION_TYPE_POOLING;
operation->round_mode = ETHOSU_ROUNDING_NATURAL;
operation->pooling.type = ETHOSU_POOLING_TYPE_AVG;
operation->pooling.activation = ETHOSU_POOLING_ACTIVATION_LUT(0);
set_feature_maps(subgraph, poperation->input_tensors[0], poperation->output_tensors[0], operation);
ethos_create_hswish_lut(operation, lut);
fill_lut(subgraph, operation, lut);
/* The LUT handles 0 point and scale, so make them equal */
operation->ofm.zero_point = operation->ifm.zero_point;
operation->ofm.scale = operation->ifm.scale;
allocate_feature_maps(subgraph, operation);
ethosu_sched_operation(subgraph, operation);
}
static void
ethosu_lower_concatenation(struct ethosu_subgraph *subgraph,
const struct pipe_ml_operation *poperation,
@ -629,6 +763,17 @@ ethosu_lower_graph(struct ethosu_subgraph *subgraph,
break;
}
case PIPE_ML_OPERATION_TYPE_HSWISH: {
ethosu_lower_hswish(subgraph, &poperations[i], &operation);
struct ethosu_operation dma_operation = {0};
ethosu_lower_lut_dma(subgraph, &poperations[i], &operation, &dma_operation);
util_dynarray_append(&subgraph->operations, dma_operation);
util_dynarray_append(&subgraph->operations, operation);
break;
}
case PIPE_ML_OPERATION_TYPE_STRIDED_SLICE: {
ethosu_lower_strided_slice(subgraph, &poperations[i], &operation);
util_dynarray_append(&subgraph->operations, operation);

View file

@ -147,6 +147,7 @@ ethosu_ml_operation_supported(struct pipe_ml_device *pdevice,
case PIPE_ML_OPERATION_TYPE_PAD:
case PIPE_ML_OPERATION_TYPE_LOGISTIC:
case PIPE_ML_OPERATION_TYPE_TANH:
case PIPE_ML_OPERATION_TYPE_HSWISH:
supported = true;
break;
case PIPE_ML_OPERATION_TYPE_RESIZE: {