Commit b241ef19 authored by nagayama15's avatar nagayama15

Complete operand swapping method

parent ea7361dd
...@@ -18,12 +18,16 @@ int main(int argc, char *argv[]) { ...@@ -18,12 +18,16 @@ int main(int argc, char *argv[]) {
wasm::Module module; wasm::Module module;
wasm::ModuleReader{}.read(argv[1], module); wasm::ModuleReader{}.read(argv[1], module);
kyut::pass::embedWatermarkOperandSwapping(module); const auto stream = kyut::CircularBitStream::from_string(argv[2]);
kyut::pass::embedWatermarkOperandSwapping(module, *stream);
wasm::ModuleWriter{}.writeText(module, argv[3]); wasm::ModuleWriter{}.writeText(module, argv[3]);
} catch (wasm::ParseException &e) { } catch (wasm::ParseException &e) {
fmt::print(std::cerr, "parse error\n"); fmt::print(std::cerr, "parse error\n");
e.dump(std::cerr); e.dump(std::cerr);
return 1;
} catch (const std::exception &e) { } catch (const std::exception &e) {
fmt::print(std::cerr, "error: {}\n", e.what()); fmt::print(std::cerr, "error: {}\n", e.what());
......
...@@ -37,12 +37,16 @@ namespace kyut { ...@@ -37,12 +37,16 @@ namespace kyut {
return size_bytes() * 8; return size_bytes() * 8;
} }
[[nodiscard]] std::size_t pos_bits() const noexcept {
return pos_bits_;
}
bool read_bit() noexcept { bool read_bit() noexcept {
if (size_bits() == 0) { if (size_bits() == 0) {
return false; return false;
} }
const auto value = bytes_[pos_bits_ >> 3] >> (pos_bits_ & 0x7); const auto value = (bytes_[pos_bits_ >> 3] >> (pos_bits_ & 0x7)) & 0x1;
if (++pos_bits_ >= size_bits()) { if (++pos_bits_ >= size_bits()) {
pos_bits_ = 0; pos_bits_ = 0;
......
...@@ -2,11 +2,15 @@ ...@@ -2,11 +2,15 @@
#include <pass.h> #include <pass.h>
#include "../commutativity.hpp"
#include "../comparison.hpp"
namespace kyut::pass { namespace kyut::pass {
class OperandSwappingWatermarkingVisitor class OperandSwappingWatermarkingVisitor
: public wasm::OverriddenVisitor<OperandSwappingWatermarkingVisitor, bool> { : public wasm::OverriddenVisitor<OperandSwappingWatermarkingVisitor, bool> {
public: public:
explicit OperandSwappingWatermarkingVisitor() = default; explicit OperandSwappingWatermarkingVisitor(CircularBitStream &stream)
: stream_(stream) {}
OperandSwappingWatermarkingVisitor(const OperandSwappingWatermarkingVisitor &) = delete; OperandSwappingWatermarkingVisitor(const OperandSwappingWatermarkingVisitor &) = delete;
OperandSwappingWatermarkingVisitor(OperandSwappingWatermarkingVisitor &&) = delete; OperandSwappingWatermarkingVisitor(OperandSwappingWatermarkingVisitor &&) = delete;
...@@ -217,6 +221,14 @@ namespace kyut::pass { ...@@ -217,6 +221,14 @@ namespace kyut::pass {
has_side_effect = visit(curr->left) || has_side_effect; has_side_effect = visit(curr->left) || has_side_effect;
has_side_effect = visit(curr->right) || has_side_effect; has_side_effect = visit(curr->right) || has_side_effect;
if (isCommutative(curr->op) && !has_side_effect) {
const auto bit = stream_.read_bit();
if (bit == (*curr->left < *curr->right)) {
swapOperands(*curr);
}
}
return has_side_effect; return has_side_effect;
} }
...@@ -278,11 +290,15 @@ namespace kyut::pass { ...@@ -278,11 +290,15 @@ namespace kyut::pass {
bool visitModule([[maybe_unused]] wasm::Module *curr) { bool visitModule([[maybe_unused]] wasm::Module *curr) {
WASM_UNREACHABLE(); WASM_UNREACHABLE();
} }
private:
CircularBitStream &stream_;
}; };
class OperandSwappingWatermarkingPass : public wasm::Pass { class OperandSwappingWatermarkingPass : public wasm::Pass {
public: public:
explicit OperandSwappingWatermarkingPass() = default; explicit OperandSwappingWatermarkingPass(CircularBitStream &stream)
: stream_(stream) {}
OperandSwappingWatermarkingPass(const OperandSwappingWatermarkingPass &) = delete; OperandSwappingWatermarkingPass(const OperandSwappingWatermarkingPass &) = delete;
OperandSwappingWatermarkingPass(OperandSwappingWatermarkingPass &&) = delete; OperandSwappingWatermarkingPass(OperandSwappingWatermarkingPass &&) = delete;
...@@ -297,17 +313,20 @@ namespace kyut::pass { ...@@ -297,17 +313,20 @@ namespace kyut::pass {
} }
void run([[maybe_unused]] wasm::PassRunner *runner, wasm::Module *module) override { void run([[maybe_unused]] wasm::PassRunner *runner, wasm::Module *module) override {
OperandSwappingWatermarkingVisitor visitor{}; OperandSwappingWatermarkingVisitor visitor{stream_};
for (const auto &func : module->functions) { for (const auto &func : module->functions) {
visitor.visit(func->body); visitor.visit(func->body);
} }
} }
private:
CircularBitStream &stream_;
}; };
void embedWatermarkOperandSwapping(wasm::Module &module) { void embedWatermarkOperandSwapping(wasm::Module &module, CircularBitStream &stream) {
wasm::PassRunner runner{&module}; wasm::PassRunner runner{&module};
runner.add<OperandSwappingWatermarkingPass>(); runner.add<OperandSwappingWatermarkingPass>(std::ref(stream));
runner.run(); runner.run();
} }
} // namespace kyut::pass } // namespace kyut::pass
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#include <wasm.h> #include <wasm.h>
#include "../circular_bit_stream.hpp"
namespace kyut::pass { namespace kyut::pass {
void embedWatermarkOperandSwapping(wasm::Module &module); void embedWatermarkOperandSwapping(wasm::Module &module, CircularBitStream &stream);
} }
#endif // INCLUDE_kyut_pass_operand_swapping_watermarker_hpp #endif // INCLUDE_kyut_pass_operand_swapping_watermarker_hpp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment