Commit 167c5e81 by Abseil Team Committed by Gennadiy Civil

Googletest export

Fix Theta(N^2) memory usage of EXPECT_EQ(string) when the strings don't match. The underlying CalculateOptimalEdits() implementation used a simple dynamic-programming approach that always used N^2 memory and time. This meant that tests for equality of large strings were ticking time bombs: They'd work fine as long as the test passed, but as soon as the strings differed the test would OOM, which is very hard to debug. I switched it out for a Dijkstra search, which is still worst-case O(N^2), but in the usual case of mostly-matching strings, it is much closer to linear. PiperOrigin-RevId: 210405025
parent 1bb76182
...@@ -159,15 +159,15 @@ GTEST_DISABLE_MSC_WARNINGS_POP_() // 4275 ...@@ -159,15 +159,15 @@ GTEST_DISABLE_MSC_WARNINGS_POP_() // 4275
#endif // GTEST_HAS_EXCEPTIONS #endif // GTEST_HAS_EXCEPTIONS
namespace edit_distance {
// Returns the optimal edits to go from 'left' to 'right'. // Returns the optimal edits to go from 'left' to 'right'.
// All edits cost the same, with replace having lower priority than // All edits cost the same, with replace having lower priority than
// add/remove. // add/remove. Returns an approximation of the maximum memory used in
// Simple implementation of the Wagner-Fischer algorithm. // 'memory_usage' if non-null.
// See http://en.wikipedia.org/wiki/Wagner-Fischer_algorithm // Uses a Dijkstra search, with a couple of simple bells and whistles added on.
enum EditType { kMatch, kAdd, kRemove, kReplace }; enum EditType { kEditMatch, kEditAdd, kEditRemove, kEditReplace };
GTEST_API_ std::vector<EditType> CalculateOptimalEdits( GTEST_API_ std::vector<EditType> CalculateOptimalEdits(
const std::vector<size_t>& left, const std::vector<size_t>& right); const std::vector<size_t>& left, const std::vector<size_t>& right,
size_t* memory_usage = NULL);
// Same as above, but the input is represented as strings. // Same as above, but the input is represented as strings.
GTEST_API_ std::vector<EditType> CalculateOptimalEdits( GTEST_API_ std::vector<EditType> CalculateOptimalEdits(
...@@ -179,8 +179,6 @@ GTEST_API_ std::string CreateUnifiedDiff(const std::vector<std::string>& left, ...@@ -179,8 +179,6 @@ GTEST_API_ std::string CreateUnifiedDiff(const std::vector<std::string>& left,
const std::vector<std::string>& right, const std::vector<std::string>& right,
size_t context = 2); size_t context = 2);
} // namespace edit_distance
// Calculate the diff between 'left' and 'right' and return it in unified diff // Calculate the diff between 'left' and 'right' and return it in unified diff
// format. // format.
// If not null, stores in 'total_line_count' the total number of lines found // If not null, stores in 'total_line_count' the total number of lines found
......
...@@ -38,10 +38,11 @@ ...@@ -38,10 +38,11 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
// The following lines pull in the real gtest *.cc files. // The following lines pull in the real gtest *.cc files.
#include "src/gtest.cc"
#include "src/gtest-death-test.cc" #include "src/gtest-death-test.cc"
#include "src/gtest-edit-distance.cc"
#include "src/gtest-filepath.cc" #include "src/gtest-filepath.cc"
#include "src/gtest-port.cc" #include "src/gtest-port.cc"
#include "src/gtest-printers.cc" #include "src/gtest-printers.cc"
#include "src/gtest-test-part.cc" #include "src/gtest-test-part.cc"
#include "src/gtest-typed-test.cc" #include "src/gtest-typed-test.cc"
#include "src/gtest.cc"
...@@ -46,7 +46,6 @@ ...@@ -46,7 +46,6 @@
#include <algorithm> #include <algorithm>
#include <iomanip> #include <iomanip>
#include <limits> #include <limits>
#include <list>
#include <map> #include <map>
#include <ostream> // NOLINT #include <ostream> // NOLINT
#include <sstream> #include <sstream>
...@@ -1068,246 +1067,6 @@ AssertionResult AssertionFailure(const Message& message) { ...@@ -1068,246 +1067,6 @@ AssertionResult AssertionFailure(const Message& message) {
namespace internal { namespace internal {
namespace edit_distance {
std::vector<EditType> CalculateOptimalEdits(const std::vector<size_t>& left,
const std::vector<size_t>& right) {
std::vector<std::vector<double> > costs(
left.size() + 1, std::vector<double>(right.size() + 1));
std::vector<std::vector<EditType> > best_move(
left.size() + 1, std::vector<EditType>(right.size() + 1));
// Populate for empty right.
for (size_t l_i = 0; l_i < costs.size(); ++l_i) {
costs[l_i][0] = static_cast<double>(l_i);
best_move[l_i][0] = kRemove;
}
// Populate for empty left.
for (size_t r_i = 1; r_i < costs[0].size(); ++r_i) {
costs[0][r_i] = static_cast<double>(r_i);
best_move[0][r_i] = kAdd;
}
for (size_t l_i = 0; l_i < left.size(); ++l_i) {
for (size_t r_i = 0; r_i < right.size(); ++r_i) {
if (left[l_i] == right[r_i]) {
// Found a match. Consume it.
costs[l_i + 1][r_i + 1] = costs[l_i][r_i];
best_move[l_i + 1][r_i + 1] = kMatch;
continue;
}
const double add = costs[l_i + 1][r_i];
const double remove = costs[l_i][r_i + 1];
const double replace = costs[l_i][r_i];
if (add < remove && add < replace) {
costs[l_i + 1][r_i + 1] = add + 1;
best_move[l_i + 1][r_i + 1] = kAdd;
} else if (remove < add && remove < replace) {
costs[l_i + 1][r_i + 1] = remove + 1;
best_move[l_i + 1][r_i + 1] = kRemove;
} else {
// We make replace a little more expensive than add/remove to lower
// their priority.
costs[l_i + 1][r_i + 1] = replace + 1.00001;
best_move[l_i + 1][r_i + 1] = kReplace;
}
}
}
// Reconstruct the best path. We do it in reverse order.
std::vector<EditType> best_path;
for (size_t l_i = left.size(), r_i = right.size(); l_i > 0 || r_i > 0;) {
EditType move = best_move[l_i][r_i];
best_path.push_back(move);
l_i -= move != kAdd;
r_i -= move != kRemove;
}
std::reverse(best_path.begin(), best_path.end());
return best_path;
}
namespace {
// Helper class to convert string into ids with deduplication.
class InternalStrings {
public:
size_t GetId(const std::string& str) {
IdMap::iterator it = ids_.find(str);
if (it != ids_.end()) return it->second;
size_t id = ids_.size();
return ids_[str] = id;
}
private:
typedef std::map<std::string, size_t> IdMap;
IdMap ids_;
};
} // namespace
std::vector<EditType> CalculateOptimalEdits(
const std::vector<std::string>& left,
const std::vector<std::string>& right) {
std::vector<size_t> left_ids, right_ids;
{
InternalStrings intern_table;
for (size_t i = 0; i < left.size(); ++i) {
left_ids.push_back(intern_table.GetId(left[i]));
}
for (size_t i = 0; i < right.size(); ++i) {
right_ids.push_back(intern_table.GetId(right[i]));
}
}
return CalculateOptimalEdits(left_ids, right_ids);
}
namespace {
// Helper class that holds the state for one hunk and prints it out to the
// stream.
// It reorders adds/removes when possible to group all removes before all
// adds. It also adds the hunk header before printint into the stream.
class Hunk {
public:
Hunk(size_t left_start, size_t right_start)
: left_start_(left_start),
right_start_(right_start),
adds_(),
removes_(),
common_() {}
void PushLine(char edit, const char* line) {
switch (edit) {
case ' ':
++common_;
FlushEdits();
hunk_.push_back(std::make_pair(' ', line));
break;
case '-':
++removes_;
hunk_removes_.push_back(std::make_pair('-', line));
break;
case '+':
++adds_;
hunk_adds_.push_back(std::make_pair('+', line));
break;
}
}
void PrintTo(std::ostream* os) {
PrintHeader(os);
FlushEdits();
for (std::list<std::pair<char, const char*> >::const_iterator it =
hunk_.begin();
it != hunk_.end(); ++it) {
*os << it->first << it->second << "\n";
}
}
bool has_edits() const { return adds_ || removes_; }
private:
void FlushEdits() {
hunk_.splice(hunk_.end(), hunk_removes_);
hunk_.splice(hunk_.end(), hunk_adds_);
}
// Print a unified diff header for one hunk.
// The format is
// "@@ -<left_start>,<left_length> +<right_start>,<right_length> @@"
// where the left/right parts are omitted if unnecessary.
void PrintHeader(std::ostream* ss) const {
*ss << "@@ ";
if (removes_) {
*ss << "-" << left_start_ << "," << (removes_ + common_);
}
if (removes_ && adds_) {
*ss << " ";
}
if (adds_) {
*ss << "+" << right_start_ << "," << (adds_ + common_);
}
*ss << " @@\n";
}
size_t left_start_, right_start_;
size_t adds_, removes_, common_;
std::list<std::pair<char, const char*> > hunk_, hunk_adds_, hunk_removes_;
};
} // namespace
// Create a list of diff hunks in Unified diff format.
// Each hunk has a header generated by PrintHeader above plus a body with
// lines prefixed with ' ' for no change, '-' for deletion and '+' for
// addition.
// 'context' represents the desired unchanged prefix/suffix around the diff.
// If two hunks are close enough that their contexts overlap, then they are
// joined into one hunk.
std::string CreateUnifiedDiff(const std::vector<std::string>& left,
const std::vector<std::string>& right,
size_t context) {
const std::vector<EditType> edits = CalculateOptimalEdits(left, right);
size_t l_i = 0, r_i = 0, edit_i = 0;
std::stringstream ss;
while (edit_i < edits.size()) {
// Find first edit.
while (edit_i < edits.size() && edits[edit_i] == kMatch) {
++l_i;
++r_i;
++edit_i;
}
// Find the first line to include in the hunk.
const size_t prefix_context = std::min(l_i, context);
Hunk hunk(l_i - prefix_context + 1, r_i - prefix_context + 1);
for (size_t i = prefix_context; i > 0; --i) {
hunk.PushLine(' ', left[l_i - i].c_str());
}
// Iterate the edits until we found enough suffix for the hunk or the input
// is over.
size_t n_suffix = 0;
for (; edit_i < edits.size(); ++edit_i) {
if (n_suffix >= context) {
// Continue only if the next hunk is very close.
std::vector<EditType>::const_iterator it = edits.begin() + edit_i;
while (it != edits.end() && *it == kMatch) ++it;
if (it == edits.end() || (it - edits.begin()) - edit_i >= context) {
// There is no next edit or it is too far away.
break;
}
}
EditType edit = edits[edit_i];
// Reset count when a non match is found.
n_suffix = edit == kMatch ? n_suffix + 1 : 0;
if (edit == kMatch || edit == kRemove || edit == kReplace) {
hunk.PushLine(edit == kMatch ? ' ' : '-', left[l_i].c_str());
}
if (edit == kAdd || edit == kReplace) {
hunk.PushLine('+', right[r_i].c_str());
}
// Advance indices, depending on edit type.
l_i += edit != kAdd;
r_i += edit != kRemove;
}
if (!hunk.has_edits()) {
// We are done. We don't want this hunk.
break;
}
hunk.PrintTo(&ss);
}
return ss.str();
}
} // namespace edit_distance
namespace { namespace {
// The string representation of the values received in EqFailure() are already // The string representation of the values received in EqFailure() are already
...@@ -1379,8 +1138,7 @@ AssertionResult EqFailure(const char* lhs_expression, ...@@ -1379,8 +1138,7 @@ AssertionResult EqFailure(const char* lhs_expression,
const std::vector<std::string> rhs_lines = const std::vector<std::string> rhs_lines =
SplitEscapedString(rhs_value); SplitEscapedString(rhs_value);
if (lhs_lines.size() > 1 || rhs_lines.size() > 1) { if (lhs_lines.size() > 1 || rhs_lines.size() > 1) {
msg << "\nWith diff:\n" msg << "\nWith diff:\n" << CreateUnifiedDiff(lhs_lines, rhs_lines);
<< edit_distance::CreateUnifiedDiff(lhs_lines, rhs_lines);
} }
} }
......
...@@ -215,6 +215,7 @@ using testing::GTEST_FLAG(stream_result_to); ...@@ -215,6 +215,7 @@ using testing::GTEST_FLAG(stream_result_to);
using testing::GTEST_FLAG(throw_on_failure); using testing::GTEST_FLAG(throw_on_failure);
using testing::IsNotSubstring; using testing::IsNotSubstring;
using testing::IsSubstring; using testing::IsSubstring;
using testing::kMaxStackTraceDepth;
using testing::Message; using testing::Message;
using testing::ScopedFakeTestPartResultReporter; using testing::ScopedFakeTestPartResultReporter;
using testing::StaticAssertTypeEq; using testing::StaticAssertTypeEq;
...@@ -234,16 +235,18 @@ using testing::internal::AlwaysTrue; ...@@ -234,16 +235,18 @@ using testing::internal::AlwaysTrue;
using testing::internal::AppendUserMessage; using testing::internal::AppendUserMessage;
using testing::internal::ArrayAwareFind; using testing::internal::ArrayAwareFind;
using testing::internal::ArrayEq; using testing::internal::ArrayEq;
using testing::internal::CalculateOptimalEdits;
using testing::internal::CodePointToUtf8; using testing::internal::CodePointToUtf8;
using testing::internal::CompileAssertTypesEqual; using testing::internal::CompileAssertTypesEqual;
using testing::internal::CopyArray; using testing::internal::CopyArray;
using testing::internal::CountIf; using testing::internal::CountIf;
using testing::internal::CreateUnifiedDiff;
using testing::internal::EditType;
using testing::internal::EqFailure; using testing::internal::EqFailure;
using testing::internal::FloatingPoint; using testing::internal::FloatingPoint;
using testing::internal::ForEach; using testing::internal::ForEach;
using testing::internal::FormatEpochTimeInMillisAsIso8601; using testing::internal::FormatEpochTimeInMillisAsIso8601;
using testing::internal::FormatTimeInMillisAsSeconds; using testing::internal::FormatTimeInMillisAsSeconds;
using testing::internal::GTestFlagSaver;
using testing::internal::GetCurrentOsStackTraceExceptTop; using testing::internal::GetCurrentOsStackTraceExceptTop;
using testing::internal::GetElementOr; using testing::internal::GetElementOr;
using testing::internal::GetNextRandomSeed; using testing::internal::GetNextRandomSeed;
...@@ -252,6 +255,7 @@ using testing::internal::GetTestTypeId; ...@@ -252,6 +255,7 @@ using testing::internal::GetTestTypeId;
using testing::internal::GetTimeInMillis; using testing::internal::GetTimeInMillis;
using testing::internal::GetTypeId; using testing::internal::GetTypeId;
using testing::internal::GetUnitTestImpl; using testing::internal::GetUnitTestImpl;
using testing::internal::GTestFlagSaver;
using testing::internal::ImplicitlyConvertible; using testing::internal::ImplicitlyConvertible;
using testing::internal::Int32; using testing::internal::Int32;
using testing::internal::Int32FromEnvOrDie; using testing::internal::Int32FromEnvOrDie;
...@@ -259,6 +263,8 @@ using testing::internal::IsAProtocolMessage; ...@@ -259,6 +263,8 @@ using testing::internal::IsAProtocolMessage;
using testing::internal::IsContainer; using testing::internal::IsContainer;
using testing::internal::IsContainerTest; using testing::internal::IsContainerTest;
using testing::internal::IsNotContainer; using testing::internal::IsNotContainer;
using testing::internal::kMaxRandomSeed;
using testing::internal::kTestTypeIdInGoogleTest;
using testing::internal::NativeArray; using testing::internal::NativeArray;
using testing::internal::OsStackTraceGetter; using testing::internal::OsStackTraceGetter;
using testing::internal::OsStackTraceGetterInterface; using testing::internal::OsStackTraceGetterInterface;
...@@ -280,12 +286,6 @@ using testing::internal::TestResultAccessor; ...@@ -280,12 +286,6 @@ using testing::internal::TestResultAccessor;
using testing::internal::UInt32; using testing::internal::UInt32;
using testing::internal::UnitTestImpl; using testing::internal::UnitTestImpl;
using testing::internal::WideStringToUtf8; using testing::internal::WideStringToUtf8;
using testing::internal::edit_distance::CalculateOptimalEdits;
using testing::internal::edit_distance::CreateUnifiedDiff;
using testing::internal::edit_distance::EditType;
using testing::internal::kMaxRandomSeed;
using testing::internal::kTestTypeIdInGoogleTest;
using testing::kMaxStackTraceDepth;
#if GTEST_HAS_STREAM_REDIRECTION #if GTEST_HAS_STREAM_REDIRECTION
using testing::internal::CaptureStdout; using testing::internal::CaptureStdout;
...@@ -3517,14 +3517,14 @@ TEST(EditDistance, TestCases) { ...@@ -3517,14 +3517,14 @@ TEST(EditDistance, TestCases) {
{__LINE__, "ABCD", "abcd", "////", {__LINE__, "ABCD", "abcd", "////",
"@@ -1,4 +1,4 @@\n-A\n-B\n-C\n-D\n+a\n+b\n+c\n+d\n"}, "@@ -1,4 +1,4 @@\n-A\n-B\n-C\n-D\n+a\n+b\n+c\n+d\n"},
// Path finding. // Path finding.
{__LINE__, "ABCDEFGH", "ABXEGH1", " -/ - +", {__LINE__, "ABCDEFGH", "ABXEGH1", " /- - +",
"@@ -1,8 +1,7 @@\n A\n B\n-C\n-D\n+X\n E\n-F\n G\n H\n+1\n"}, "@@ -1,8 +1,7 @@\n A\n B\n-C\n-D\n+X\n E\n-F\n G\n H\n+1\n"},
{__LINE__, "AAAABCCCC", "ABABCDCDC", "- / + / ", {__LINE__, "AAAABCCCC", "ABABCDCDC", " -/ + / ",
"@@ -1,9 +1,9 @@\n-A\n A\n-A\n+B\n A\n B\n C\n+D\n C\n-C\n+D\n C\n"}, "@@ -1,9 +1,9 @@\n A\n-A\n-A\n+B\n A\n B\n C\n+D\n C\n-C\n+D\n C\n"},
{__LINE__, "ABCDE", "BCDCD", "- +/", {__LINE__, "ABCDE", "BCDCD", "- /+",
"@@ -1,5 +1,5 @@\n-A\n B\n C\n D\n-E\n+C\n+D\n"}, "@@ -1,5 +1,5 @@\n-A\n B\n C\n D\n-E\n+C\n+D\n"},
{__LINE__, "ABCDEFGHIJKL", "BCDCDEFGJKLJK", "- ++ -- ++", {__LINE__, "ABCDEFGHIJKL", "BGDCDEFGJKLJK", "- ++ -- ++",
"@@ -1,4 +1,5 @@\n-A\n B\n+C\n+D\n C\n D\n" "@@ -1,4 +1,5 @@\n-A\n B\n+G\n+D\n C\n D\n"
"@@ -6,7 +7,7 @@\n F\n G\n-H\n-I\n J\n K\n L\n+J\n+K\n"}, "@@ -6,7 +7,7 @@\n F\n G\n-H\n-I\n J\n K\n L\n+J\n+K\n"},
{}}; {}};
for (const Case* c = kCases; c->left; ++c) { for (const Case* c = kCases; c->left; ++c) {
...@@ -3542,6 +3542,57 @@ TEST(EditDistance, TestCases) { ...@@ -3542,6 +3542,57 @@ TEST(EditDistance, TestCases) {
} }
} }
// Tests that we can run CalculateOptimalEdits for a large vector, i.e. we can
// compute diffs for large strings.
TEST(EditDistance, LargeVectorWithDiffs) {
const int kSize = 300000;
std::vector<size_t> left;
std::vector<size_t> right;
std::vector<EditType> expected(kSize, testing::internal::kEditMatch);
left.reserve(kSize);
right.reserve(kSize);
for (int i = 0; i < kSize; ++i) {
// Make the contents of the vectors unique. This greatly speeds up
// the algorithm, since it doesn't spend time finding matches for
// different alignments.
left.push_back(i);
right.push_back(i);
}
for (int i = 0; i < 10; ++i) {
right[i] = kSize + i;
expected[i] = testing::internal::kEditReplace;
right[kSize - i - 1] = kSize * 2 + i;
expected[kSize - i - 1] = testing::internal::kEditReplace;
}
size_t memory_usage;
EXPECT_EQ(CalculateOptimalEdits(left, right, &memory_usage), expected);
EXPECT_GT(memory_usage, kSize);
EXPECT_LT(memory_usage, kSize * 2);
}
// Tests that we can run CalculateOptimalEdits for two vectors N and M, where
// M = N plus additional junk at the end. The current algorithm only does O(M)
// "real" work in this case, but allocates some extra memory. We test that this
// is still fast enough for common cases, and we aren't allocating an
// excessive amount of extra memory.
TEST(EditDistance, LargeVectorWithTrailingJunk) {
const int kSize = 200000;
const int kAdditionalSize = 2000;
std::vector<size_t> left(kSize, 0);
std::vector<size_t> right(kSize + kAdditionalSize, 0);
std::vector<EditType> expected(kSize + kAdditionalSize,
testing::internal::kEditMatch);
for (int i = 0; i < kAdditionalSize; ++i) {
expected[i + kSize] = testing::internal::kEditAdd;
}
size_t memory_usage;
EXPECT_EQ(CalculateOptimalEdits(left, right, &memory_usage), expected);
EXPECT_GT(memory_usage, kSize);
EXPECT_LT(memory_usage, 6000000);
}
// Tests EqFailure(), used for implementing *EQ* assertions. // Tests EqFailure(), used for implementing *EQ* assertions.
TEST(AssertionTest, EqFailure) { TEST(AssertionTest, EqFailure) {
const std::string foo_val("5"), bar_val("6"); const std::string foo_val("5"), bar_val("6");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment