Verified Commit 5b0c3ad2 authored by nagayama15's avatar nagayama15

feat: implement expression order and fix the function reordering method to use...

feat: implement expression order and fix the function reordering method to use it instead of their n
parent d4e25c8d
#ifndef INCLUDE_kyut_Commutativity_hpp
#define INCLUDE_kyut_Commutativity_hpp
#include <boost/optional.hpp>
#include "wasm.h"
namespace kyut {
inline boost::optional<wasm::BinaryOp> swapped_binary_op(wasm::BinaryOp op) {
static_assert(wasm::InvalidBinary == 178);
switch (op) {
// Commutative operators
case wasm::AddInt32:
case wasm::MulInt32:
case wasm::AndInt32:
case wasm::OrInt32:
case wasm::XorInt32:
case wasm::EqInt32:
case wasm::NeInt32:
case wasm::AddInt64:
case wasm::MulInt64:
case wasm::AndInt64:
case wasm::OrInt64:
case wasm::XorInt64:
case wasm::EqInt64:
case wasm::NeInt64:
case wasm::AddFloat32:
case wasm::MulFloat32:
case wasm::MinFloat32:
case wasm::MaxFloat32:
case wasm::EqFloat32:
case wasm::NeFloat32:
case wasm::AddFloat64:
case wasm::MulFloat64:
case wasm::MinFloat64:
case wasm::MaxFloat64:
case wasm::EqFloat64:
case wasm::NeFloat64:
return op;
// Relarational operators
case wasm::LtSInt32:
return wasm::GtSInt32;
case wasm::LtUInt32:
return wasm::GtUInt32;
case wasm::LeSInt32:
return wasm::GeSInt32;
case wasm::LeUInt32:
return wasm::GeUInt32;
case wasm::GtSInt32:
return wasm::LtSInt32;
case wasm::GtUInt32:
return wasm::LtUInt32;
case wasm::GeSInt32:
return wasm::LeSInt32;
case wasm::GeUInt32:
return wasm::LeUInt32;
case wasm::LtSInt64:
return wasm::GtSInt64;
case wasm::LtUInt64:
return wasm::GtUInt64;
case wasm::LeSInt64:
return wasm::GeSInt64;
case wasm::LeUInt64:
return wasm::GeUInt64;
case wasm::GtSInt64:
return wasm::LtSInt64;
case wasm::GtUInt64:
return wasm::LtUInt64;
case wasm::GeSInt64:
return wasm::LeSInt64;
case wasm::GeUInt64:
return wasm::LeUInt64;
case wasm::LtFloat32:
return wasm::GtFloat32;
case wasm::LeFloat32:
return wasm::GeFloat32;
case wasm::GtFloat32:
return wasm::LtFloat32;
case wasm::GeFloat32:
return wasm::LeFloat32;
case wasm::LtFloat64:
return wasm::GtFloat64;
case wasm::LeFloat64:
return wasm::GeFloat64;
case wasm::GtFloat64:
return wasm::LtFloat64;
case wasm::GeFloat64:
return wasm::LeFloat64;
// Commutative SIMD operators
case wasm::EqVecI8x16:
case wasm::NeVecI8x16:
case wasm::EqVecI16x8:
case wasm::NeVecI16x8:
case wasm::EqVecI32x4:
case wasm::NeVecI32x4:
case wasm::EqVecF32x4:
case wasm::NeVecF32x4:
case wasm::EqVecF64x2:
case wasm::NeVecF64x2:
case wasm::AndVec128:
case wasm::OrVec128:
case wasm::XorVec128:
case wasm::AddVecI8x16:
case wasm::AddSatSVecI8x16:
case wasm::AddSatUVecI8x16:
case wasm::MulVecI8x16:
case wasm::MinSVecI8x16:
case wasm::MinUVecI8x16:
case wasm::MaxSVecI8x16:
case wasm::MaxUVecI8x16:
case wasm::AddVecI16x8:
case wasm::AddSatSVecI16x8:
case wasm::AddSatUVecI16x8:
case wasm::MulVecI16x8:
case wasm::MinSVecI16x8:
case wasm::MinUVecI16x8:
case wasm::MaxSVecI16x8:
case wasm::MaxUVecI16x8:
case wasm::AddVecI32x4:
case wasm::MulVecI32x4:
case wasm::MinSVecI32x4:
case wasm::MinUVecI32x4:
case wasm::MaxSVecI32x4:
case wasm::MaxUVecI32x4:
case wasm::AddVecI64x2:
case wasm::MulVecI64x2:
case wasm::AddVecF32x4:
case wasm::MulVecF32x4:
case wasm::MinVecF32x4:
case wasm::MaxVecF32x4:
case wasm::PMinVecF32x4:
case wasm::PMaxVecF32x4:
case wasm::AddVecF64x2:
case wasm::MulVecF64x2:
case wasm::MinVecF64x2:
case wasm::MaxVecF64x2:
case wasm::PMinVecF64x2:
case wasm::PMaxVecF64x2:
return op;
// Relarational SIMD operators
case wasm::LtSVecI8x16:
return wasm::GtSVecI8x16;
case wasm::LtUVecI8x16:
return wasm::GtUVecI8x16;
case wasm::GtSVecI8x16:
return wasm::LtSVecI8x16;
case wasm::GtUVecI8x16:
return wasm::LtUVecI8x16;
case wasm::LeSVecI8x16:
return wasm::GeSVecI8x16;
case wasm::LeUVecI8x16:
return wasm::GeUVecI8x16;
case wasm::GeSVecI8x16:
return wasm::LeSVecI8x16;
case wasm::GeUVecI8x16:
return wasm::LeUVecI8x16;
case wasm::LtSVecI16x8:
return wasm::GtSVecI16x8;
case wasm::LtUVecI16x8:
return wasm::GtUVecI16x8;
case wasm::GtSVecI16x8:
return wasm::LtSVecI16x8;
case wasm::GtUVecI16x8:
return wasm::LtUVecI16x8;
case wasm::LeSVecI16x8:
return wasm::GeSVecI16x8;
case wasm::LeUVecI16x8:
return wasm::GeUVecI16x8;
case wasm::GeSVecI16x8:
return wasm::LeSVecI16x8;
case wasm::GeUVecI16x8:
return wasm::LeUVecI16x8;
case wasm::LtSVecI32x4:
return wasm::GtSVecI32x4;
case wasm::LtUVecI32x4:
return wasm::GtUVecI32x4;
case wasm::GtSVecI32x4:
return wasm::LtSVecI32x4;
case wasm::GtUVecI32x4:
return wasm::LtUVecI32x4;
case wasm::LeSVecI32x4:
return wasm::GeSVecI32x4;
case wasm::LeUVecI32x4:
return wasm::GeUVecI32x4;
case wasm::GeSVecI32x4:
return wasm::LeSVecI32x4;
case wasm::GeUVecI32x4:
return wasm::LeUVecI32x4;
case wasm::LtVecF32x4:
return wasm::GtVecF32x4;
case wasm::GtVecF32x4:
return wasm::LtVecF32x4;
case wasm::LeVecF32x4:
return wasm::GeVecF32x4;
case wasm::GeVecF32x4:
return wasm::LeVecF32x4;
case wasm::LtVecF64x2:
return wasm::GtVecF64x2;
case wasm::GtVecF64x2:
return wasm::LtVecF64x2;
case wasm::LeVecF64x2:
return wasm::GeVecF64x2;
case wasm::GeVecF64x2:
return wasm::LeVecF64x2;
default:
return boost::none;
}
}
inline bool is_commutative(wasm::BinaryOp op) {
return swapped_binary_op(op).has_value();
}
} // namespace kyut
#endif // INCLUDE_kyut_Commutativity_hpp
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define INCLUDE_kyut_methods_FunctionOrdering_hpp #define INCLUDE_kyut_methods_FunctionOrdering_hpp
#include "../Ordering.hpp" #include "../Ordering.hpp"
#include "wasm.h" #include "../wasm-ext/Compare.hpp"
namespace kyut { namespace kyut {
class CircularBitStreamReader; class CircularBitStreamReader;
...@@ -24,7 +24,7 @@ namespace kyut::methods::function_ordering { ...@@ -24,7 +24,7 @@ namespace kyut::methods::function_ordering {
start, start,
end, end,
[](const auto& a, const auto& b) { [](const auto& a, const auto& b) {
return a->name < b->name; // TODO: ordered by body return *a->body < *b->body;
}); });
return size_bits; return size_bits;
...@@ -44,7 +44,7 @@ namespace kyut::methods::function_ordering { ...@@ -44,7 +44,7 @@ namespace kyut::methods::function_ordering {
start, start,
end, end,
[](const auto& a, const auto& b) { [](const auto& a, const auto& b) {
return a->name < b->name; // TODO: ordered by body return *a->body < *b->body;
}); });
return size_bits; return size_bits;
......
#ifndef INCLUDE_kyut_wasm_ext_Compare_inl_hpp
#define INCLUDE_kyut_wasm_ext_Compare_inl_hpp
#include "Compare.hpp"
#include <algorithm>
#include "../Commutativity.hpp"
namespace wasm {
inline bool operator<(const Literal& a, const Literal& b) {
return std::less<Literal>{}(a, b);
}
inline bool operator<(const Expression& a, const Expression& b) {
static_assert(Expression::NumExpressionIds == 49);
// Expr* -> optional<Expr&>
constexpr auto opt = [](const Expression* p) -> boost::optional<const Expression&> {
if (p == nullptr) {
return boost::none;
}
return *p;
};
if (a._id != b._id) {
return a._id < b._id;
}
if (a.type != b.type) {
return a.type < b.type;
}
switch (a._id) {
case Expression::Id::BlockId: {
const auto& x = *a.cast<Block>();
const auto& y = *b.cast<Block>();
return x.list < y.list;
}
case Expression::Id::IfId: {
const auto& x = *a.cast<If>();
const auto& y = *b.cast<If>();
return std::forward_as_tuple(*x.condition, *x.ifTrue, opt(x.ifFalse)) <
std::forward_as_tuple(*y.condition, *y.ifTrue, opt(y.ifFalse));
}
case Expression::Id::LoopId: {
const auto& x = *a.cast<Loop>();
const auto& y = *b.cast<Loop>();
return *x.body < *y.body;
}
case Expression::Id::BreakId: {
const auto& x = *a.cast<Break>();
const auto& y = *b.cast<Break>();
return std::forward_as_tuple(opt(x.value), opt(x.condition)) <
std::forward_as_tuple(opt(y.value), opt(y.condition));
}
case Expression::Id::SwitchId: {
const auto& x = *a.cast<Switch>();
const auto& y = *b.cast<Switch>();
return std::forward_as_tuple(opt(x.value), opt(x.condition)) <
std::forward_as_tuple(opt(y.value), opt(y.condition));
}
case Expression::Id::CallId: {
const auto& x = *a.cast<Call>();
const auto& y = *b.cast<Call>();
return x.operands < y.operands;
}
case Expression::Id::CallIndirectId: {
const auto& x = *a.cast<CallIndirect>();
const auto& y = *b.cast<CallIndirect>();
return std::forward_as_tuple(x.operands, *x.target) < std::forward_as_tuple(y.operands, *y.target);
}
case Expression::Id::LocalGetId: {
const auto& x = *a.cast<LocalGet>();
const auto& y = *b.cast<LocalGet>();
return x.index < y.index;
}
case Expression::Id::LocalSetId: {
const auto& x = *a.cast<LocalSet>();
const auto& y = *b.cast<LocalSet>();
return std::tie(x.index, *x.value) < std::tie(y.index, *y.value);
}
case Expression::Id::GlobalGetId: {
const auto& x = *a.cast<GlobalGet>();
const auto& y = *b.cast<GlobalGet>();
return x.name < y.name;
}
case Expression::Id::GlobalSetId: {
const auto& x = *a.cast<GlobalSet>();
const auto& y = *b.cast<GlobalSet>();
return std::tie(*x.value, x.name) < std::tie(*y.value, y.name);
}
case Expression::Id::LoadId: {
const auto& x = *a.cast<Load>();
const auto& y = *b.cast<Load>();
return *x.ptr < *y.ptr;
}
case Expression::Id::StoreId: {
const auto& x = *a.cast<Store>();
const auto& y = *b.cast<Store>();
return std::tie(*x.ptr, *x.value) < std::tie(*y.ptr, *y.value);
}
case Expression::Id::ConstId: {
const auto& x = *a.cast<Const>();
const auto& y = *b.cast<Const>();
return x.value < y.value;
}
case Expression::Id::UnaryId: {
const auto& x = *a.cast<Unary>();
const auto& y = *b.cast<Unary>();
return std::tie(x.op, *x.value) < std::tie(y.op, *y.value);
}
case Expression::Id::BinaryId: {
const auto& x = *a.cast<Binary>();
const auto& y = *b.cast<Binary>();
constexpr auto normalize = [](const wasm::Binary& node)
-> std::tuple<wasm::BinaryOp, wasm::Expression&, wasm::Expression&> {
if (!kyut::is_commutative(node.op)) {
// Non-commutative binary expr
return {node.op, *node.left, *node.right};
}
// Commutative binary expr
if (*node.left < *node.right) {
return {node.op, *node.left, *node.right};
} else {
return {*kyut::swapped_binary_op(node.op), *node.right, *node.left};
}
};
return normalize(x) < normalize(y);
}
case Expression::Id::SelectId: {
const auto& x = *a.cast<Select>();
const auto& y = *b.cast<Select>();
return std::tie(*x.ifTrue, *x.ifFalse, *x.condition) < std::tie(*y.ifTrue, *y.ifFalse, *y.condition);
}
case Expression::Id::DropId: {
const auto& x = *a.cast<Drop>();
const auto& y = *b.cast<Drop>();
return *x.value < *y.value;
}
case Expression::Id::ReturnId: {
const auto& x = *a.cast<Return>();
const auto& y = *b.cast<Return>();
return opt(x.value) < opt(y.value);
}
case Expression::Id::MemorySizeId: {
[[maybe_unused]] const auto& x = *a.cast<MemorySize>();
[[maybe_unused]] const auto& y = *b.cast<MemorySize>();
return false;
}
case Expression::Id::MemoryGrowId: {
const auto& x = *a.cast<MemoryGrow>();
const auto& y = *b.cast<MemoryGrow>();
return opt(x.delta) < opt(y.delta);
}
case Expression::Id::NopId: {
[[maybe_unused]] const auto& x = *a.cast<Nop>();
[[maybe_unused]] const auto& y = *b.cast<Nop>();
return false;
}
case Expression::Id::UnreachableId: {
[[maybe_unused]] const auto& x = *a.cast<Unreachable>();
[[maybe_unused]] const auto& y = *b.cast<Unreachable>();
return false;
}
case Expression::Id::AtomicRMWId: {
const auto& x = *a.cast<AtomicRMW>();
const auto& y = *b.cast<AtomicRMW>();
return std::tie(x.op, x.bytes, x.offset, *x.ptr, *x.value) <
std::tie(y.op, y.bytes, y.offset, *y.ptr, *y.value);
}
case Expression::Id::AtomicCmpxchgId: {
const auto& x = *a.cast<AtomicCmpxchg>();
const auto& y = *b.cast<AtomicCmpxchg>();
return std::tie(x.bytes, x.offset, *x.ptr, *x.expected, *x.replacement) <
std::tie(y.bytes, y.offset, *y.ptr, *y.expected, *y.replacement);
}
case Expression::Id::AtomicWaitId: {
const auto& x = *a.cast<AtomicWait>();
const auto& y = *b.cast<AtomicWait>();
return std::tie(x.offset, *x.ptr, *x.expected, *x.timeout, x.expectedType) <
std::tie(y.offset, *y.ptr, *y.expected, *y.timeout, y.expectedType);
}
case Expression::Id::AtomicNotifyId: {
const auto& x = *a.cast<AtomicNotify>();
const auto& y = *b.cast<AtomicNotify>();
return std::tie(x.offset, *x.ptr, *x.notifyCount) < std::tie(y.offset, *y.ptr, *y.notifyCount);
}
case Expression::Id::AtomicFenceId: {
[[maybe_unused]] const auto& x = *a.cast<AtomicFence>();
[[maybe_unused]] const auto& y = *b.cast<AtomicFence>();
return false;
}
case Expression::Id::SIMDExtractId: {
const auto& x = *a.cast<SIMDExtract>();
const auto& y = *b.cast<SIMDExtract>();
return std::tie(x.op, *x.vec, x.index) < std::tie(y.op, *y.vec, y.index);
}
case Expression::Id::SIMDReplaceId: {
const auto& x = *a.cast<SIMDReplace>();
const auto& y = *b.cast<SIMDReplace>();
return std::tie(x.op, *x.vec, x.index, *x.value) < std::tie(y.op, *y.vec, y.index, *y.value);
}
case Expression::Id::SIMDShuffleId: {
const auto& x = *a.cast<SIMDShuffle>();
const auto& y = *b.cast<SIMDShuffle>();
return std::tie(*x.left, *x.right, x.mask) < std::tie(*y.left, *y.right, y.mask);
}
case Expression::Id::SIMDTernaryId: {
const auto& x = *a.cast<SIMDTernary>();
const auto& y = *b.cast<SIMDTernary>();
return std::tie(*x.a, *x.b, *x.c) < std::tie(*y.a, *y.b, *y.c);
}
case Expression::Id::SIMDShiftId: {
const auto& x = *a.cast<SIMDShift>();
const auto& y = *b.cast<SIMDShift>();
return std::tie(x.op, *x.vec, *x.shift) < std::tie(y.op, *y.vec, *y.shift);
}
case Expression::Id::SIMDLoadId: {
const auto& x = *a.cast<SIMDLoad>();
const auto& y = *b.cast<SIMDLoad>();
return std::tie(x.op, x.offset, x.align, *x.ptr) < std::tie(y.op, y.offset, y.align, *y.ptr);
}
case Expression::Id::MemoryInitId: {
const auto& x = *a.cast<MemoryInit>();
const auto& y = *b.cast<MemoryInit>();
return std::tie(x.segment, *x.dest, *x.offset, *x.size) < std::tie(y.segment, *y.dest, *y.offset, *y.size);
}
case Expression::Id::DataDropId: {
const auto& x = *a.cast<DataDrop>();
const auto& y = *b.cast<DataDrop>();
return x.segment < y.segment;
}
case Expression::Id::MemoryCopyId: {
const auto& x = *a.cast<MemoryCopy>();
const auto& y = *b.cast<MemoryCopy>();
return std::tie(*x.dest, *x.source, *x.size) < std::tie(*y.dest, *y.source, *y.size);
}
case Expression::Id::MemoryFillId: {
const auto& x = *a.cast<MemoryFill>();
const auto& y = *b.cast<MemoryFill>();
return std::tie(*x.dest, *x.value, *x.size) < std::tie(*y.dest, *y.value, *y.size);
}
case Expression::Id::PopId: {
[[maybe_unused]] const auto& x = *a.cast<Pop>();
[[maybe_unused]] const auto& y = *b.cast<Pop>();
return false;
}
case Expression::Id::RefNullId: {
[[maybe_unused]] const auto& x = *a.cast<RefNull>();
[[maybe_unused]] const auto& y = *b.cast<RefNull>();
return false;
}
case Expression::Id::RefIsNullId: {
const auto& x = *a.cast<RefIsNull>();
const auto& y = *b.cast<RefIsNull>();
return *x.value < *y.value;
}
case Expression::Id::RefFuncId: {
[[maybe_unused]] const auto& x = *a.cast<RefFunc>();
[[maybe_unused]] const auto& y = *b.cast<RefFunc>();
return false;
}
case Expression::Id::TryId: {
const auto& x = *a.cast<Try>();
const auto& y = *b.cast<Try>();
return std::forward_as_tuple(opt(x.body), opt(x.catchBody)) < std::forward_as_tuple(opt(y.body), opt(y.catchBody));
}
case Expression::Id::ThrowId: {
const auto& x = *a.cast<Throw>();
const auto& y = *b.cast<Throw>();
return x.operands < y.operands;
}
case Expression::Id::RethrowId: {
const auto& x = *a.cast<Rethrow>();
const auto& y = *b.cast<Rethrow>();
return std::forward_as_tuple(opt(x.exnref)) < std::forward_as_tuple(opt(y.exnref));
}
case Expression::Id::BrOnExnId: {
const auto& x = *a.cast<BrOnExn>();
const auto& y = *b.cast<BrOnExn>();
return std::forward_as_tuple(opt(x.exnref)) < std::forward_as_tuple(opt(y.exnref));
}
case Expression::Id::TupleMakeId: {
const auto& x = *a.cast<TupleMake>();
const auto& y = *b.cast<TupleMake>();
return x.operands < y.operands;
}
case Expression::Id::TupleExtractId: {
const auto& x = *a.cast<TupleExtract>();
const auto& y = *b.cast<TupleExtract>();
return std::tie(*x.tuple, x.index) < std::tie(*y.tuple, y.index);
}
default: {
WASM_UNREACHABLE("unknown expression id");
}
}
}
inline bool operator<(const ExpressionList& a, const ExpressionList& b) {
return std::lexicographical_compare(
std::begin(a),
std::end(a),
std::begin(b),
std::end(b),
[](const Expression* a, const Expression* b) {
return *a < *b;
});
}
} // namespace wasm
#endif // INCLUDE_kyut_wasm_ext_Compare_inl_hpp
#ifndef INCLUDE_kyut_wasm_ext_Compare_hpp
#define INCLUDE_kyut_wasm_ext_Compare_hpp
#include "wasm.h"
namespace wasm {
bool operator<(const Literal& a, const Literal& b);
bool operator<(const Expression& a, const Expression& b);
bool operator<(const ExpressionList& a, const ExpressionList& b);
} // namespace wasm
#include "Compare-inl.hpp"
#endif // INCLUDE_kyut_wasm_ext_Compare_hpp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment