wasm: make Runner dynamically instansiate Modules for making things massively parallel

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso
2025-04-14 08:29:11 -04:00
parent 72d6eda7de
commit 5610b026cc
2 changed files with 192 additions and 147 deletions
+125 -46
View File
@@ -5,18 +5,36 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math"
"os" "os"
"strconv"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/api"
) )
func UpdateNonce(uint32) {} func UpdateNonce(uint32) {}
var (
validationTime = promauto.NewHistogramVec(prometheus.HistogramOpts{
Name: "anubis_wasm_validation_time",
Help: "The time taken for the validation function to run per checker (nanoseconds)",
Buckets: prometheus.ExponentialBucketsRange(1, math.Pow(2, 31), 32),
}, []string{"fname"})
validationCount = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "anubis_wasm_validation",
Help: "The number of times the validation logic has been run and its success rate",
}, []string{"fname", "success"})
)
type Runner struct { type Runner struct {
r wazero.Runtime r wazero.Runtime
code wazero.CompiledModule code wazero.CompiledModule
module api.Module fname string
} }
func NewRunner(ctx context.Context, fname string, fin io.ReadCloser) (*Runner, error) { func NewRunner(ctx context.Context, fname string, fin io.ReadCloser) (*Runner, error) {
@@ -41,25 +59,16 @@ func NewRunner(ctx context.Context, fname string, fin io.ReadCloser) (*Runner, e
return nil, fmt.Errorf("wasm: can't compile module: %w", err) return nil, fmt.Errorf("wasm: can't compile module: %w", err)
} }
mod, err := r.InstantiateModule(ctx, code, wazero.NewModuleConfig().WithName(fname))
if err != nil {
return nil, fmt.Errorf("wasm: can't instantiate module: %w", err)
}
result := &Runner{ result := &Runner{
r: r, r: r,
code: code, code: code,
module: mod, fname: fname,
}
if err := result.checkExports(); err != nil {
return nil, fmt.Errorf("wasm: module is missing exports: %w", err)
} }
return result, nil return result, nil
} }
func (r *Runner) checkExports() error { func (r *Runner) checkExports(module api.Module) error {
funcs := []string{ funcs := []string{
"anubis_work", "anubis_work",
"anubis_validate", "anubis_validate",
@@ -74,7 +83,7 @@ func (r *Runner) checkExports() error {
var errs []error var errs []error
for _, fun := range funcs { for _, fun := range funcs {
if r.module.ExportedFunction(fun) == nil { if module.ExportedFunction(fun) == nil {
errs = append(errs, fmt.Errorf("function %s is not defined", fun)) errs = append(errs, fmt.Errorf("function %s is not defined", fun))
} }
} }
@@ -86,8 +95,8 @@ func (r *Runner) checkExports() error {
return nil return nil
} }
func (r *Runner) anubisWork(ctx context.Context, difficulty, initialNonce, iterand uint32) (uint32, error) { func (r *Runner) anubisWork(ctx context.Context, module api.Module, difficulty, initialNonce, iterand uint32) (uint32, error) {
results, err := r.module.ExportedFunction("anubis_work").Call(ctx, uint64(difficulty), uint64(initialNonce), uint64(iterand)) results, err := module.ExportedFunction("anubis_work").Call(ctx, uint64(difficulty), uint64(initialNonce), uint64(iterand))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -95,8 +104,8 @@ func (r *Runner) anubisWork(ctx context.Context, difficulty, initialNonce, itera
return uint32(results[0]), nil return uint32(results[0]), nil
} }
func (r *Runner) anubisValidate(ctx context.Context, nonce, difficulty uint32) (bool, error) { func (r *Runner) anubisValidate(ctx context.Context, module api.Module, nonce, difficulty uint32) (bool, error) {
results, err := r.module.ExportedFunction("anubis_validate").Call(ctx, uint64(nonce), uint64(difficulty)) results, err := module.ExportedFunction("anubis_validate").Call(ctx, uint64(nonce), uint64(difficulty))
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -105,8 +114,8 @@ func (r *Runner) anubisValidate(ctx context.Context, nonce, difficulty uint32) (
return results[0] == 1, nil return results[0] == 1, nil
} }
func (r *Runner) dataPtr(ctx context.Context) (uint32, error) { func (r *Runner) dataPtr(ctx context.Context, module api.Module) (uint32, error) {
results, err := r.module.ExportedFunction("data_ptr").Call(ctx) results, err := module.ExportedFunction("data_ptr").Call(ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -114,13 +123,13 @@ func (r *Runner) dataPtr(ctx context.Context) (uint32, error) {
return uint32(results[0]), nil return uint32(results[0]), nil
} }
func (r *Runner) setDataLength(ctx context.Context, length uint32) error { func (r *Runner) setDataLength(ctx context.Context, module api.Module, length uint32) error {
_, err := r.module.ExportedFunction("set_data_length").Call(ctx, uint64(length)) _, err := module.ExportedFunction("set_data_length").Call(ctx, uint64(length))
return err return err
} }
func (r *Runner) resultHashPtr(ctx context.Context) (uint32, error) { func (r *Runner) resultHashPtr(ctx context.Context, module api.Module) (uint32, error) {
results, err := r.module.ExportedFunction("result_hash_ptr").Call(ctx) results, err := module.ExportedFunction("result_hash_ptr").Call(ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -128,8 +137,8 @@ func (r *Runner) resultHashPtr(ctx context.Context) (uint32, error) {
return uint32(results[0]), nil return uint32(results[0]), nil
} }
func (r *Runner) resultHashSize(ctx context.Context) (uint32, error) { func (r *Runner) resultHashSize(ctx context.Context, module api.Module) (uint32, error) {
results, err := r.module.ExportedFunction("result_hash_size").Call(ctx) results, err := module.ExportedFunction("result_hash_size").Call(ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -137,8 +146,8 @@ func (r *Runner) resultHashSize(ctx context.Context) (uint32, error) {
return uint32(results[0]), nil return uint32(results[0]), nil
} }
func (r *Runner) verificationHashPtr(ctx context.Context) (uint32, error) { func (r *Runner) verificationHashPtr(ctx context.Context, module api.Module) (uint32, error) {
results, err := r.module.ExportedFunction("verification_hash_ptr").Call(ctx) results, err := module.ExportedFunction("verification_hash_ptr").Call(ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -146,8 +155,8 @@ func (r *Runner) verificationHashPtr(ctx context.Context) (uint32, error) {
return uint32(results[0]), nil return uint32(results[0]), nil
} }
func (r *Runner) verificationHashSize(ctx context.Context) (uint32, error) { func (r *Runner) verificationHashSize(ctx context.Context, module api.Module) (uint32, error) {
results, err := r.module.ExportedFunction("verification_hash_size").Call(ctx) results, err := module.ExportedFunction("verification_hash_size").Call(ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -155,41 +164,41 @@ func (r *Runner) verificationHashSize(ctx context.Context) (uint32, error) {
return uint32(results[0]), nil return uint32(results[0]), nil
} }
func (r *Runner) WriteData(ctx context.Context, data []byte) (uint32, error) { func (r *Runner) writeData(ctx context.Context, module api.Module, data []byte) error {
if len(data) > 4096 { if len(data) > 4096 {
return 0, os.ErrInvalid return os.ErrInvalid
} }
length := uint32(len(data)) length := uint32(len(data))
dataPtr, err := r.dataPtr(ctx) dataPtr, err := r.dataPtr(ctx, module)
if err != nil { if err != nil {
return 0, fmt.Errorf("can't read data pointer: %w", err) return fmt.Errorf("can't read data pointer: %w", err)
} }
if !r.module.Memory().Write(dataPtr, data) { if !module.Memory().Write(dataPtr, data) {
return 0, fmt.Errorf("[unexpected] can't write memory, is data out of range??") return fmt.Errorf("[unexpected] can't write memory, is data out of range??")
} }
if err := r.setDataLength(ctx, length); err != nil { if err := r.setDataLength(ctx, module, length); err != nil {
return 0, fmt.Errorf("can't set data length: %w", err) return fmt.Errorf("can't set data length: %w", err)
} }
return length, nil return nil
} }
func (r *Runner) ReadResult(ctx context.Context) ([]byte, error) { func (r *Runner) readResult(ctx context.Context, module api.Module) ([]byte, error) {
length, err := r.resultHashSize(ctx) length, err := r.resultHashSize(ctx, module)
if err != nil { if err != nil {
return nil, fmt.Errorf("can't get result hash size: %w", err) return nil, fmt.Errorf("can't get result hash size: %w", err)
} }
ptr, err := r.resultHashPtr(ctx) ptr, err := r.resultHashPtr(ctx, module)
if err != nil { if err != nil {
return nil, fmt.Errorf("can't get result hash pointer: %w", err) return nil, fmt.Errorf("can't get result hash pointer: %w", err)
} }
buf, ok := r.module.Memory().Read(ptr, length) buf, ok := module.Memory().Read(ptr, length)
if !ok { if !ok {
return nil, fmt.Errorf("[unexpected] can't read from memory, is something out of range??") return nil, fmt.Errorf("[unexpected] can't read from memory, is something out of range??")
} }
@@ -197,8 +206,78 @@ func (r *Runner) ReadResult(ctx context.Context) ([]byte, error) {
return buf, nil return buf, nil
} }
func (r *Runner) WriteVerification(ctx context.Context, data []byte) error { func (r *Runner) run(ctx context.Context, data []byte, difficulty, initialNonce, iterand uint32) (uint32, []byte, api.Module, error) {
length, err := r.verificationHashSize(ctx) mod, err := r.r.InstantiateModule(ctx, r.code, wazero.NewModuleConfig().WithName(r.fname))
if err != nil {
return 0, nil, nil, fmt.Errorf("can't instantiate module: %w", err)
}
if err := r.checkExports(mod); err != nil {
return 0, nil, nil, err
}
if err := r.writeData(ctx, mod, data); err != nil {
return 0, nil, nil, err
}
nonce, err := r.anubisWork(ctx, mod, difficulty, initialNonce, iterand)
if err != nil {
return 0, nil, nil, fmt.Errorf("can't run work function: %w", err)
}
hash, err := r.readResult(ctx, mod)
if err != nil {
return 0, nil, nil, fmt.Errorf("can't read result: %w", err)
}
return nonce, hash, mod, nil
}
func (r *Runner) Run(ctx context.Context, data []byte, difficulty, initialNonce, iterand uint32) (uint32, []byte, error) {
nonce, hash, _, err := r.run(ctx, data, difficulty, initialNonce, iterand)
if err != nil {
return 0, nil, fmt.Errorf("can't run %s: %w", r.fname, err)
}
return nonce, hash, nil
}
func (r *Runner) verify(ctx context.Context, data, verify []byte, nonce, difficulty uint32) (bool, api.Module, error) {
mod, err := r.r.InstantiateModule(ctx, r.code, wazero.NewModuleConfig().WithName(r.fname))
if err != nil {
return false, nil, fmt.Errorf("can't instantiate module: %w", err)
}
if err := r.checkExports(mod); err != nil {
return false, nil, err
}
if err := r.writeData(ctx, mod, data); err != nil {
return false, nil, err
}
if err := r.writeVerification(ctx, mod, verify); err != nil {
return false, nil, err
}
ok, err := r.anubisValidate(ctx, mod, nonce, difficulty)
if err != nil {
return false, nil, fmt.Errorf("can't validate hash %x from challenge %x, nonce %d and difficulty %d: %w", verify, data, nonce, difficulty, err)
}
return ok, mod, nil
}
func (r *Runner) Verify(ctx context.Context, data, verify []byte, nonce, difficulty uint32) (bool, error) {
t0 := time.Now()
ok, _, err := r.verify(ctx, data, verify, nonce, difficulty)
validationTime.WithLabelValues(r.fname).Observe(float64(time.Since(t0)))
validationCount.WithLabelValues(r.fname, strconv.FormatBool(ok))
return ok, err
}
func (r *Runner) writeVerification(ctx context.Context, module api.Module, data []byte) error {
length, err := r.verificationHashSize(ctx, module)
if err != nil { if err != nil {
return fmt.Errorf("can't get verification hash size: %v", err) return fmt.Errorf("can't get verification hash size: %v", err)
} }
@@ -207,12 +286,12 @@ func (r *Runner) WriteVerification(ctx context.Context, data []byte) error {
return fmt.Errorf("data is too big, want %d bytes, got: %d", length, len(data)) return fmt.Errorf("data is too big, want %d bytes, got: %d", length, len(data))
} }
ptr, err := r.verificationHashPtr(ctx) ptr, err := r.verificationHashPtr(ctx, module)
if err != nil { if err != nil {
return fmt.Errorf("can't get verification hash pointer: %v", err) return fmt.Errorf("can't get verification hash pointer: %v", err)
} }
if !r.module.Memory().Write(ptr, data) { if !module.Memory().Write(ptr, data) {
return fmt.Errorf("[unexpected] can't write memory, is data out of range??") return fmt.Errorf("[unexpected] can't write memory, is data out of range??")
} }
+58 -92
View File
@@ -5,14 +5,13 @@ import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"io/fs" "io/fs"
"os"
"testing" "testing"
"time" "time"
"github.com/TecharoHQ/anubis/web" "github.com/TecharoHQ/anubis/web"
) )
func abiTest(t *testing.T, fname string, difficulty uint32) { func abiTest(t testing.TB, fname string, difficulty uint32) {
fin, err := web.Static.Open("static/wasm/" + fname) fin, err := web.Static.Open("static/wasm/" + fname)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -30,31 +29,16 @@ func abiTest(t *testing.T, fname string, difficulty uint32) {
fmt.Fprint(h, t.Name()) fmt.Fprint(h, t.Name())
data := h.Sum(nil) data := h.Sum(nil)
if n, err := runner.WriteData(ctx, data); err != nil { nonce, hash, mod, err := runner.run(ctx, data, difficulty, 0, 1)
t.Fatalf("can't write data: %v", err)
} else {
t.Logf("wrote %d bytes to data segment", n)
}
t0 := time.Now()
nonce, err := runner.anubisWork(ctx, difficulty, 0, 1)
if err != nil { if err != nil {
t.Fatalf("can't do test work run: %v", err) t.Fatal(err)
}
t.Logf("got nonce %d in %s", nonce, time.Since(t0))
hash, err := runner.ReadResult(ctx)
if err != nil {
t.Fatalf("can't read result: %v", err)
} }
t.Logf("got hash %x", hash) if err := runner.writeVerification(ctx, mod, hash); err != nil {
if err := runner.WriteVerification(ctx, hash); err != nil {
t.Fatalf("can't write verification: %v", err) t.Fatalf("can't write verification: %v", err)
} }
ok, err := runner.anubisValidate(ctx, nonce, difficulty) ok, err := runner.anubisValidate(ctx, mod, nonce, difficulty)
if err != nil { if err != nil {
t.Fatalf("can't run validation: %v", err) t.Fatalf("can't run validation: %v", err)
} }
@@ -63,7 +47,7 @@ func abiTest(t *testing.T, fname string, difficulty uint32) {
t.Error("validation failed") t.Error("validation failed")
} }
t.Logf("used %d pages of wasm memory (%d bytes)", runner.module.Memory().Size()/63356, runner.module.Memory().Size()) t.Logf("used %d pages of wasm memory (%d bytes)", mod.Memory().Size()/63356, mod.Memory().Size())
} }
func TestAlgos(t *testing.T) { func TestAlgos(t *testing.T) {
@@ -81,6 +65,8 @@ func TestAlgos(t *testing.T) {
} }
func bench(b *testing.B, fname string, difficulties []uint32) { func bench(b *testing.B, fname string, difficulties []uint32) {
b.Helper()
fin, err := web.Static.Open("static/wasm/" + fname) fin, err := web.Static.Open("static/wasm/" + fname)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
@@ -98,64 +84,54 @@ func bench(b *testing.B, fname string, difficulties []uint32) {
fmt.Fprint(h, "This is an example value that exists only to test the system.") fmt.Fprint(h, "This is an example value that exists only to test the system.")
data := h.Sum(nil) data := h.Sum(nil)
if n, err := runner.WriteData(ctx, data); err != nil { _, _, mod, err := runner.run(ctx, data, 0, 0, 1)
b.Fatalf("can't write data: %v", err) if err != nil {
} else { b.Fatal(err)
b.Logf("wrote %d bytes to data segment", n) }
for _, difficulty := range difficulties {
b.Run(fmt.Sprintf("difficulty/%d", difficulty), func(b *testing.B) {
for b.Loop() {
difficulty := difficulty
_, err := runner.anubisWork(ctx, mod, difficulty, 0, 1)
if err != nil {
b.Fatalf("can't do test work run: %v", err)
}
}
})
} }
} }
func BenchmarkSHA256(b *testing.B) { func BenchmarkSHA256(b *testing.B) {
fin, err := web.Static.Open("static/wasm/sha256.wasm") bench(b, "sha256.wasm", []uint32{4, 6, 8, 10, 12, 14, 16})
if err != nil {
b.Fatal(err)
}
defer fin.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
b.Cleanup(cancel)
runner, err := NewRunner(ctx, "sha256.wasm", fin)
if err != nil {
b.Fatal(err)
}
h := sha256.New()
fmt.Fprint(h, "testificate")
data := h.Sum(nil)
if n, err := runner.WriteData(ctx, data); err != nil {
b.Fatalf("can't write data: %v", err)
} else {
b.Logf("wrote %d bytes to data segment", n)
}
for _, cs := range []struct {
Difficulty uint32
}{
{4},
{6},
{8},
{10},
{12},
{14},
{16},
} {
b.Run(fmt.Sprintf("difficulty/%d", cs.Difficulty), func(b *testing.B) {
for b.Loop() {
difficulty := cs.Difficulty
_, err := runner.anubisWork(ctx, difficulty, 0, 1)
if err != nil {
b.Fatalf("can't do test work run: %v", err)
}
}
})
}
} }
func BenchmarkArgon2ID(b *testing.B) { func BenchmarkArgon2ID(b *testing.B) {
const difficulty = 4 // one nibble, intentionally easy for testing bench(b, "argon2id.wasm", []uint32{4, 6, 8})
}
fin, err := web.Static.Open("static/wasm/argon2id.wasm") func BenchmarkValidate(b *testing.B) {
fnames, err := fs.ReadDir(web.Static, "static/wasm")
if err != nil {
b.Fatal(err)
}
h := sha256.New()
fmt.Fprint(h, "This is an example value that exists only to test the system.")
data := h.Sum(nil)
for _, fname := range fnames {
fname := fname.Name()
difficulty := uint32(1)
switch fname {
case "sha256.wasm":
difficulty = 16
}
b.Run(fname, func(b *testing.B) {
fin, err := web.Static.Open("static/wasm/" + fname)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@@ -163,34 +139,24 @@ func BenchmarkArgon2ID(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
b.Cleanup(cancel) b.Cleanup(cancel)
runner, err := NewRunner(ctx, "argon2id.wasm", fin) runner, err := NewRunner(ctx, fname, fin)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
h := sha256.New() nonce, hash, mod, err := runner.run(ctx, data, difficulty, 0, 1)
fmt.Fprint(h, os.Args[0]) if err != nil {
data := h.Sum(nil) b.Fatal(err)
}
if n, err := runner.WriteData(ctx, data); err != nil {
b.Fatalf("can't write data: %v", err) if err := runner.writeVerification(ctx, mod, hash); err != nil {
} else { b.Fatalf("can't write verification: %v", err)
b.Logf("wrote %d bytes to data segment", n)
} }
for _, cs := range []struct {
Difficulty uint32
}{
{4},
{6},
{8},
} {
b.Run(fmt.Sprintf("difficulty/%d", cs.Difficulty), func(b *testing.B) {
for b.Loop() { for b.Loop() {
difficulty := cs.Difficulty _, err := runner.anubisValidate(ctx, mod, nonce, difficulty)
_, err := runner.anubisWork(ctx, difficulty, 0, 1)
if err != nil { if err != nil {
b.Fatalf("can't do test work run: %v", err) b.Fatalf("can't run validation: %v", err)
} }
} }
}) })