Commit 0d6a8dfd authored by nagayama15's avatar nagayama15

🚧 Define the node orders

parent 704cb19f
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
#include <algorithm> #include <algorithm>
#include <boost/optional.hpp>
#include "../BitStreamWriter.hpp" #include "../BitStreamWriter.hpp"
#include "../CircularBitStreamReader.hpp" #include "../CircularBitStreamReader.hpp"
namespace kyut::watermarker {
namespace {
// Expression types // Expression types
#define EXPR_TYPES() \ #define EXPR_TYPES() \
EXPR_TYPE(Block) \ EXPR_TYPE(Block) \
...@@ -45,6 +45,12 @@ namespace kyut::watermarker { ...@@ -45,6 +45,12 @@ namespace kyut::watermarker {
EXPR_TYPE(MemoryCopy) \ EXPR_TYPE(MemoryCopy) \
EXPR_TYPE(MemoryFill) EXPR_TYPE(MemoryFill)
namespace wasm {
bool operator<(const wasm::Expression &a, const wasm::Expression &b);
} // namespace wasm
namespace kyut::watermarker {
namespace {
enum class SideEffect : std::uint32_t { enum class SideEffect : std::uint32_t {
none = 0, none = 0,
readOnly = 1, readOnly = 1,
...@@ -412,6 +418,11 @@ namespace kyut::watermarker { ...@@ -412,6 +418,11 @@ namespace kyut::watermarker {
return (std::max)(embedExpression(expr.left, stream), embedExpression(expr.right, stream)); return (std::max)(embedExpression(expr.left, stream), embedExpression(expr.right, stream));
} }
if (!(*expr.left < *expr.right) && !(*expr.right < *expr.left)) {
// If both sides are the same or cannot be ordered, skip embedding
return (std::max)(embedExpression(expr.left, stream), embedExpression(expr.right, stream));
}
// TODO: implement watermarking // TODO: implement watermarking
const auto leftSideEffect = embedExpression(expr.left, stream); const auto leftSideEffect = embedExpression(expr.left, stream);
const auto rightSideEffect = embedExpression(expr.right, stream); const auto rightSideEffect = embedExpression(expr.right, stream);
...@@ -589,3 +600,187 @@ namespace kyut::watermarker { ...@@ -589,3 +600,187 @@ namespace kyut::watermarker {
WASM_UNREACHABLE(); WASM_UNREACHABLE();
} }
} // namespace kyut::watermarker } // namespace kyut::watermarker
namespace wasm {
// Comparator
boost::optional<const wasm::Expression &> opt(const wasm::Expression *expr) {
if (expr == nullptr) {
return boost::none;
}
return *expr;
}
bool operator<(const wasm::Literal &a, const wasm::Literal &b) {
if (a.type != b.type) {
return a.type < b.type;
}
switch (a.type) {
case wasm::none:
return false;
case wasm::i32:
return a.geti32() < b.geti32();
case wasm::i64:
return a.geti64() < b.geti64();
case wasm::f32:
return a.getf32() < b.getf32();
case wasm::f64:
return a.getf64() < b.getf64();
case wasm::v128:
return a.getv128() < b.getv128();
case wasm::except_ref:
return false;
case wasm::unreachable:
return false;
default:
WASM_UNREACHABLE();
}
}
bool operator<(const wasm::ExpressionList &a, const wasm::ExpressionList &b) {
return std::lexicographical_compare(
std::begin(a), std::end(a), std::begin(b), std::end(b), [](const auto &x, const auto &y) {
return *x < *y;
});
}
bool operator<(const wasm::Block &a, const wasm::Block &b) {
return a.list < b.list;
}
bool operator<(const wasm::If &a, const wasm::If &b) {
return std::forward_as_tuple(*a.condition, *a.ifTrue, opt(a.ifFalse)) <
std::forward_as_tuple(*b.condition, *b.ifTrue, opt(b.ifFalse));
}
bool operator<(const wasm::Loop &a, const wasm::Loop &b) {
return *a.body < *b.body;
}
bool operator<(const wasm::Break &a, const wasm::Break &b) {
return std::forward_as_tuple(opt(a.value), opt(a.condition)) <
std::forward_as_tuple(opt(b.value), opt(b.condition));
}
bool operator<(const wasm::Switch &a, const wasm::Switch &b) {
return std::forward_as_tuple(opt(a.value), opt(a.condition)) <
std::forward_as_tuple(opt(b.value), opt(b.condition));
}
bool operator<(const wasm::Call &a, const wasm::Call &b) {
return std::forward_as_tuple(a.operands, a.target) < std::forward_as_tuple(b.operands, b.target);
}
bool operator<(const wasm::CallIndirect &a, const wasm::CallIndirect &b) {
return std::forward_as_tuple(a.operands, *a.target) < std::forward_as_tuple(b.operands, *b.target);
}
bool operator<(const wasm::GetLocal &a, const wasm::GetLocal &b) {
return a.index < b.index;
}
bool operator<(const wasm::SetLocal &a, const wasm::SetLocal &b) {
return std::tie(a.index, *a.value) < std::tie(b.index, *b.value);
}
bool operator<(const wasm::GetGlobal &a, const wasm::GetGlobal &b) {
return a.name < b.name;
}
bool operator<(const wasm::SetGlobal &a, const wasm::SetGlobal &b) {
return std::tie(*a.value, a.name) < std::tie(*b.value, b.name);
}
bool operator<(const wasm::Load &a, const wasm::Load &b) {
return *a.ptr < *b.ptr;
}
bool operator<(const wasm::Store &a, const wasm::Store &b) {
return std::tie(*a.ptr, *a.value) < std::tie(*b.ptr, *b.value);
}
bool operator<(const wasm::Const &a, const wasm::Const &b) {
return a.value < b.value;
}
bool operator<(const wasm::Unary &a, const wasm::Unary &b) {
return std::tie(a.op, *a.value) < std::tie(b.op, *b.value);
}
bool operator<(const wasm::Binary &a, const wasm::Binary &b) {
if (a.op != b.op) {
return a.op < b.op;
}
if (!kyut::watermarker::isCommutative(a.op)) {
// Noncommutative
return std::tie(*a.left, *a.right) < std::tie(*b.left, *b.right);
}
// Commutative
return std::minmax(*a.left, *a.right) < std::minmax(*b.left, *b.right);
}
bool operator<(const wasm::Select &a, const wasm::Select &b) {
return std::tie(a.condition, *a.ifTrue, *a.ifFalse) < std::tie(b.condition, *b.ifTrue, *b.ifFalse);
}
bool operator<(const wasm::Drop &a, const wasm::Drop &b) {
return *a.value < *b.value;
}
bool operator<(const wasm::Return &a, const wasm::Return &b) {
return opt(a.value) < opt(b.value);
}
bool operator<(const wasm::Host &a, const wasm::Host &b) {
return std::tie(a.op, a.operands) < std::tie(b.op, b.operands);
}
bool operator<([[maybe_unused]] const wasm::Nop &a, [[maybe_unused]] const wasm::Nop &b) {
return false;
}
bool operator<([[maybe_unused]] const wasm::Unreachable &a, [[maybe_unused]] const wasm::Unreachable &b) {
return false;
}
bool operator<(const wasm::AtomicRMW &a, const wasm::AtomicRMW &b) {
return std::tie(a.op, a.bytes, a.offset, *a.ptr, *a.value) <
std::tie(b.op, b.bytes, b.offset, *b.ptr, *b.value);
}
bool operator<(const wasm::AtomicCmpxchg &a, const wasm::AtomicCmpxchg &b) {
return std::tie(a.bytes, a.offset, *a.ptr, *a.expected, *a.replacement) <
std::tie(b.bytes, b.offset, *b.ptr, *b.expected, *b.replacement);
}
bool operator<(const wasm::AtomicWait &a, const wasm::AtomicWait &b) {
return std::tie(a.offset, *a.ptr, *a.expected, *a.timeout, a.expectedType) <
std::tie(b.offset, *b.ptr, *b.expected, *b.timeout, b.expectedType);
}
bool operator<(const wasm::AtomicNotify &a, const wasm::AtomicNotify &b) {
return std::tie(a.offset, *a.ptr, *a.notifyCount) < std::tie(b.offset, *b.ptr, *b.notifyCount);
}
bool operator<(const wasm::SIMDExtract &a, const wasm::SIMDExtract &b) {
return std::tie(a.op, *a.vec, a.index) < std::tie(b.op, *b.vec, b.index);
}
bool operator<(const wasm::SIMDReplace &a, const wasm::SIMDReplace &b) {
return std::tie(a.op, *a.vec, a.index, *a.value) < std::tie(b.op, *b.vec, b.index, *b.value);
}
bool operator<(const wasm::SIMDShuffle &a, const wasm::SIMDShuffle &b) {
return std::tie(*a.left, *a.right, a.mask) < std::tie(*b.left, *b.right, b.mask);
}
bool operator<(const wasm::SIMDBitselect &a, const wasm::SIMDBitselect &b) {
return std::tie(*a.left, *a.right, *a.cond) < std::tie(*b.left, *b.right, *b.cond);
}
bool operator<(const wasm::SIMDShift &a, const wasm::SIMDShift &b) {
return std::tie(a.op, *a.vec, *a.shift) < std::tie(b.op, *b.vec, *b.shift);
}
bool operator<(const wasm::MemoryInit &a, const wasm::MemoryInit &b) {
return std::tie(a.segment, *a.dest, *a.offset, *a.size) < std::tie(b.segment, *b.dest, *b.offset, *b.size);
}
bool operator<(const wasm::DataDrop &a, const wasm::DataDrop &b) {
return a.segment < b.segment;
}
bool operator<(const wasm::MemoryCopy &a, const wasm::MemoryCopy &b) {
return std::tie(*a.dest, *a.source, *a.size) < std::tie(*b.dest, *b.source, *b.size);
}
bool operator<(const wasm::MemoryFill &a, const wasm::MemoryFill &b) {
return std::tie(*a.dest, *a.value, *a.size) < std::tie(*b.dest, *b.value, *b.size);
}
bool operator<(const wasm::Expression &a, const wasm::Expression &b) {
if (std::tie(a._id, a.type) != std::tie(b._id, b.type)) {
return std::tie(a._id, a.type) < std::tie(b._id, b.type);
}
switch (a._id) {
#define EXPR_TYPE(name) \
case ::wasm::Expression::name##Id: \
return (*a.cast<::wasm::name>()) < (*b.cast<::wasm::name>());
EXPR_TYPES()
#undef EXPR_TYPE
default:
WASM_UNREACHABLE();
}
}
} // namespace wasm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment