Commit d0443f96 authored by nagayama15's avatar nagayama15

Merge branch 'extracter-operands-swapping' into 'master'

Implement watermark extractor of operands swapping method

See merge request !6
parents 38f1c1e2 06e2f428
add_executable(kyut
add_library(elder
kyut.cpp
kyut/pass/operand_swapping_watermarker.cpp
kyut/pass/operand_swapping_extractor.cpp
)
target_link_libraries(kyut
target_link_libraries(elder
binaryen::binaryen
fmt::fmt
)
add_executable(kyut
kyut.cpp
)
target_link_libraries(kyut
elder
)
add_executable(pisn
pisn.cpp
)
target_link_libraries(pisn
elder
)
......@@ -22,7 +22,7 @@ int main(int argc, char *argv[]) {
kyut::pass::embedWatermarkOperandSwapping(module, *stream);
wasm::ModuleWriter{}.writeText(module, argv[3]);
wasm::ModuleWriter{}.write(module, argv[3]);
} catch (wasm::ParseException &e) {
fmt::print(std::cerr, "parse error\n");
e.dump(std::cerr);
......
#ifndef INCLUDE_kyut_bit_writer_hpp
#define INCLUDE_kyut_bit_writer_hpp
#include <cassert>
#include <cstdint>
#include <vector>
namespace kyut {
class BitWriter {
public:
explicit BitWriter(std::size_t reserved = 256)
: bytes_()
, pos_bits_(0) {
bytes_.reserve(reserved);
}
BitWriter(const BitWriter &) = delete;
BitWriter(BitWriter &&) = delete;
BitWriter &operator=(const BitWriter &) = delete;
BitWriter &operator=(BitWriter &&) = delete;
~BitWriter() noexcept = default;
[[nodiscard]] const std::vector<std::uint8_t> &bytes() const noexcept {
return bytes_;
}
[[nodiscard]] std::size_t size_bytes() const noexcept {
return bytes_.size();
}
[[nodiscard]] std::size_t size_bits() const noexcept {
return size_bytes() * 8 + pos_bits_ % 8;
}
void write_bit(bool bit) {
if (pos_bits_ % 8 == 0) {
bytes_.emplace_back(std::uint8_t{0x00});
}
if (bit) {
bytes_.back() |= 1 << (pos_bits_ % 8);
}
pos_bits_++;
}
void write(std::uint64_t bits, std::size_t size_bits) {
assert(size_bits < 64);
for (std::size_t i = 0; i < size_bits; i++) {
write_bit((bits & (1 << i)) != 0);
}
}
private:
std::vector<std::uint8_t> bytes_;
std::size_t pos_bits_;
};
} // namespace kyut
#endif // INCLUDE_kyut_bit_writer_hpp
#include "operand_swapping_extractor.hpp"
#include <pass.h>
#include "../commutativity.hpp"
#include "../comparison.hpp"
#include "side_effect_checker.hpp"
namespace kyut::pass {
class OperandSwappingExtractingVisitor
: public wasm::UnifiedExpressionVisitor<OperandSwappingExtractingVisitor, SideEffect> {
public:
explicit OperandSwappingExtractingVisitor(BitWriter &writer)
: writer_(writer) {}
OperandSwappingExtractingVisitor(const OperandSwappingExtractingVisitor &) = delete;
OperandSwappingExtractingVisitor(OperandSwappingExtractingVisitor &&) = delete;
OperandSwappingExtractingVisitor &operator=(const OperandSwappingExtractingVisitor &) = delete;
OperandSwappingExtractingVisitor &operator=(OperandSwappingExtractingVisitor &&) = delete;
~OperandSwappingExtractingVisitor() noexcept = default;
SideEffect visitExpression(wasm::Expression *curr) {
return SideEffectCheckingVisitor{}.visit(curr);
}
SideEffect visitBinary(wasm::Binary *curr) {
auto side_effect_left = visit(curr->left);
auto side_effect_right = visit(curr->right);
// operands can be swapped if [write(=2), none(=0)] or [read(=1), read(=1)]
auto can_swap_operands =
(static_cast<std::int32_t>(side_effect_left) + static_cast<std::int32_t>(side_effect_left) <= 2);
if (isCommutative(curr->op) && can_swap_operands) {
const auto bit = !(*curr->left < *curr->right);
writer_.write_bit(bit);
}
return (std::max)(side_effect_left, side_effect_right);
}
private:
BitWriter &writer_;
};
class OperandSwappingExtractingPass : public wasm::Pass {
public:
explicit OperandSwappingExtractingPass(BitWriter &writer)
: writer_(writer) {}
OperandSwappingExtractingPass(const OperandSwappingExtractingPass &) = delete;
OperandSwappingExtractingPass(OperandSwappingExtractingPass &&) = delete;
OperandSwappingExtractingPass &operator=(const OperandSwappingExtractingPass &) = delete;
OperandSwappingExtractingPass &operator=(OperandSwappingExtractingPass &&) = delete;
~OperandSwappingExtractingPass() noexcept = default;
bool modifiesBinaryenIR() noexcept override {
return false;
}
void run([[maybe_unused]] wasm::PassRunner *runner, wasm::Module *module) override {
OperandSwappingExtractingVisitor visitor{writer_};
for (const auto &func : module->functions) {
visitor.visit(func->body);
}
}
private:
BitWriter &writer_;
};
void extractWatermarkOperandSwapping(wasm::Module &module, BitWriter &writer) {
wasm::PassRunner runner{&module};
runner.add<OperandSwappingExtractingPass>(std::ref(writer));
runner.run();
}
} // namespace kyut::pass
#ifndef INCLUDE_kyut_pass_operand_swapping_extractor_cpp
#define INCLUDE_kyut_pass_operand_swapping_extractor_cpp
#include <wasm.h>
#include "../bit_writer.hpp"
namespace kyut::pass {
void extractWatermarkOperandSwapping(wasm::Module &module, BitWriter &writer);
}
#endif // INCLUDE_kyut_pass_operand_swapping_extractor_cpp
......@@ -4,11 +4,12 @@
#include "../commutativity.hpp"
#include "../comparison.hpp"
#include "../side_effect.hpp"
#include "side_effect_checker.hpp"
namespace kyut::pass {
class OperandSwappingWatermarkingVisitor
: public wasm::OverriddenVisitor<OperandSwappingWatermarkingVisitor, SideEffect> {
: public wasm::UnifiedExpressionVisitor<OperandSwappingWatermarkingVisitor, SideEffect> {
public:
explicit OperandSwappingWatermarkingVisitor(CircularBitStream &stream)
: stream_(stream) {}
......@@ -21,199 +22,8 @@ namespace kyut::pass {
~OperandSwappingWatermarkingVisitor() noexcept = default;
SideEffect visitBlock(wasm::Block *curr) {
auto side_effect = SideEffect::none;
for (const auto &expr : curr->list) {
side_effect = (std::max)(visit(expr), side_effect);
}
return side_effect;
}
SideEffect visitIf(wasm::If *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->condition), side_effect);
side_effect = (std::max)(visit(curr->ifTrue), side_effect);
if (curr->ifFalse) {
side_effect = (std::max)(visit(curr->ifFalse), side_effect);
}
return side_effect;
}
SideEffect visitLoop(wasm::Loop *curr) {
return visit(curr->body);
}
SideEffect visitBreak(wasm::Break *curr) {
if (curr->condition) {
visit(curr->condition);
}
if (curr->value) {
visit(curr->value);
}
return SideEffect::write;
}
SideEffect visitSwitch(wasm::Switch *curr) {
visit(curr->condition);
if (curr->value) {
visit(curr->value);
}
return SideEffect::write;
}
SideEffect visitCall(wasm::Call *curr) {
for (const auto &expr : curr->operands) {
visit(expr);
}
return SideEffect::write;
}
SideEffect visitCallIndirect(wasm::CallIndirect *curr) {
visit(curr->target);
for (const auto &expr : curr->operands) {
visit(expr);
}
return SideEffect::write;
}
SideEffect visitGetLocal([[maybe_unused]] wasm::GetLocal *curr) {
return SideEffect::read;
}
SideEffect visitSetLocal(wasm::SetLocal *curr) {
visit(curr->value);
return SideEffect::write;
}
SideEffect visitGetGlobal([[maybe_unused]] wasm::GetGlobal *curr) {
return SideEffect::read;
}
SideEffect visitSetGlobal(wasm::SetGlobal *curr) {
visit(curr->value);
return SideEffect::write;
}
SideEffect visitLoad(wasm::Load *curr) {
return (std::max)(visit(curr->ptr), SideEffect::read);
}
SideEffect visitStore(wasm::Store *curr) {
visit(curr->ptr);
visit(curr->value);
return SideEffect::write;
}
SideEffect visitAtomicRMW(wasm::AtomicRMW *curr) {
visit(curr->ptr);
visit(curr->value);
return SideEffect::write;
}
SideEffect visitAtomicCmpxchg(wasm::AtomicCmpxchg *curr) {
visit(curr->ptr);
visit(curr->expected);
visit(curr->replacement);
return SideEffect::write;
}
SideEffect visitAtomicWait(wasm::AtomicWait *curr) {
visit(curr->ptr);
visit(curr->expected);
visit(curr->timeout);
return SideEffect::write;
}
SideEffect visitAtomicNotify(wasm::AtomicNotify *curr) {
visit(curr->ptr);
visit(curr->notifyCount);
return SideEffect::write;
}
SideEffect visitSIMDExtract(wasm::SIMDExtract *curr) {
return visit(curr->vec);
}
SideEffect visitSIMDReplace(wasm::SIMDReplace *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->vec), side_effect);
side_effect = (std::max)(visit(curr->value), side_effect);
return side_effect;
}
SideEffect visitSIMDShuffle(wasm::SIMDShuffle *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->left), side_effect);
side_effect = (std::max)(visit(curr->right), side_effect);
return side_effect;
}
SideEffect visitSIMDBitselect(wasm::SIMDBitselect *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->cond), side_effect);
side_effect = (std::max)(visit(curr->left), side_effect);
side_effect = (std::max)(visit(curr->right), side_effect);
return side_effect;
}
SideEffect visitSIMDShift(wasm::SIMDShift *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->vec), side_effect);
side_effect = (std::max)(visit(curr->shift), side_effect);
return side_effect;
}
SideEffect visitMemoryInit(wasm::MemoryInit *curr) {
visit(curr->dest);
visit(curr->offset);
visit(curr->size);
return SideEffect::write;
}
SideEffect visitDataDrop([[maybe_unused]] wasm::DataDrop *curr) {
return SideEffect::write;
}
SideEffect visitMemoryCopy(wasm::MemoryCopy *curr) {
visit(curr->dest);
visit(curr->source);
visit(curr->size);
return SideEffect::write;
}
SideEffect visitMemoryFill(wasm::MemoryFill *curr) {
visit(curr->dest);
visit(curr->value);
visit(curr->size);
return SideEffect::write;
}
SideEffect visitConst([[maybe_unused]] wasm::Const *curr) {
return SideEffect::none;
}
SideEffect visitUnary(wasm::Unary *curr) {
return visit(curr->value);
SideEffect visitExpression(wasm::Expression *curr) {
return SideEffectCheckingVisitor{}.visit(curr);
}
SideEffect visitBinary(wasm::Binary *curr) {
......@@ -235,69 +45,6 @@ namespace kyut::pass {
return (std::max)(side_effect_left, side_effect_right);
}
SideEffect visitSelect(wasm::Select *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->condition), side_effect);
side_effect = (std::max)(visit(curr->ifTrue), side_effect);
side_effect = (std::max)(visit(curr->ifFalse), side_effect);
return side_effect;
}
SideEffect visitDrop(wasm::Drop *curr) {
return visit(curr->value);
}
SideEffect visitReturn(wasm::Return *curr) {
if (curr->value) {
visit(curr->value);
}
return SideEffect::write;
}
SideEffect visitHost(wasm::Host *curr) {
for (const auto &expr : curr->operands) {
visit(expr);
}
if (curr->op == wasm::HostOp::CurrentMemory) {
return SideEffect::read;
} else {
return SideEffect::write;
}
}
SideEffect visitNop([[maybe_unused]] wasm::Nop *curr) {
return SideEffect::none;
}
SideEffect visitUnreachable([[maybe_unused]] wasm::Unreachable *curr) {
WASM_UNREACHABLE();
}
SideEffect visitFunctionType([[maybe_unused]] wasm::FunctionType *curr) {
WASM_UNREACHABLE();
}
SideEffect visitExport([[maybe_unused]] wasm::Export *curr) {
WASM_UNREACHABLE();
}
SideEffect visitGlobal([[maybe_unused]] wasm::Global *curr) {
WASM_UNREACHABLE();
}
SideEffect visitFunction([[maybe_unused]] wasm::Function *curr) {
WASM_UNREACHABLE();
}
SideEffect visitTable([[maybe_unused]] wasm::Table *curr) {
WASM_UNREACHABLE();
}
SideEffect visitMemory([[maybe_unused]] wasm::Memory *curr) {
WASM_UNREACHABLE();
}
SideEffect visitModule([[maybe_unused]] wasm::Module *curr) {
WASM_UNREACHABLE();
}
private:
CircularBitStream &stream_;
};
......
#ifndef INCLUDE_kyut_pass_side_effect_checker_hpp
#define INCLUDE_kyut_pass_side_effect_checker_hpp
#include <wasm-traversal.h>
#include "../side_effect.hpp"
namespace kyut::pass {
class SideEffectCheckingVisitor : public wasm::OverriddenVisitor<SideEffectCheckingVisitor, SideEffect> {
public:
explicit SideEffectCheckingVisitor() = default;
SideEffectCheckingVisitor(const SideEffectCheckingVisitor &) = delete;
SideEffectCheckingVisitor(SideEffectCheckingVisitor &&) = delete;
SideEffectCheckingVisitor &operator=(const SideEffectCheckingVisitor &) = delete;
SideEffectCheckingVisitor &operator=(SideEffectCheckingVisitor &&) = delete;
~SideEffectCheckingVisitor() noexcept = default;
SideEffect visitBlock(wasm::Block *curr) {
auto side_effect = SideEffect::none;
for (const auto &expr : curr->list) {
side_effect = (std::max)(visit(expr), side_effect);
}
return side_effect;
}
SideEffect visitIf(wasm::If *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->condition), side_effect);
side_effect = (std::max)(visit(curr->ifTrue), side_effect);
if (curr->ifFalse) {
side_effect = (std::max)(visit(curr->ifFalse), side_effect);
}
return side_effect;
}
SideEffect visitLoop(wasm::Loop *curr) {
return visit(curr->body);
}
SideEffect visitBreak(wasm::Break *curr) {
if (curr->condition) {
visit(curr->condition);
}
if (curr->value) {
visit(curr->value);
}
return SideEffect::write;
}
SideEffect visitSwitch(wasm::Switch *curr) {
visit(curr->condition);
if (curr->value) {
visit(curr->value);
}
return SideEffect::write;
}
SideEffect visitCall(wasm::Call *curr) {
for (const auto &expr : curr->operands) {
visit(expr);
}
return SideEffect::write;
}
SideEffect visitCallIndirect(wasm::CallIndirect *curr) {
visit(curr->target);
for (const auto &expr : curr->operands) {
visit(expr);
}
return SideEffect::write;
}
SideEffect visitGetLocal([[maybe_unused]] wasm::GetLocal *curr) {
return SideEffect::read;
}
SideEffect visitSetLocal(wasm::SetLocal *curr) {
visit(curr->value);
return SideEffect::write;
}
SideEffect visitGetGlobal([[maybe_unused]] wasm::GetGlobal *curr) {
return SideEffect::read;
}
SideEffect visitSetGlobal(wasm::SetGlobal *curr) {
visit(curr->value);
return SideEffect::write;
}
SideEffect visitLoad(wasm::Load *curr) {
return (std::max)(visit(curr->ptr), SideEffect::read);
}
SideEffect visitStore(wasm::Store *curr) {
visit(curr->ptr);
visit(curr->value);
return SideEffect::write;
}
SideEffect visitAtomicRMW(wasm::AtomicRMW *curr) {
visit(curr->ptr);
visit(curr->value);
return SideEffect::write;
}
SideEffect visitAtomicCmpxchg(wasm::AtomicCmpxchg *curr) {
visit(curr->ptr);
visit(curr->expected);
visit(curr->replacement);
return SideEffect::write;
}
SideEffect visitAtomicWait(wasm::AtomicWait *curr) {
visit(curr->ptr);
visit(curr->expected);
visit(curr->timeout);
return SideEffect::write;
}
SideEffect visitAtomicNotify(wasm::AtomicNotify *curr) {
visit(curr->ptr);
visit(curr->notifyCount);
return SideEffect::write;
}
SideEffect visitSIMDExtract(wasm::SIMDExtract *curr) {
return visit(curr->vec);
}
SideEffect visitSIMDReplace(wasm::SIMDReplace *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->vec), side_effect);
side_effect = (std::max)(visit(curr->value), side_effect);
return side_effect;
}
SideEffect visitSIMDShuffle(wasm::SIMDShuffle *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->left), side_effect);
side_effect = (std::max)(visit(curr->right), side_effect);
return side_effect;
}
SideEffect visitSIMDBitselect(wasm::SIMDBitselect *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->cond), side_effect);
side_effect = (std::max)(visit(curr->left), side_effect);
side_effect = (std::max)(visit(curr->right), side_effect);
return side_effect;
}
SideEffect visitSIMDShift(wasm::SIMDShift *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->vec), side_effect);
side_effect = (std::max)(visit(curr->shift), side_effect);
return side_effect;
}
SideEffect visitMemoryInit(wasm::MemoryInit *curr) {
visit(curr->dest);
visit(curr->offset);
visit(curr->size);
return SideEffect::write;
}
SideEffect visitDataDrop([[maybe_unused]] wasm::DataDrop *curr) {
return SideEffect::write;
}
SideEffect visitMemoryCopy(wasm::MemoryCopy *curr) {
visit(curr->dest);
visit(curr->source);
visit(curr->size);
return SideEffect::write;
}
SideEffect visitMemoryFill(wasm::MemoryFill *curr) {
visit(curr->dest);
visit(curr->value);
visit(curr->size);
return SideEffect::write;
}
SideEffect visitConst([[maybe_unused]] wasm::Const *curr) {
return SideEffect::none;
}
SideEffect visitUnary(wasm::Unary *curr) {
return visit(curr->value);
}
SideEffect visitBinary(wasm::Binary *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->left), side_effect);
side_effect = (std::max)(visit(curr->right), side_effect);
return side_effect;
}
SideEffect visitSelect(wasm::Select *curr) {
auto side_effect = SideEffect::none;
side_effect = (std::max)(visit(curr->condition), side_effect);
side_effect = (std::max)(visit(curr->ifTrue), side_effect);
side_effect = (std::max)(visit(curr->ifFalse), side_effect);
return side_effect;
}
SideEffect visitDrop(wasm::Drop *curr) {
return visit(curr->value);
}
SideEffect visitReturn(wasm::Return *curr) {
if (curr->value) {
visit(curr->value);
}
return SideEffect::write;
}
SideEffect visitHost(wasm::Host *curr) {
for (const auto &expr : curr->operands) {
visit(expr);
}
if (curr->op == wasm::HostOp::CurrentMemory) {
return SideEffect::read;
} else {
return SideEffect::write;
}
}
SideEffect visitNop([[maybe_unused]] wasm::Nop *curr) {
return SideEffect::none;
}
SideEffect visitUnreachable([[maybe_unused]] wasm::Unreachable *curr) {
WASM_UNREACHABLE();
}
SideEffect visitFunctionType([[maybe_unused]] wasm::FunctionType *curr) {
WASM_UNREACHABLE();
}
SideEffect visitExport([[maybe_unused]] wasm::Export *curr) {
WASM_UNREACHABLE();
}
SideEffect visitGlobal([[maybe_unused]] wasm::Global *curr) {
WASM_UNREACHABLE();
}
SideEffect visitFunction([[maybe_unused]] wasm::Function *curr) {
WASM_UNREACHABLE();
}
SideEffect visitTable([[maybe_unused]] wasm::Table *curr) {
WASM_UNREACHABLE();
}
SideEffect visitMemory([[maybe_unused]] wasm::Memory *curr) {
WASM_UNREACHABLE();
}
SideEffect visitModule([[maybe_unused]] wasm::Module *curr) {
WASM_UNREACHABLE();
}
};
} // namespace kyut::pass
#endif // INCLUDE_kyut_pass_side_effect_checker_hpp
#include <iostream>
#include <fmt/ostream.h>
#include <wasm-io.h>
#include "kyut/pass/operand_swapping_extractor.hpp"
int main(int argc, char *argv[]) {
try {
if (argc != 2) {
fmt::print(std::cerr,
"WebAssembly digital watermark extractor.\n"
"usage: pisn <input file>\n");
return 1;
}
wasm::Module module;
wasm::ModuleReader{}.read(argv[1], module);
auto writer = kyut::BitWriter{};
kyut::pass::extractWatermarkOperandSwapping(module, writer);
for (const auto byte : writer.bytes()) {
fmt::print("{:02X} ", byte);
}
} catch (wasm::ParseException &e) {
fmt::print(std::cerr, "parse error\n");
e.dump(std::cerr);
return 1;
} catch (const std::exception &e) {
fmt::print(std::cerr, "error: {}\n", e.what());
return 1;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment