Commit 704cb19f authored by nagayama15's avatar nagayama15

🚧 Implement the operand swappers

parent dc4af29c
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
(i32.const 1) (i32.const 1)
(i32.const 2) (i32.const 2)
) )
(i32.add (i32.lt_s
(i32.const 4) (i32.const 4)
(i32.const 3) (i32.const 3)
) )
......
...@@ -51,6 +51,265 @@ namespace kyut::watermarker { ...@@ -51,6 +51,265 @@ namespace kyut::watermarker {
write = 2, write = 2,
}; };
std::optional<wasm::BinaryOp> getSwappedPredicate(wasm::BinaryOp op) {
switch (op) {
// Commutative instructions
case wasm::AddInt32:
case wasm::MulInt32:
case wasm::AndInt32:
case wasm::OrInt32:
case wasm::XorInt32:
case wasm::EqInt32:
case wasm::NeInt32:
case wasm::AddInt64:
case wasm::MulInt64:
case wasm::AndInt64:
case wasm::OrInt64:
case wasm::XorInt64:
case wasm::EqInt64:
case wasm::NeInt64:
case wasm::AddFloat32:
case wasm::MulFloat32:
case wasm::MinFloat32:
case wasm::MaxFloat32:
case wasm::EqFloat32:
case wasm::NeFloat32:
case wasm::AddFloat64:
case wasm::MulFloat64:
case wasm::MinFloat64:
case wasm::MaxFloat64:
case wasm::EqFloat64:
case wasm::NeFloat64:
return op;
// Comparators
case wasm::LtSInt32:
return wasm::GtSInt32;
case wasm::LtUInt32:
return wasm::GtUInt32;
case wasm::LeSInt32:
return wasm::GeSInt32;
case wasm::LeUInt32:
return wasm::GeUInt32;
case wasm::GtSInt32:
return wasm::LtSInt32;
case wasm::GtUInt32:
return wasm::LtUInt32;
case wasm::GeSInt32:
return wasm::LeSInt32;
case wasm::GeUInt32:
return wasm::LeUInt32;
case wasm::LtSInt64:
return wasm::GtSInt64;
case wasm::LtUInt64:
return wasm::GtUInt64;
case wasm::LeSInt64:
return wasm::GeSInt64;
case wasm::LeUInt64:
return wasm::GeUInt64;
case wasm::GtSInt64:
return wasm::LtSInt64;
case wasm::GtUInt64:
return wasm::LtUInt64;
case wasm::GeSInt64:
return wasm::LeSInt64;
case wasm::GeUInt64:
return wasm::LeUInt64;
case wasm::LtFloat32:
return wasm::GtFloat32;
case wasm::LeFloat32:
return wasm::GeFloat32;
case wasm::GtFloat32:
return wasm::LtFloat32;
case wasm::GeFloat32:
return wasm::LeFloat32;
case wasm::LtFloat64:
return wasm::GtFloat64;
case wasm::LeFloat64:
return wasm::GeFloat64;
case wasm::GtFloat64:
return wasm::LtFloat64;
case wasm::GeFloat64:
return wasm::LeFloat64;
// Noncommutative instructions
case wasm::SubInt32:
case wasm::DivSInt32:
case wasm::DivUInt32:
case wasm::RemSInt32:
case wasm::RemUInt32:
case wasm::ShlInt32:
case wasm::ShrUInt32:
case wasm::ShrSInt32:
case wasm::RotLInt32:
case wasm::RotRInt32:
case wasm::SubInt64:
case wasm::DivSInt64:
case wasm::DivUInt64:
case wasm::RemSInt64:
case wasm::RemUInt64:
case wasm::ShlInt64:
case wasm::ShrUInt64:
case wasm::ShrSInt64:
case wasm::RotLInt64:
case wasm::RotRInt64:
case wasm::SubFloat32:
case wasm::DivFloat32:
case wasm::CopySignFloat32:
case wasm::SubFloat64:
case wasm::DivFloat64:
case wasm::CopySignFloat64:
return std::nullopt;
// Commutative SIMD instructions
case wasm::EqVecI8x16:
case wasm::NeVecI8x16:
case wasm::EqVecI16x8:
case wasm::NeVecI16x8:
case wasm::EqVecI32x4:
case wasm::NeVecI32x4:
case wasm::EqVecF32x4:
case wasm::NeVecF32x4:
case wasm::EqVecF64x2:
case wasm::NeVecF64x2:
case wasm::AndVec128:
case wasm::OrVec128:
case wasm::XorVec128:
case wasm::AddVecI8x16:
case wasm::AddSatSVecI8x16:
case wasm::AddSatUVecI8x16:
case wasm::MulVecI8x16:
case wasm::AddVecI16x8:
case wasm::AddSatSVecI16x8:
case wasm::AddSatUVecI16x8:
case wasm::MulVecI16x8:
case wasm::AddVecI32x4:
case wasm::MulVecI32x4:
case wasm::AddVecI64x2:
case wasm::AddVecF32x4:
case wasm::MulVecF32x4:
case wasm::MinVecF32x4:
case wasm::MaxVecF32x4:
case wasm::AddVecF64x2:
case wasm::MulVecF64x2:
case wasm::MinVecF64x2:
case wasm::MaxVecF64x2:
return op;
// SIMD comparators
case wasm::LtSVecI8x16:
return wasm::GtSVecI8x16;
case wasm::LtUVecI8x16:
return wasm::GtUVecI8x16;
case wasm::GtSVecI8x16:
return wasm::LtSVecI8x16;
case wasm::GtUVecI8x16:
return wasm::LtUVecI8x16;
case wasm::LeSVecI8x16:
return wasm::GeSVecI8x16;
case wasm::LeUVecI8x16:
return wasm::GeUVecI8x16;
case wasm::GeSVecI8x16:
return wasm::LeSVecI8x16;
case wasm::GeUVecI8x16:
return wasm::LeUVecI8x16;
case wasm::LtSVecI16x8:
return wasm::GtSVecI16x8;
case wasm::LtUVecI16x8:
return wasm::GtUVecI16x8;
case wasm::GtSVecI16x8:
return wasm::LtSVecI16x8;
case wasm::GtUVecI16x8:
return wasm::LtUVecI16x8;
case wasm::LeSVecI16x8:
return wasm::GeSVecI16x8;
case wasm::LeUVecI16x8:
return wasm::GeUVecI16x8;
case wasm::GeSVecI16x8:
return wasm::LeSVecI16x8;
case wasm::GeUVecI16x8:
return wasm::LeUVecI16x8;
case wasm::LtSVecI32x4:
return wasm::GtSVecI32x4;
case wasm::LtUVecI32x4:
return wasm::GtUVecI32x4;
case wasm::GtSVecI32x4:
return wasm::LtSVecI32x4;
case wasm::GtUVecI32x4:
return wasm::LtUVecI32x4;
case wasm::LeSVecI32x4:
return wasm::GeSVecI32x4;
case wasm::LeUVecI32x4:
return wasm::GeUVecI32x4;
case wasm::GeSVecI32x4:
return wasm::LeSVecI32x4;
case wasm::GeUVecI32x4:
return wasm::LeUVecI32x4;
case wasm::LtVecF32x4:
return wasm::GtVecF32x4;
case wasm::GtVecF32x4:
return wasm::LtVecF32x4;
case wasm::LeVecF32x4:
return wasm::GeVecF32x4;
case wasm::GeVecF32x4:
return wasm::LeVecF32x4;
case wasm::LtVecF64x2:
return wasm::GtVecF64x2;
case wasm::GtVecF64x2:
return wasm::LtVecF64x2;
case wasm::LeVecF64x2:
return wasm::GeVecF64x2;
case wasm::GeVecF64x2:
return wasm::LeVecF64x2;
// Noncommutative SIMD instructions
case wasm::SubVecI8x16:
case wasm::SubSatSVecI8x16:
case wasm::SubSatUVecI8x16:
case wasm::SubVecI16x8:
case wasm::SubSatSVecI16x8:
case wasm::SubSatUVecI16x8:
case wasm::SubVecI32x4:
case wasm::SubVecI64x2:
case wasm::SubVecF32x4:
case wasm::DivVecF32x4:
case wasm::SubVecF64x2:
case wasm::DivVecF64x2:
return std::nullopt;
default:
WASM_UNREACHABLE();
}
}
bool isCommutative(wasm::BinaryOp op) {
return getSwappedPredicate(op).has_value();
}
bool swapOperands(wasm::Binary &expr) {
if (const auto newOp = getSwappedPredicate(expr.op)) {
expr.op = *newOp;
std::swap(expr.left, expr.right);
return true;
}
return false;
}
// Watermark embedder // Watermark embedder
SideEffect embedExpression(wasm::Expression *expr, CircularBitStreamReader &stream); SideEffect embedExpression(wasm::Expression *expr, CircularBitStreamReader &stream);
...@@ -148,8 +407,18 @@ namespace kyut::watermarker { ...@@ -148,8 +407,18 @@ namespace kyut::watermarker {
} }
SideEffect embedBinary(wasm::Binary &expr, CircularBitStreamReader &stream) { SideEffect embedBinary(wasm::Binary &expr, CircularBitStreamReader &stream) {
if (!isCommutative(expr.op)) {
// The operands of noncommutative instructions cannot be swapped
return (std::max)(embedExpression(expr.left, stream), embedExpression(expr.right, stream));
}
// TODO: implement watermarking // TODO: implement watermarking
return (std::max)(embedExpression(expr.left, stream), embedExpression(expr.right, stream)); const auto leftSideEffect = embedExpression(expr.left, stream);
const auto rightSideEffect = embedExpression(expr.right, stream);
(void)swapOperands;
return (std::max)(leftSideEffect, rightSideEffect);
} }
SideEffect embedSelect(wasm::Select &expr, CircularBitStreamReader &stream) { SideEffect embedSelect(wasm::Select &expr, CircularBitStreamReader &stream) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment