Commit 4f2ec866 authored by nagayama15's avatar nagayama15

Fix comparator bug

parent 1e3f6211
...@@ -439,7 +439,7 @@ namespace kyut::watermarker { ...@@ -439,7 +439,7 @@ namespace kyut::watermarker {
return (std::max)(loEffect, hiEffect); return (std::max)(loEffect, hiEffect);
} }
f(*expr); f(*expr, *lo, *hi);
return (std::max)(loEffect, hiEffect); return (std::max)(loEffect, hiEffect);
} }
...@@ -587,15 +587,15 @@ namespace kyut::watermarker { ...@@ -587,15 +587,15 @@ namespace kyut::watermarker {
// Embed watermarks // Embed watermarks
const auto posStart = stream.tell(); const auto posStart = stream.tell();
OperandSwappingVisitor visitor{[&stream](wasm::Binary &expr) { OperandSwappingVisitor visitor{
// Embed watermarks by swapping operands [&](wasm::Binary &expr, wasm::Expression &lo, [[maybe_unused]] wasm::Expression &hi) {
const bool bit = stream.readBit(); // Embed watermarks by swapping operands
const auto &lo = (std::min)(*expr.left, *expr.right); const bool bit = stream.readBit();
if (bit == (expr.left == &lo)) { if (bit == (expr.left == &lo)) {
swapOperands(expr); swapOperands(expr);
} }
}}; }};
for (const auto f : functions) { for (const auto f : functions) {
visitor.visitFunction(f); visitor.visitFunction(f);
...@@ -620,12 +620,11 @@ namespace kyut::watermarker { ...@@ -620,12 +620,11 @@ namespace kyut::watermarker {
// Extract watermarks // Extract watermarks
const auto posStart = stream.tell(); const auto posStart = stream.tell();
OperandSwappingVisitor visitor{[&stream](wasm::Binary &expr) { OperandSwappingVisitor visitor{
// Extract watermarks from the order of operands [&](wasm::Binary &expr, wasm::Expression &lo, [[maybe_unused]] wasm::Expression &hi) {
const auto &lo = (std::min)(*expr.left, *expr.right); // Extract watermarks from the order of operands
stream.writeBit(expr.left != &lo);
stream.writeBit(expr.left != &lo); }};
}};
for (const auto f : functions) { for (const auto f : functions) {
visitor.visitFunction(f); visitor.visitFunction(f);
...@@ -728,17 +727,23 @@ namespace wasm { ...@@ -728,17 +727,23 @@ namespace wasm {
return std::tie(a.op, *a.value) < std::tie(b.op, *b.value); return std::tie(a.op, *a.value) < std::tie(b.op, *b.value);
} }
bool operator<(const wasm::Binary &a, const wasm::Binary &b) { bool operator<(const wasm::Binary &a, const wasm::Binary &b) {
if (a.op != b.op) { // Normalize expression
return a.op < b.op; constexpr auto normalize =
} [](const wasm::Binary &x) -> std::tuple<wasm::BinaryOp, wasm::Expression &, wasm::Expression &> {
if (!kyut::watermarker::isCommutative(x.op)) {
// Noncommutative
return {x.op, *x.left, *x.right};
}
if (!kyut::watermarker::isCommutative(a.op)) { // Commutative
// Noncommutative if (*x.right < *x.left) {
return std::tie(*a.left, *a.right) < std::tie(*b.left, *b.right); return {*kyut::watermarker::getSwappedPredicate(x.op), *x.right, *x.left};
} } else {
return {x.op, *x.left, *x.right};
}
};
// Commutative return normalize(a) < normalize(b);
return std::minmax(*a.left, *a.right) < std::minmax(*b.left, *b.right);
} }
bool operator<(const wasm::Select &a, const wasm::Select &b) { bool operator<(const wasm::Select &a, const wasm::Select &b) {
return std::tie(*a.condition, *a.ifTrue, *a.ifFalse) < std::tie(*b.condition, *b.ifTrue, *b.ifFalse); return std::tie(*a.condition, *a.ifTrue, *a.ifFalse) < std::tie(*b.condition, *b.ifTrue, *b.ifFalse);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment