Commit 4bf1e9f1 authored by nagayama15's avatar nagayama15

Merge branch 'method/operand-swapping' into 'master'

Implement operand swapping method

See merge request !3
parents b880f2e2 b241ef19
......@@ -7,10 +7,10 @@
int main(int argc, char *argv[]) {
try {
if (argc != 3) {
if (argc != 4) {
fmt::print(std::cerr,
"WebAssembly digital watermarker.\n"
"usage: kyut <input file> <watermark>\n");
"usage: kyut <input file> <watermark> <output file>\n");
return 1;
}
......@@ -18,10 +18,16 @@ int main(int argc, char *argv[]) {
wasm::Module module;
wasm::ModuleReader{}.read(argv[1], module);
kyut::pass::embedWatermarkOperandSwapping(module);
const auto stream = kyut::CircularBitStream::from_string(argv[2]);
kyut::pass::embedWatermarkOperandSwapping(module, *stream);
wasm::ModuleWriter{}.writeText(module, argv[3]);
} catch (wasm::ParseException &e) {
fmt::print(std::cerr, "parse error\n");
e.dump(std::cerr);
return 1;
} catch (const std::exception &e) {
fmt::print(std::cerr, "error: {}\n", e.what());
......
#ifndef INCLUDE_kyut_circular_bit_stream_hpp
#define INCLUDE_kyut_circular_bit_stream_hpp
#include <cassert>
#include <cstdint>
#include <memory>
#include <string_view>
#include <vector>
namespace kyut {
class CircularBitStream {
public:
explicit CircularBitStream(const std::uint8_t *data, std::size_t size_bytes) noexcept
: bytes_(data, data + size_bytes)
, pos_bits_(0) {
assert(data != nullptr || size_bytes == 0);
}
CircularBitStream(const CircularBitStream &) = delete;
CircularBitStream(CircularBitStream &&) = delete;
CircularBitStream &operator=(const CircularBitStream &) = delete;
CircularBitStream &operator=(CircularBitStream &&) = delete;
~CircularBitStream() noexcept = default;
[[nodiscard]] const std::vector<std::uint8_t> &bytes() const noexcept {
return bytes_;
}
[[nodiscard]] std::size_t size_bytes() const noexcept {
return bytes_.size();
}
[[nodiscard]] std::size_t size_bits() const noexcept {
return size_bytes() * 8;
}
[[nodiscard]] std::size_t pos_bits() const noexcept {
return pos_bits_;
}
bool read_bit() noexcept {
if (size_bits() == 0) {
return false;
}
const auto value = (bytes_[pos_bits_ >> 3] >> (pos_bits_ & 0x7)) & 0x1;
if (++pos_bits_ >= size_bits()) {
pos_bits_ = 0;
}
return value;
}
std::uint64_t read(std::size_t size_bits) {
assert(size_bits <= 64);
std::uint64_t value = 0;
for (size_t i = 0; i < size_bits; i++) {
if (read_bit()) {
value |= std::uint64_t{1} << i;
}
}
return value;
}
static std::unique_ptr<CircularBitStream> from_string(std::string_view s) {
return std::make_unique<CircularBitStream>(reinterpret_cast<const std::uint8_t *>(s.data()), s.length());
}
private:
std::vector<std::uint8_t> bytes_;
std::size_t pos_bits_;
};
} // namespace kyut
#endif // INCLUDE_kyut_circular_bit_stream_hpp
#ifndef INCLUDE_kyut_commutativity_hpp
#define INCLUDE_kyut_commutativity_hpp
#include <wasm.h>
namespace kyut {
constexpr bool isCommutative(wasm::BinaryOp op) noexcept {
switch (op) {
case wasm::BinaryOp::AddInt32:
case wasm::BinaryOp::MulInt32:
case wasm::BinaryOp::AndInt32:
case wasm::BinaryOp::OrInt32:
case wasm::BinaryOp::XorInt32:
case wasm::BinaryOp::EqInt32:
case wasm::BinaryOp::NeInt32:
case wasm::BinaryOp::AddInt64:
case wasm::BinaryOp::MulInt64:
case wasm::BinaryOp::AndInt64:
case wasm::BinaryOp::OrInt64:
case wasm::BinaryOp::XorInt64:
case wasm::BinaryOp::EqInt64:
case wasm::BinaryOp::NeInt64:
case wasm::BinaryOp::AddFloat32:
case wasm::BinaryOp::MulFloat32:
case wasm::BinaryOp::MinFloat32:
case wasm::BinaryOp::MaxFloat32:
case wasm::BinaryOp::EqFloat32:
case wasm::BinaryOp::NeFloat32:
case wasm::BinaryOp::AddFloat64:
case wasm::BinaryOp::MulFloat64:
case wasm::BinaryOp::MinFloat64:
case wasm::BinaryOp::MaxFloat64:
case wasm::BinaryOp::EqFloat64:
case wasm::BinaryOp::NeFloat64:
return true;
// Relational operators
case wasm::BinaryOp::LtSInt32:
case wasm::BinaryOp::LtUInt32:
case wasm::BinaryOp::LeSInt32:
case wasm::BinaryOp::LeUInt32:
case wasm::BinaryOp::GtSInt32:
case wasm::BinaryOp::GtUInt32:
case wasm::BinaryOp::GeSInt32:
case wasm::BinaryOp::GeUInt32:
case wasm::BinaryOp::LtSInt64:
case wasm::BinaryOp::LtUInt64:
case wasm::BinaryOp::LeSInt64:
case wasm::BinaryOp::LeUInt64:
case wasm::BinaryOp::GtSInt64:
case wasm::BinaryOp::GtUInt64:
case wasm::BinaryOp::GeSInt64:
case wasm::BinaryOp::GeUInt64:
case wasm::BinaryOp::LtFloat32:
case wasm::BinaryOp::LeFloat32:
case wasm::BinaryOp::GtFloat32:
case wasm::BinaryOp::GeFloat32:
return true;
// TODO: SIMD operators
default:
return false;
}
}
inline bool swapOperands(wasm::Binary &expr) noexcept {
if (!isCommutative(expr.op)) {
return false;
}
// Invert relational operator
expr.op = [](wasm::BinaryOp op) noexcept {
switch (op) {
case wasm::BinaryOp::LtSInt32:
return wasm::BinaryOp::GtSInt32;
case wasm::BinaryOp::LtUInt32:
return wasm::BinaryOp::GtUInt32;
case wasm::BinaryOp::LeSInt32:
return wasm::BinaryOp::GeSInt32;
case wasm::BinaryOp::LeUInt32:
return wasm::BinaryOp::GeUInt32;
case wasm::BinaryOp::GtSInt32:
return wasm::BinaryOp::LtSInt32;
case wasm::BinaryOp::GtUInt32:
return wasm::BinaryOp::LtUInt32;
case wasm::BinaryOp::GeSInt32:
return wasm::BinaryOp::LeSInt32;
case wasm::BinaryOp::GeUInt32:
return wasm::BinaryOp::LeUInt32;
case wasm::BinaryOp::LtSInt64:
return wasm::BinaryOp::GtSInt64;
case wasm::BinaryOp::LtUInt64:
return wasm::BinaryOp::GtUInt64;
case wasm::BinaryOp::LeSInt64:
return wasm::BinaryOp::GeSInt64;
case wasm::BinaryOp::LeUInt64:
return wasm::BinaryOp::GeUInt64;
case wasm::BinaryOp::GtSInt64:
return wasm::BinaryOp::LtSInt64;
case wasm::BinaryOp::GtUInt64:
return wasm::BinaryOp::LtUInt64;
case wasm::BinaryOp::GeSInt64:
return wasm::BinaryOp::LeSInt64;
case wasm::BinaryOp::GeUInt64:
return wasm::BinaryOp::LeUInt64;
case wasm::BinaryOp::LtFloat32:
return wasm::BinaryOp::GtFloat32;
case wasm::BinaryOp::LeFloat32:
return wasm::BinaryOp::GeFloat32;
case wasm::BinaryOp::GtFloat32:
return wasm::BinaryOp::LtFloat32;
case wasm::BinaryOp::GeFloat32:
return wasm::BinaryOp::LeFloat32;
// TODO: SIMD operations
default:
return op;
}
}
(expr.op);
// Swap operands
std::swap(expr.left, expr.right);
return true;
}
} // namespace kyut
#endif // INCLUDE_kyut_commutativity_hpp
......@@ -2,11 +2,15 @@
#include <pass.h>
#include "../commutativity.hpp"
#include "../comparison.hpp"
namespace kyut::pass {
class OperandSwappingWatermarkingVisitor
: public wasm::OverriddenVisitor<OperandSwappingWatermarkingVisitor, bool> {
public:
explicit OperandSwappingWatermarkingVisitor() = default;
explicit OperandSwappingWatermarkingVisitor(CircularBitStream &stream)
: stream_(stream) {}
OperandSwappingWatermarkingVisitor(const OperandSwappingWatermarkingVisitor &) = delete;
OperandSwappingWatermarkingVisitor(OperandSwappingWatermarkingVisitor &&) = delete;
......@@ -217,6 +221,14 @@ namespace kyut::pass {
has_side_effect = visit(curr->left) || has_side_effect;
has_side_effect = visit(curr->right) || has_side_effect;
if (isCommutative(curr->op) && !has_side_effect) {
const auto bit = stream_.read_bit();
if (bit == (*curr->left < *curr->right)) {
swapOperands(*curr);
}
}
return has_side_effect;
}
......@@ -278,11 +290,15 @@ namespace kyut::pass {
bool visitModule([[maybe_unused]] wasm::Module *curr) {
WASM_UNREACHABLE();
}
private:
CircularBitStream &stream_;
};
class OperandSwappingWatermarkingPass : public wasm::Pass {
public:
explicit OperandSwappingWatermarkingPass() = default;
explicit OperandSwappingWatermarkingPass(CircularBitStream &stream)
: stream_(stream) {}
OperandSwappingWatermarkingPass(const OperandSwappingWatermarkingPass &) = delete;
OperandSwappingWatermarkingPass(OperandSwappingWatermarkingPass &&) = delete;
......@@ -297,17 +313,20 @@ namespace kyut::pass {
}
void run([[maybe_unused]] wasm::PassRunner *runner, wasm::Module *module) override {
OperandSwappingWatermarkingVisitor visitor{};
OperandSwappingWatermarkingVisitor visitor{stream_};
for (const auto &func : module->functions) {
visitor.visit(func->body);
}
}
private:
CircularBitStream &stream_;
};
void embedWatermarkOperandSwapping(wasm::Module &module) {
void embedWatermarkOperandSwapping(wasm::Module &module, CircularBitStream &stream) {
wasm::PassRunner runner{&module};
runner.add<OperandSwappingWatermarkingPass>();
runner.add<OperandSwappingWatermarkingPass>(std::ref(stream));
runner.run();
}
} // namespace kyut::pass
......@@ -3,8 +3,10 @@
#include <wasm.h>
#include "../circular_bit_stream.hpp"
namespace kyut::pass {
void embedWatermarkOperandSwapping(wasm::Module &module);
void embedWatermarkOperandSwapping(wasm::Module &module, CircularBitStream &stream);
}
#endif // INCLUDE_kyut_pass_operand_swapping_watermarker_hpp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment