From 5610b026cc78918beec0d5c62ded984caaf27d31 Mon Sep 17 00:00:00 2001 From: Xe Iaso Date: Mon, 14 Apr 2025 08:29:11 -0400 Subject: [PATCH] wasm: make Runner dynamically instansiate Modules for making things massively parallel Signed-off-by: Xe Iaso --- wasm/wasm.go | 179 +++++++++++++++++++++++++++++++++------------- wasm/wasm_test.go | 160 ++++++++++++++++------------------------- 2 files changed, 192 insertions(+), 147 deletions(-) diff --git a/wasm/wasm.go b/wasm/wasm.go index 8f1aeb23..742b7feb 100644 --- a/wasm/wasm.go +++ b/wasm/wasm.go @@ -5,18 +5,36 @@ import ( "errors" "fmt" "io" + "math" "os" + "strconv" + "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" ) func UpdateNonce(uint32) {} +var ( + validationTime = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "anubis_wasm_validation_time", + Help: "The time taken for the validation function to run per checker (nanoseconds)", + Buckets: prometheus.ExponentialBucketsRange(1, math.Pow(2, 31), 32), + }, []string{"fname"}) + + validationCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "anubis_wasm_validation", + Help: "The number of times the validation logic has been run and its success rate", + }, []string{"fname", "success"}) +) + type Runner struct { - r wazero.Runtime - code wazero.CompiledModule - module api.Module + r wazero.Runtime + code wazero.CompiledModule + fname string } func NewRunner(ctx context.Context, fname string, fin io.ReadCloser) (*Runner, error) { @@ -41,25 +59,16 @@ func NewRunner(ctx context.Context, fname string, fin io.ReadCloser) (*Runner, e return nil, fmt.Errorf("wasm: can't compile module: %w", err) } - mod, err := r.InstantiateModule(ctx, code, wazero.NewModuleConfig().WithName(fname)) - if err != nil { - return nil, fmt.Errorf("wasm: can't instantiate module: %w", err) - } - result := &Runner{ - r: r, - code: code, - module: mod, - } - - if err := result.checkExports(); err != nil { - return nil, fmt.Errorf("wasm: module is missing exports: %w", err) + r: r, + code: code, + fname: fname, } return result, nil } -func (r *Runner) checkExports() error { +func (r *Runner) checkExports(module api.Module) error { funcs := []string{ "anubis_work", "anubis_validate", @@ -74,7 +83,7 @@ func (r *Runner) checkExports() error { var errs []error for _, fun := range funcs { - if r.module.ExportedFunction(fun) == nil { + if module.ExportedFunction(fun) == nil { errs = append(errs, fmt.Errorf("function %s is not defined", fun)) } } @@ -86,8 +95,8 @@ func (r *Runner) checkExports() error { return nil } -func (r *Runner) anubisWork(ctx context.Context, difficulty, initialNonce, iterand uint32) (uint32, error) { - results, err := r.module.ExportedFunction("anubis_work").Call(ctx, uint64(difficulty), uint64(initialNonce), uint64(iterand)) +func (r *Runner) anubisWork(ctx context.Context, module api.Module, difficulty, initialNonce, iterand uint32) (uint32, error) { + results, err := module.ExportedFunction("anubis_work").Call(ctx, uint64(difficulty), uint64(initialNonce), uint64(iterand)) if err != nil { return 0, err } @@ -95,8 +104,8 @@ func (r *Runner) anubisWork(ctx context.Context, difficulty, initialNonce, itera return uint32(results[0]), nil } -func (r *Runner) anubisValidate(ctx context.Context, nonce, difficulty uint32) (bool, error) { - results, err := r.module.ExportedFunction("anubis_validate").Call(ctx, uint64(nonce), uint64(difficulty)) +func (r *Runner) anubisValidate(ctx context.Context, module api.Module, nonce, difficulty uint32) (bool, error) { + results, err := module.ExportedFunction("anubis_validate").Call(ctx, uint64(nonce), uint64(difficulty)) if err != nil { return false, err } @@ -105,8 +114,8 @@ func (r *Runner) anubisValidate(ctx context.Context, nonce, difficulty uint32) ( return results[0] == 1, nil } -func (r *Runner) dataPtr(ctx context.Context) (uint32, error) { - results, err := r.module.ExportedFunction("data_ptr").Call(ctx) +func (r *Runner) dataPtr(ctx context.Context, module api.Module) (uint32, error) { + results, err := module.ExportedFunction("data_ptr").Call(ctx) if err != nil { return 0, err } @@ -114,13 +123,13 @@ func (r *Runner) dataPtr(ctx context.Context) (uint32, error) { return uint32(results[0]), nil } -func (r *Runner) setDataLength(ctx context.Context, length uint32) error { - _, err := r.module.ExportedFunction("set_data_length").Call(ctx, uint64(length)) +func (r *Runner) setDataLength(ctx context.Context, module api.Module, length uint32) error { + _, err := module.ExportedFunction("set_data_length").Call(ctx, uint64(length)) return err } -func (r *Runner) resultHashPtr(ctx context.Context) (uint32, error) { - results, err := r.module.ExportedFunction("result_hash_ptr").Call(ctx) +func (r *Runner) resultHashPtr(ctx context.Context, module api.Module) (uint32, error) { + results, err := module.ExportedFunction("result_hash_ptr").Call(ctx) if err != nil { return 0, err } @@ -128,8 +137,8 @@ func (r *Runner) resultHashPtr(ctx context.Context) (uint32, error) { return uint32(results[0]), nil } -func (r *Runner) resultHashSize(ctx context.Context) (uint32, error) { - results, err := r.module.ExportedFunction("result_hash_size").Call(ctx) +func (r *Runner) resultHashSize(ctx context.Context, module api.Module) (uint32, error) { + results, err := module.ExportedFunction("result_hash_size").Call(ctx) if err != nil { return 0, err } @@ -137,8 +146,8 @@ func (r *Runner) resultHashSize(ctx context.Context) (uint32, error) { return uint32(results[0]), nil } -func (r *Runner) verificationHashPtr(ctx context.Context) (uint32, error) { - results, err := r.module.ExportedFunction("verification_hash_ptr").Call(ctx) +func (r *Runner) verificationHashPtr(ctx context.Context, module api.Module) (uint32, error) { + results, err := module.ExportedFunction("verification_hash_ptr").Call(ctx) if err != nil { return 0, err } @@ -146,8 +155,8 @@ func (r *Runner) verificationHashPtr(ctx context.Context) (uint32, error) { return uint32(results[0]), nil } -func (r *Runner) verificationHashSize(ctx context.Context) (uint32, error) { - results, err := r.module.ExportedFunction("verification_hash_size").Call(ctx) +func (r *Runner) verificationHashSize(ctx context.Context, module api.Module) (uint32, error) { + results, err := module.ExportedFunction("verification_hash_size").Call(ctx) if err != nil { return 0, err } @@ -155,41 +164,41 @@ func (r *Runner) verificationHashSize(ctx context.Context) (uint32, error) { return uint32(results[0]), nil } -func (r *Runner) WriteData(ctx context.Context, data []byte) (uint32, error) { +func (r *Runner) writeData(ctx context.Context, module api.Module, data []byte) error { if len(data) > 4096 { - return 0, os.ErrInvalid + return os.ErrInvalid } length := uint32(len(data)) - dataPtr, err := r.dataPtr(ctx) + dataPtr, err := r.dataPtr(ctx, module) if err != nil { - return 0, fmt.Errorf("can't read data pointer: %w", err) + return fmt.Errorf("can't read data pointer: %w", err) } - if !r.module.Memory().Write(dataPtr, data) { - return 0, fmt.Errorf("[unexpected] can't write memory, is data out of range??") + if !module.Memory().Write(dataPtr, data) { + return fmt.Errorf("[unexpected] can't write memory, is data out of range??") } - if err := r.setDataLength(ctx, length); err != nil { - return 0, fmt.Errorf("can't set data length: %w", err) + if err := r.setDataLength(ctx, module, length); err != nil { + return fmt.Errorf("can't set data length: %w", err) } - return length, nil + return nil } -func (r *Runner) ReadResult(ctx context.Context) ([]byte, error) { - length, err := r.resultHashSize(ctx) +func (r *Runner) readResult(ctx context.Context, module api.Module) ([]byte, error) { + length, err := r.resultHashSize(ctx, module) if err != nil { return nil, fmt.Errorf("can't get result hash size: %w", err) } - ptr, err := r.resultHashPtr(ctx) + ptr, err := r.resultHashPtr(ctx, module) if err != nil { return nil, fmt.Errorf("can't get result hash pointer: %w", err) } - buf, ok := r.module.Memory().Read(ptr, length) + buf, ok := module.Memory().Read(ptr, length) if !ok { return nil, fmt.Errorf("[unexpected] can't read from memory, is something out of range??") } @@ -197,8 +206,78 @@ func (r *Runner) ReadResult(ctx context.Context) ([]byte, error) { return buf, nil } -func (r *Runner) WriteVerification(ctx context.Context, data []byte) error { - length, err := r.verificationHashSize(ctx) +func (r *Runner) run(ctx context.Context, data []byte, difficulty, initialNonce, iterand uint32) (uint32, []byte, api.Module, error) { + mod, err := r.r.InstantiateModule(ctx, r.code, wazero.NewModuleConfig().WithName(r.fname)) + if err != nil { + return 0, nil, nil, fmt.Errorf("can't instantiate module: %w", err) + } + + if err := r.checkExports(mod); err != nil { + return 0, nil, nil, err + } + + if err := r.writeData(ctx, mod, data); err != nil { + return 0, nil, nil, err + } + + nonce, err := r.anubisWork(ctx, mod, difficulty, initialNonce, iterand) + if err != nil { + return 0, nil, nil, fmt.Errorf("can't run work function: %w", err) + } + + hash, err := r.readResult(ctx, mod) + if err != nil { + return 0, nil, nil, fmt.Errorf("can't read result: %w", err) + } + + return nonce, hash, mod, nil +} + +func (r *Runner) Run(ctx context.Context, data []byte, difficulty, initialNonce, iterand uint32) (uint32, []byte, error) { + nonce, hash, _, err := r.run(ctx, data, difficulty, initialNonce, iterand) + if err != nil { + return 0, nil, fmt.Errorf("can't run %s: %w", r.fname, err) + } + + return nonce, hash, nil +} + +func (r *Runner) verify(ctx context.Context, data, verify []byte, nonce, difficulty uint32) (bool, api.Module, error) { + mod, err := r.r.InstantiateModule(ctx, r.code, wazero.NewModuleConfig().WithName(r.fname)) + if err != nil { + return false, nil, fmt.Errorf("can't instantiate module: %w", err) + } + + if err := r.checkExports(mod); err != nil { + return false, nil, err + } + + if err := r.writeData(ctx, mod, data); err != nil { + return false, nil, err + } + + if err := r.writeVerification(ctx, mod, verify); err != nil { + return false, nil, err + } + + ok, err := r.anubisValidate(ctx, mod, nonce, difficulty) + if err != nil { + return false, nil, fmt.Errorf("can't validate hash %x from challenge %x, nonce %d and difficulty %d: %w", verify, data, nonce, difficulty, err) + } + + return ok, mod, nil +} + +func (r *Runner) Verify(ctx context.Context, data, verify []byte, nonce, difficulty uint32) (bool, error) { + t0 := time.Now() + ok, _, err := r.verify(ctx, data, verify, nonce, difficulty) + validationTime.WithLabelValues(r.fname).Observe(float64(time.Since(t0))) + validationCount.WithLabelValues(r.fname, strconv.FormatBool(ok)) + return ok, err +} + +func (r *Runner) writeVerification(ctx context.Context, module api.Module, data []byte) error { + length, err := r.verificationHashSize(ctx, module) if err != nil { return fmt.Errorf("can't get verification hash size: %v", err) } @@ -207,12 +286,12 @@ func (r *Runner) WriteVerification(ctx context.Context, data []byte) error { return fmt.Errorf("data is too big, want %d bytes, got: %d", length, len(data)) } - ptr, err := r.verificationHashPtr(ctx) + ptr, err := r.verificationHashPtr(ctx, module) if err != nil { return fmt.Errorf("can't get verification hash pointer: %v", err) } - if !r.module.Memory().Write(ptr, data) { + if !module.Memory().Write(ptr, data) { return fmt.Errorf("[unexpected] can't write memory, is data out of range??") } diff --git a/wasm/wasm_test.go b/wasm/wasm_test.go index b60e226e..daaeb259 100644 --- a/wasm/wasm_test.go +++ b/wasm/wasm_test.go @@ -5,14 +5,13 @@ import ( "crypto/sha256" "fmt" "io/fs" - "os" "testing" "time" "github.com/TecharoHQ/anubis/web" ) -func abiTest(t *testing.T, fname string, difficulty uint32) { +func abiTest(t testing.TB, fname string, difficulty uint32) { fin, err := web.Static.Open("static/wasm/" + fname) if err != nil { t.Fatal(err) @@ -30,31 +29,16 @@ func abiTest(t *testing.T, fname string, difficulty uint32) { fmt.Fprint(h, t.Name()) data := h.Sum(nil) - if n, err := runner.WriteData(ctx, data); err != nil { - t.Fatalf("can't write data: %v", err) - } else { - t.Logf("wrote %d bytes to data segment", n) - } - - t0 := time.Now() - nonce, err := runner.anubisWork(ctx, difficulty, 0, 1) + nonce, hash, mod, err := runner.run(ctx, data, difficulty, 0, 1) if err != nil { - t.Fatalf("can't do test work run: %v", err) - } - t.Logf("got nonce %d in %s", nonce, time.Since(t0)) - - hash, err := runner.ReadResult(ctx) - if err != nil { - t.Fatalf("can't read result: %v", err) + t.Fatal(err) } - t.Logf("got hash %x", hash) - - if err := runner.WriteVerification(ctx, hash); err != nil { + if err := runner.writeVerification(ctx, mod, hash); err != nil { t.Fatalf("can't write verification: %v", err) } - ok, err := runner.anubisValidate(ctx, nonce, difficulty) + ok, err := runner.anubisValidate(ctx, mod, nonce, difficulty) if err != nil { t.Fatalf("can't run validation: %v", err) } @@ -63,7 +47,7 @@ func abiTest(t *testing.T, fname string, difficulty uint32) { t.Error("validation failed") } - t.Logf("used %d pages of wasm memory (%d bytes)", runner.module.Memory().Size()/63356, runner.module.Memory().Size()) + t.Logf("used %d pages of wasm memory (%d bytes)", mod.Memory().Size()/63356, mod.Memory().Size()) } func TestAlgos(t *testing.T) { @@ -81,6 +65,8 @@ func TestAlgos(t *testing.T) { } func bench(b *testing.B, fname string, difficulties []uint32) { + b.Helper() + fin, err := web.Static.Open("static/wasm/" + fname) if err != nil { b.Fatal(err) @@ -98,99 +84,79 @@ func bench(b *testing.B, fname string, difficulties []uint32) { fmt.Fprint(h, "This is an example value that exists only to test the system.") data := h.Sum(nil) - if n, err := runner.WriteData(ctx, data); err != nil { - b.Fatalf("can't write data: %v", err) - } else { - b.Logf("wrote %d bytes to data segment", n) + _, _, mod, err := runner.run(ctx, data, 0, 0, 1) + if err != nil { + b.Fatal(err) + } + + for _, difficulty := range difficulties { + b.Run(fmt.Sprintf("difficulty/%d", difficulty), func(b *testing.B) { + for b.Loop() { + difficulty := difficulty + _, err := runner.anubisWork(ctx, mod, difficulty, 0, 1) + if err != nil { + b.Fatalf("can't do test work run: %v", err) + } + } + }) } } func BenchmarkSHA256(b *testing.B) { - fin, err := web.Static.Open("static/wasm/sha256.wasm") - if err != nil { - b.Fatal(err) - } - defer fin.Close() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - b.Cleanup(cancel) - - runner, err := NewRunner(ctx, "sha256.wasm", fin) - if err != nil { - b.Fatal(err) - } - - h := sha256.New() - fmt.Fprint(h, "testificate") - data := h.Sum(nil) - - if n, err := runner.WriteData(ctx, data); err != nil { - b.Fatalf("can't write data: %v", err) - } else { - b.Logf("wrote %d bytes to data segment", n) - } - - for _, cs := range []struct { - Difficulty uint32 - }{ - {4}, - {6}, - {8}, - {10}, - {12}, - {14}, - {16}, - } { - b.Run(fmt.Sprintf("difficulty/%d", cs.Difficulty), func(b *testing.B) { - for b.Loop() { - difficulty := cs.Difficulty - _, err := runner.anubisWork(ctx, difficulty, 0, 1) - if err != nil { - b.Fatalf("can't do test work run: %v", err) - } - } - }) - } + bench(b, "sha256.wasm", []uint32{4, 6, 8, 10, 12, 14, 16}) } func BenchmarkArgon2ID(b *testing.B) { - const difficulty = 4 // one nibble, intentionally easy for testing + bench(b, "argon2id.wasm", []uint32{4, 6, 8}) +} - fin, err := web.Static.Open("static/wasm/argon2id.wasm") - if err != nil { - b.Fatal(err) - } - defer fin.Close() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - b.Cleanup(cancel) - - runner, err := NewRunner(ctx, "argon2id.wasm", fin) +func BenchmarkValidate(b *testing.B) { + fnames, err := fs.ReadDir(web.Static, "static/wasm") if err != nil { b.Fatal(err) } h := sha256.New() - fmt.Fprint(h, os.Args[0]) + fmt.Fprint(h, "This is an example value that exists only to test the system.") data := h.Sum(nil) - if n, err := runner.WriteData(ctx, data); err != nil { - b.Fatalf("can't write data: %v", err) - } else { - b.Logf("wrote %d bytes to data segment", n) - } + for _, fname := range fnames { + fname := fname.Name() + + difficulty := uint32(1) + + switch fname { + case "sha256.wasm": + difficulty = 16 + } + + b.Run(fname, func(b *testing.B) { + fin, err := web.Static.Open("static/wasm/" + fname) + if err != nil { + b.Fatal(err) + } + defer fin.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + b.Cleanup(cancel) + + runner, err := NewRunner(ctx, fname, fin) + if err != nil { + b.Fatal(err) + } + + nonce, hash, mod, err := runner.run(ctx, data, difficulty, 0, 1) + if err != nil { + b.Fatal(err) + } + + if err := runner.writeVerification(ctx, mod, hash); err != nil { + b.Fatalf("can't write verification: %v", err) + } - for _, cs := range []struct { - Difficulty uint32 - }{ - {4}, - {6}, - {8}, - } { - b.Run(fmt.Sprintf("difficulty/%d", cs.Difficulty), func(b *testing.B) { for b.Loop() { - difficulty := cs.Difficulty - _, err := runner.anubisWork(ctx, difficulty, 0, 1) + _, err := runner.anubisValidate(ctx, mod, nonce, difficulty) if err != nil { - b.Fatalf("can't do test work run: %v", err) + b.Fatalf("can't run validation: %v", err) } } })