Commit 8a26c932 authored by nagayama15's avatar nagayama15

Rewrite embedder using visitor

parent ff48a66b
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <wasm-traversal.h>
#include "../BitStreamWriter.hpp" #include "../BitStreamWriter.hpp"
#include "../CircularBitStreamReader.hpp" #include "../CircularBitStreamReader.hpp"
...@@ -317,269 +319,260 @@ namespace kyut::watermarker { ...@@ -317,269 +319,260 @@ namespace kyut::watermarker {
} }
// Watermark embedder // Watermark embedder
SideEffect embedExpression(wasm::Expression *expr, CircularBitStreamReader &stream); struct EmbeddingVisitor : wasm::OverriddenVisitor<EmbeddingVisitor, SideEffect> {
CircularBitStreamReader &stream;
SideEffect embedExpressionList(const wasm::ExpressionList &exprs, CircularBitStreamReader &stream) {
auto effect = SideEffect::none;
for (const auto expr : exprs) {
effect = (std::max)(embedExpression(expr, stream), effect);
}
return effect; explicit EmbeddingVisitor(CircularBitStreamReader &stream)
} : stream(stream) {}
SideEffect embedBlock(wasm::Block &expr, CircularBitStreamReader &stream) { SideEffect visitExpressionList(const wasm::ExpressionList &exprs) {
return embedExpressionList(expr.list, stream); auto effect = SideEffect::none;
}
SideEffect embedIf(wasm::If &expr, CircularBitStreamReader &stream) { for (const auto expr : exprs) {
return (std::max)({ effect = (std::max)(visit(expr), effect);
embedExpression(expr.condition, stream), }
embedExpression(expr.ifTrue, stream),
embedExpression(expr.ifFalse, stream),
});
}
SideEffect embedLoop(wasm::Loop &expr, CircularBitStreamReader &stream) { return effect;
return embedExpression(expr.body, stream); }
}
SideEffect embedBreak(wasm::Break &expr, CircularBitStreamReader &stream) { SideEffect visitBlock(wasm::Block *expr) {
embedExpression(expr.value, stream); return visitExpressionList(expr->list);
embedExpression(expr.condition, stream); }
return SideEffect::write; SideEffect visitIf(wasm::If *expr) {
} return (std::max)({
visit(expr->condition),
visit(expr->ifTrue),
visit(expr->ifFalse),
});
}
SideEffect embedSwitch(wasm::Switch &expr, CircularBitStreamReader &stream) { SideEffect visitLoop(wasm::Loop *expr) {
return (std::max)(embedExpression(expr.condition, stream), embedExpression(expr.value, stream)); return visit(expr->body);
} }
SideEffect embedCall(wasm::Call &expr, CircularBitStreamReader &stream) { SideEffect visitBreak(wasm::Break *expr) {
embedExpressionList(expr.operands, stream); visit(expr->value);
visit(expr->condition);
// It is difficult to estimate the side effects of the function calls return SideEffect::write;
return SideEffect::write; }
}
SideEffect embedCallIndirect(wasm::CallIndirect &expr, CircularBitStreamReader &stream) { SideEffect visitSwitch(wasm::Switch *expr) {
embedExpression(expr.target, stream); return (std::max)(visit(expr->condition), visit(expr->value));
embedExpressionList(expr.operands, stream); }
// It is difficult to estimate the side effects of the function calls SideEffect visitCall(wasm::Call *expr) {
return SideEffect::write; visitExpressionList(expr->operands);
}
SideEffect embedGetLocal([[maybe_unused]] wasm::GetLocal &expr, // It is difficult to estimate the side effects of the function calls
[[maybe_unused]] CircularBitStreamReader &stream) { return SideEffect::write;
return SideEffect::readOnly; }
}
SideEffect embedSetLocal(wasm::SetLocal &expr, CircularBitStreamReader &stream) { SideEffect visitCallIndirect(wasm::CallIndirect *expr) {
embedExpression(expr.value, stream); visit(expr->target);
visitExpressionList(expr->operands);
return SideEffect::write; // It is difficult to estimate the side effects of the function calls
} return SideEffect::write;
}
SideEffect embedGetGlobal([[maybe_unused]] wasm::GetGlobal &expr, SideEffect visitGetLocal([[maybe_unused]] wasm::GetLocal *expr) {
[[maybe_unused]] CircularBitStreamReader &stream) { return SideEffect::readOnly;
return SideEffect::readOnly; }
}
SideEffect embedSetGlobal(wasm::SetGlobal &expr, CircularBitStreamReader &stream) { SideEffect visitSetLocal(wasm::SetLocal *expr) {
embedExpression(expr.value, stream); visit(expr->value);
return SideEffect::write; return SideEffect::write;
} }
SideEffect embedLoad(wasm::Load &expr, CircularBitStreamReader &stream) { SideEffect visitGetGlobal([[maybe_unused]] wasm::GetGlobal *expr) {
return (std::max)(embedExpression(expr.ptr, stream), SideEffect::readOnly); return SideEffect::readOnly;
} }
SideEffect embedStore(wasm::Store &expr, CircularBitStreamReader &stream) { SideEffect visitSetGlobal(wasm::SetGlobal *expr) {
embedExpression(expr.ptr, stream); visit(expr->value);
embedExpression(expr.value, stream);
return SideEffect::write; return SideEffect::write;
} }
SideEffect embedConst([[maybe_unused]] wasm::Const &expr, [[maybe_unused]] CircularBitStreamReader &stream) { SideEffect visitLoad(wasm::Load *expr) {
return SideEffect::none; return (std::max)(visit(expr->ptr), SideEffect::readOnly);
} }
SideEffect embedUnary(wasm::Unary &expr, CircularBitStreamReader &stream) { SideEffect visitStore(wasm::Store *expr) {
return embedExpression(expr.value, stream); visit(expr->ptr);
} visit(expr->value);
SideEffect embedBinary(wasm::Binary &expr, CircularBitStreamReader &stream) { return SideEffect::write;
if (!isCommutative(expr.op)) {
// The operands of noncommutative instructions cannot be swapped
return (std::max)(embedExpression(expr.left, stream), embedExpression(expr.right, stream));
} }
if (!(*expr.left < *expr.right) && !(*expr.right < *expr.left)) { SideEffect visitConst([[maybe_unused]] wasm::Const *expr) {
// If both sides are the same or cannot be ordered, skip embedding return SideEffect::none;
return (std::max)(embedExpression(expr.left, stream), embedExpression(expr.right, stream));
} }
// Sort both of the operands SideEffect visitUnary(wasm::Unary *expr) {
auto [lo, hi] = std::minmax(expr.left, expr.right, [](auto a, auto b) { return *a < *b; }); return visit(expr->value);
}
const auto loEffect = embedExpression(lo, stream); SideEffect visitBinary(wasm::Binary *expr) {
const auto hiEffect = embedExpression(hi, stream); if (!isCommutative(expr->op)) {
// The operands of noncommutative instructions cannot be swapped
return (std::max)(visit(expr->left), visit(expr->right));
}
if (static_cast<std::uint32_t>(loEffect) + static_cast<std::uint32_t>(hiEffect) >= 3) { if (!(*expr->left < *expr->right) && !(*expr->right < *expr->left)) {
// The operands have side effect and cannot be swapped // If both sides are the same or cannot be ordered, skip visitding
return (std::max)(loEffect, hiEffect); return (std::max)(visit(expr->left), visit(expr->right));
} }
// Embed watermarks by swapping operands // Sort both of the operands
const bool bit = stream.readBit(); auto [lo, hi] = std::minmax(expr->left, expr->right, [](auto a, auto b) { return *a < *b; });
if (bit == (expr.left == lo)) { const auto loEffect = visit(lo);
swapOperands(expr); const auto hiEffect = visit(hi);
}
return (std::max)(loEffect, hiEffect); if (static_cast<std::uint32_t>(loEffect) + static_cast<std::uint32_t>(hiEffect) >= 3) {
} // The operands have side effect and cannot be swapped
return (std::max)(loEffect, hiEffect);
}
SideEffect embedSelect(wasm::Select &expr, CircularBitStreamReader &stream) { // Embed watermarks by swapping operands
return (std::max)({ const bool bit = stream.readBit();
embedExpression(expr.condition, stream),
embedExpression(expr.ifTrue, stream),
embedExpression(expr.ifFalse, stream),
});
}
SideEffect embedDrop(wasm::Drop &expr, CircularBitStreamReader &stream) { if (bit == (expr->left == lo)) {
return embedExpression(expr.value, stream); swapOperands(*expr);
} }
SideEffect embedReturn(wasm::Return &expr, CircularBitStreamReader &stream) { return (std::max)(loEffect, hiEffect);
embedExpression(expr.value, stream); }
return SideEffect::write; SideEffect visitSelect(wasm::Select *expr) {
} return (std::max)({
visit(expr->condition),
visit(expr->ifTrue),
visit(expr->ifFalse),
});
}
SideEffect embedHost(wasm::Host &expr, CircularBitStreamReader &stream) { SideEffect visitDrop(wasm::Drop *expr) {
embedExpressionList(expr.operands, stream); return visit(expr->value);
}
return SideEffect::write; SideEffect visitReturn(wasm::Return *expr) {
} visit(expr->value);
SideEffect embedNop([[maybe_unused]] wasm::Nop &expr, [[maybe_unused]] CircularBitStreamReader &stream) { return SideEffect::write;
return SideEffect::none; }
}
SideEffect embedUnreachable([[maybe_unused]] wasm::Unreachable &expr, SideEffect visitHost(wasm::Host *expr) {
[[maybe_unused]] CircularBitStreamReader &stream) { visitExpressionList(expr->operands);
return SideEffect::write;
}
SideEffect embedAtomicRMW(wasm::AtomicRMW &expr, CircularBitStreamReader &stream) { return SideEffect::write;
embedExpression(expr.ptr, stream); }
embedExpression(expr.value, stream);
return SideEffect::write; SideEffect visitNop([[maybe_unused]] wasm::Nop *exp) {
} return SideEffect::none;
}
SideEffect embedAtomicCmpxchg(wasm::AtomicCmpxchg &expr, CircularBitStreamReader &stream) { SideEffect visitUnreachable([[maybe_unused]] wasm::Unreachable *expr) {
embedExpression(expr.ptr, stream); return SideEffect::write;
embedExpression(expr.expected, stream); }
embedExpression(expr.replacement, stream);
return SideEffect::write; SideEffect visitAtomicRMW(wasm::AtomicRMW *expr) {
} visit(expr->ptr);
visit(expr->value);
SideEffect embedAtomicWait(wasm::AtomicWait &expr, CircularBitStreamReader &stream) { return SideEffect::write;
embedExpression(expr.ptr, stream); }
embedExpression(expr.expected, stream);
embedExpression(expr.timeout, stream);
return SideEffect::write; SideEffect visitAtomicCmpxchg(wasm::AtomicCmpxchg *expr) {
} visit(expr->ptr);
visit(expr->expected);
visit(expr->replacement);
SideEffect embedAtomicNotify(wasm::AtomicNotify &expr, CircularBitStreamReader &stream) { return SideEffect::write;
embedExpression(expr.ptr, stream); }
embedExpression(expr.notifyCount, stream);
return SideEffect::write; SideEffect visitAtomicWait(wasm::AtomicWait *expr) {
} visit(expr->ptr);
visit(expr->expected);
visit(expr->timeout);
SideEffect embedSIMDExtract(wasm::SIMDExtract &expr, CircularBitStreamReader &stream) { return SideEffect::write;
return embedExpression(expr.vec, stream); }
}
SideEffect embedSIMDReplace(wasm::SIMDReplace &expr, CircularBitStreamReader &stream) { SideEffect visitAtomicNotify(wasm::AtomicNotify *expr) {
return (std::max)(embedExpression(expr.vec, stream), embedExpression(expr.value, stream)); visit(expr->ptr);
} visit(expr->notifyCount);
SideEffect embedSIMDShuffle(wasm::SIMDShuffle &expr, CircularBitStreamReader &stream) { return SideEffect::write;
return (std::max)(embedExpression(expr.left, stream), embedExpression(expr.right, stream)); }
}
SideEffect embedSIMDBitselect(wasm::SIMDBitselect &expr, CircularBitStreamReader &stream) { SideEffect visitSIMDExtract(wasm::SIMDExtract *expr) {
return (std::max)({ return visit(expr->vec);
embedExpression(expr.cond, stream), }
embedExpression(expr.left, stream),
embedExpression(expr.right, stream),
});
}
SideEffect embedSIMDShift(wasm::SIMDShift &expr, CircularBitStreamReader &stream) { SideEffect visitSIMDReplace(wasm::SIMDReplace *expr) {
return (std::max)(embedExpression(expr.vec, stream), embedExpression(expr.shift, stream)); return (std::max)(visit(expr->vec), visit(expr->value));
} }
SideEffect embedMemoryInit(wasm::MemoryInit &expr, CircularBitStreamReader &stream) { SideEffect visitSIMDShuffle(wasm::SIMDShuffle *expr) {
embedExpression(expr.dest, stream); return (std::max)(visit(expr->left), visit(expr->right));
embedExpression(expr.offset, stream); }
embedExpression(expr.size, stream);
return SideEffect::write; SideEffect visitSIMDBitselect(wasm::SIMDBitselect *expr) {
} return (std::max)({
visit(expr->cond),
visit(expr->left),
visit(expr->right),
});
}
SideEffect embedDataDrop([[maybe_unused]] wasm::DataDrop &expr, SideEffect visitSIMDShift(wasm::SIMDShift *expr) {
[[maybe_unused]] CircularBitStreamReader &stream) { return (std::max)(visit(expr->vec), visit(expr->shift));
return SideEffect::write; }
}
SideEffect embedMemoryCopy(wasm::MemoryCopy &expr, CircularBitStreamReader &stream) { SideEffect visitMemoryInit(wasm::MemoryInit *expr) {
embedExpression(expr.dest, stream); visit(expr->dest);
embedExpression(expr.source, stream); visit(expr->offset);
embedExpression(expr.size, stream); visit(expr->size);
return SideEffect::write; return SideEffect::write;
} }
SideEffect embedMemoryFill(wasm::MemoryFill &expr, CircularBitStreamReader &stream) { SideEffect visitDataDrop([[maybe_unused]] wasm::DataDrop *expr) {
embedExpression(expr.dest, stream); return SideEffect::write;
embedExpression(expr.value, stream); }
embedExpression(expr.size, stream);
return SideEffect::write; SideEffect visitMemoryCopy(wasm::MemoryCopy *expr) {
} visit(expr->dest);
visit(expr->source);
visit(expr->size);
SideEffect embedExpression(wasm::Expression *expr, CircularBitStreamReader &stream) { return SideEffect::write;
if (expr == nullptr) {
return SideEffect::none;
} }
switch (expr->_id) { SideEffect visitMemoryFill(wasm::MemoryFill *expr) {
#define EXPR_TYPE(name) \ visit(expr->dest);
case ::wasm::Expression::name##Id: \ visit(expr->value);
return embed##name(*expr->cast<::wasm::name>(), stream); visit(expr->size);
EXPR_TYPES()
#undef EXPR_TYPE
default: return SideEffect::write;
WASM_UNREACHABLE();
} }
}
void embedFunction(wasm::Function &function, CircularBitStreamReader &stream) { SideEffect visit(wasm::Expression *expr) {
embedExpression(function.body, stream); if (expr == nullptr) {
} return SideEffect::none;
}
return OverriddenVisitor::visit(expr);
}
void visitFunction(wasm::Function *function) {
visit(function->body);
}
};
} // namespace } // namespace
std::size_t embedOperandSwapping(wasm::Module &module, CircularBitStreamReader &stream) { std::size_t embedOperandSwapping(wasm::Module &module, CircularBitStreamReader &stream) {
...@@ -596,10 +589,11 @@ namespace kyut::watermarker { ...@@ -596,10 +589,11 @@ namespace kyut::watermarker {
std::begin(functions), std::end(functions), [](const auto a, const auto b) { return a->name < b->name; }); std::begin(functions), std::end(functions), [](const auto a, const auto b) { return a->name < b->name; });
// Embed watermarks // Embed watermarks
std::size_t posStart = stream.tell(); EmbeddingVisitor embedder{stream};
const auto posStart = stream.tell();
for (const auto f : functions) { for (const auto f : functions) {
embedFunction(*f, stream); embedder.visitFunction(f);
} }
return stream.tell() - posStart; return stream.tell() - posStart;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment