teflon: Add tests for FullyConnected

Same as we do with convolutions and additions.

Reviewed-by: Philipp Zabel <p.zabel@pengutronix.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32510>
This commit is contained in:
Tomeu Vizoso 2024-11-13 14:48:23 +01:00 committed by Marge Bot
parent 3d8f108514
commit ad82a7c388
4 changed files with 163 additions and 0 deletions

View file

@ -250,6 +250,91 @@ add_generate_model(int input_size,
return buf;
}
static void
patch_fully_connected(unsigned operation_index,
tflite::ModelT *model,
int input_size,
int output_channels,
bool is_signed)
{
unsigned input_index;
unsigned weights_index;
unsigned bias_index;
unsigned output_index;
unsigned weights_buffer_index;
unsigned bias_buffer_index;
auto subgraph = model->subgraphs[0];
/* Operation */
auto value = new tflite::FullyConnectedOptionsT();
subgraph->operators[operation_index]->builtin_options.value = value;
input_index = subgraph->operators[operation_index]->inputs.data()[0];
weights_index = subgraph->operators[operation_index]->inputs.data()[1];
bias_index = subgraph->operators[operation_index]->inputs.data()[2];
output_index = subgraph->operators[operation_index]->outputs.data()[0];
/* Input */
auto input_tensor = subgraph->tensors[input_index];
input_tensor->shape.data()[0] = 1;
input_tensor->shape.data()[1] = input_size;
input_tensor->type = is_signed ? tflite::TensorType_INT8 : tflite::TensorType_UINT8;
/* Bias */
auto bias_tensor = subgraph->tensors[bias_index];
bias_buffer_index = bias_tensor->buffer;
bias_tensor->shape.data()[0] = output_channels;
auto bias_data = &model->buffers[bias_buffer_index]->data;
xt::xarray<int32_t> bias_array = xt::random::randint<int32_t>({output_channels}, -20000, 20000);
bias_data->resize(bias_array.size() * sizeof(int32_t));
memcpy(bias_data->data(), bias_array.data(), bias_array.size() * sizeof(int32_t));
/* Weight */
auto weight_tensor = subgraph->tensors[weights_index];
weights_buffer_index = weight_tensor->buffer;
weight_tensor->shape.data()[0] = output_channels;
weight_tensor->shape.data()[1] = input_size;
weight_tensor->type = is_signed ? tflite::TensorType_INT8 : tflite::TensorType_UINT8;
auto weights_data = &model->buffers[weights_buffer_index]->data;
std::vector<int> weight_shape;
weight_shape = {output_channels, input_size};
xt::xarray<uint8_t> weights_array = xt::random::randint<uint8_t>(weight_shape, 0, 255);
weights_data->resize(weights_array.size());
memcpy(weights_data->data(), weights_array.data(), weights_array.size());
/* Output */
auto output_tensor = subgraph->tensors[output_index];
output_tensor->shape.data()[0] = 1;
output_tensor->shape.data()[1] = output_channels;
output_tensor->type = is_signed ? tflite::TensorType_INT8 : tflite::TensorType_UINT8;
}
void *
fully_connected_generate_model(int input_size,
int output_channels,
bool is_signed,
size_t *buf_size)
{
void *buf;
tflite::ModelT model;
read_model("fully_connected.tflite", model);
patch_fully_connected(0, &model, input_size, output_channels, is_signed);
flatbuffers::FlatBufferBuilder builder;
builder.Finish(tflite::Model::Pack(builder, &model), "TFL3");
*buf_size = builder.GetSize();
buf = malloc(*buf_size);
memcpy(buf, builder.GetBufferPointer(), builder.GetSize());
return buf;
}
static void
tflite_error_cb(void *user_data, const char *format, va_list args)
{

View file

@ -31,6 +31,11 @@ void *add_generate_model(int input_size,
bool depthwise,
size_t *buf_size);
void *fully_connected_generate_model(int input_size,
int output_channels,
bool is_signed,
size_t *buf_size);
void run_model(TfLiteModel *model, enum executor executor, void ***input, size_t *num_inputs,
void ***output, size_t **output_sizes, TfLiteType **output_types,
size_t *num_outputs, std::string cache_dir);

View file

@ -17,6 +17,7 @@
#define TEST_CONV2D 1
#define TEST_DEPTHWISE 1
#define TEST_ADD 1
#define TEST_FULLY_CONNECTED 1
#define TEST_MOBILENETV1 1
#define TEST_MOBILEDET 1
#define TEST_YOLOX 1
@ -35,6 +36,8 @@ std::vector<int> dw_channels{1, 32, 120, 128, 256};
std::vector<int> dw_weight_size{3, 5};
std::vector<int> weight_size{1, 3, 5};
std::vector<int> input_size{3, 5, 8, 80, 112};
std::vector<int> fc_channels{23, 46, 128, 256, 512};
std::vector<int> fc_size{128, 1280, 25088, 62720};
static void
set_seed(unsigned seed)
@ -223,6 +226,41 @@ test_add(int input_size, int weight_size, int input_channels, int output_channel
free(buf);
}
void
test_fully_connected(int input_size, int output_channels, bool is_signed, int seed)
{
void *buf = NULL;
size_t buf_size;
std::ostringstream cache_dir, model_cache;
cache_dir << "/var/cache/teflon_tests/fc_" << input_size << "_" << output_channels << "_" << is_signed << "_" << seed;
model_cache << cache_dir.str() << "/"
<< "model.tflite";
set_seed(seed);
if (cache_is_enabled()) {
if (access(model_cache.str().c_str(), F_OK) == 0) {
buf = read_buf(model_cache.str().c_str(), &buf_size);
}
}
if (buf == 0) {
buf = fully_connected_generate_model(input_size, output_channels, is_signed, &buf_size);
if (cache_is_enabled()) {
if (access(cache_dir.str().c_str(), F_OK) != 0) {
ASSERT_TRUE(std::filesystem::create_directories(cache_dir.str().c_str()));
}
std::ofstream file(model_cache.str().c_str(), std::ios::out | std::ios::binary);
file.write(reinterpret_cast<const char *>(buf), buf_size);
file.close();
}
}
test_model(buf, buf_size, cache_dir.str(), TOLERANCE);
free(buf);
}
#if TEST_CONV2D
class Conv2D : public testing::TestWithParam<std::tuple<bool, bool, int, int, int, int, int>> {};
@ -383,6 +421,41 @@ INSTANTIATE_TEST_SUITE_P(
#endif
#if TEST_FULLY_CONNECTED
class FullyConnected : public testing::TestWithParam<std::tuple<bool, int, int>> {};
TEST_P(FullyConnected, Op)
{
test_fully_connected(
std::get<2>(GetParam()),
std::get<1>(GetParam()),
std::get<0>(GetParam()),
4);
}
static inline std::string
FullyConnectedTestCaseName(
const testing::TestParamInfo<std::tuple<bool, int, int>> &info)
{
std::string name = "";
name += "input_size_" + std::to_string(std::get<2>(info.param));
name += "_output_channels_" + std::to_string(std::get<1>(info.param));
name += "_is_signed_" + std::to_string(std::get<0>(info.param));
return name;
}
INSTANTIATE_TEST_SUITE_P(
, FullyConnected,
::testing::Combine(::testing::ValuesIn(is_signed),
::testing::ValuesIn(output_channels),
::testing::ValuesIn(fc_size)),
FullyConnectedTestCaseName);
#endif
#if TEST_MOBILENETV1
class MobileNetV1 : public ::testing::Test {};