Commit 6210551c authored by nagayama15's avatar nagayama15

Split commutativity function into source files

parent 4f2ec866
#ifndef INCLUDE_kyut_Commutativity_hpp
#define INCLUDE_kyut_Commutativity_hpp
#include <wasm.h>
#include <boost/optional.hpp>
#include <boost/optional/optional_io.hpp>
namespace kyut {
[[nodiscard]]
inline boost::optional<wasm::BinaryOp> getSwappedPredicate(wasm::BinaryOp op) {
static_assert(wasm::InvalidBinary == 152);
switch (op) {
// Commutative instructions
case wasm::AddInt32:
case wasm::MulInt32:
case wasm::AndInt32:
case wasm::OrInt32:
case wasm::XorInt32:
case wasm::EqInt32:
case wasm::NeInt32:
case wasm::AddInt64:
case wasm::MulInt64:
case wasm::AndInt64:
case wasm::OrInt64:
case wasm::XorInt64:
case wasm::EqInt64:
case wasm::NeInt64:
case wasm::AddFloat32:
case wasm::MulFloat32:
case wasm::MinFloat32:
case wasm::MaxFloat32:
case wasm::EqFloat32:
case wasm::NeFloat32:
case wasm::AddFloat64:
case wasm::MulFloat64:
case wasm::MinFloat64:
case wasm::MaxFloat64:
case wasm::EqFloat64:
case wasm::NeFloat64:
return op;
// Comparators
case wasm::LtSInt32:
return wasm::GtSInt32;
case wasm::LtUInt32:
return wasm::GtUInt32;
case wasm::LeSInt32:
return wasm::GeSInt32;
case wasm::LeUInt32:
return wasm::GeUInt32;
case wasm::GtSInt32:
return wasm::LtSInt32;
case wasm::GtUInt32:
return wasm::LtUInt32;
case wasm::GeSInt32:
return wasm::LeSInt32;
case wasm::GeUInt32:
return wasm::LeUInt32;
case wasm::LtSInt64:
return wasm::GtSInt64;
case wasm::LtUInt64:
return wasm::GtUInt64;
case wasm::LeSInt64:
return wasm::GeSInt64;
case wasm::LeUInt64:
return wasm::GeUInt64;
case wasm::GtSInt64:
return wasm::LtSInt64;
case wasm::GtUInt64:
return wasm::LtUInt64;
case wasm::GeSInt64:
return wasm::LeSInt64;
case wasm::GeUInt64:
return wasm::LeUInt64;
case wasm::LtFloat32:
return wasm::GtFloat32;
case wasm::LeFloat32:
return wasm::GeFloat32;
case wasm::GtFloat32:
return wasm::LtFloat32;
case wasm::GeFloat32:
return wasm::LeFloat32;
case wasm::LtFloat64:
return wasm::GtFloat64;
case wasm::LeFloat64:
return wasm::GeFloat64;
case wasm::GtFloat64:
return wasm::LtFloat64;
case wasm::GeFloat64:
return wasm::LeFloat64;
// Commutative SIMD instructions
case wasm::EqVecI8x16:
case wasm::NeVecI8x16:
case wasm::EqVecI16x8:
case wasm::NeVecI16x8:
case wasm::EqVecI32x4:
case wasm::NeVecI32x4:
case wasm::EqVecF32x4:
case wasm::NeVecF32x4:
case wasm::EqVecF64x2:
case wasm::NeVecF64x2:
case wasm::AndVec128:
case wasm::OrVec128:
case wasm::XorVec128:
case wasm::AddVecI8x16:
case wasm::AddSatSVecI8x16:
case wasm::AddSatUVecI8x16:
case wasm::MulVecI8x16:
case wasm::AddVecI16x8:
case wasm::AddSatSVecI16x8:
case wasm::AddSatUVecI16x8:
case wasm::MulVecI16x8:
case wasm::AddVecI32x4:
case wasm::MulVecI32x4:
case wasm::AddVecI64x2:
case wasm::AddVecF32x4:
case wasm::MulVecF32x4:
case wasm::MinVecF32x4:
case wasm::MaxVecF32x4:
case wasm::AddVecF64x2:
case wasm::MulVecF64x2:
case wasm::MinVecF64x2:
case wasm::MaxVecF64x2:
return op;
// SIMD comparators
case wasm::LtSVecI8x16:
return wasm::GtSVecI8x16;
case wasm::LtUVecI8x16:
return wasm::GtUVecI8x16;
case wasm::GtSVecI8x16:
return wasm::LtSVecI8x16;
case wasm::GtUVecI8x16:
return wasm::LtUVecI8x16;
case wasm::LeSVecI8x16:
return wasm::GeSVecI8x16;
case wasm::LeUVecI8x16:
return wasm::GeUVecI8x16;
case wasm::GeSVecI8x16:
return wasm::LeSVecI8x16;
case wasm::GeUVecI8x16:
return wasm::LeUVecI8x16;
case wasm::LtSVecI16x8:
return wasm::GtSVecI16x8;
case wasm::LtUVecI16x8:
return wasm::GtUVecI16x8;
case wasm::GtSVecI16x8:
return wasm::LtSVecI16x8;
case wasm::GtUVecI16x8:
return wasm::LtUVecI16x8;
case wasm::LeSVecI16x8:
return wasm::GeSVecI16x8;
case wasm::LeUVecI16x8:
return wasm::GeUVecI16x8;
case wasm::GeSVecI16x8:
return wasm::LeSVecI16x8;
case wasm::GeUVecI16x8:
return wasm::LeUVecI16x8;
case wasm::LtSVecI32x4:
return wasm::GtSVecI32x4;
case wasm::LtUVecI32x4:
return wasm::GtUVecI32x4;
case wasm::GtSVecI32x4:
return wasm::LtSVecI32x4;
case wasm::GtUVecI32x4:
return wasm::LtUVecI32x4;
case wasm::LeSVecI32x4:
return wasm::GeSVecI32x4;
case wasm::LeUVecI32x4:
return wasm::GeUVecI32x4;
case wasm::GeSVecI32x4:
return wasm::LeSVecI32x4;
case wasm::GeUVecI32x4:
return wasm::LeUVecI32x4;
case wasm::LtVecF32x4:
return wasm::GtVecF32x4;
case wasm::GtVecF32x4:
return wasm::LtVecF32x4;
case wasm::LeVecF32x4:
return wasm::GeVecF32x4;
case wasm::GeVecF32x4:
return wasm::LeVecF32x4;
case wasm::LtVecF64x2:
return wasm::GtVecF64x2;
case wasm::GtVecF64x2:
return wasm::LtVecF64x2;
case wasm::LeVecF64x2:
return wasm::GeVecF64x2;
case wasm::GeVecF64x2:
return wasm::LeVecF64x2;
default:
return boost::none;
}
}
[[nodiscard]]
inline bool isCommutative(wasm::BinaryOp op) {
return getSwappedPredicate(op).has_value();
}
} // namespace kyut
#endif // INCLUDE_kyut_Commutativity_hpp
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
#include <algorithm> #include <algorithm>
#include <boost/optional.hpp>
#include <wasm-traversal.h> #include <wasm-traversal.h>
#include "../BitStreamWriter.hpp" #include "../BitStreamWriter.hpp"
#include "../CircularBitStreamReader.hpp" #include "../CircularBitStreamReader.hpp"
#include "../Commutativity.hpp"
// Expression types // Expression types
#define EXPR_TYPES() \ #define EXPR_TYPES() \
...@@ -59,254 +58,6 @@ namespace kyut::watermarker { ...@@ -59,254 +58,6 @@ namespace kyut::watermarker {
write = 2, write = 2,
}; };
std::optional<wasm::BinaryOp> getSwappedPredicate(wasm::BinaryOp op) {
switch (op) {
// Commutative instructions
case wasm::AddInt32:
case wasm::MulInt32:
case wasm::AndInt32:
case wasm::OrInt32:
case wasm::XorInt32:
case wasm::EqInt32:
case wasm::NeInt32:
case wasm::AddInt64:
case wasm::MulInt64:
case wasm::AndInt64:
case wasm::OrInt64:
case wasm::XorInt64:
case wasm::EqInt64:
case wasm::NeInt64:
case wasm::AddFloat32:
case wasm::MulFloat32:
case wasm::MinFloat32:
case wasm::MaxFloat32:
case wasm::EqFloat32:
case wasm::NeFloat32:
case wasm::AddFloat64:
case wasm::MulFloat64:
case wasm::MinFloat64:
case wasm::MaxFloat64:
case wasm::EqFloat64:
case wasm::NeFloat64:
return op;
// Comparators
case wasm::LtSInt32:
return wasm::GtSInt32;
case wasm::LtUInt32:
return wasm::GtUInt32;
case wasm::LeSInt32:
return wasm::GeSInt32;
case wasm::LeUInt32:
return wasm::GeUInt32;
case wasm::GtSInt32:
return wasm::LtSInt32;
case wasm::GtUInt32:
return wasm::LtUInt32;
case wasm::GeSInt32:
return wasm::LeSInt32;
case wasm::GeUInt32:
return wasm::LeUInt32;
case wasm::LtSInt64:
return wasm::GtSInt64;
case wasm::LtUInt64:
return wasm::GtUInt64;
case wasm::LeSInt64:
return wasm::GeSInt64;
case wasm::LeUInt64:
return wasm::GeUInt64;
case wasm::GtSInt64:
return wasm::LtSInt64;
case wasm::GtUInt64:
return wasm::LtUInt64;
case wasm::GeSInt64:
return wasm::LeSInt64;
case wasm::GeUInt64:
return wasm::LeUInt64;
case wasm::LtFloat32:
return wasm::GtFloat32;
case wasm::LeFloat32:
return wasm::GeFloat32;
case wasm::GtFloat32:
return wasm::LtFloat32;
case wasm::GeFloat32:
return wasm::LeFloat32;
case wasm::LtFloat64:
return wasm::GtFloat64;
case wasm::LeFloat64:
return wasm::GeFloat64;
case wasm::GtFloat64:
return wasm::LtFloat64;
case wasm::GeFloat64:
return wasm::LeFloat64;
// Noncommutative instructions
case wasm::SubInt32:
case wasm::DivSInt32:
case wasm::DivUInt32:
case wasm::RemSInt32:
case wasm::RemUInt32:
case wasm::ShlInt32:
case wasm::ShrUInt32:
case wasm::ShrSInt32:
case wasm::RotLInt32:
case wasm::RotRInt32:
case wasm::SubInt64:
case wasm::DivSInt64:
case wasm::DivUInt64:
case wasm::RemSInt64:
case wasm::RemUInt64:
case wasm::ShlInt64:
case wasm::ShrUInt64:
case wasm::ShrSInt64:
case wasm::RotLInt64:
case wasm::RotRInt64:
case wasm::SubFloat32:
case wasm::DivFloat32:
case wasm::CopySignFloat32:
case wasm::SubFloat64:
case wasm::DivFloat64:
case wasm::CopySignFloat64:
return std::nullopt;
// Commutative SIMD instructions
case wasm::EqVecI8x16:
case wasm::NeVecI8x16:
case wasm::EqVecI16x8:
case wasm::NeVecI16x8:
case wasm::EqVecI32x4:
case wasm::NeVecI32x4:
case wasm::EqVecF32x4:
case wasm::NeVecF32x4:
case wasm::EqVecF64x2:
case wasm::NeVecF64x2:
case wasm::AndVec128:
case wasm::OrVec128:
case wasm::XorVec128:
case wasm::AddVecI8x16:
case wasm::AddSatSVecI8x16:
case wasm::AddSatUVecI8x16:
case wasm::MulVecI8x16:
case wasm::AddVecI16x8:
case wasm::AddSatSVecI16x8:
case wasm::AddSatUVecI16x8:
case wasm::MulVecI16x8:
case wasm::AddVecI32x4:
case wasm::MulVecI32x4:
case wasm::AddVecI64x2:
case wasm::AddVecF32x4:
case wasm::MulVecF32x4:
case wasm::MinVecF32x4:
case wasm::MaxVecF32x4:
case wasm::AddVecF64x2:
case wasm::MulVecF64x2:
case wasm::MinVecF64x2:
case wasm::MaxVecF64x2:
return op;
// SIMD comparators
case wasm::LtSVecI8x16:
return wasm::GtSVecI8x16;
case wasm::LtUVecI8x16:
return wasm::GtUVecI8x16;
case wasm::GtSVecI8x16:
return wasm::LtSVecI8x16;
case wasm::GtUVecI8x16:
return wasm::LtUVecI8x16;
case wasm::LeSVecI8x16:
return wasm::GeSVecI8x16;
case wasm::LeUVecI8x16:
return wasm::GeUVecI8x16;
case wasm::GeSVecI8x16:
return wasm::LeSVecI8x16;
case wasm::GeUVecI8x16:
return wasm::LeUVecI8x16;
case wasm::LtSVecI16x8:
return wasm::GtSVecI16x8;
case wasm::LtUVecI16x8:
return wasm::GtUVecI16x8;
case wasm::GtSVecI16x8:
return wasm::LtSVecI16x8;
case wasm::GtUVecI16x8:
return wasm::LtUVecI16x8;
case wasm::LeSVecI16x8:
return wasm::GeSVecI16x8;
case wasm::LeUVecI16x8:
return wasm::GeUVecI16x8;
case wasm::GeSVecI16x8:
return wasm::LeSVecI16x8;
case wasm::GeUVecI16x8:
return wasm::LeUVecI16x8;
case wasm::LtSVecI32x4:
return wasm::GtSVecI32x4;
case wasm::LtUVecI32x4:
return wasm::GtUVecI32x4;
case wasm::GtSVecI32x4:
return wasm::LtSVecI32x4;
case wasm::GtUVecI32x4:
return wasm::LtUVecI32x4;
case wasm::LeSVecI32x4:
return wasm::GeSVecI32x4;
case wasm::LeUVecI32x4:
return wasm::GeUVecI32x4;
case wasm::GeSVecI32x4:
return wasm::LeSVecI32x4;
case wasm::GeUVecI32x4:
return wasm::LeUVecI32x4;
case wasm::LtVecF32x4:
return wasm::GtVecF32x4;
case wasm::GtVecF32x4:
return wasm::LtVecF32x4;
case wasm::LeVecF32x4:
return wasm::GeVecF32x4;
case wasm::GeVecF32x4:
return wasm::LeVecF32x4;
case wasm::LtVecF64x2:
return wasm::GtVecF64x2;
case wasm::GtVecF64x2:
return wasm::LtVecF64x2;
case wasm::LeVecF64x2:
return wasm::GeVecF64x2;
case wasm::GeVecF64x2:
return wasm::LeVecF64x2;
// Noncommutative SIMD instructions
case wasm::SubVecI8x16:
case wasm::SubSatSVecI8x16:
case wasm::SubSatUVecI8x16:
case wasm::SubVecI16x8:
case wasm::SubSatSVecI16x8:
case wasm::SubSatUVecI16x8:
case wasm::SubVecI32x4:
case wasm::SubVecI64x2:
case wasm::SubVecF32x4:
case wasm::DivVecF32x4:
case wasm::SubVecF64x2:
case wasm::DivVecF64x2:
return std::nullopt;
default:
WASM_UNREACHABLE();
}
}
bool isCommutative(wasm::BinaryOp op) {
return getSwappedPredicate(op).has_value();
}
bool swapOperands(wasm::Binary &expr) { bool swapOperands(wasm::Binary &expr) {
if (const auto newOp = getSwappedPredicate(expr.op)) { if (const auto newOp = getSwappedPredicate(expr.op)) {
expr.op = *newOp; expr.op = *newOp;
...@@ -730,14 +481,14 @@ namespace wasm { ...@@ -730,14 +481,14 @@ namespace wasm {
// Normalize expression // Normalize expression
constexpr auto normalize = constexpr auto normalize =
[](const wasm::Binary &x) -> std::tuple<wasm::BinaryOp, wasm::Expression &, wasm::Expression &> { [](const wasm::Binary &x) -> std::tuple<wasm::BinaryOp, wasm::Expression &, wasm::Expression &> {
if (!kyut::watermarker::isCommutative(x.op)) { if (!kyut::isCommutative(x.op)) {
// Noncommutative // Noncommutative
return {x.op, *x.left, *x.right}; return {x.op, *x.left, *x.right};
} }
// Commutative // Commutative
if (*x.right < *x.left) { if (*x.right < *x.left) {
return {*kyut::watermarker::getSwappedPredicate(x.op), *x.right, *x.left}; return {*kyut::getSwappedPredicate(x.op), *x.right, *x.left};
} else { } else {
return {x.op, *x.left, *x.right}; return {x.op, *x.left, *x.right};
} }
......
...@@ -2,6 +2,7 @@ add_executable(test_kyut ...@@ -2,6 +2,7 @@ add_executable(test_kyut
test_kyut.cpp test_kyut.cpp
kyut/test_CircularBitStreamReader.cpp kyut/test_CircularBitStreamReader.cpp
kyut/test_BitStreamWriter.cpp kyut/test_BitStreamWriter.cpp
kyut/test_Commutativity.cpp
kyut/watermarker/test_FunctionOrderingWatermarker.cpp kyut/watermarker/test_FunctionOrderingWatermarker.cpp
kyut/watermarker/test_OperandSwappingWatermarker.cpp kyut/watermarker/test_OperandSwappingWatermarker.cpp
) )
......
#include <kyut/Commutativity.hpp>
#include <boost/test/unit_test.hpp>
BOOST_AUTO_TEST_SUITE(kyut)
BOOST_AUTO_TEST_SUITE(commutativity)
BOOST_AUTO_TEST_CASE(get_swapped_predicate) {
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::AddInt32), wasm::AddInt32);
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::AddInt64), wasm::AddInt64);
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::SubInt32), boost::none);
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::EqFloat32), wasm::EqFloat32);
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::NeFloat32), wasm::NeFloat32);
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::LtFloat32), wasm::GtFloat32);
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::LeFloat32), wasm::GeFloat32);
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::GtFloat32), wasm::LtFloat32);
BOOST_CHECK_EQUAL(getSwappedPredicate(wasm::GeFloat32), wasm::LeFloat32);
}
BOOST_AUTO_TEST_CASE(is_commutative) {
BOOST_CHECK_EQUAL(isCommutative(wasm::AddInt32), true);
BOOST_CHECK_EQUAL(isCommutative(wasm::AddInt64), true);
BOOST_CHECK_EQUAL(isCommutative(wasm::SubInt32), false);
BOOST_CHECK_EQUAL(isCommutative(wasm::EqFloat32), true);
BOOST_CHECK_EQUAL(isCommutative(wasm::NeFloat32), true);
BOOST_CHECK_EQUAL(isCommutative(wasm::LtFloat32), true);
BOOST_CHECK_EQUAL(isCommutative(wasm::LeFloat32), true);
BOOST_CHECK_EQUAL(isCommutative(wasm::GtFloat32), true);
BOOST_CHECK_EQUAL(isCommutative(wasm::GeFloat32), true);
}
BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE_END()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment