Commit 7ecdae7b authored by nagayama15's avatar nagayama15

Split comparators into source files

parent 6210551c
#ifndef INCLUDE_kyut_Comparator_hpp
#define INCLUDE_kyut_Comparator_hpp
#include <wasm.h>
namespace wasm {
[[nodiscard]] bool operator<(const Literal &a, const Literal &b);
[[nodiscard]] bool operator<(const Expression &a, const Expression &b);
[[nodiscard]] bool operator<(const ExpressionList &a, const ExpressionList &b);
} // namespace wasm
#include "Comparator.inl.hpp"
#endif // INCLUDE_kyut_Comparator_hpp
#ifndef INCLUDE_kyut_Comparator_inl_hpp
#define INCLUDE_kyut_Comparator_inl_hpp
#include <wasm.h>
#include <boost/optional.hpp>
namespace wasm {
[[nodiscard]] inline bool operator<(const Literal &a, const Literal &b) {
if (a.type != b.type) {
return a.type < b.type;
}
switch (a.type) {
case none:
return false;
case i32:
return a.geti32() < b.geti32();
case i64:
return a.geti64() < b.geti64();
case f32:
return a.getf32() < b.getf32();
case f64:
return a.getf64() < b.getf64();
case v128:
return a.getv128() < b.getv128();
case except_ref:
return false;
case unreachable:
return false;
default:
WASM_UNREACHABLE();
}
}
[[nodiscard]] inline bool operator<(const Expression &a, const Expression &b) {
constexpr auto opt = [](const Expression *p) -> boost::optional<const Expression &> {
if (p == nullptr) {
return boost::none;
}
return *p;
};
if (std::tie(a._id, a.type) != std::tie(b._id, b.type)) {
return std::tie(a._id, a.type) < std::tie(b._id, b.type);
}
switch (a._id) {
case Expression::BlockId: {
const auto &x = *a.cast<Block>();
const auto &y = *b.cast<Block>();
return x.list < y.list;
}
case Expression::IfId: {
const auto &x = *a.cast<If>();
const auto &y = *b.cast<If>();
return std::forward_as_tuple(*x.condition, *x.ifTrue, opt(x.ifFalse)) <
std::forward_as_tuple(*y.condition, *y.ifTrue, opt(y.ifFalse));
}
case Expression::LoopId: {
const auto &x = *a.cast<Loop>();
const auto &y = *b.cast<Loop>();
return *x.body < *y.body;
}
case Expression::BreakId: {
const auto &x = *a.cast<Break>();
const auto &y = *b.cast<Break>();
return std::forward_as_tuple(opt(x.value), opt(x.condition)) <
std::forward_as_tuple(opt(y.value), opt(y.condition));
}
case Expression::SwitchId: {
const auto &x = *a.cast<Switch>();
const auto &y = *b.cast<Switch>();
return std::forward_as_tuple(opt(x.value), opt(x.condition)) <
std::forward_as_tuple(opt(y.value), opt(y.condition));
}
case Expression::CallId: {
const auto &x = *a.cast<Call>();
const auto &y = *b.cast<Call>();
return std::forward_as_tuple(x.operands, x.target) < std::forward_as_tuple(y.operands, y.target);
}
case Expression::CallIndirectId: {
const auto &x = *a.cast<CallIndirect>();
const auto &y = *b.cast<CallIndirect>();
return std::forward_as_tuple(x.operands, *x.target) < std::forward_as_tuple(y.operands, *y.target);
}
case Expression::GetLocalId: {
const auto &x = *a.cast<GetLocal>();
const auto &y = *b.cast<GetLocal>();
return x.index < y.index;
}
case Expression::SetLocalId: {
const auto &x = *a.cast<SetLocal>();
const auto &y = *b.cast<SetLocal>();
return std::tie(x.index, *x.value) < std::tie(y.index, *y.value);
}
case Expression::GetGlobalId: {
const auto &x = *a.cast<GetGlobal>();
const auto &y = *b.cast<GetGlobal>();
return x.name < y.name;
}
case Expression::SetGlobalId: {
const auto &x = *a.cast<SetGlobal>();
const auto &y = *b.cast<SetGlobal>();
return std::tie(*x.value, x.name) < std::tie(*y.value, y.name);
}
case Expression::LoadId: {
const auto &x = *a.cast<Load>();
const auto &y = *b.cast<Load>();
return *x.ptr < *y.ptr;
}
case Expression::StoreId: {
const auto &x = *a.cast<Store>();
const auto &y = *b.cast<Store>();
return std::tie(*x.ptr, *x.value) < std::tie(*y.ptr, *y.value);
}
case Expression::ConstId: {
const auto &x = *a.cast<Const>();
const auto &y = *b.cast<Const>();
return x.value < y.value;
}
case Expression::UnaryId: {
const auto &x = *a.cast<Unary>();
const auto &y = *b.cast<Unary>();
return std::tie(x.op, *x.value) < std::tie(y.op, *y.value);
}
case Expression::BinaryId: {
const auto &x = *a.cast<Binary>();
const auto &y = *b.cast<Binary>();
// Normalize expression
constexpr auto normalize =
[](const wasm::Binary &bin) -> std::tuple<wasm::BinaryOp, wasm::Expression &, wasm::Expression &> {
if (!kyut::isCommutative(bin.op)) {
// Noncommutative
return {bin.op, *bin.left, *bin.right};
}
// Commutative
if (*bin.right < *bin.left) {
return {*kyut::getSwappedPredicate(bin.op), *bin.right, *bin.left};
} else {
return {bin.op, *bin.left, *bin.right};
}
};
return normalize(x) < normalize(y);
}
case Expression::SelectId: {
const auto &x = *a.cast<Select>();
const auto &y = *b.cast<Select>();
return std::tie(*x.ifTrue, *x.ifFalse, *x.condition) < std::tie(*y.ifTrue, *y.ifFalse, *y.condition);
}
case Expression::DropId: {
const auto &x = *a.cast<Drop>();
const auto &y = *b.cast<Drop>();
return *x.value < *y.value;
}
case Expression::ReturnId: {
const auto &x = *a.cast<Return>();
const auto &y = *b.cast<Return>();
return opt(x.value) < opt(y.value);
}
case Expression::HostId: {
const auto &x = *a.cast<Host>();
const auto &y = *b.cast<Host>();
return std::tie(x.op, x.operands) < std::tie(y.op, y.operands);
}
case Expression::NopId: {
return false;
}
case Expression::UnreachableId: {
return false;
}
case Expression::AtomicRMWId: {
const auto &x = *a.cast<AtomicRMW>();
const auto &y = *b.cast<AtomicRMW>();
return std::tie(x.op, x.bytes, x.offset.addr, *x.ptr, *x.value) <
std::tie(y.op, y.bytes, y.offset.addr, *y.ptr, *y.value);
}
case Expression::AtomicCmpxchgId: {
const auto &x = *a.cast<AtomicCmpxchg>();
const auto &y = *b.cast<AtomicCmpxchg>();
return std::tie(x.bytes, x.offset.addr, *x.ptr, *x.expected, *x.replacement) <
std::tie(y.bytes, y.offset.addr, *y.ptr, *y.expected, *y.replacement);
}
case Expression::AtomicWaitId: {
const auto &x = *a.cast<AtomicWait>();
const auto &y = *b.cast<AtomicWait>();
return std::tie(x.offset.addr, *x.ptr, *x.expected, *x.timeout, x.expectedType) <
std::tie(y.offset.addr, *y.ptr, *y.expected, *y.timeout, y.expectedType);
}
case Expression::AtomicNotifyId: {
const auto &x = *a.cast<AtomicNotify>();
const auto &y = *b.cast<AtomicNotify>();
return std::tie(x.offset.addr, *x.ptr, *x.notifyCount) < std::tie(y.offset.addr, *y.ptr, *y.notifyCount);
}
case Expression::SIMDExtractId: {
const auto &x = *a.cast<SIMDExtract>();
const auto &y = *b.cast<SIMDExtract>();
return std::tie(x.op, *x.vec, x.index) < std::tie(y.op, *y.vec, y.index);
}
case Expression::SIMDReplaceId: {
const auto &x = *a.cast<SIMDReplace>();
const auto &y = *b.cast<SIMDReplace>();
return std::tie(x.op, *x.vec, x.index, *x.value) < std::tie(y.op, *y.vec, y.index, *y.value);
}
case Expression::SIMDShuffleId: {
const auto &x = *a.cast<SIMDShuffle>();
const auto &y = *b.cast<SIMDShuffle>();
return std::tie(*x.left, *x.right, x.mask) < std::tie(*y.left, *y.right, y.mask);
}
case Expression::SIMDBitselectId: {
const auto &x = *a.cast<SIMDBitselect>();
const auto &y = *b.cast<SIMDBitselect>();
return std::tie(*x.left, *x.right, *x.cond) < std::tie(*y.left, *y.right, *y.cond);
}
case Expression::SIMDShiftId: {
const auto &x = *a.cast<SIMDShift>();
const auto &y = *b.cast<SIMDShift>();
return std::tie(x.op, *x.vec, *x.shift) < std::tie(y.op, *y.vec, *y.shift);
}
case Expression::MemoryInitId: {
const auto &x = *a.cast<MemoryInit>();
const auto &y = *b.cast<MemoryInit>();
return std::tie(x.segment, *x.dest, *x.offset, *x.size) < std::tie(y.segment, *y.dest, *y.offset, *y.size);
}
case Expression::DataDropId: {
const auto &x = *a.cast<DataDrop>();
const auto &y = *b.cast<DataDrop>();
return x.segment < y.segment;
}
case Expression::MemoryCopyId: {
const auto &x = *a.cast<MemoryCopy>();
const auto &y = *b.cast<MemoryCopy>();
return std::tie(*x.dest, *x.source, *x.size) < std::tie(*y.dest, *y.source, *y.size);
}
case Expression::MemoryFillId: {
const auto &x = *a.cast<MemoryFill>();
const auto &y = *b.cast<MemoryFill>();
return std::tie(*x.dest, *x.value, *x.size) < std::tie(*y.dest, *y.value, *y.size);
}
default:
WASM_UNREACHABLE();
}
} // namespace wasm
[[nodiscard]] inline bool operator<(const ExpressionList &a, const ExpressionList &b) {
return std::lexicographical_compare(
std::begin(a), std::end(a), std::begin(b), std::end(b), [](const auto &x, const auto &y) {
return *x < *y;
});
}
} // namespace wasm
#endif // INCLUDE_kyut_Comparator_inl_hpp
...@@ -7,48 +7,7 @@ ...@@ -7,48 +7,7 @@
#include "../BitStreamWriter.hpp" #include "../BitStreamWriter.hpp"
#include "../CircularBitStreamReader.hpp" #include "../CircularBitStreamReader.hpp"
#include "../Commutativity.hpp" #include "../Commutativity.hpp"
#include "../Comparator.hpp"
// Expression types
#define EXPR_TYPES() \
EXPR_TYPE(Block) \
EXPR_TYPE(If) \
EXPR_TYPE(Loop) \
EXPR_TYPE(Break) \
EXPR_TYPE(Switch) \
EXPR_TYPE(Call) \
EXPR_TYPE(CallIndirect) \
EXPR_TYPE(GetLocal) \
EXPR_TYPE(SetLocal) \
EXPR_TYPE(GetGlobal) \
EXPR_TYPE(SetGlobal) \
EXPR_TYPE(Load) \
EXPR_TYPE(Store) \
EXPR_TYPE(Const) \
EXPR_TYPE(Unary) \
EXPR_TYPE(Binary) \
EXPR_TYPE(Select) \
EXPR_TYPE(Drop) \
EXPR_TYPE(Return) \
EXPR_TYPE(Host) \
EXPR_TYPE(Nop) \
EXPR_TYPE(Unreachable) \
EXPR_TYPE(AtomicRMW) \
EXPR_TYPE(AtomicCmpxchg) \
EXPR_TYPE(AtomicWait) \
EXPR_TYPE(AtomicNotify) \
EXPR_TYPE(SIMDExtract) \
EXPR_TYPE(SIMDReplace) \
EXPR_TYPE(SIMDShuffle) \
EXPR_TYPE(SIMDBitselect) \
EXPR_TYPE(SIMDShift) \
EXPR_TYPE(MemoryInit) \
EXPR_TYPE(DataDrop) \
EXPR_TYPE(MemoryCopy) \
EXPR_TYPE(MemoryFill)
namespace wasm {
bool operator<(const wasm::Expression &a, const wasm::Expression &b);
} // namespace wasm
namespace kyut::watermarker { namespace kyut::watermarker {
namespace { namespace {
...@@ -384,193 +343,3 @@ namespace kyut::watermarker { ...@@ -384,193 +343,3 @@ namespace kyut::watermarker {
return stream.tell() - posStart; return stream.tell() - posStart;
} }
} // namespace kyut::watermarker } // namespace kyut::watermarker
namespace wasm {
// Comparator
boost::optional<const wasm::Expression &> opt(const wasm::Expression *expr) {
if (expr == nullptr) {
return boost::none;
}
return *expr;
}
bool operator<(const wasm::Literal &a, const wasm::Literal &b) {
if (a.type != b.type) {
return a.type < b.type;
}
switch (a.type) {
case wasm::none:
return false;
case wasm::i32:
return a.geti32() < b.geti32();
case wasm::i64:
return a.geti64() < b.geti64();
case wasm::f32:
return a.getf32() < b.getf32();
case wasm::f64:
return a.getf64() < b.getf64();
case wasm::v128:
return a.getv128() < b.getv128();
case wasm::except_ref:
return false;
case wasm::unreachable:
return false;
default:
WASM_UNREACHABLE();
}
}
bool operator<(const wasm::ExpressionList &a, const wasm::ExpressionList &b) {
return std::lexicographical_compare(
std::begin(a), std::end(a), std::begin(b), std::end(b), [](const auto &x, const auto &y) {
return *x < *y;
});
}
bool operator<(const wasm::Block &a, const wasm::Block &b) {
return a.list < b.list;
}
bool operator<(const wasm::If &a, const wasm::If &b) {
return std::forward_as_tuple(*a.condition, *a.ifTrue, opt(a.ifFalse)) <
std::forward_as_tuple(*b.condition, *b.ifTrue, opt(b.ifFalse));
}
bool operator<(const wasm::Loop &a, const wasm::Loop &b) {
return *a.body < *b.body;
}
bool operator<(const wasm::Break &a, const wasm::Break &b) {
return std::forward_as_tuple(opt(a.value), opt(a.condition)) <
std::forward_as_tuple(opt(b.value), opt(b.condition));
}
bool operator<(const wasm::Switch &a, const wasm::Switch &b) {
return std::forward_as_tuple(opt(a.value), opt(a.condition)) <
std::forward_as_tuple(opt(b.value), opt(b.condition));
}
bool operator<(const wasm::Call &a, const wasm::Call &b) {
return std::forward_as_tuple(a.operands, a.target) < std::forward_as_tuple(b.operands, b.target);
}
bool operator<(const wasm::CallIndirect &a, const wasm::CallIndirect &b) {
return std::forward_as_tuple(a.operands, *a.target) < std::forward_as_tuple(b.operands, *b.target);
}
bool operator<(const wasm::GetLocal &a, const wasm::GetLocal &b) {
return a.index < b.index;
}
bool operator<(const wasm::SetLocal &a, const wasm::SetLocal &b) {
return std::tie(a.index, *a.value) < std::tie(b.index, *b.value);
}
bool operator<(const wasm::GetGlobal &a, const wasm::GetGlobal &b) {
return a.name < b.name;
}
bool operator<(const wasm::SetGlobal &a, const wasm::SetGlobal &b) {
return std::tie(*a.value, a.name) < std::tie(*b.value, b.name);
}
bool operator<(const wasm::Load &a, const wasm::Load &b) {
return *a.ptr < *b.ptr;
}
bool operator<(const wasm::Store &a, const wasm::Store &b) {
return std::tie(*a.ptr, *a.value) < std::tie(*b.ptr, *b.value);
}
bool operator<(const wasm::Const &a, const wasm::Const &b) {
return a.value < b.value;
}
bool operator<(const wasm::Unary &a, const wasm::Unary &b) {
return std::tie(a.op, *a.value) < std::tie(b.op, *b.value);
}
bool operator<(const wasm::Binary &a, const wasm::Binary &b) {
// Normalize expression
constexpr auto normalize =
[](const wasm::Binary &x) -> std::tuple<wasm::BinaryOp, wasm::Expression &, wasm::Expression &> {
if (!kyut::isCommutative(x.op)) {
// Noncommutative
return {x.op, *x.left, *x.right};
}
// Commutative
if (*x.right < *x.left) {
return {*kyut::getSwappedPredicate(x.op), *x.right, *x.left};
} else {
return {x.op, *x.left, *x.right};
}
};
return normalize(a) < normalize(b);
}
bool operator<(const wasm::Select &a, const wasm::Select &b) {
return std::tie(*a.condition, *a.ifTrue, *a.ifFalse) < std::tie(*b.condition, *b.ifTrue, *b.ifFalse);
}
bool operator<(const wasm::Drop &a, const wasm::Drop &b) {
return *a.value < *b.value;
}
bool operator<(const wasm::Return &a, const wasm::Return &b) {
return opt(a.value) < opt(b.value);
}
bool operator<(const wasm::Host &a, const wasm::Host &b) {
return std::tie(a.op, a.operands) < std::tie(b.op, b.operands);
}
bool operator<([[maybe_unused]] const wasm::Nop &a, [[maybe_unused]] const wasm::Nop &b) {
return false;
}
bool operator<([[maybe_unused]] const wasm::Unreachable &a, [[maybe_unused]] const wasm::Unreachable &b) {
return false;
}
bool operator<(const wasm::AtomicRMW &a, const wasm::AtomicRMW &b) {
return std::tie(a.op, a.bytes, a.offset.addr, *a.ptr, *a.value) <
std::tie(b.op, b.bytes, b.offset.addr, *b.ptr, *b.value);
}
bool operator<(const wasm::AtomicCmpxchg &a, const wasm::AtomicCmpxchg &b) {
return std::tie(a.bytes, a.offset.addr, *a.ptr, *a.expected, *a.replacement) <
std::tie(b.bytes, b.offset.addr, *b.ptr, *b.expected, *b.replacement);
}
bool operator<(const wasm::AtomicWait &a, const wasm::AtomicWait &b) {
return std::tie(a.offset.addr, *a.ptr, *a.expected, *a.timeout, a.expectedType) <
std::tie(b.offset.addr, *b.ptr, *b.expected, *b.timeout, b.expectedType);
}
bool operator<(const wasm::AtomicNotify &a, const wasm::AtomicNotify &b) {
return std::tie(a.offset.addr, *a.ptr, *a.notifyCount) < std::tie(b.offset.addr, *b.ptr, *b.notifyCount);
}
bool operator<(const wasm::SIMDExtract &a, const wasm::SIMDExtract &b) {
return std::tie(a.op, *a.vec, a.index) < std::tie(b.op, *b.vec, b.index);
}
bool operator<(const wasm::SIMDReplace &a, const wasm::SIMDReplace &b) {
return std::tie(a.op, *a.vec, a.index, *a.value) < std::tie(b.op, *b.vec, b.index, *b.value);
}
bool operator<(const wasm::SIMDShuffle &a, const wasm::SIMDShuffle &b) {
return std::tie(*a.left, *a.right, a.mask) < std::tie(*b.left, *b.right, b.mask);
}
bool operator<(const wasm::SIMDBitselect &a, const wasm::SIMDBitselect &b) {
return std::tie(*a.left, *a.right, *a.cond) < std::tie(*b.left, *b.right, *b.cond);
}
bool operator<(const wasm::SIMDShift &a, const wasm::SIMDShift &b) {
return std::tie(a.op, *a.vec, *a.shift) < std::tie(b.op, *b.vec, *b.shift);
}
bool operator<(const wasm::MemoryInit &a, const wasm::MemoryInit &b) {
return std::tie(a.segment, *a.dest, *a.offset, *a.size) < std::tie(b.segment, *b.dest, *b.offset, *b.size);
}
bool operator<(const wasm::DataDrop &a, const wasm::DataDrop &b) {
return a.segment < b.segment;
}
bool operator<(const wasm::MemoryCopy &a, const wasm::MemoryCopy &b) {
return std::tie(*a.dest, *a.source, *a.size) < std::tie(*b.dest, *b.source, *b.size);
}
bool operator<(const wasm::MemoryFill &a, const wasm::MemoryFill &b) {
return std::tie(*a.dest, *a.value, *a.size) < std::tie(*b.dest, *b.value, *b.size);
}
bool operator<(const wasm::Expression &a, const wasm::Expression &b) {
if (std::tie(a._id, a.type) != std::tie(b._id, b.type)) {
return std::tie(a._id, a.type) < std::tie(b._id, b.type);
}
switch (a._id) {
#define EXPR_TYPE(name) \
case ::wasm::Expression::name##Id: \
return (*a.cast<::wasm::name>()) < (*b.cast<::wasm::name>());
EXPR_TYPES()
#undef EXPR_TYPE
default:
WASM_UNREACHABLE();
}
}
} // namespace wasm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment