diff --git a/src/gallium/targets/teflon/test_executor.cpp b/src/gallium/targets/teflon/test_executor.cpp index 4e9cbcff659..604282302f8 100644 --- a/src/gallium/targets/teflon/test_executor.cpp +++ b/src/gallium/targets/teflon/test_executor.cpp @@ -250,6 +250,91 @@ add_generate_model(int input_size, return buf; } +static void +patch_fully_connected(unsigned operation_index, + tflite::ModelT *model, + int input_size, + int output_channels, + bool is_signed) +{ + unsigned input_index; + unsigned weights_index; + unsigned bias_index; + unsigned output_index; + unsigned weights_buffer_index; + unsigned bias_buffer_index; + + auto subgraph = model->subgraphs[0]; + + /* Operation */ + auto value = new tflite::FullyConnectedOptionsT(); + subgraph->operators[operation_index]->builtin_options.value = value; + + input_index = subgraph->operators[operation_index]->inputs.data()[0]; + weights_index = subgraph->operators[operation_index]->inputs.data()[1]; + bias_index = subgraph->operators[operation_index]->inputs.data()[2]; + output_index = subgraph->operators[operation_index]->outputs.data()[0]; + + /* Input */ + auto input_tensor = subgraph->tensors[input_index]; + input_tensor->shape.data()[0] = 1; + input_tensor->shape.data()[1] = input_size; + input_tensor->type = is_signed ? tflite::TensorType_INT8 : tflite::TensorType_UINT8; + + /* Bias */ + auto bias_tensor = subgraph->tensors[bias_index]; + bias_buffer_index = bias_tensor->buffer; + bias_tensor->shape.data()[0] = output_channels; + + auto bias_data = &model->buffers[bias_buffer_index]->data; + xt::xarray bias_array = xt::random::randint({output_channels}, -20000, 20000); + bias_data->resize(bias_array.size() * sizeof(int32_t)); + memcpy(bias_data->data(), bias_array.data(), bias_array.size() * sizeof(int32_t)); + + /* Weight */ + auto weight_tensor = subgraph->tensors[weights_index]; + weights_buffer_index = weight_tensor->buffer; + weight_tensor->shape.data()[0] = output_channels; + weight_tensor->shape.data()[1] = input_size; + weight_tensor->type = is_signed ? tflite::TensorType_INT8 : tflite::TensorType_UINT8; + + auto weights_data = &model->buffers[weights_buffer_index]->data; + std::vector weight_shape; + weight_shape = {output_channels, input_size}; + + xt::xarray weights_array = xt::random::randint(weight_shape, 0, 255); + weights_data->resize(weights_array.size()); + memcpy(weights_data->data(), weights_array.data(), weights_array.size()); + + /* Output */ + auto output_tensor = subgraph->tensors[output_index]; + output_tensor->shape.data()[0] = 1; + output_tensor->shape.data()[1] = output_channels; + output_tensor->type = is_signed ? tflite::TensorType_INT8 : tflite::TensorType_UINT8; +} + +void * +fully_connected_generate_model(int input_size, + int output_channels, + bool is_signed, + size_t *buf_size) +{ + void *buf; + tflite::ModelT model; + read_model("fully_connected.tflite", model); + + patch_fully_connected(0, &model, input_size, output_channels, is_signed); + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(tflite::Model::Pack(builder, &model), "TFL3"); + + *buf_size = builder.GetSize(); + buf = malloc(*buf_size); + memcpy(buf, builder.GetBufferPointer(), builder.GetSize()); + + return buf; +} + static void tflite_error_cb(void *user_data, const char *format, va_list args) { diff --git a/src/gallium/targets/teflon/test_executor.h b/src/gallium/targets/teflon/test_executor.h index 9240d7342bf..c3da340a285 100644 --- a/src/gallium/targets/teflon/test_executor.h +++ b/src/gallium/targets/teflon/test_executor.h @@ -31,6 +31,11 @@ void *add_generate_model(int input_size, bool depthwise, size_t *buf_size); +void *fully_connected_generate_model(int input_size, + int output_channels, + bool is_signed, + size_t *buf_size); + void run_model(TfLiteModel *model, enum executor executor, void ***input, size_t *num_inputs, void ***output, size_t **output_sizes, TfLiteType **output_types, size_t *num_outputs, std::string cache_dir); diff --git a/src/gallium/targets/teflon/test_teflon.cpp b/src/gallium/targets/teflon/test_teflon.cpp index 58744f327f2..3d3ec918942 100644 --- a/src/gallium/targets/teflon/test_teflon.cpp +++ b/src/gallium/targets/teflon/test_teflon.cpp @@ -17,6 +17,7 @@ #define TEST_CONV2D 1 #define TEST_DEPTHWISE 1 #define TEST_ADD 1 +#define TEST_FULLY_CONNECTED 1 #define TEST_MOBILENETV1 1 #define TEST_MOBILEDET 1 #define TEST_YOLOX 1 @@ -35,6 +36,8 @@ std::vector dw_channels{1, 32, 120, 128, 256}; std::vector dw_weight_size{3, 5}; std::vector weight_size{1, 3, 5}; std::vector input_size{3, 5, 8, 80, 112}; +std::vector fc_channels{23, 46, 128, 256, 512}; +std::vector fc_size{128, 1280, 25088, 62720}; static void set_seed(unsigned seed) @@ -223,6 +226,41 @@ test_add(int input_size, int weight_size, int input_channels, int output_channel free(buf); } +void +test_fully_connected(int input_size, int output_channels, bool is_signed, int seed) +{ + void *buf = NULL; + size_t buf_size; + std::ostringstream cache_dir, model_cache; + cache_dir << "/var/cache/teflon_tests/fc_" << input_size << "_" << output_channels << "_" << is_signed << "_" << seed; + model_cache << cache_dir.str() << "/" + << "model.tflite"; + + set_seed(seed); + + if (cache_is_enabled()) { + if (access(model_cache.str().c_str(), F_OK) == 0) { + buf = read_buf(model_cache.str().c_str(), &buf_size); + } + } + + if (buf == 0) { + buf = fully_connected_generate_model(input_size, output_channels, is_signed, &buf_size); + + if (cache_is_enabled()) { + if (access(cache_dir.str().c_str(), F_OK) != 0) { + ASSERT_TRUE(std::filesystem::create_directories(cache_dir.str().c_str())); + } + std::ofstream file(model_cache.str().c_str(), std::ios::out | std::ios::binary); + file.write(reinterpret_cast(buf), buf_size); + file.close(); + } + } + + test_model(buf, buf_size, cache_dir.str(), TOLERANCE); + free(buf); +} + #if TEST_CONV2D class Conv2D : public testing::TestWithParam> {}; @@ -383,6 +421,41 @@ INSTANTIATE_TEST_SUITE_P( #endif +#if TEST_FULLY_CONNECTED + +class FullyConnected : public testing::TestWithParam> {}; + +TEST_P(FullyConnected, Op) +{ + test_fully_connected( + std::get<2>(GetParam()), + std::get<1>(GetParam()), + std::get<0>(GetParam()), + 4); +} + +static inline std::string +FullyConnectedTestCaseName( + const testing::TestParamInfo> &info) +{ + std::string name = ""; + + name += "input_size_" + std::to_string(std::get<2>(info.param)); + name += "_output_channels_" + std::to_string(std::get<1>(info.param)); + name += "_is_signed_" + std::to_string(std::get<0>(info.param)); + + return name; +} + +INSTANTIATE_TEST_SUITE_P( + , FullyConnected, + ::testing::Combine(::testing::ValuesIn(is_signed), + ::testing::ValuesIn(output_channels), + ::testing::ValuesIn(fc_size)), + FullyConnectedTestCaseName); + +#endif + #if TEST_MOBILENETV1 class MobileNetV1 : public ::testing::Test {}; diff --git a/src/gallium/targets/teflon/tests/fully_connected.tflite b/src/gallium/targets/teflon/tests/fully_connected.tflite new file mode 100644 index 00000000000..d781a7ddb34 Binary files /dev/null and b/src/gallium/targets/teflon/tests/fully_connected.tflite differ