diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 18a8df77..d6ee12ad 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -48,6 +48,8 @@ jobs: run: | brew bundle + - uses: actions-rust-lang/setup-rust-toolchain@v1 + - name: Setup Golang caches uses: actions/cache@v4 with: diff --git a/go.mod b/go.mod index 6bb97417..6641ceb3 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/playwright-community/playwright-go v0.5001.0 github.com/prometheus/client_golang v1.21.1 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a + github.com/tetratelabs/wazero v1.9.0 github.com/yl2chen/cidranger v1.0.2 golang.org/x/net v0.38.0 ) diff --git a/go.sum b/go.sum index 26c90d7f..ad39bb5f 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/yl2chen/cidranger v1.0.2 h1:lbOWZVCG1tCRX4u24kuM1Tb4nHqWkDxwLdoS+SevawU= github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/internal/test/playwright_test.go b/internal/test/playwright_test.go index 88d94bcb..5d96ebb0 100644 --- a/internal/test/playwright_test.go +++ b/internal/test/playwright_test.go @@ -25,6 +25,7 @@ import ( "os" "os/exec" "strconv" + "strings" "testing" "time" @@ -378,7 +379,7 @@ func saveScreenshot(t *testing.T, page playwright.Page) { return } - f, err := os.CreateTemp("", "anubis-test-fail-*.png") + f, err := os.CreateTemp("./var", "anubis-test-fail-"+strings.ReplaceAll(t.Name(), "/", "--")+"-*.png") if err != nil { t.Logf("could not create temporary file: %v", err) return diff --git a/package.json b/package.json index cfff0b65..77e55c4d 100644 --- a/package.json +++ b/package.json @@ -6,7 +6,9 @@ "scripts": { "test": "npm run assets && go test ./...", "test:integration": "npm run assets && go test -v ./internal/test", - "assets": "go generate ./... && ./web/build.sh && ./xess/build.sh", + "assets:frontend": "go generate ./... && ./web/build.sh && ./xess/build.sh", + "assets:wasm": "cargo build --release --target wasm32-unknown-unknown && mv ./target/wasm32-unknown-unknown/release/*.wasm ./web/static/wasm", + "assets": "npm run assets:frontend && npm run assets:wasm", "build": "npm run assets && go build -o ./var/anubis ./cmd/anubis", "dev": "npm run assets && go run ./cmd/anubis --use-remote-address", "container": "npm run assets && go run ./cmd/containerbuild", diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..ead536e7 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "stable" +targets = ["wasm32-unknown-unknown"] +profile = "minimal" diff --git a/wasm/pow/sha256/src/lib.rs b/wasm/pow/sha256/src/lib.rs index fd53bdf9..8179bf35 100644 --- a/wasm/pow/sha256/src/lib.rs +++ b/wasm/pow/sha256/src/lib.rs @@ -2,38 +2,66 @@ use lazy_static::lazy_static; use sha2::{Digest, Sha256}; use std::sync::Mutex; +// Statically allocated buffers at compile time. lazy_static! { - static ref DATA_BUFFER: Mutex<[u8; 1024]> = Mutex::new([0; 1024]); + /// The data buffer is a bit weird in that it doesn't have an explicit length as it can + /// and will change depending on the challenge input that was sent by the server. + /// However, it can only fit 4096 bytes of data (one amd64 machine page). This is + /// slightly overkill for the purposes of an Anubis check, but it's fine to assume + /// that the browser can afford this much ram usage. + /// + /// Callers should fetch the base data pointer, write up to 4096 bytes, and then + /// `set_data_length` the number of bytes they have written + /// + /// This is also functionally a write-only buffer, so it doesn't really matter that + /// the length of this buffer isn't exposed. + static ref DATA_BUFFER: Mutex<[u8; 4096]> = Mutex::new([0; 4096]); static ref DATA_LENGTH: Mutex = Mutex::new(0); + + /// SHA-256 hashes are 32 bytes (256 bits). These are stored in static buffers due to the + /// fact that you cannot easily pass data from host space to WebAssembly space. static ref RESULT_HASH: Mutex<[u8; 32]> = Mutex::new([0; 32]); static ref VERIFICATION_HASH: Mutex<[u8; 32]> = Mutex::new([0; 32]); } -#[link(wasm_import_module = "anubis")] // Usually matches your JS namespace +#[link(wasm_import_module = "anubis")] unsafe extern "C" { - // Declare the imported function + /// The runtime expects this function to be defined. It is called whenever the Anubis check + /// worker processes about 1024 hashes. This can be a no-op if you want. fn anubis_update_nonce(nonce: u32); } +/// Safe wrapper to `anubis_update_nonce`. fn update_nonce(nonce: u32) { unsafe { anubis_update_nonce(nonce); } } -/// Core validation function +/// Core validation function. Compare each bit in the hash by progressively masking bits until +/// some are found to not be matching. +/// +/// There are probably more clever ways to do this, likely involving lookup tables or something +/// really fun like that. However in my testing this lets us get up to 200 kilohashes per second +/// on my Ryzen 7950x3D, up from about 50 kilohashes per second in JavaScript. fn validate(hash: &[u8], difficulty: u32) -> bool { let mut remaining = difficulty; for &byte in hash { + // If we're out of bits to check, exit. This is all good. if remaining == 0 { break; } + + // If there are more than 8 bits remaining, the entire byte should be a + // zero. This fast-path compares the byte to 0 and if it matches, subtract + // 8 bits. if remaining >= 8 { if byte != 0 { return false; } remaining -= 8; } else { + // Otherwise mask off individual bits and check against them. let mask = 0xFF << (8 - remaining); if (byte & mask) != 0 { return false; @@ -44,7 +72,19 @@ fn validate(hash: &[u8], difficulty: u32) -> bool { true } -/// Computes hash for given nonce +/// Computes hash for given nonce. +/// +/// This differs from the JavaScript implementations by constructing the hash differently. In +/// JavaScript implementations, the SHA-256 input is the result of appending the nonce as an +/// integer to the hex-formatted challenge, eg: +/// +/// sha256(`${challenge}${nonce}`); +/// +/// This **does work**, however I think that this can be done a bit better by operating on the +/// challenge bytes _directly_ and treating the nonce as a salt. +/// +/// The nonce is also randomly encoded in either big or little endian depending on the last +/// byte of the data buffer in an effort to make it more annoying to automate with GPUs. fn compute_hash(nonce: u32) -> [u8; 32] { let data = DATA_BUFFER.lock().unwrap(); let data_len = *DATA_LENGTH.lock().unwrap(); @@ -62,8 +102,24 @@ fn compute_hash(nonce: u32) -> [u8; 32] { hasher.finalize().into() } -// WebAssembly exports - +/// This function is the main entrypoint for the Anubis proof of work implementation. +/// +/// This expects `DATA_BUFFER` to be pre-populated with the challenge value as "raw bytes". +/// The definition of what goes in the data buffer is an exercise for the implementor, but +/// for SHA-256 we store the hash as "raw bytes". The data buffer is intentionally oversized +/// so that the challenge value can be expanded in the future. +/// +/// `difficulty` is the number of leading bits that must match `0` in order for the +/// challenge to be successfully passed. This will be validated by the server. +/// +/// `initial_nonce` is the initial value of the nonce (number used once). This nonce will be +/// appended to the challenge value in order to find a hash matching the specified +/// difficulty. +/// +/// `iterand` (noun form of iterate) is the amount that the nonce should be increased by +/// every iteration of the proof of work loop. This will vary by how many threads are +/// running the proof-of-work check, and also functions as a thread ID. This prevents +/// wasting CPU time retrying a hash+nonce pair that likely won't work. #[unsafe(no_mangle)] pub extern "C" fn anubis_work(difficulty: u32, initial_nonce: u32, iterand: u32) -> u32 { let mut nonce = initial_nonce; @@ -72,6 +128,8 @@ pub extern "C" fn anubis_work(difficulty: u32, initial_nonce: u32, iterand: u32) let hash = compute_hash(nonce); if validate(&hash, difficulty) { + // If the challenge worked, copy the bytes into `RESULT_HASH` so the runtime + // can pick it up. let mut challenge = RESULT_HASH.lock().unwrap(); challenge.copy_from_slice(&hash); return nonce; @@ -92,16 +150,30 @@ pub extern "C" fn anubis_work(difficulty: u32, initial_nonce: u32, iterand: u32) } } +/// This function is called by the server in order to validate a proof-of-work challenge. +/// This expects `DATA_BUFFER` to be set to the challenge value and `VERIFICATION_HASH` to +/// be set to the "raw bytes" of the SHA-256 hash that the client calculated. +/// +/// If everything is good, it returns true. Otherwise, it returns false. +/// +/// XXX(Xe): this could probably return an error code for what step fails, but this is fine +/// for now. #[unsafe(no_mangle)] pub extern "C" fn anubis_validate(nonce: u32, difficulty: u32) -> bool { let computed = compute_hash(nonce); let valid = validate(&computed, difficulty); + if !valid { + return false; + } let verification = VERIFICATION_HASH.lock().unwrap(); - valid && computed == *verification + computed == *verification } -// Memory accessors +// These functions exist to give pointers and lengths to the runtime around the Anubis +// checks, this allows JavaScript and Go to safely manipulate the memory layout that Rust +// has statically allocated at compile time without having to assume how the Rust compiler +// is going to lay it out. #[unsafe(no_mangle)] pub extern "C" fn result_hash_ptr() -> *const u8 { @@ -133,7 +205,6 @@ pub extern "C" fn data_ptr() -> *const u8 { #[unsafe(no_mangle)] pub extern "C" fn set_data_length(len: u32) { - // Add missing length setter let mut data_length = DATA_LENGTH.lock().unwrap(); *data_length = len as usize; } diff --git a/wasm/wasm.go b/wasm/wasm.go new file mode 100644 index 00000000..8f1aeb23 --- /dev/null +++ b/wasm/wasm.go @@ -0,0 +1,220 @@ +package wasm + +import ( + "context" + "errors" + "fmt" + "io" + "os" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +func UpdateNonce(uint32) {} + +type Runner struct { + r wazero.Runtime + code wazero.CompiledModule + module api.Module +} + +func NewRunner(ctx context.Context, fname string, fin io.ReadCloser) (*Runner, error) { + data, err := io.ReadAll(fin) + if err != nil { + return nil, fmt.Errorf("wasm: can't read from fin: %w", err) + } + + r := wazero.NewRuntime(ctx) + + _, err = r.NewHostModuleBuilder("anubis"). + NewFunctionBuilder(). + WithFunc(func(context.Context, uint32) {}). + Export("anubis_update_nonce"). + Instantiate(ctx) + if err != nil { + return nil, fmt.Errorf("wasm: can't export anubis_update_nonce: %w", err) + } + + code, err := r.CompileModule(ctx, data) + if err != nil { + return nil, fmt.Errorf("wasm: can't compile module: %w", err) + } + + mod, err := r.InstantiateModule(ctx, code, wazero.NewModuleConfig().WithName(fname)) + if err != nil { + return nil, fmt.Errorf("wasm: can't instantiate module: %w", err) + } + + result := &Runner{ + r: r, + code: code, + module: mod, + } + + if err := result.checkExports(); err != nil { + return nil, fmt.Errorf("wasm: module is missing exports: %w", err) + } + + return result, nil +} + +func (r *Runner) checkExports() error { + funcs := []string{ + "anubis_work", + "anubis_validate", + "data_ptr", + "set_data_length", + "result_hash_ptr", + "result_hash_size", + "verification_hash_ptr", + "verification_hash_size", + } + + var errs []error + + for _, fun := range funcs { + if r.module.ExportedFunction(fun) == nil { + errs = append(errs, fmt.Errorf("function %s is not defined", fun)) + } + } + + if len(errs) != 0 { + return errors.Join(errs...) + } + + return nil +} + +func (r *Runner) anubisWork(ctx context.Context, difficulty, initialNonce, iterand uint32) (uint32, error) { + results, err := r.module.ExportedFunction("anubis_work").Call(ctx, uint64(difficulty), uint64(initialNonce), uint64(iterand)) + if err != nil { + return 0, err + } + + return uint32(results[0]), nil +} + +func (r *Runner) anubisValidate(ctx context.Context, nonce, difficulty uint32) (bool, error) { + results, err := r.module.ExportedFunction("anubis_validate").Call(ctx, uint64(nonce), uint64(difficulty)) + if err != nil { + return false, err + } + + // Rust booleans are 1 if true + return results[0] == 1, nil +} + +func (r *Runner) dataPtr(ctx context.Context) (uint32, error) { + results, err := r.module.ExportedFunction("data_ptr").Call(ctx) + if err != nil { + return 0, err + } + + return uint32(results[0]), nil +} + +func (r *Runner) setDataLength(ctx context.Context, length uint32) error { + _, err := r.module.ExportedFunction("set_data_length").Call(ctx, uint64(length)) + return err +} + +func (r *Runner) resultHashPtr(ctx context.Context) (uint32, error) { + results, err := r.module.ExportedFunction("result_hash_ptr").Call(ctx) + if err != nil { + return 0, err + } + + return uint32(results[0]), nil +} + +func (r *Runner) resultHashSize(ctx context.Context) (uint32, error) { + results, err := r.module.ExportedFunction("result_hash_size").Call(ctx) + if err != nil { + return 0, err + } + + return uint32(results[0]), nil +} + +func (r *Runner) verificationHashPtr(ctx context.Context) (uint32, error) { + results, err := r.module.ExportedFunction("verification_hash_ptr").Call(ctx) + if err != nil { + return 0, err + } + + return uint32(results[0]), nil +} + +func (r *Runner) verificationHashSize(ctx context.Context) (uint32, error) { + results, err := r.module.ExportedFunction("verification_hash_size").Call(ctx) + if err != nil { + return 0, err + } + + return uint32(results[0]), nil +} + +func (r *Runner) WriteData(ctx context.Context, data []byte) (uint32, error) { + if len(data) > 4096 { + return 0, os.ErrInvalid + } + + length := uint32(len(data)) + + dataPtr, err := r.dataPtr(ctx) + if err != nil { + return 0, fmt.Errorf("can't read data pointer: %w", err) + } + + if !r.module.Memory().Write(dataPtr, data) { + return 0, fmt.Errorf("[unexpected] can't write memory, is data out of range??") + } + + if err := r.setDataLength(ctx, length); err != nil { + return 0, fmt.Errorf("can't set data length: %w", err) + } + + return length, nil +} + +func (r *Runner) ReadResult(ctx context.Context) ([]byte, error) { + length, err := r.resultHashSize(ctx) + if err != nil { + return nil, fmt.Errorf("can't get result hash size: %w", err) + } + + ptr, err := r.resultHashPtr(ctx) + if err != nil { + return nil, fmt.Errorf("can't get result hash pointer: %w", err) + } + + buf, ok := r.module.Memory().Read(ptr, length) + if !ok { + return nil, fmt.Errorf("[unexpected] can't read from memory, is something out of range??") + } + + return buf, nil +} + +func (r *Runner) WriteVerification(ctx context.Context, data []byte) error { + length, err := r.verificationHashSize(ctx) + if err != nil { + return fmt.Errorf("can't get verification hash size: %v", err) + } + + if length != uint32(len(data)) { + return fmt.Errorf("data is too big, want %d bytes, got: %d", length, len(data)) + } + + ptr, err := r.verificationHashPtr(ctx) + if err != nil { + return fmt.Errorf("can't get verification hash pointer: %v", err) + } + + if !r.module.Memory().Write(ptr, data) { + return fmt.Errorf("[unexpected] can't write memory, is data out of range??") + } + + return nil +} diff --git a/wasm/wasm_test.go b/wasm/wasm_test.go new file mode 100644 index 00000000..3b04f065 --- /dev/null +++ b/wasm/wasm_test.go @@ -0,0 +1,66 @@ +package wasm + +import ( + "context" + "crypto/sha256" + "fmt" + "os" + "testing" + "time" + + "github.com/TecharoHQ/anubis/web" +) + +func TestSHA256(t *testing.T) { + const difficulty = 4 // one nibble, intentionally easy for testing + + fin, err := web.Static.Open("static/wasm/sha256.wasm") + if err != nil { + t.Fatal(err) + } + defer fin.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) + + runner, err := NewRunner(ctx, "sha256.wasm", fin) + if err != nil { + t.Fatal(err) + } + + h := sha256.New() + fmt.Fprint(h, os.Args[0]) + data := h.Sum(nil) + + if n, err := runner.WriteData(ctx, data); err != nil { + t.Fatalf("can't write data: %v", err) + } else { + t.Logf("wrote %d bytes to data segment", n) + } + + t0 := time.Now() + nonce, err := runner.anubisWork(ctx, difficulty, 0, 1) + if err != nil { + t.Fatalf("can't do test work run: %v", err) + } + t.Logf("got nonce %d in %s", nonce, time.Since(t0)) + + hash, err := runner.ReadResult(ctx) + if err != nil { + t.Fatalf("can't read result: %v", err) + } + + t.Logf("got hash %x", hash) + + if err := runner.WriteVerification(ctx, hash); err != nil { + t.Fatalf("can't write verification: %v", err) + } + + ok, err := runner.anubisValidate(ctx, nonce, difficulty) + if err != nil { + t.Fatalf("can't run validation: %v", err) + } + + if !ok { + t.Error("validation failed") + } +}