Commit f649775a authored by nagayama15's avatar nagayama15

Implement the extractor using operands swapping method

parent 8a26c932
...@@ -573,6 +573,258 @@ namespace kyut::watermarker { ...@@ -573,6 +573,258 @@ namespace kyut::watermarker {
visit(function->body); visit(function->body);
} }
}; };
// Watermark extractor
struct ExtractingVisitor : wasm::OverriddenVisitor<ExtractingVisitor, SideEffect> {
BitStreamWriter &stream;
explicit ExtractingVisitor(BitStreamWriter &stream)
: stream(stream) {}
SideEffect visitExpressionList(const wasm::ExpressionList &exprs) {
auto effect = SideEffect::none;
for (const auto expr : exprs) {
effect = (std::max)(visit(expr), effect);
}
return effect;
}
SideEffect visitBlock(wasm::Block *expr) {
return visitExpressionList(expr->list);
}
SideEffect visitIf(wasm::If *expr) {
return (std::max)({
visit(expr->condition),
visit(expr->ifTrue),
visit(expr->ifFalse),
});
}
SideEffect visitLoop(wasm::Loop *expr) {
return visit(expr->body);
}
SideEffect visitBreak(wasm::Break *expr) {
visit(expr->value);
visit(expr->condition);
return SideEffect::write;
}
SideEffect visitSwitch(wasm::Switch *expr) {
return (std::max)(visit(expr->condition), visit(expr->value));
}
SideEffect visitCall(wasm::Call *expr) {
visitExpressionList(expr->operands);
// It is difficult to estimate the side effects of the function calls
return SideEffect::write;
}
SideEffect visitCallIndirect(wasm::CallIndirect *expr) {
visit(expr->target);
visitExpressionList(expr->operands);
// It is difficult to estimate the side effects of the function calls
return SideEffect::write;
}
SideEffect visitGetLocal([[maybe_unused]] wasm::GetLocal *expr) {
return SideEffect::readOnly;
}
SideEffect visitSetLocal(wasm::SetLocal *expr) {
visit(expr->value);
return SideEffect::write;
}
SideEffect visitGetGlobal([[maybe_unused]] wasm::GetGlobal *expr) {
return SideEffect::readOnly;
}
SideEffect visitSetGlobal(wasm::SetGlobal *expr) {
visit(expr->value);
return SideEffect::write;
}
SideEffect visitLoad(wasm::Load *expr) {
return (std::max)(visit(expr->ptr), SideEffect::readOnly);
}
SideEffect visitStore(wasm::Store *expr) {
visit(expr->ptr);
visit(expr->value);
return SideEffect::write;
}
SideEffect visitConst([[maybe_unused]] wasm::Const *expr) {
return SideEffect::none;
}
SideEffect visitUnary(wasm::Unary *expr) {
return visit(expr->value);
}
SideEffect visitBinary(wasm::Binary *expr) {
if (!isCommutative(expr->op)) {
// The operands of noncommutative instructions cannot be swapped
return (std::max)(visit(expr->left), visit(expr->right));
}
if (!(*expr->left < *expr->right) && !(*expr->right < *expr->left)) {
// If both sides are the same or cannot be ordered, skip visitding
return (std::max)(visit(expr->left), visit(expr->right));
}
// Sort both of the operands
auto [lo, hi] = std::minmax(expr->left, expr->right, [](auto a, auto b) { return *a < *b; });
const auto loEffect = visit(lo);
const auto hiEffect = visit(hi);
if (static_cast<std::uint32_t>(loEffect) + static_cast<std::uint32_t>(hiEffect) >= 3) {
// The operands have side effect and cannot be swapped
return (std::max)(loEffect, hiEffect);
}
// Extract watermarks from operands orders
stream.writeBit(lo == expr->right);
return (std::max)(loEffect, hiEffect);
}
SideEffect visitSelect(wasm::Select *expr) {
return (std::max)({
visit(expr->condition),
visit(expr->ifTrue),
visit(expr->ifFalse),
});
}
SideEffect visitDrop(wasm::Drop *expr) {
return visit(expr->value);
}
SideEffect visitReturn(wasm::Return *expr) {
visit(expr->value);
return SideEffect::write;
}
SideEffect visitHost(wasm::Host *expr) {
visitExpressionList(expr->operands);
return SideEffect::write;
}
SideEffect visitNop([[maybe_unused]] wasm::Nop *exp) {
return SideEffect::none;
}
SideEffect visitUnreachable([[maybe_unused]] wasm::Unreachable *expr) {
return SideEffect::write;
}
SideEffect visitAtomicRMW(wasm::AtomicRMW *expr) {
visit(expr->ptr);
visit(expr->value);
return SideEffect::write;
}
SideEffect visitAtomicCmpxchg(wasm::AtomicCmpxchg *expr) {
visit(expr->ptr);
visit(expr->expected);
visit(expr->replacement);
return SideEffect::write;
}
SideEffect visitAtomicWait(wasm::AtomicWait *expr) {
visit(expr->ptr);
visit(expr->expected);
visit(expr->timeout);
return SideEffect::write;
}
SideEffect visitAtomicNotify(wasm::AtomicNotify *expr) {
visit(expr->ptr);
visit(expr->notifyCount);
return SideEffect::write;
}
SideEffect visitSIMDExtract(wasm::SIMDExtract *expr) {
return visit(expr->vec);
}
SideEffect visitSIMDReplace(wasm::SIMDReplace *expr) {
return (std::max)(visit(expr->vec), visit(expr->value));
}
SideEffect visitSIMDShuffle(wasm::SIMDShuffle *expr) {
return (std::max)(visit(expr->left), visit(expr->right));
}
SideEffect visitSIMDBitselect(wasm::SIMDBitselect *expr) {
return (std::max)({
visit(expr->cond),
visit(expr->left),
visit(expr->right),
});
}
SideEffect visitSIMDShift(wasm::SIMDShift *expr) {
return (std::max)(visit(expr->vec), visit(expr->shift));
}
SideEffect visitMemoryInit(wasm::MemoryInit *expr) {
visit(expr->dest);
visit(expr->offset);
visit(expr->size);
return SideEffect::write;
}
SideEffect visitDataDrop([[maybe_unused]] wasm::DataDrop *expr) {
return SideEffect::write;
}
SideEffect visitMemoryCopy(wasm::MemoryCopy *expr) {
visit(expr->dest);
visit(expr->source);
visit(expr->size);
return SideEffect::write;
}
SideEffect visitMemoryFill(wasm::MemoryFill *expr) {
visit(expr->dest);
visit(expr->value);
visit(expr->size);
return SideEffect::write;
}
SideEffect visit(wasm::Expression *expr) {
if (expr == nullptr) {
return SideEffect::none;
}
return OverriddenVisitor::visit(expr);
}
void visitFunction(wasm::Function *function) {
visit(function->body);
}
};
} // namespace } // namespace
std::size_t embedOperandSwapping(wasm::Module &module, CircularBitStreamReader &stream) { std::size_t embedOperandSwapping(wasm::Module &module, CircularBitStreamReader &stream) {
...@@ -589,8 +841,8 @@ namespace kyut::watermarker { ...@@ -589,8 +841,8 @@ namespace kyut::watermarker {
std::begin(functions), std::end(functions), [](const auto a, const auto b) { return a->name < b->name; }); std::begin(functions), std::end(functions), [](const auto a, const auto b) { return a->name < b->name; });
// Embed watermarks // Embed watermarks
EmbeddingVisitor embedder{stream};
const auto posStart = stream.tell(); const auto posStart = stream.tell();
EmbeddingVisitor embedder{stream};
for (const auto f : functions) { for (const auto f : functions) {
embedder.visitFunction(f); embedder.visitFunction(f);
...@@ -600,10 +852,27 @@ namespace kyut::watermarker { ...@@ -600,10 +852,27 @@ namespace kyut::watermarker {
} }
std::size_t extractOperandSwapping(wasm::Module &module, BitStreamWriter &stream) { std::size_t extractOperandSwapping(wasm::Module &module, BitStreamWriter &stream) {
(void)module; // Sort functions in the module by name
(void)stream; std::vector<wasm::Function *> functions;
functions.reserve(module.functions.size());
WASM_UNREACHABLE(); std::transform(std::begin(module.functions),
std::end(module.functions),
std::back_inserter(functions),
[](const auto &f) { return f.get(); });
std::sort(
std::begin(functions), std::end(functions), [](const auto a, const auto b) { return a->name < b->name; });
// Extract watermarks
const auto posStart = stream.tell();
ExtractingVisitor extractor{stream};
for (const auto f : functions) {
extractor.visitFunction(f);
}
return stream.tell() - posStart;
} }
} // namespace kyut::watermarker } // namespace kyut::watermarker
......
...@@ -136,5 +136,27 @@ BOOST_AUTO_TEST_CASE(embed_operand_swapping_110111) { ...@@ -136,5 +136,27 @@ BOOST_AUTO_TEST_CASE(embed_operand_swapping_110111) {
} }
} }
BOOST_AUTO_TEST_CASE(extract_operand_swapping) {
for (std::uint8_t i = 0; i < 64; i++) {
const std::uint8_t x = i << 2;
wasm::Module module;
wasm::ModuleReader{}.read(KYUT_TEST_SOURCE_DIR "/example/test2.wast", module);
// Embed x
CircularBitStreamReader s{{x}};
const auto numBitsEmbedded = embedOperandSwapping(module, s);
BitStreamWriter w;
const auto numBitsExtracted = extractOperandSwapping(module, w);
BOOST_REQUIRE_EQUAL(numBitsEmbedded, std::size_t{6});
BOOST_REQUIRE_EQUAL(numBitsExtracted, std::size_t{6});
BOOST_REQUIRE_EQUAL(w.tell(), std::size_t{6});
BOOST_REQUIRE_EQUAL(w.data().size(), std::size_t{1});
BOOST_REQUIRE_EQUAL(w.data()[0], std::uint8_t{x});
}
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment