mastertrainer_test.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. // (C) Copyright 2017, Google Inc.
  2. // Licensed under the Apache License, Version 2.0 (the "License");
  3. // you may not use this file except in compliance with the License.
  4. // You may obtain a copy of the License at
  5. // http://www.apache.org/licenses/LICENSE-2.0
  6. // Unless required by applicable law or agreed to in writing, software
  7. // distributed under the License is distributed on an "AS IS" BASIS,
  8. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. // See the License for the specific language governing permissions and
  10. // limitations under the License.
  11. // Although this is a trivial-looking test, it exercises a lot of code:
  12. // SampleIterator has to correctly iterate over the correct characters, or
  13. // it will fail.
  14. // The canonical and cloud features computed by TrainingSampleSet need to
  15. // be correct, along with the distance caches, organizing samples by font
  16. // and class, indexing of features, distance calculations.
  17. // IntFeatureDist has to work, or the canonical samples won't work.
  18. // Mastertrainer has ability to read tr files and set itself up tested.
  19. // Finally the serialize/deserialize test ensures that MasterTrainer,
  20. // TrainingSampleSet, TrainingSample can all serialize/deserialize correctly
  21. // enough to reproduce the same results.
  22. #include "include_gunit.h"
  23. #include "commontraining.h"
  24. #include "errorcounter.h"
  25. #include "log.h" // for LOG
  26. #include "mastertrainer.h"
  27. #include "shapeclassifier.h"
  28. #include "shapetable.h"
  29. #include "trainingsample.h"
  30. #include "unicharset.h"
  31. #include <string>
  32. #include <utility>
  33. #include <vector>
  34. using namespace tesseract;
  35. // Specs of the MockClassifier.
  36. static const int kNumTopNErrs = 10;
  37. static const int kNumTop2Errs = kNumTopNErrs + 20;
  38. static const int kNumTop1Errs = kNumTop2Errs + 30;
  39. static const int kNumTopTopErrs = kNumTop1Errs + 25;
  40. static const int kNumNonReject = 1000;
  41. static const int kNumCorrect = kNumNonReject - kNumTop1Errs;
  42. // The total number of answers is given by the number of non-rejects plus
  43. // all the multiple answers.
  44. static const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
  45. (kNumTop1Errs - kNumTop2Errs) + (kNumTopTopErrs - kNumTop1Errs);
  46. #ifndef DISABLED_LEGACY_ENGINE
  47. static bool safe_strto32(const std::string &str, int *pResult) {
  48. long n = strtol(str.c_str(), nullptr, 0);
  49. *pResult = n;
  50. return true;
  51. }
  52. #endif
  53. // Mock ShapeClassifier that cheats by looking at the correct answer, and
  54. // creates a specific pattern of errors that can be tested.
  55. class MockClassifier : public ShapeClassifier {
  56. public:
  57. explicit MockClassifier(ShapeTable *shape_table)
  58. : shape_table_(shape_table), num_done_(0), done_bad_font_(false) {
  59. // Add a false font answer to the shape table. We pick a random unichar_id,
  60. // add a new shape for it with a false font. Font must actually exist in
  61. // the font table, but not match anything in the first 1000 samples.
  62. false_unichar_id_ = 67;
  63. false_shape_ = shape_table_->AddShape(false_unichar_id_, 25);
  64. }
  65. ~MockClassifier() override = default;
  66. // Classifies the given [training] sample, writing to results.
  67. // If debug is non-zero, then various degrees of classifier dependent debug
  68. // information is provided.
  69. // If keep_this (a shape index) is >= 0, then the results should always
  70. // contain keep_this, and (if possible) anything of intermediate confidence.
  71. // The return value is the number of classes saved in results.
  72. int ClassifySample(const TrainingSample &sample, Image page_pix, int debug, UNICHAR_ID keep_this,
  73. std::vector<ShapeRating> *results) override {
  74. results->clear();
  75. // Everything except the first kNumNonReject is a reject.
  76. if (++num_done_ > kNumNonReject) {
  77. return 0;
  78. }
  79. int class_id = sample.class_id();
  80. int font_id = sample.font_id();
  81. int shape_id = shape_table_->FindShape(class_id, font_id);
  82. // Get ids of some wrong answers.
  83. int wrong_id1 = shape_id > 10 ? shape_id - 1 : shape_id + 1;
  84. int wrong_id2 = shape_id > 10 ? shape_id - 2 : shape_id + 2;
  85. if (num_done_ <= kNumTopNErrs) {
  86. // The first kNumTopNErrs are top-n errors.
  87. results->push_back(ShapeRating(wrong_id1, 1.0f));
  88. } else if (num_done_ <= kNumTop2Errs) {
  89. // The next kNumTop2Errs - kNumTopNErrs are top-2 errors.
  90. results->push_back(ShapeRating(wrong_id1, 1.0f));
  91. results->push_back(ShapeRating(wrong_id2, 0.875f));
  92. results->push_back(ShapeRating(shape_id, 0.75f));
  93. } else if (num_done_ <= kNumTop1Errs) {
  94. // The next kNumTop1Errs - kNumTop2Errs are top-1 errors.
  95. results->push_back(ShapeRating(wrong_id1, 1.0f));
  96. results->push_back(ShapeRating(shape_id, 0.8f));
  97. } else if (num_done_ <= kNumTopTopErrs) {
  98. // The next kNumTopTopErrs - kNumTop1Errs are cases where the actual top
  99. // is not correct, but do not count as a top-1 error because the rating
  100. // is close enough to the top answer.
  101. results->push_back(ShapeRating(wrong_id1, 1.0f));
  102. results->push_back(ShapeRating(shape_id, 0.99f));
  103. } else if (!done_bad_font_ && class_id == false_unichar_id_) {
  104. // There is a single character with a bad font.
  105. results->push_back(ShapeRating(false_shape_, 1.0f));
  106. done_bad_font_ = true;
  107. } else {
  108. // Everything else is correct.
  109. results->push_back(ShapeRating(shape_id, 1.0f));
  110. }
  111. return results->size();
  112. }
  113. // Provides access to the ShapeTable that this classifier works with.
  114. const ShapeTable *GetShapeTable() const override {
  115. return shape_table_;
  116. }
  117. private:
  118. // Borrowed pointer to the ShapeTable.
  119. ShapeTable *shape_table_;
  120. // Unichar_id of a random character that occurs after the first 60 samples.
  121. int false_unichar_id_;
  122. // Shape index of prepared false answer for false_unichar_id.
  123. int false_shape_;
  124. // The number of classifications we have processed.
  125. int num_done_;
  126. // True after the false font has been emitted.
  127. bool done_bad_font_;
  128. };
  129. const double kMin1lDistance = 0.25;
  130. // The fixture for testing Tesseract.
  131. class MasterTrainerTest : public testing::Test {
  132. #ifndef DISABLED_LEGACY_ENGINE
  133. protected:
  134. void SetUp() override {
  135. std::locale::global(std::locale(""));
  136. file::MakeTmpdir();
  137. }
  138. std::string TestDataNameToPath(const std::string &name) {
  139. return file::JoinPath(TESTING_DIR, name);
  140. }
  141. std::string TmpNameToPath(const std::string &name) {
  142. return file::JoinPath(FLAGS_test_tmpdir, name);
  143. }
  144. MasterTrainerTest() :
  145. shape_table_(nullptr),
  146. master_trainer_(nullptr) {
  147. }
  148. ~MasterTrainerTest() override {
  149. delete shape_table_;
  150. }
  151. // Initializes the master_trainer_ and shape_table_.
  152. // if load_from_tmp, then reloads a master trainer that was saved by a
  153. // previous call in which it was false.
  154. void LoadMasterTrainer() {
  155. FLAGS_output_trainer = TmpNameToPath("tmp_trainer").c_str();
  156. FLAGS_F = file::JoinPath(LANGDATA_DIR, "font_properties").c_str();
  157. FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
  158. FLAGS_U = TestDataNameToPath("eng.unicharset").c_str();
  159. std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
  160. const char *filelist[] = {tr_file_name.c_str(), nullptr};
  161. std::string file_prefix;
  162. delete shape_table_;
  163. shape_table_ = nullptr;
  164. master_trainer_ = LoadTrainingData(filelist, false, &shape_table_, file_prefix);
  165. EXPECT_TRUE(master_trainer_ != nullptr);
  166. EXPECT_TRUE(shape_table_ != nullptr);
  167. }
  168. // EXPECTs that the distance between I and l in Arial is 0 and that the
  169. // distance to 1 is significantly not 0.
  170. void VerifyIl1() {
  171. // Find the font id for Arial.
  172. int font_id = master_trainer_->GetFontInfoId("Arial");
  173. EXPECT_GE(font_id, 0);
  174. // Track down the characters we are interested in.
  175. int unichar_I = master_trainer_->unicharset().unichar_to_id("I");
  176. EXPECT_GT(unichar_I, 0);
  177. int unichar_l = master_trainer_->unicharset().unichar_to_id("l");
  178. EXPECT_GT(unichar_l, 0);
  179. int unichar_1 = master_trainer_->unicharset().unichar_to_id("1");
  180. EXPECT_GT(unichar_1, 0);
  181. // Now get the shape ids.
  182. int shape_I = shape_table_->FindShape(unichar_I, font_id);
  183. EXPECT_GE(shape_I, 0);
  184. int shape_l = shape_table_->FindShape(unichar_l, font_id);
  185. EXPECT_GE(shape_l, 0);
  186. int shape_1 = shape_table_->FindShape(unichar_1, font_id);
  187. EXPECT_GE(shape_1, 0);
  188. float dist_I_l = master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_l);
  189. // No tolerance here. We expect that I and l should match exactly.
  190. EXPECT_EQ(0.0f, dist_I_l);
  191. float dist_l_I = master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_I);
  192. // BOTH ways.
  193. EXPECT_EQ(0.0f, dist_l_I);
  194. // l/1 on the other hand should be distinct.
  195. float dist_l_1 = master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_1);
  196. EXPECT_GT(dist_l_1, kMin1lDistance);
  197. float dist_1_l = master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_l);
  198. EXPECT_GT(dist_1_l, kMin1lDistance);
  199. // So should I/1.
  200. float dist_I_1 = master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_1);
  201. EXPECT_GT(dist_I_1, kMin1lDistance);
  202. float dist_1_I = master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_I);
  203. EXPECT_GT(dist_1_I, kMin1lDistance);
  204. }
  205. // Objects declared here can be used by all tests in the test case for Foo.
  206. ShapeTable *shape_table_;
  207. std::unique_ptr<MasterTrainer> master_trainer_;
  208. #endif
  209. };
  210. // Tests that the MasterTrainer correctly loads its data and reaches the correct
  211. // conclusion over the distance between Arial I l and 1.
  212. TEST_F(MasterTrainerTest, Il1Test) {
  213. #ifdef DISABLED_LEGACY_ENGINE
  214. // Skip test because LoadTrainingData is missing.
  215. GTEST_SKIP();
  216. #else
  217. // Initialize the master_trainer_ and load the Arial tr file.
  218. LoadMasterTrainer();
  219. VerifyIl1();
  220. #endif
  221. }
  222. // Tests the ErrorCounter using a MockClassifier to check that it counts
  223. // error categories correctly.
  224. TEST_F(MasterTrainerTest, ErrorCounterTest) {
  225. #ifdef DISABLED_LEGACY_ENGINE
  226. // Skip test because LoadTrainingData is missing.
  227. GTEST_SKIP();
  228. #else
  229. // Initialize the master_trainer_ from the saved tmp file.
  230. LoadMasterTrainer();
  231. // Add the space character to the shape_table_ if not already present to
  232. // count junk.
  233. if (shape_table_->FindShape(0, -1) < 0) {
  234. shape_table_->AddShape(0, 0);
  235. }
  236. // Make a mock classifier.
  237. auto shape_classifier = std::make_unique<MockClassifier>(shape_table_);
  238. // Get the accuracy report.
  239. std::string accuracy_report;
  240. master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0, false,
  241. shape_classifier.get(), &accuracy_report);
  242. LOG(INFO) << accuracy_report.c_str();
  243. std::string result_string = accuracy_report.c_str();
  244. std::vector<std::string> results = split(result_string, '\t');
  245. EXPECT_EQ(tesseract::CT_SIZE + 1, results.size());
  246. int result_values[tesseract::CT_SIZE];
  247. for (int i = 0; i < tesseract::CT_SIZE; ++i) {
  248. EXPECT_TRUE(safe_strto32(results[i + 1], &result_values[i]));
  249. }
  250. // These tests are more-or-less immune to additions to the number of
  251. // categories or changes in the training data.
  252. int num_samples = master_trainer_->GetSamples()->num_raw_samples();
  253. EXPECT_EQ(kNumCorrect, result_values[tesseract::CT_UNICHAR_TOP_OK]);
  254. EXPECT_EQ(1, result_values[tesseract::CT_FONT_ATTR_ERR]);
  255. EXPECT_EQ(kNumTopTopErrs, result_values[tesseract::CT_UNICHAR_TOPTOP_ERR]);
  256. EXPECT_EQ(kNumTop1Errs, result_values[tesseract::CT_UNICHAR_TOP1_ERR]);
  257. EXPECT_EQ(kNumTop2Errs, result_values[tesseract::CT_UNICHAR_TOP2_ERR]);
  258. EXPECT_EQ(kNumTopNErrs, result_values[tesseract::CT_UNICHAR_TOPN_ERR]);
  259. // Each of the TOPTOP errs also counts as a multi-unichar.
  260. EXPECT_EQ(kNumTopTopErrs - kNumTop1Errs, result_values[tesseract::CT_OK_MULTI_UNICHAR]);
  261. EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
  262. EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
  263. #endif
  264. }