Commit 5d78143f authored by nagayama15's avatar nagayama15

Implement function ordering watermark extractor

parent 233da654
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include "../BitStreamWriter.hpp"
#include "../CircularBitStreamReader.hpp" #include "../CircularBitStreamReader.hpp"
namespace kyut::watermarker { namespace kyut::watermarker {
...@@ -77,4 +78,61 @@ namespace kyut::watermarker { ...@@ -77,4 +78,61 @@ namespace kyut::watermarker {
return numBits; return numBits;
} }
std::size_t extractFunctionOrdering(wasm::Module &module, BitStreamWriter &stream, std::size_t maxChunkSize) {
assert(2 <= maxChunkSize && maxChunkSize < 21 && "because 21! > 2^64");
// Number of bits extracted in the module
std::size_t numBits = 0;
// Split according to the function in the module has body or not
// [begin, start) has no body, and [start, end) has
const auto start = std::partition(
std::begin(module.functions), std::end(module.functions), [](const auto &f) { return f->body == nullptr; });
const size_t count = std::distance(start, std::end(module.functions));
for (size_t i = 0; i < count; i += maxChunkSize) {
const auto chunkSize = (std::min)(maxChunkSize, count - i);
const auto chunkBegin = start + i;
const auto chunkEnd = chunkBegin + chunkSize;
// Number of bits embedded in the chunk
const auto numBitsEmbeddedInChunk = factorialBitLengthTable[chunkSize];
// Extract watermarks from the chunk
std::vector<wasm::Function *> functions;
functions.reserve(chunkSize);
std::transform(chunkBegin, chunkEnd, std::back_inserter(functions), [](const auto &f) { return f.get(); });
std::sort(std::begin(functions), std::end(functions), [](const auto &a, const auto &b) {
return a->name < b->name;
});
std::int64_t watermark = 0;
std::int64_t base = 1;
for (auto it = chunkBegin; it != chunkEnd; ++it) {
// Get index of the function `*it`
const auto pos = std::find_if(
std::begin(functions), std::end(functions), [it](const auto &f) { return f == it->get(); });
assert(pos != std::end(functions));
const std::size_t index = std::distance(std::begin(functions), pos);
watermark += index * base;
base *= functions.size();
// Remove the function found in this step
functions.erase(pos);
}
stream.write(watermark, numBitsEmbeddedInChunk);
numBits += numBitsEmbeddedInChunk;
}
return numBits;
}
} // namespace kyut::watermarker } // namespace kyut::watermarker
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
#include <wasm.h> #include <wasm.h>
namespace kyut { namespace kyut {
class BitStreamWriter;
class CircularBitStreamReader; class CircularBitStreamReader;
} } // namespace kyut
namespace kyut::watermarker { namespace kyut::watermarker {
/** /**
...@@ -17,6 +18,16 @@ namespace kyut::watermarker { ...@@ -17,6 +18,16 @@ namespace kyut::watermarker {
* @return Number of watermark bits embedded in the module. * @return Number of watermark bits embedded in the module.
*/ */
std::size_t embedFunctionOrdering(wasm::Module &module, CircularBitStreamReader &stream, std::size_t maxChunkSize); std::size_t embedFunctionOrdering(wasm::Module &module, CircularBitStreamReader &stream, std::size_t maxChunkSize);
/**
* @brief Extract watermarks by changing the order of functions.
*
* @param module
* @param stream Output stream to save the watermarks.
* @param maxChunkSize The maximum number of functions in the watermark chunk.
* @return Number of watermark bits extracted from the module.
*/
std::size_t extractFunctionOrdering(wasm::Module &module, BitStreamWriter &stream, std::size_t maxChunkSize);
} // namespace kyut::watermarker } // namespace kyut::watermarker
#endif // INCLUDE_kyut_watermark_FunctionOrderingWatermarker_hpp #endif // INCLUDE_kyut_watermark_FunctionOrderingWatermarker_hpp
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <wasm-io.h> #include <wasm-io.h>
#include <kyut/BitStreamWriter.hpp>
#include <kyut/CircularBitStreamReader.hpp> #include <kyut/CircularBitStreamReader.hpp>
BOOST_AUTO_TEST_SUITE(kyut) BOOST_AUTO_TEST_SUITE(kyut)
...@@ -83,6 +84,75 @@ BOOST_AUTO_TEST_CASE(embed_function_ordering) { ...@@ -83,6 +84,75 @@ BOOST_AUTO_TEST_CASE(embed_function_ordering) {
} }
} }
BOOST_AUTO_TEST_CASE(extract_function_ordering) {
wasm::Module module;
wasm::ModuleReader{}.read(KYUT_TEST_SOURCE_DIR "/example/test1.wast", module);
// Embed 0b00
{
CircularBitStreamReader s{{0b0000'0000}};
const auto numBitsEmbedded = embedFunctionOrdering(module, s, 10);
BOOST_REQUIRE_EQUAL(numBitsEmbedded, std::size_t{2});
BitStreamWriter w;
const auto numBitsExtracted = extractFunctionOrdering(module, w, 10);
BOOST_REQUIRE_EQUAL(numBitsExtracted, std::size_t{2});
BOOST_REQUIRE_EQUAL(w.tell(), std::size_t{2});
BOOST_REQUIRE_EQUAL(w.data().size(), std::size_t{1});
BOOST_REQUIRE_EQUAL(w.data()[0], std::uint8_t{0b0000'0000});
}
// Embed 0b01
{
CircularBitStreamReader s{{0b0100'0000}};
const auto numBitsEmbedded = embedFunctionOrdering(module, s, 10);
BOOST_REQUIRE_EQUAL(numBitsEmbedded, std::size_t{2});
BitStreamWriter w;
const auto numBitsExtracted = extractFunctionOrdering(module, w, 10);
BOOST_REQUIRE_EQUAL(numBitsExtracted, std::size_t{2});
BOOST_REQUIRE_EQUAL(w.tell(), std::size_t{2});
BOOST_REQUIRE_EQUAL(w.data().size(), std::size_t{1});
BOOST_REQUIRE_EQUAL(w.data()[0], std::uint8_t{0b0100'0000});
}
// Embed 0b10
{
CircularBitStreamReader s{{0b1000'0000}};
const auto numBitsEmbedded = embedFunctionOrdering(module, s, 10);
BOOST_REQUIRE_EQUAL(numBitsEmbedded, std::size_t{2});
BitStreamWriter w;
const auto numBitsExtracted = extractFunctionOrdering(module, w, 10);
BOOST_REQUIRE_EQUAL(numBitsExtracted, std::size_t{2});
BOOST_REQUIRE_EQUAL(w.tell(), std::size_t{2});
BOOST_REQUIRE_EQUAL(w.data().size(), std::size_t{1});
BOOST_REQUIRE_EQUAL(w.data()[0], std::uint8_t{0b1000'0000});
}
// Embed 0b11
{
CircularBitStreamReader s{{0b1100'0000}};
const auto numBitsEmbedded = embedFunctionOrdering(module, s, 10);
BOOST_REQUIRE_EQUAL(numBitsEmbedded, std::size_t{2});
BitStreamWriter w;
const auto numBitsExtracted = extractFunctionOrdering(module, w, 10);
BOOST_REQUIRE_EQUAL(numBitsExtracted, std::size_t{2});
BOOST_REQUIRE_EQUAL(w.tell(), std::size_t{2});
BOOST_REQUIRE_EQUAL(w.data().size(), std::size_t{1});
BOOST_REQUIRE_EQUAL(w.data()[0], std::uint8_t{0b1100'0000});
}
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment