diff --git a/.gitignore b/.gitignore index 751dea4..6ee9767 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ identity/ event_initiator.identity.json event_initiator.key event_initiator.key.age +config.yaml +peers.json diff --git a/Makefile.local b/Makefile.local new file mode 100644 index 0000000..f350478 --- /dev/null +++ b/Makefile.local @@ -0,0 +1,11 @@ +.PHONY: clean new + +clean: + @# only kill the window if it exists + @if tmux list-windows -F "#{window_name}" \ + | grep -qw "^mpcium$$"; then \ + tmux kill-window -t mpcium; \ + fi + +new: clean + @tmuxifier load-window mpcium diff --git a/clean_logs.sh b/clean_logs.sh new file mode 100755 index 0000000..96364a3 --- /dev/null +++ b/clean_logs.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Directories to clean under +nodes=("node0" "node1" "node2") + +for dir in "${nodes[@]}"; do + identity_dir="$dir" + echo "Cleaning .txt files in $identity_dir..." + if [ -d "$identity_dir" ]; then + find "$identity_dir" -type f -name "*.txt" -print -delete + else + echo "Directory $identity_dir not found" + fi +done + +echo "✅ Cleanup complete." diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index cff7b70..14c805b 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -18,7 +18,7 @@ import ( "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/mpc/node" "github.com/hashicorp/consul/api" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -132,28 +132,36 @@ func runNode(ctx context.Context, c *cli.Command) error { mqManager := messaging.NewNATsMessageQueueManager("mpc", []string{ "mpc.mpc_keygen_success.*", event.SigningResultTopic, + "mpc.mpc_resharing_success.*", }, natsConn) + genkeyRequestQueue := mqManager.NewMessagePullSubscriber("mpc_keygen_request") + defer genkeyRequestQueue.Close() genKeySuccessQueue := mqManager.NewMessageQueue("mpc_keygen_success") defer genKeySuccessQueue.Close() singingResultQueue := mqManager.NewMessageQueue("signing_result") defer singingResultQueue.Close() + resharingResultQueue := mqManager.NewMessageQueue("mpc_resharing_success") + defer resharingResultQueue.Close() logger.Info("Node is running", "peerID", nodeID, "name", nodeName) peerNodeIDs := GetPeerIDs(peers) - peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV()) + peerRegistry := node.NewRegistry(nodeID, peerNodeIDs, consulClient.KV()) - mpcNode := mpc.NewNode( + mpcNode := node.NewNode( nodeID, peerNodeIDs, pubsub, directMessaging, badgerKV, keyinfoStore, - peerRegistry, identityStore, + peerRegistry, + consulClient.KV(), ) + // Preload preparams for the first time + mpcNode.PreloadPreParams() defer mpcNode.Close() eventConsumer := eventconsumer.NewEventConsumer( @@ -161,6 +169,7 @@ func runNode(ctx context.Context, c *cli.Command) error { pubsub, genKeySuccessQueue, singingResultQueue, + resharingResultQueue, identityStore, ) eventConsumer.Run() @@ -174,6 +183,7 @@ func runNode(ctx context.Context, c *cli.Command) error { timeoutConsumer.Run() defer timeoutConsumer.Close() signingConsumer := eventconsumer.NewSigningConsumer(natsConn, signingStream, pubsub) + keygenConsumer := eventconsumer.NewKeygenConsumer(natsConn, genkeyRequestQueue, pubsub) // Make the node ready before starting the signing consumer peerRegistry.Ready() @@ -188,6 +198,11 @@ func runNode(ctx context.Context, c *cli.Command) error { cancel() }() + fmt.Print("Run keygen consumer") + if err := keygenConsumer.Run(appContext); err != nil { + logger.Error("error running keygen consumer:", err) + } + if err := signingConsumer.Run(appContext); err != nil { logger.Error("error running consumer:", err) } diff --git a/docker-compose.yaml b/docker-compose.yaml index 7aebc38..88b818d 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,5 +1,3 @@ -version: "3" - services: nats-server: image: nats:latest diff --git a/examples/generate/main.go b/examples/generate/main.go index fb004ed..f7d3972 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -1,15 +1,17 @@ package main import ( + "flag" "fmt" "os" "os/signal" "syscall" + "time" "github.com/fystack/mpcium/pkg/client" "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/mpc" "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -17,6 +19,11 @@ import ( func main() { const environment = "development" + + // Parse the -n flag + numWallets := flag.Int("n", 1, "Number of wallets to generate") + flag.Parse() + config.InitViperConfig() logger.Init(environment, false) @@ -25,25 +32,34 @@ func main() { if err != nil { logger.Fatal("Failed to connect to NATS", err) } - defer natsConn.Drain() // drain inflight msgs + defer natsConn.Drain() defer natsConn.Close() mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, KeyPath: "./event_initiator.key", }) - err = mpcClient.OnWalletCreationResult(func(event mpc.KeygenSuccessEvent) { + + err = mpcClient.OnWalletCreationResult(func(event event.KeygenSuccessEvent) { logger.Info("Received wallet creation result", "event", event) }) if err != nil { logger.Fatal("Failed to subscribe to wallet-creation results", err) } - walletID := uuid.New().String() - if err := mpcClient.CreateWallet(walletID); err != nil { - logger.Fatal("CreateWallet failed", err) + for i := 0; i < *numWallets; i++ { + walletID := uuid.New().String() + if err := mpcClient.CreateWallet(walletID); err != nil { + logger.Error("CreateWallet failed", err) + continue + } + time.Sleep(100 * time.Millisecond) + logger.Info("CreateWallet sent", "walletID", walletID) } - logger.Info("CreateWallet sent, awaiting result...", "walletID", walletID) + + logger.Info("All CreateWallet requests sent, awaiting results...") + + // Wait for shutdown signal stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) <-stop diff --git a/examples/reshare/main.go b/examples/reshare/main.go new file mode 100644 index 0000000..b05b8c6 --- /dev/null +++ b/examples/reshare/main.go @@ -0,0 +1,52 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" + "github.com/nats-io/nats.go" + "github.com/spf13/viper" +) + +func main() { + const environment = "development" + config.InitViperConfig() + logger.Init(environment, false) + + natsURL := viper.GetString("nats.url") + natsConn, err := nats.Connect(natsURL) + if err != nil { + logger.Fatal("Failed to connect to NATS", err) + } + defer natsConn.Drain() // drain inflight msgs + defer natsConn.Close() + + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + KeyPath: "./event_initiator.key", + }) + err = mpcClient.OnResharingResult(func(event event.ResharingSuccessEvent) { + logger.Info("Received resharing result", "event", event) + }) + if err != nil { + logger.Fatal("Failed to subscribe to resharing results", err) + } + + walletID := "892122fd-f2f4-46dc-be25-6fd0b83dff60" + if err := mpcClient.Resharing(walletID, 2, types.KeyTypeSecp256k1); err != nil { + logger.Fatal("Resharing failed", err) + } + logger.Info("Resharing sent, awaiting result...", "walletID", walletID) + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + <-stop + + fmt.Println("Shutting down.") +} diff --git a/examples/sign/main.go b/examples/sign/main.go index d5e2410..e0c0d8f 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -39,9 +39,9 @@ func main() { dummyTx := []byte("deadbeef") // replace with real transaction bytes txMsg := &types.SignTxMessage{ - KeyType: types.KeyTypeEd25519, - WalletID: "77dd7e23-9d5c-4ff1-8759-f119d1b19b36", - NetworkInternalCode: "solana-devnet", + KeyType: types.KeyTypeSecp256k1, + WalletID: "9af13a60-9aa9-4069-ba3f-bd6d821c8905", + NetworkInternalCode: "ethereum-sepolia", TxID: txID, Tx: dummyTx, } diff --git a/go.mod b/go.mod index fb45e42..60cc498 100644 --- a/go.mod +++ b/go.mod @@ -10,17 +10,18 @@ require ( github.com/bnb-chain/tss-lib/v2 v2.0.2 github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 github.com/dgraph-io/badger/v4 v4.2.0 + github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/hashicorp/consul/api v1.26.1 github.com/mitchellh/mapstructure v1.5.0 - github.com/nats-io/nats.go v1.31.0 + github.com/nats-io/nats.go v1.43.0 github.com/rs/zerolog v1.31.0 github.com/samber/lo v1.39.0 github.com/spf13/viper v1.18.0 github.com/stretchr/testify v1.10.0 github.com/urfave/cli/v3 v3.3.2 - go.uber.org/mock v0.3.0 golang.org/x/term v0.31.0 + google.golang.org/protobuf v1.36.6 ) require ( @@ -40,7 +41,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/glog v1.2.4 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect github.com/google/flatbuffers v1.12.1 // indirect github.com/google/go-cmp v0.7.0 // indirect @@ -56,12 +56,12 @@ require ( github.com/hashicorp/serf v0.10.1 // indirect github.com/ipfs/go-log v1.0.5 // indirect github.com/ipfs/go-log/v2 v2.1.3 // indirect - github.com/klauspost/compress v1.17.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect - github.com/nats-io/nkeys v0.4.6 // indirect + github.com/nats-io/nkeys v0.4.11 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/otiai10/primes v0.0.0-20210501021515-f1b2be525a11 // indirect @@ -86,7 +86,6 @@ require ( golang.org/x/net v0.39.0 // indirect golang.org/x/sys v0.32.0 // indirect golang.org/x/text v0.24.0 // indirect - google.golang.org/protobuf v1.36.6 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 3c448e3..7930c50 100644 --- a/go.sum +++ b/go.sum @@ -210,8 +210,8 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= -github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= -github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -256,10 +256,10 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/nats-io/nats.go v1.31.0 h1:/WFBHEc/dOKBF6qf1TZhrdEfTmOZ5JzdJ+Y3m6Y/p7E= -github.com/nats-io/nats.go v1.31.0/go.mod h1:di3Bm5MLsoB4Bx61CBTsxuarI36WbhAwOm8QrW39+i8= -github.com/nats-io/nkeys v0.4.6 h1:IzVe95ru2CT6ta874rt9saQRkWfe2nFj1NtvYSLqMzY= -github.com/nats-io/nkeys v0.4.6/go.mod h1:4DxZNzenSVd1cYQoAa8948QY3QDjrHfcfVADymtkpts= +github.com/nats-io/nats.go v1.43.0 h1:uRFZ2FEoRvP64+UUhaTokyS18XBCR/xM2vQZKO4i8ug= +github.com/nats-io/nats.go v1.43.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= +github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0= +github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= @@ -373,8 +373,6 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= -go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/pkg/client/client.go b/pkg/client/client.go index 6314158..519edf6 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "os" - "path/filepath" "strings" "filippo.io/age" @@ -15,230 +14,259 @@ import ( "github.com/fystack/mpcium/pkg/eventconsumer" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/mpc" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" ) const ( - GenerateWalletSuccessTopic = "mpc.mpc_keygen_success.*" // wildcard to listen to all success events + defaultKeyPath = "./event_initiator.key" + keyFileExt = ".age" + + // NATS stream names + mpcSigningStream = "mpc-signing" + + // NATS queue names + mpcKeygenSuccessQueue = "mpc_keygen_success" + mpcSigningResultQueue = "signing_result" + mpcResharingSuccessQueue = "mpc_resharing_success" + mpcKeygenRequestQueue = "mpc_keygen_request" + + // NATS subjects + mpcSigningRequestSubject = "mpc.signing_request.*" + mpcKeygenSuccessSubject = "mpc.mpc_keygen_success.*" + mpcSigningResultSubject = "mpc.signing_result.*" + mpcResharingSuccessSubject = "mpc.mpc_resharing_success.*" + mpcKeygenRequestSubject = "mpc.mpc_keygen_request.*" ) type MPCClient interface { CreateWallet(walletID string) error - OnWalletCreationResult(callback func(event mpc.KeygenSuccessEvent)) error + OnWalletCreationResult(callback func(event.KeygenSuccessEvent)) error SignTransaction(msg *types.SignTxMessage) error - OnSignResult(callback func(event event.SigningResultEvent)) error + OnSignResult(callback func(event.SigningResultEvent)) error + + Resharing(walletID string, newThreshold int, keyType types.KeyType) error + OnResharingResult(callback func(event.ResharingSuccessEvent)) error } type mpcClient struct { - signingStream messaging.StreamPubsub - pubsub messaging.PubSub - genKeySuccessQueue messaging.MessageQueue - signResultQueue messaging.MessageQueue - privKey ed25519.PrivateKey + signingStream messaging.StreamPubsub + pubsub messaging.PubSub + genKeySuccessQueue messaging.MessageQueue + signResultQueue messaging.MessageQueue + resharingResultQueue messaging.MessageQueue + genKeyRequestQueue messaging.MessageQueue + + privKey ed25519.PrivateKey } // Options defines configuration options for creating a new MPCClient type Options struct { - // NATS connection - NatsConn *nats.Conn - - // Key path options - KeyPath string // Path to unencrypted key (default: "./event_initiator.key") - - // Encryption options + NatsConn *nats.Conn + KeyPath string // Path to unencrypted key (default: "./event_initiator.key") Encrypted bool // Whether the key is encrypted Password string // Password for encrypted key } // NewMPCClient creates a new MPC client using the provided options. -// It reads the Ed25519 private key from disk and sets up messaging connections. -// If the key is encrypted (.age file), decryption options must be provided in the config. func NewMPCClient(opts Options) MPCClient { - // Set default paths if not provided + // Set default key path if not provided if opts.KeyPath == "" { - opts.KeyPath = filepath.Join(".", "event_initiator.key") + opts.KeyPath = defaultKeyPath } - if strings.HasSuffix(opts.KeyPath, ".age") { + // Auto-detect encryption based on file extension + if strings.HasSuffix(opts.KeyPath, keyFileExt) { opts.Encrypted = true } - var privHexBytes []byte - var err error - - // Check if key file exists - if _, err := os.Stat(opts.KeyPath); err == nil { - if opts.Encrypted { - // Encrypted key exists, try to decrypt it - if opts.Password == "" { - logger.Fatal("Encrypted key found but no decryption option provided", nil) - } + // Load private key + privKey := loadPrivateKey(opts) - // Read encrypted file - encryptedBytes, err := os.ReadFile(opts.KeyPath) - if err != nil { - logger.Fatal("Failed to read encrypted private key file", err) - } + // Initialize messaging components + signingStream := initSigningStream(opts.NatsConn) + pubsub := messaging.NewNATSPubSub(opts.NatsConn) + manager := initMessageQueueManager(opts.NatsConn) - // Decrypt the key using the provided password - privHexBytes, err = decryptPrivateKey(encryptedBytes, opts.Password) - if err != nil { - logger.Fatal("Failed to decrypt private key", err) - } - } else { - // Unencrypted key exists, read it normally - privHexBytes, err = os.ReadFile(opts.KeyPath) - if err != nil { - logger.Fatal("Failed to read private key file", err) - } - } - } else { - logger.Fatal("No private key file found", nil) + return &mpcClient{ + signingStream: signingStream, + pubsub: pubsub, + genKeySuccessQueue: manager.NewMessageQueue(mpcKeygenSuccessQueue), + signResultQueue: manager.NewMessageQueue(mpcSigningResultQueue), + resharingResultQueue: manager.NewMessageQueue(mpcResharingSuccessQueue), + genKeyRequestQueue: manager.NewMessagePullSubscriber(mpcKeygenRequestQueue), + privKey: privKey, } +} - privHex := string(privHexBytes) - // Decode private key from hex - privSeed, err := hex.DecodeString(privHex) - if err != nil { - fmt.Println("Failed to decode private key hex:", err) - os.Exit(1) - } +func initMessageQueueManager(natsConn *nats.Conn) *messaging.NATsMessageQueueManager { + return messaging.NewNATsMessageQueueManager("mpc", []string{ + mpcKeygenSuccessSubject, + mpcSigningResultSubject, + mpcResharingSuccessSubject, + mpcKeygenRequestSubject, + }, natsConn) +} - // Reconstruct full Ed25519 private key from seed - priv := ed25519.NewKeyFromSeed(privSeed) +// CreateWallet generates a GenerateKeyMessage, signs it, and publishes it. +func (c *mpcClient) CreateWallet(walletID string) error { + msg := &types.GenerateKeyMessage{WalletID: walletID} - // 2) Create the PubSub for both publish & subscribe - signingStream, err := messaging.NewJetStreamPubSub(opts.NatsConn, "mpc-signing", []string{ - "mpc.signing_request.*", - }) + raw, err := msg.Raw() if err != nil { - logger.Fatal("Failed to create JetStream PubSub", err) + return fmt.Errorf("CreateWallet: raw payload error: %w", err) } - pubsub := messaging.NewNATSPubSub(opts.NatsConn) - - manager := messaging.NewNATsMessageQueueManager("mpc", []string{ - "mpc.mpc_keygen_success.*", - "mpc.signing_result.*", - }, opts.NatsConn) - - genKeySuccessQueue := manager.NewMessageQueue("mpc_keygen_success") - signResultQueue := manager.NewMessageQueue("signing_result") + msg.Signature = ed25519.Sign(c.privKey, raw) + bytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("CreateWallet: marshal error: %w", err) + } - return &mpcClient{ - signingStream: signingStream, - pubsub: pubsub, - genKeySuccessQueue: genKeySuccessQueue, - signResultQueue: signResultQueue, - privKey: priv, + if err := c.genKeyRequestQueue.Enqueue(mpcKeygenRequestSubject, bytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf("%s.%s", eventconsumer.MPCGenerateEvent, walletID), + }); err != nil { + return fmt.Errorf("CreateWallet: publish error: %w", err) } + return nil } -// decryptPrivateKey decrypts the encrypted private key using the provided password -func decryptPrivateKey(encryptedData []byte, password string) ([]byte, error) { - // Create an age identity (decryption key) from the password - identity, err := age.NewScryptIdentity(password) +func (c *mpcClient) OnWalletCreationResult(callback func(event.KeygenSuccessEvent)) error { + return c.handleQueueEvent(c.genKeySuccessQueue, event.KeygenSuccessEventTopic, callback) +} + +func (c *mpcClient) SignTransaction(msg *types.SignTxMessage) error { + raw, err := msg.Raw() if err != nil { - return nil, fmt.Errorf("failed to create identity from password: %w", err) + return fmt.Errorf("SignTransaction: raw payload error: %w", err) } - // Create a reader from the encrypted data - decrypter, err := age.Decrypt(strings.NewReader(string(encryptedData)), identity) + msg.Signature = ed25519.Sign(c.privKey, raw) + bytes, err := json.Marshal(msg) if err != nil { - return nil, fmt.Errorf("failed to create decrypter: %w", err) + return fmt.Errorf("SignTransaction: marshal error: %w", err) } - // Read the decrypted data - decryptedData, err := io.ReadAll(decrypter) - if err != nil { - return nil, fmt.Errorf("failed to read decrypted data: %w", err) + if err := c.signingStream.Publish(event.SigningRequestEventTopic, bytes); err != nil { + return fmt.Errorf("SignTransaction: publish error: %w", err) } + return nil +} - return decryptedData, nil +func (c *mpcClient) OnSignResult(callback func(event.SigningResultEvent)) error { + return c.handleQueueEvent(c.signResultQueue, event.SigningResultCompleteTopic, callback) } -// CreateWallet generates a GenerateKeyMessage, signs it, and publishes it. -func (c *mpcClient) CreateWallet(walletID string) error { - // build the message - msg := &types.GenerateKeyMessage{ - WalletID: walletID, +func (c *mpcClient) Resharing(walletID string, newThreshold int, keyType types.KeyType) error { + msg := &types.ResharingMessage{ + WalletID: walletID, + NewThreshold: newThreshold, + KeyType: keyType, } - // compute the canonical raw bytes + raw, err := msg.Raw() if err != nil { - return fmt.Errorf("CreateWallet: raw payload error: %w", err) + return fmt.Errorf("Resharing: raw payload error: %w", err) } - // sign - msg.Signature = ed25519.Sign(c.privKey, raw) + msg.Signature = ed25519.Sign(c.privKey, raw) bytes, err := json.Marshal(msg) if err != nil { - return fmt.Errorf("CreateWallet: marshal error: %w", err) + return fmt.Errorf("Resharing: marshal error: %w", err) } - if err := c.pubsub.Publish(eventconsumer.MPCGenerateEvent, bytes); err != nil { - return fmt.Errorf("CreateWallet: publish error: %w", err) + if err := c.pubsub.Publish(eventconsumer.MPCResharingEvent, bytes); err != nil { + return fmt.Errorf("Resharing: publish error: %w", err) } return nil } -// The callback will be invoked whenever a wallet creation result is received. -func (c *mpcClient) OnWalletCreationResult(callback func(event mpc.KeygenSuccessEvent)) error { - err := c.genKeySuccessQueue.Dequeue(GenerateWalletSuccessTopic, func(msg []byte) error { - var event mpc.KeygenSuccessEvent - err := json.Unmarshal(msg, &event) - if err != nil { - return err +func (c *mpcClient) OnResharingResult(callback func(event.ResharingSuccessEvent)) error { + return c.handleQueueEvent(c.resharingResultQueue, event.ResharingSuccessEventTopic, callback) +} + +// Generic handler for queue events +func (c *mpcClient) handleQueueEvent(queue messaging.MessageQueue, topic string, callback interface{}) error { + return queue.Dequeue(topic, func(msg []byte) error { + switch cb := callback.(type) { + case func(event.KeygenSuccessEvent): + var event event.KeygenSuccessEvent + if err := json.Unmarshal(msg, &event); err != nil { + return err + } + cb(event) + case func(event.SigningResultEvent): + var event event.SigningResultEvent + if err := json.Unmarshal(msg, &event); err != nil { + return err + } + cb(event) + case func(event.ResharingSuccessEvent): + var event event.ResharingSuccessEvent + if err := json.Unmarshal(msg, &event); err != nil { + return err + } + cb(event) + default: + return fmt.Errorf("unsupported callback type") } - callback(event) return nil }) +} + +func loadPrivateKey(opts Options) ed25519.PrivateKey { + if _, err := os.Stat(opts.KeyPath); os.IsNotExist(err) { + logger.Fatal("No private key file found", nil) + } + + var privHexBytes []byte + var err error + + if opts.Encrypted { + if opts.Password == "" { + logger.Fatal("Encrypted key found but no decryption option provided", nil) + } + privHexBytes, err = loadEncryptedKey(opts.KeyPath, opts.Password) + } else { + privHexBytes, err = os.ReadFile(opts.KeyPath) + } if err != nil { - return fmt.Errorf("OnWalletCreationResult: subscribe error: %w", err) + logger.Fatal("Failed to read private key file", err) } - return nil + privSeed, err := hex.DecodeString(string(privHexBytes)) + if err != nil { + logger.Fatal("Failed to decode private key hex", err) + } + + return ed25519.NewKeyFromSeed(privSeed) } -// SignTransaction builds a SignTxMessage, signs it, and publishes it. -func (c *mpcClient) SignTransaction(msg *types.SignTxMessage) error { - // compute the canonical raw bytes (omitting Signature field) - raw, err := msg.Raw() +func loadEncryptedKey(keyPath, password string) ([]byte, error) { + encryptedBytes, err := os.ReadFile(keyPath) if err != nil { - return fmt.Errorf("SignTransaction: raw payload error: %w", err) + return nil, fmt.Errorf("failed to read encrypted key file: %w", err) } - // sign - msg.Signature = ed25519.Sign(c.privKey, raw) - bytes, err := json.Marshal(msg) + identity, err := age.NewScryptIdentity(password) if err != nil { - return fmt.Errorf("SignTransaction: marshal error: %w", err) + return nil, fmt.Errorf("failed to create identity from password: %w", err) } - if err := c.signingStream.Publish(event.SigningRequestEventTopic, bytes); err != nil { - return fmt.Errorf("SignTransaction: publish error: %w", err) + decrypter, err := age.Decrypt(strings.NewReader(string(encryptedBytes)), identity) + if err != nil { + return nil, fmt.Errorf("failed to create decrypter: %w", err) } - return nil -} -func (c *mpcClient) OnSignResult(callback func(event event.SigningResultEvent)) error { - err := c.signResultQueue.Dequeue(event.SigningResultCompleteTopic, func(msg []byte) error { - var event event.SigningResultEvent - err := json.Unmarshal(msg, &event) - if err != nil { - return err - } - callback(event) - return nil - }) + return io.ReadAll(decrypter) +} +func initSigningStream(natsConn *nats.Conn) messaging.StreamPubsub { + stream, err := messaging.NewJetStreamPubSub(natsConn, mpcSigningStream, []string{mpcSigningRequestSubject}) if err != nil { - return fmt.Errorf("OnSignResult: subscribe error: %w", err) + logger.Fatal("Failed to create JetStream PubSub", err) } - - return nil + return stream } diff --git a/pkg/common/concurrency/utils.go b/pkg/common/concurrency/utils.go new file mode 100644 index 0000000..d668869 --- /dev/null +++ b/pkg/common/concurrency/utils.go @@ -0,0 +1,49 @@ +package concurrency + +import ( + "runtime" +) + +// GetVirtualCoreCount returns the number of logical CPUs (virtual cores) available on the system. +// This includes physical cores *and* hyperthreads. +func GetVirtualCoreCount() int { + return runtime.NumCPU() +} + +// GetTSSConcurrencyLimit returns the recommended maximum number of concurrent TSS sessions. +// It estimates the number of *physical* cores by dividing the virtual core count by 2, +// because each physical core typically has 2 logical threads due to hyperthreading. +// +// Threshold signing (e.g., ECDSA) is CPU-bound and does not benefit much from hyperthreads, +// so we limit concurrency based on physical core estimates. +func GetTSSConcurrencyLimit() int { + logicalCores := GetVirtualCoreCount() + + // Estimate physical cores by dividing virtual CPUs by 2 + estimatedPhysicalCores := logicalCores / 2 + if estimatedPhysicalCores < 1 { + estimatedPhysicalCores = 1 // always allow at least one session + } + + return calculateAllowedSessions(estimatedPhysicalCores) +} + +// calculateAllowedSessions maps physical core count to safe TSS concurrency limits. +// You can tune these thresholds depending on your latency and throughput requirements. +func calculateAllowedSessions(coreCount int) int { + switch { + case coreCount <= 2: + return 1 + case coreCount <= 4: + return 2 + case coreCount <= 8: + return 3 + case coreCount <= 12: + return 5 + case coreCount <= 16: + return 6 + default: + // For large systems, reserve some headroom for OS, logs, GC, etc. + return coreCount / 2 + } +} diff --git a/pkg/event/event.go b/pkg/event/event.go new file mode 100644 index 0000000..1def269 --- /dev/null +++ b/pkg/event/event.go @@ -0,0 +1,37 @@ +package event + +const ( + KeygenSuccessEventTopic = "mpc.mpc_keygen_success.*" + ResharingSuccessEventTopic = "mpc.mpc_resharing_success.*" + + TypeGenerateWalletSuccess = "mpc.mpc_keygen_success.%s" + TypeSigningResultComplete = "mpc.mpc_signing_result_complete.%s.%s" + TypeResharingSuccess = "mpc.mpc_resharing_success.%s.%d" +) + +type KeygenSuccessEvent struct { + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` +} + +type SigningResultEvent struct { + ResultType SigningResultType `json:"result_type"` + ErrorReason string `json:"error_reason"` + IsTimeout bool `json:"is_timeout"` + NetworkInternalCode string `json:"network_internal_code"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` + R []byte `json:"r"` + S []byte `json:"s"` + SignatureRecovery []byte `json:"signature_recovery"` + + // TODO: define two separate events for eddsa and ecdsa + Signature []byte `json:"signature"` +} + +type ResharingSuccessEvent struct { + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` +} diff --git a/pkg/event/sign.go b/pkg/event/sign.go index cb8d53d..bece990 100644 --- a/pkg/event/sign.go +++ b/pkg/event/sign.go @@ -17,38 +17,3 @@ const ( SigningResultTypeSuccess SigningResultTypeError ) - -type SigningResultEvent struct { - ResultType SigningResultType `json:"result_type"` - ErrorReason string `json:"error_reason"` - IsTimeout bool `json:"is_timeout"` - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - R []byte `json:"r"` - S []byte `json:"s"` - SignatureRecovery []byte `json:"signature_recovery"` - - // TODO: define two separate events for eddsa and ecdsa - Signature []byte `json:"signature"` -} - -type SigningResultSuccessEvent struct { - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - R []byte `json:"r"` - S []byte `json:"s"` - SignatureRecovery []byte `json:"signature_recovery"` - - // TODO: define two separate events for eddsa and ecdsa - Signature []byte `json:"signature"` -} - -type SigningResultErrorEvent struct { - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - ErrorReason string `json:"error_reason"` - IsTimeout bool `json:"is_timeout"` -} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 2f98210..3671789 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -3,26 +3,37 @@ package eventconsumer import ( "context" "encoding/json" - "errors" "fmt" "log" "math/big" "sync" "time" + "github.com/fystack/mpcium/pkg/common/concurrency" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/monitoring" + "github.com/fystack/mpcium/pkg/mpc/node" + "github.com/fystack/mpcium/pkg/mpc/session" + "github.com/fystack/mpcium/pkg/tsslimiter" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" "github.com/spf13/viper" ) const ( - MPCGenerateEvent = "mpc:generate" - MPCSignEvent = "mpc:sign" + MPCGenerateEvent = "mpc:generate" + MPCSignEvent = "mpc:sign" + MPCResharingEvent = "mpc:reshare" + + // Default version for keygen + DefaultVersion int = 1 + SessionTimeout = 15 * time.Second + MaxConcurrentSessions = 5 + // how long the entire handler will wait for *all* sessions + publishing: + HandlerTimeout = 20 * time.Second ) type EventConsumer interface { @@ -30,16 +41,23 @@ type EventConsumer interface { Close() error } +func Elaps(start time.Time, text string) { + elapsed := time.Since(start) + fmt.Printf("%s, Elapsed time: %d ms\n", text, elapsed.Milliseconds()) +} + type eventConsumer struct { - node *mpc.Node + node *node.Node pubsub messaging.PubSub mpcThreshold int - genKeySucecssQueue messaging.MessageQueue - signingResultQueue messaging.MessageQueue + genKeySuccessQueue messaging.MessageQueue + signingResultQueue messaging.MessageQueue + resharingResultQueue messaging.MessageQueue keyGenerationSub messaging.Subscription signingSub messaging.Subscription + resharingSub messaging.Subscription identityStore identity.Store // Track active sessions with timestamps for cleanup @@ -48,26 +66,34 @@ type eventConsumer struct { cleanupInterval time.Duration // How often to run cleanup sessionTimeout time.Duration // How long before a session is considered stale cleanupStopChan chan struct{} // Signal to stop cleanup goroutine + limiterQueue tsslimiter.Queue } func NewEventConsumer( - node *mpc.Node, + node *node.Node, pubsub messaging.PubSub, - genKeySucecssQueue messaging.MessageQueue, + genKeySuccessQueue messaging.MessageQueue, signingResultQueue messaging.MessageQueue, + resharingResultQueue messaging.MessageQueue, identityStore identity.Store, ) EventConsumer { + limiter := tsslimiter.NewWeightedLimiter(concurrency.GetTSSConcurrencyLimit()) + bufferSize := 100 + limiterQueue := tsslimiter.NewWeightedQueue(limiter, bufferSize) + ec := &eventConsumer{ - node: node, - pubsub: pubsub, - genKeySucecssQueue: genKeySucecssQueue, - signingResultQueue: signingResultQueue, - activeSessions: make(map[string]time.Time), - cleanupInterval: 5 * time.Minute, // Run cleanup every 5 minutes - sessionTimeout: 30 * time.Minute, // Consider sessions older than 30 minutes stale - cleanupStopChan: make(chan struct{}), - mpcThreshold: viper.GetInt("mpc_threshold"), - identityStore: identityStore, + node: node, + pubsub: pubsub, + genKeySuccessQueue: genKeySuccessQueue, + signingResultQueue: signingResultQueue, + resharingResultQueue: resharingResultQueue, + activeSessions: make(map[string]time.Time), + cleanupInterval: 5 * time.Minute, // Run cleanup every 5 minutes + sessionTimeout: 30 * time.Minute, // Consider sessions older than 30 minutes stale + cleanupStopChan: make(chan struct{}), + mpcThreshold: viper.GetInt("mpc_threshold"), + identityStore: identityStore, + limiterQueue: limiterQueue, } // Start background cleanup goroutine @@ -87,109 +113,166 @@ func (ec *eventConsumer) Run() { log.Fatal("Failed to consume tx signing event", err) } + err = ec.consumeResharingEvent() + if err != nil { + log.Fatal("Failed to consume resharing event", err) + } + logger.Info("MPC Event consumer started...!") } - func (ec *eventConsumer) consumeKeyGenerationEvent() error { + // Create session limiter channel with capacity 5 sub, err := ec.pubsub.Subscribe(MPCGenerateEvent, func(natMsg *nats.Msg) { - raw := natMsg.Data - var msg types.GenerateKeyMessage - err := json.Unmarshal(raw, &msg) - if err != nil { - logger.Error("Failed to unmarshal signing message", err) - return + logger.Info("Received key generation event", "subject", natMsg.Subject) + job := tsslimiter.SessionJob{ + Type: tsslimiter.SessionKeygenCombined, + Run: func() error { + return ec.handleKeyGenerationEvent(context.Background(), natMsg) + }, + OnError: func(err error) { + logger.Error("Failed to handle key generation event", err) + }, + Name: fmt.Sprintf("keygen-%s", string(natMsg.Data)), } - logger.Info("Received key generation event", "msg", msg) + ec.limiterQueue.Enqueue(job) + }) - err = ec.identityStore.VerifyInitiatorMessage(&msg) - if err != nil { - logger.Error("Failed to verify initiator message", err) - return - } + if err != nil { + return err + } - walletID := msg.WalletID - session, err := ec.node.CreateKeyGenSession(walletID, ec.mpcThreshold, ec.genKeySucecssQueue) - if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) - return - } - eddsaSession, err := ec.node.CreateEDDSAKeyGenSession(walletID, ec.mpcThreshold, ec.genKeySucecssQueue) - if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) - return - } + ec.keyGenerationSub = sub + return nil +} + +func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, natMsg *nats.Msg) error { + raw := natMsg.Data + ctx, handlerCancel := context.WithTimeout(parentCtx, HandlerTimeout) + defer handlerCancel() - session.Init() - eddsaSession.Init() + // 1) decode and verify + var msg types.GenerateKeyMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return fmt.Errorf("unmarshal message: %w", err) + } + if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { + return fmt.Errorf("verify initiator: %w", err) + } - ctx, done := context.WithCancel(context.Background()) - ctxEddsa, doneEddsa := context.WithCancel(context.Background()) + walletID := msg.WalletID + successEvent := &event.KeygenSuccessEvent{WalletID: walletID} - successEvent := &mpc.KeygenSuccessEvent{ - WalletID: walletID, + // 2) prepare both sessions + s0, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySuccessQueue) + if err != nil { + return fmt.Errorf("create ECDSA session: %w", err) + } + + s1, err := ec.node.CreateKeygenSession(types.KeyTypeEd25519, walletID, ec.mpcThreshold, ec.genKeySuccessQueue) + if err != nil { + s0.Close() + return fmt.Errorf("create EDDSA session: %w", err) + } + + s0.Listen(ctx) + s1.Listen(ctx) + + defer s0.Close() + defer s1.Close() + + runKeygen := func(s session.Session, keyType types.KeyType) error { + sessionCtx, sessionCancel := context.WithTimeout(ctx, SessionTimeout) + defer sessionCancel() + + // // 1. Wait for all parties to be ready to start + if err := s.WaitForReady(sessionCtx, fmt.Sprintf("KEYGEN-start:%s", keyType)); err != nil { + return fmt.Errorf("failed to wait for ready: %w", err) } - var wg sync.WaitGroup - wg.Add(2) - go func() { - for { - select { - case <-ctx.Done(): - successEvent.ECDSAPubKey = session.GetPubKeyResult() - wg.Done() - return - case err := <-session.ErrCh: - logger.Error("Keygen session error", err) - } + doneCh := make(chan error, 1) + + // 2. Start the key generation protocol + s.StartKeygen(sessionCtx, s.Send, func(data []byte) { + logger.Info("[callback] StartKeygen fired", "walletID", walletID, "keyType", keyType) + if err := s.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data); err != nil { + logger.Error("Failed to save key", err, "walletID", walletID, "keyType", keyType) + doneCh <- err + return } - }() - go func() { - for { - select { - case <-ctxEddsa.Done(): - successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() - wg.Done() - return - case err := <-eddsaSession.ErrCh: - logger.Error("Keygen session error", err) + if pubKey, err := s.GetPublicKey(data); err == nil { + switch keyType { + case types.KeyTypeSecp256k1: + successEvent.ECDSAPubKey = pubKey + case types.KeyTypeEd25519: + successEvent.EDDSAPubKey = pubKey } + } else { + logger.Error("Failed to get public key", err, "walletID", walletID, "keyType", keyType) + doneCh <- err + return } - }() - session.ListenToIncomingMessageAsync() - eddsaSession.ListenToIncomingMessageAsync() - // TODO: replace sleep with distributed lock - time.Sleep(1 * time.Second) - - go session.GenerateKey(done) - go eddsaSession.GenerateKey(doneEddsa) + // // 3. Wait for all parties to confirm completion + if err := s.WaitForReady(sessionCtx, fmt.Sprintf("KEYGEN-complete:%s", keyType)); err != nil { + doneCh <- fmt.Errorf("failed to wait for completion: %w", err) + return + } - wg.Wait() - logger.Info("Closing session successfully!", "event", successEvent) + doneCh <- nil + }) - successEventBytes, err := json.Marshal(successEvent) - if err != nil { - logger.Error("Failed to marshal keygen success event", err) - return + select { + case err := <-doneCh: + if err != nil { + return fmt.Errorf("keygen onComplete failed: %w", err) + } + return nil + case err := <-s.ErrCh(): + return fmt.Errorf("session error during keygen: %w", err) + case <-sessionCtx.Done(): + return fmt.Errorf("keygen timed out: %w", sessionCtx.Err()) } + } - err = ec.genKeySucecssQueue.Enqueue(fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), - }) - if err != nil { - logger.Error("Failed to publish key generation success message", err) - return - } + logger.Info("Starting ECDSA key generation...", "walletID", walletID) + if err := runKeygen(s0, types.KeyTypeSecp256k1); err != nil { + return fmt.Errorf("ECDSA keygen failed: %w", err) + } + logger.Info("ECDSA key generation completed.", "walletID", walletID) - logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) + logger.Info("Starting EDDSA key generation...", "walletID", walletID) + if err := runKeygen(s1, types.KeyTypeEd25519); err != nil { + return fmt.Errorf("EDDSA keygen failed: %w", err) + } + logger.Info("EDDSA key generation completed.", "walletID", walletID) + // 3) Send reply to keygen consumer after both keygens complete + // 4) marshal & publish success + successBytes, err := json.Marshal(successEvent) + if err != nil { + return fmt.Errorf("marshal success event: %w", err) + } - }) + if err := ec.genKeySuccessQueue.Enqueue( + fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), + successBytes, + &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), + }, + ); err != nil { + return fmt.Errorf("enqueue success event: %w", err) + } - ec.keyGenerationSub = sub - if err != nil { - return err + if natMsg.Reply != "" { + err = ec.pubsub.Publish(natMsg.Reply, successBytes) + if err != nil { + logger.Error("Failed to publish reply", err) + } else { + logger.Info("Reply sent to keygen consumer", "reply", natMsg.Reply, "walletID", walletID) + } } + + logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) return nil } @@ -209,17 +292,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return } - logger.Info( - "Received signing event", - "waleltID", - msg.WalletID, - "type", - msg.KeyType, - "tx", - msg.TxID, - "Id", - ec.node.ID(), - ) + logger.Info("Received signing event", "msg", msg) // Check for duplicate session and track if new if ec.checkDuplicateSession(msg.WalletID, msg.TxID) { @@ -227,105 +300,99 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return } - var session mpc.ISigningSession - switch msg.KeyType { - case types.KeyTypeSecp256k1: - session, err = ec.node.CreateSigningSession( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - ec.mpcThreshold, - ec.signingResultQueue, - ) - case types.KeyTypeEd25519: - session, err = ec.node.CreateEDDSASigningSession( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - ec.mpcThreshold, - ec.signingResultQueue, - ) - - } + // Add session to tracking before starting + ec.addSession(msg.WalletID, msg.TxID) + keyInfoVersion, err := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) if err != nil { - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - "Failed to create signing session", - natMsg, - ) + logger.Error("Failed to get party version", err) + ec.removeSession(msg.WalletID, msg.TxID) return } - txBigInt := new(big.Int).SetBytes(msg.Tx) - err = session.Init(txBigInt) + signingSession, err := ec.node.CreateSigningSession( + msg.KeyType, + msg.WalletID, + msg.TxID, + keyInfoVersion, + ec.mpcThreshold, + ec.signingResultQueue, + ) + if err != nil { - if errors.Is(err, mpc.ErrNotEnoughParticipants) { - logger.Info("RETRY LATER: Not enough participants to sign") - //Return for retry later - return - } ec.handleSigningSessionError( msg.WalletID, msg.TxID, msg.NetworkInternalCode, err, - "Failed to init signing session", + "Failed to create signing session", natMsg, ) + ec.removeSession(msg.WalletID, msg.TxID) return } - // Mark session as already processed - ec.addSession(msg.WalletID, msg.TxID) + go signingSession.Listen(context.Background()) - ctx, done := context.WithCancel(context.Background()) + txBigInt := new(big.Int).SetBytes(msg.Tx) go func() { - for { - select { - case <-ctx.Done(): + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + signingSession.StartSigning(ctx, txBigInt, signingSession.Send, func(data []byte) { + cancel() + signatureData, err := signingSession.VerifySignature(msg.Tx, data) + if err != nil { + logger.Error("Failed to verify signature", err) + ec.removeSession(msg.WalletID, msg.TxID) return - case err := <-session.ErrChan(): - if err != nil { - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - "Failed to sign tx", - natMsg, - ) - return - } } - } - }() - session.ListenToIncomingMessageAsync() - // TODO: use consul distributed lock here, only sign after all nodes has already completed listing to incoming message async - // The purpose of the sleep is to be ensuring that the node has properly set up its message listeners - // before it starts the signing process. If the signing process starts sending messages before other nodes - // have set up their listeners, those messages might be missed, potentially causing the signing process to fail. - // One solution: - // The messaging includes mechanisms for direct point-to-point communication (in point2point.go). - // The nodes could explicitly coordinate through request-response patterns before starting signing - time.Sleep(1 * time.Second) - - onSuccess := func(data []byte) { - done() - if natMsg.Reply != "" { - err = ec.pubsub.Publish(natMsg.Reply, data) + signingResult := event.SigningResultEvent{ + WalletID: msg.WalletID, + TxID: msg.TxID, + NetworkInternalCode: msg.NetworkInternalCode, + ResultType: event.SigningResultTypeSuccess, + Signature: data, + R: signatureData.R, + S: signatureData.S, + SignatureRecovery: signatureData.SignatureRecovery, + } + + signingResultBytes, err := json.Marshal(signingResult) if err != nil { - logger.Error("Failed to publish reply", err) - } else { - logger.Info("Reply to the original message", "reply", natMsg.Reply) + logger.Error("Failed to marshal signing result event", err) + ec.removeSession(msg.WalletID, msg.TxID) + return } + + err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(event.TypeSigningResultComplete, msg.WalletID, msg.TxID), + }) + if err != nil { + logger.Error("Failed to publish signing result event", err) + ec.removeSession(msg.WalletID, msg.TxID) + return + } + + logger.Info("Signing completed", "walletID", msg.WalletID, "txID", msg.TxID, "data", len(data)) + ec.removeSession(msg.WalletID, msg.TxID) + + }) + }() + + go func() { + for err := range signingSession.ErrCh() { + logger.Error("Error from session", err) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to sign tx", + natMsg, + ) + ec.removeSession(msg.WalletID, msg.TxID) } - } - go session.Sign(onSuccess) + }() }) ec.signingSub = sub @@ -336,8 +403,133 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return nil } +func (ec *eventConsumer) consumeResharingEvent() error { + sub, err := ec.pubsub.Subscribe(MPCResharingEvent, func(natMsg *nats.Msg) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + if err := ec.handleReshareEvent(ctx, natMsg.Data); err != nil { + logger.Error("Failed to handle resharing event", err) + } + }) + if err != nil { + return err + } + + ec.resharingSub = sub + return nil +} + +func (ec *eventConsumer) handleReshareEvent(ctx context.Context, raw []byte) error { + var msg types.ResharingMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return fmt.Errorf("unmarshal message: %w", err) + } + logger.Info("Received resharing event", + "walletID", msg.WalletID, + "oldThreshold", ec.mpcThreshold, + "newThreshold", msg.NewThreshold) + + if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { + return fmt.Errorf("verify initiator: %w", err) + } + + keyInfoVersion, _ := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) + + oldSession, err := ec.node.CreateResharingSession(true, msg.KeyType, msg.WalletID, ec.mpcThreshold, keyInfoVersion, ec.resharingResultQueue) + if err != nil { + return fmt.Errorf("create old session: %w", err) + } + + newSession, err := ec.node.CreateResharingSession(false, msg.KeyType, msg.WalletID, msg.NewThreshold, keyInfoVersion, ec.resharingResultQueue) + if err != nil { + return fmt.Errorf("create new session: %w", err) + } + + go oldSession.Listen(context.Background()) + go newSession.Listen(context.Background()) + + successEvent := &event.ResharingSuccessEvent{WalletID: msg.WalletID} + + var wg sync.WaitGroup + wg.Add(2) + + // Error monitor + go func() { + for { + select { + case err := <-oldSession.ErrCh(): + logger.Error("Error from old session", err) + case err := <-newSession.ErrCh(): + logger.Error("Error from new session", err) + } + } + }() + + // Start old session + go func() { + ctxOld, cancelOld := context.WithCancel(ctx) + defer cancelOld() + oldSession.StartResharing(ctxOld, + oldSession.PartyIDs(), + newSession.PartyIDs(), + ec.mpcThreshold, + msg.NewThreshold, + oldSession.Send, + func([]byte) { wg.Done() }, + ) + }() + + // Start new session + go func() { + ctxNew, cancelNew := context.WithCancel(ctx) + defer cancelNew() + newSession.StartResharing(ctxNew, + oldSession.PartyIDs(), + newSession.PartyIDs(), + ec.mpcThreshold, + msg.NewThreshold, + newSession.Send, + func(data []byte) { + if pubKey, err := newSession.GetPublicKey(data); err == nil { + newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), msg.NewThreshold, keyInfoVersion+1, data) + if msg.KeyType == types.KeyTypeSecp256k1 { + successEvent.ECDSAPubKey = pubKey + } else { + successEvent.EDDSAPubKey = pubKey + } + } else { + logger.Error("Failed to get public key", err) + } + wg.Done() + }, + ) + }() + + wg.Wait() + + eventBytes, err := json.Marshal(successEvent) + if err != nil { + return fmt.Errorf("marshal success event: %w", err) + } + + err = ec.resharingResultQueue.Enqueue( + event.ResharingSuccessEventTopic, + eventBytes, + &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(event.TypeResharingSuccess, msg.WalletID, keyInfoVersion+1), + }, + ) + if err != nil { + return fmt.Errorf("enqueue resharing success: %w", err) + } + + logger.Info("[COMPLETED RESH] Resharing completed successfully", "walletID", msg.WalletID) + return nil +} + func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkInternalCode string, err error, errMsg string, natMsg *nats.Msg) { - logger.Error("Signing session error", err, "walletID", walletID, "txID", txID, "error", errMsg) + logger.Error("signing session error", err, "walletID", walletID, "txID", txID, "error", errMsg) signingResult := event.SigningResultEvent{ ResultType: event.SigningResultTypeError, NetworkInternalCode: NetworkInternalCode, @@ -348,7 +540,7 @@ func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkIntern signingResultBytes, err := json.Marshal(signingResult) if err != nil { - logger.Error("Failed to marshal signing result event", err) + logger.Error("failed to marshal signing result event", err) return } @@ -430,14 +622,24 @@ func (ec *eventConsumer) Close() error { // Signal cleanup routine to stop close(ec.cleanupStopChan) - err := ec.keyGenerationSub.Unsubscribe() - if err != nil { - return err + if ec.keyGenerationSub != nil { + if err := ec.keyGenerationSub.Unsubscribe(); err != nil { + return err + } } - err = ec.signingSub.Unsubscribe() - if err != nil { - return err + if ec.signingSub != nil { + if err := ec.signingSub.Unsubscribe(); err != nil { + return err + } } + if ec.resharingSub != nil { + if err := ec.resharingSub.Unsubscribe(); err != nil { + return err + } + } + + // Ensure all monitoring logs are written to disk before exiting. + monitoring.Close() return nil } diff --git a/pkg/eventconsumer/events.go b/pkg/eventconsumer/events.go index 5b9ca06..4d71714 100644 --- a/pkg/eventconsumer/events.go +++ b/pkg/eventconsumer/events.go @@ -6,7 +6,7 @@ type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" - KeyTypeEd25519 = "ed25519" + KeyTypeEd25519 KeyType = "ed25519" ) // InitiatorMessage is anything that carries a payload to verify and its signature. diff --git a/pkg/eventconsumer/keygen_consumer.go b/pkg/eventconsumer/keygen_consumer.go new file mode 100644 index 0000000..c810d46 --- /dev/null +++ b/pkg/eventconsumer/keygen_consumer.go @@ -0,0 +1,146 @@ +package eventconsumer + +import ( + "context" + "errors" + "time" + + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +const ( + // Maximum time to wait for a keygen response. + keygenResponseTimeout = 90 * time.Second + // How often to poll for the reply message. + keygenPollingInterval = 1 * time.Second +) + +// KeygenConsumer represents a consumer that processes keygen events. +type KeygenConsumer interface { + // Run starts the consumer and blocks until the provided context is canceled. + Run(ctx context.Context) error + // Close performs a graceful shutdown of the consumer. + Close() error +} + +// keygenConsumer implements KeygenConsumer. +type keygenConsumer struct { + natsConn *nats.Conn + pubsub messaging.PubSub + keygenRequestQueue messaging.MessageQueue + + // jsSub holds the JetStream subscription, so it can be cleaned up during Close(). + jsSub messaging.Subscription +} + +// NewKeygenConsumer returns a new instance of KeygenConsumer. +func NewKeygenConsumer(natsConn *nats.Conn, keygenRequestQueue messaging.MessageQueue, pubsub messaging.PubSub) KeygenConsumer { + return &keygenConsumer{ + natsConn: natsConn, + pubsub: pubsub, + keygenRequestQueue: keygenRequestQueue, + } +} + +// Run subscribes to keygen events and processes them until the context is canceled. +func (sc *keygenConsumer) Run(ctx context.Context) error { + logger.Info("Starting key generation event consumer") + + go func() { + // Initial fetch + logger.Info("Calling to fetch key generation events...") + err := sc.keygenRequestQueue.Fetch(5, func(msg jetstream.Msg) error { + sc.handleKeygenEvent(msg) + return nil + }) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + logger.Error("Error fetching key generation events", err) + } + + // Then start the ticker + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + logger.Info("Stopping key generation event processing") + return + + case <-ticker.C: + logger.Info("Calling to fetch key generation events...") + + err := sc.keygenRequestQueue.Fetch(5, func(msg jetstream.Msg) error { + sc.handleKeygenEvent(msg) + return nil + }) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + logger.Error("Error fetching key generation events", err) + } + } + } + }() + + return nil +} + +func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { + // Create a reply inbox to receive the keygen event response. + replyInbox := nats.NewInbox() + + // Use a synchronous subscription for the reply inbox. + replySub, err := sc.natsConn.SubscribeSync(replyInbox) + if err != nil { + logger.Error("KeygenConsumer: Failed to subscribe to reply inbox", err) + _ = msg.Term() + return + } + defer replySub.Unsubscribe() + + // Publish the keygen event with the reply inbox. + if err := sc.pubsub.PublishWithReply(MPCGenerateEvent, replyInbox, msg.Data()); err != nil { + logger.Error("KeygenConsumer: Failed to publish keygen event with reply", err) + _ = msg.Term() + return + } + + deadline := time.Now().Add(keygenResponseTimeout) + for time.Now().Before(deadline) { + replyMsg, err := replySub.NextMsg(keygenPollingInterval) + if err != nil { + if err == nats.ErrTimeout { + continue + } + logger.Error("KeygenConsumer: Error receiving reply message", err) + break + } + if replyMsg != nil { + logger.Info("KeygenConsumer: Completed keygen event reply received") + if err := msg.Ack(); err != nil && !messaging.IsAlreadyAcknowledged(err) { + logger.Error("KeygenConsumer: ACK failed", err) + } + return + } + } + + // Timeout + logger.Warn("KeygenConsumer: Timeout waiting for keygen event response") + if err := msg.Term(); err != nil { + logger.Error("KeygenConsumer: Failed to terminate message", err) + } +} + +// Close unsubscribes from the JetStream subject and cleans up resources. +func (sc *keygenConsumer) Close() error { + if sc.jsSub != nil { + if err := sc.jsSub.Unsubscribe(); err != nil { + logger.Error("KeygenConsumer: Failed to unsubscribe from JetStream", err) + return err + } + logger.Info("KeygenConsumer: Unsubscribed from JetStream") + } + return nil +} diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 863a696..0ee5800 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -8,6 +8,7 @@ import ( "io" "os" "path/filepath" + "strings" "sync" "syscall" @@ -264,5 +265,5 @@ func (s *fileStore) VerifyInitiatorMessage(msg types.InitiatorMessage) error { } func partyIDToNodeID(partyID *tss.PartyID) string { - return string(partyID.KeyInt().Bytes()) + return strings.Split(partyID.Moniker, ":")[0] } diff --git a/pkg/keyinfo/keyinfo.go b/pkg/keyinfo/keyinfo.go index 6952e7d..49d4c7f 100644 --- a/pkg/keyinfo/keyinfo.go +++ b/pkg/keyinfo/keyinfo.go @@ -11,6 +11,7 @@ import ( type KeyInfo struct { ParticipantPeerIDs []string `json:"participant_peer_ids"` Threshold int `json:"threshold"` + Version int `json:"version"` } type store struct { diff --git a/pkg/messaging/message_queue.go b/pkg/messaging/message_queue.go index d5a1c7f..d1a5f95 100644 --- a/pkg/messaging/message_queue.go +++ b/pkg/messaging/message_queue.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "strings" + "time" "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" @@ -17,6 +19,7 @@ var ( type MessageQueue interface { Enqueue(topic string, message []byte, options *EnqueueOptions) error Dequeue(topic string, handler func(message []byte) error) error + Fetch(batch int, handler func(msg jetstream.Msg) error) error Close() } @@ -31,6 +34,13 @@ type msgQueue struct { consumerContext jetstream.ConsumeContext } +type msgPull struct { + consumerName string + js jetstream.JetStream + consumer jetstream.Consumer + fetchMaxWait time.Duration +} + type NATsMessageQueueManager struct { queueName string js jetstream.JetStream @@ -57,7 +67,7 @@ func NewNATsMessageQueueManager(queueName string, subjectWildCards []string, nc Name: queueName, Description: "Stream for " + queueName, Subjects: subjectWildCards, - MaxBytes: 1024, + MaxBytes: 100_000_000, // Light Production (Low Traffic) 100_000_000 (100 MB) Storage: jetstream.FileStorage, Retention: jetstream.WorkQueuePolicy, }) @@ -81,7 +91,10 @@ func (m *NATsMessageQueueManager) NewMessageQueue(consumerName string) MessageQu cfg := jetstream.ConsumerConfig{ Name: consumerName, Durable: consumerName, - MaxAckPending: 4, + MaxAckPending: 1000, + // If a message isn't acked within AckWait, it will be redelivered up to MaxDelive + AckWait: 60 * time.Second, + AckPolicy: jetstream.AckExplicitPolicy, FilterSubjects: []string{ consumerWildCard, }, @@ -151,6 +164,70 @@ func (mq *msgQueue) Close() { } } -func (n *msgQueue) handleReconnect(nc *nats.Conn) { - logger.Info("NATS: Reconnected to NATS") +func (m *NATsMessageQueueManager) NewMessagePullSubscriber(consumerName string) MessageQueue { + mq := &msgQueue{ + consumerName: consumerName, + js: m.js, + } + consumerWildCard := fmt.Sprintf("%s.%s.*", m.queueName, consumerName) + cfg := jetstream.ConsumerConfig{ + Name: consumerName, + Durable: consumerName, + MaxAckPending: 1000, + // If a message isn't acked within AckWait, it will be redelivered up to MaxDelive + AckWait: 180 * time.Second, + MaxWaiting: 1000, + AckPolicy: jetstream.AckExplicitPolicy, + FilterSubjects: []string{ + consumerWildCard, + }, + MaxDeliver: 1, + MaxRequestBatch: 10, + } + + logger.Info("Creating pull consumer for subject", "config", cfg) + consumer, err := m.js.CreateOrUpdateConsumer(context.Background(), m.queueName, cfg) + if err != nil { + logger.Fatal("Error creating JetStream consumer: ", err) + } + + mq.consumer = consumer + return mq +} + +func (mq *msgQueue) Fetch(batch int, handler func(msg jetstream.Msg) error) error { + // Reduced fetch timeout from 2 minutes to 30 seconds for faster processing + msgs, err := mq.consumer.Fetch(batch, jetstream.FetchMaxWait(30*time.Second)) + if err != nil { + return fmt.Errorf("error fetching messages: %w", err) + } + + for msg := range msgs.Messages() { + meta, _ := msg.Metadata() + logger.Debug("Received message", "meta", meta) // Changed to Debug to reduce log noise + err := handler(msg) + if err != nil { + if errors.Is(err, ErrPermament) { + logger.Info("Permanent error on message", "subject", msg.Subject) + msg.Term() + continue + } + + logger.Error("Error handling message: ", err) + msg.Nak() + continue + } + + err = msg.Ack() + if err != nil { + if !IsAlreadyAcknowledged(err) { + logger.Error("Error acknowledging message:", err) + } + } + } + return nil +} + +func IsAlreadyAcknowledged(err error) bool { + return err != nil && strings.Contains(err.Error(), nats.ErrMsgAlreadyAckd.Error()) } diff --git a/pkg/messaging/point2point.go b/pkg/messaging/point2point.go index a9af8f7..143b99c 100644 --- a/pkg/messaging/point2point.go +++ b/pkg/messaging/point2point.go @@ -37,7 +37,7 @@ func (d *natsDirectMessaging) Send(id string, message []byte) error { retry.Delay(50*time.Millisecond), retry.DelayType(retry.FixedDelay), retry.OnRetry(func(n uint, err error) { - logger.Error("Failed to send direct message message", err, "retryCount", retryCount) + logger.Error("Failed to send direct message message", err, "retryCount", retryCount, "target", id) }), ) diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 9860e02..4e64a13 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -17,7 +17,7 @@ type Subscription interface { type PubSub interface { Publish(topic string, message []byte) error - PublishWithReply(ttopic, reply string, data []byte) error + PublishWithReply(topic, reply string, data []byte) error Subscribe(topic string, handler func(msg *nats.Msg)) (Subscription, error) } diff --git a/pkg/monitoring/recorder.go b/pkg/monitoring/recorder.go new file mode 100644 index 0000000..4e6786f --- /dev/null +++ b/pkg/monitoring/recorder.go @@ -0,0 +1,77 @@ +package monitoring + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" + + "github.com/fystack/mpcium/pkg/logger" +) + +// KeygenTimestamps holds the structured data for a single key generation event. +type KeygenTimestamps struct { + WalletID string `json:"wallet_id"` + NodeID string `json:"node_id"` + KeyType string `json:"key_type"` + StartTime time.Time `json:"start_time"` + CompletionTime time.Time `json:"completion_time"` + InitDurationMs int64 `json:"init_duration_ms"` + RunDurationMs int64 `json:"run_duration_ms"` +} + +var ( + logFile *os.File + logOnce sync.Once + logMux sync.Mutex +) + +// initLogFile initializes the log file for appending. It ensures this only happens once. +func initLogFile() { + logOnce.Do(func() { + logDir := "monitoring" + if err := os.MkdirAll(logDir, 0755); err != nil { + logger.Error("Failed to create monitoring directory", err) + return + } + + var err error + logFile, err = os.OpenFile(filepath.Join(logDir, "keygen_times.jsonl"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + logger.Error("Failed to open keygen log file", err) + } + }) +} + +// RecordKeygenCompletion marshals the timestamp data to JSON and writes it to the log file. +func RecordKeygenCompletion(data KeygenTimestamps) { + initLogFile() + if logFile == nil { + return // Initialization failed + } + + logMux.Lock() + defer logMux.Unlock() + + line, err := json.Marshal(data) + if err != nil { + logger.Error("Failed to marshal keygen timestamp data", err) + return + } + + if _, err := logFile.Write(append(line, '\n')); err != nil { + logger.Error("Failed to write to keygen log file", err) + } +} + +// Close ensures the log file is synced and closed gracefully. +func Close() { + logMux.Lock() + defer logMux.Unlock() + + if logFile != nil { + logFile.Sync() + logFile.Close() + } +} diff --git a/pkg/mpc/ecdsa_keygen_session.go b/pkg/mpc/ecdsa_keygen_session.go deleted file mode 100644 index 98dee70..0000000 --- a/pkg/mpc/ecdsa_keygen_session.go +++ /dev/null @@ -1,151 +0,0 @@ -package mpc - -import ( - "crypto/ecdsa" - "encoding/json" - "fmt" - - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/encoding" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" -) - -const ( - TypeGenerateWalletSuccess = "mpc.mpc_keygen_success.%s" -) - -type KeygenSession struct { - Session - endCh chan *keygen.LocalPartySaveData -} - -type KeygenSuccessEvent struct { - WalletID string `json:"wallet_id"` - ECDSAPubKey []byte `json:"ecdsa_pub_key"` - EDDSAPubKey []byte `json:"eddsa_pub_key"` -} - -func NewKeygenSession( - walletID string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - partyIDs []*tss.PartyID, - threshold int, - preParams *keygen.LocalPreParams, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *KeygenSession { - return &KeygenSession{ - Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: threshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: partyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - preParams: preParams, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("keygen:broadcast:ecdsa:%s", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:ecdsa:%s:%s", nodeID, walletID) - }, - }, - composeKey: func(walletID string) string { - return fmt.Sprintf("ecdsa:%s", walletID) - }, - getRoundFunc: GetEcdsaMsgRound, - resultQueue: resultQueue, - sessionType: SessionTypeEcdsa, - identityStore: identityStore, - }, - endCh: make(chan *keygen.LocalPartySaveData), - } -} - -func (s *KeygenSession) Init() { - logger.Infof("Initializing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - ctx := tss.NewPeerContext(s.partyIDs) - params := tss.NewParameters(tss.S256(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - s.party = keygen.NewLocalParty(params, s.outCh, s.endCh, *s.preParams) - logger.Infof("[INITIALIZED] Initialized session successfully partyID: %s, peerIDs %s, walletID %s, threshold = %d", s.selfPartyID, s.partyIDs, s.walletID, s.threshold) -} - -func (s *KeygenSession) GenerateKey(done func()) { - logger.Info("Starting to generate key", "walletID", s.walletID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - select { - case msg := <-s.outCh: - s.handleTssMessage(msg) - case saveData := <-s.endCh: - keyBytes, err := json.Marshal(saveData) - if err != nil { - s.ErrCh <- err - return - } - - err = s.kvstore.Put(s.composeKey(s.walletID), keyBytes) - if err != nil { - logger.Error("Failed to save key", err, "walletID", s.walletID) - s.ErrCh <- err - return - } - - keyInfo := keyinfo.KeyInfo{ - ParticipantPeerIDs: s.participantPeerIDs, - Threshold: s.threshold, - } - - err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) - if err != nil { - logger.Error("Failed to save keyinfo", err, "walletID", s.walletID) - s.ErrCh <- err - return - } - - publicKey := saveData.ECDSAPub - - pubKey := &ecdsa.PublicKey{ - Curve: publicKey.Curve(), - X: publicKey.X(), - Y: publicKey.Y(), - } - - pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) - if err != nil { - logger.Error("failed to encode public key", err) - s.ErrCh <- fmt.Errorf("failed to encode public key: %w", err) - return - } - s.pubkeyBytes = pubKeyBytes - done() - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - // done() - return - } - } -} diff --git a/pkg/mpc/ecdsa_rounds.go b/pkg/mpc/ecdsa_rounds.go deleted file mode 100644 index 8e70f7a..0000000 --- a/pkg/mpc/ecdsa_rounds.go +++ /dev/null @@ -1,120 +0,0 @@ -package mpc - -import ( - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" -) - -const ( - KEYGEN1 = "KGRound1Message" - KEYGEN2aUnicast = "KGRound2Message1" - KEYGEN2b = "KGRound2Message2" - KEYGEN3 = "KGRound3Message" - KEYSIGN1aUnicast = "SignRound1Message1" - KEYSIGN1b = "SignRound1Message2" - KEYSIGN2Unicast = "SignRound2Message" - KEYSIGN3 = "SignRound3Message" - KEYSIGN4 = "SignRound4Message" - KEYSIGN5 = "SignRound5Message" - KEYSIGN6 = "SignRound6Message" - KEYSIGN7 = "SignRound7Message" - KEYSIGN8 = "SignRound8Message" - KEYSIGN9 = "SignRound9Message" - TSSKEYGENROUNDS = 4 - TSSKEYSIGNROUNDS = 10 -) - -func GetEcdsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (RoundInfo, error) { - parsedMsg, err := tss.ParseWireMessage(msg, partyID, isBroadcast) - if err != nil { - return RoundInfo{}, err - } - switch parsedMsg.Content().(type) { - case *keygen.KGRound1Message: - return RoundInfo{ - Index: 0, - RoundMsg: KEYGEN1, - }, nil - - case *keygen.KGRound2Message1: - return RoundInfo{ - Index: 1, - RoundMsg: KEYGEN2aUnicast, - }, nil - - case *keygen.KGRound2Message2: - return RoundInfo{ - Index: 2, - RoundMsg: KEYGEN2b, - }, nil - - case *keygen.KGRound3Message: - return RoundInfo{ - Index: 3, - RoundMsg: KEYGEN3, - }, nil - - case *signing.SignRound1Message1: - return RoundInfo{ - Index: 0, - RoundMsg: KEYSIGN1aUnicast, - }, nil - - case *signing.SignRound1Message2: - return RoundInfo{ - Index: 1, - RoundMsg: KEYSIGN1b, - }, nil - - case *signing.SignRound2Message: - return RoundInfo{ - Index: 2, - RoundMsg: KEYSIGN2Unicast, - }, nil - - case *signing.SignRound3Message: - return RoundInfo{ - Index: 3, - RoundMsg: KEYSIGN3, - }, nil - - case *signing.SignRound4Message: - return RoundInfo{ - Index: 4, - RoundMsg: KEYSIGN4, - }, nil - - case *signing.SignRound5Message: - return RoundInfo{ - Index: 5, - RoundMsg: KEYSIGN5, - }, nil - - case *signing.SignRound6Message: - return RoundInfo{ - Index: 6, - RoundMsg: KEYSIGN6, - }, nil - - case *signing.SignRound7Message: - return RoundInfo{ - Index: 7, - RoundMsg: KEYSIGN7, - }, nil - case *signing.SignRound8Message: - return RoundInfo{ - Index: 8, - RoundMsg: KEYSIGN8, - }, nil - case *signing.SignRound9Message: - return RoundInfo{ - Index: 9, - RoundMsg: KEYSIGN9, - }, nil - - default: - return RoundInfo{}, errors.New("unknown round") - } -} diff --git a/pkg/mpc/ecdsa_signing_session.go b/pkg/mpc/ecdsa_signing_session.go deleted file mode 100644 index c0dcbaf..0000000 --- a/pkg/mpc/ecdsa_signing_session.go +++ /dev/null @@ -1,204 +0,0 @@ -package mpc - -import ( - "crypto/ecdsa" - "encoding/json" - "fmt" - "math/big" - - "github.com/bnb-chain/tss-lib/v2/common" - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" - "github.com/fystack/mpcium/pkg/event" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/samber/lo" -) - -// Ecdsa signing session -type SigningSession struct { - Session - endCh chan *common.SignatureData - data *keygen.LocalPartySaveData - tx *big.Int - txID string - networkInternalCode string -} - -type ISession interface { - ErrChan() <-chan error - ListenToIncomingMessageAsync() -} - -type ISigningSession interface { - ISession - - Init(tx *big.Int) error - Sign(onSuccess func(data []byte)) -} - -func NewSigningSession( - walletID string, - txID string, - networkInternalCode string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - partyIDs []*tss.PartyID, - threshold int, - preParams *keygen.LocalPreParams, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *SigningSession { - return &SigningSession{ - Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: threshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: partyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - preParams: preParams, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("sign:ecdsa:broadcast:%s:%s", walletID, txID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("sign:ecdsa:direct:%s:%s", nodeID, txID) - }, - }, - composeKey: func(waleltID string) string { - return fmt.Sprintf("ecdsa:%s", waleltID) - }, - getRoundFunc: GetEcdsaMsgRound, - resultQueue: resultQueue, - identityStore: identityStore, - }, - endCh: make(chan *common.SignatureData), - txID: txID, - networkInternalCode: networkInternalCode, - } -} - -func (s *SigningSession) Init(tx *big.Int) error { - logger.Infof("Initializing signing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - ctx := tss.NewPeerContext(s.partyIDs) - params := tss.NewParameters(tss.S256(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - - keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get wallet data from KVStore") - } - - keyInfo, err := s.keyinfoStore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get key info data") - } - - if len(s.participantPeerIDs) < keyInfo.Threshold+1 { - logger.Warn("Not enough participants to sign", "participants", s.participantPeerIDs, "expected", keyInfo.Threshold+1) - return ErrNotEnoughParticipants - } - - // check if t+1 participants are present - result := lo.Intersect(s.participantPeerIDs, keyInfo.ParticipantPeerIDs) - if len(result) < keyInfo.Threshold+1 { - return fmt.Errorf( - "Incompatible peerIDs to participate in signing. Current participants: %v, expected participants: %v", - s.participantPeerIDs, - keyInfo.ParticipantPeerIDs, - ) - } - - logger.Info("Have enough participants to sign", "participants", s.participantPeerIDs) - // Check if all the participants of the key are present - var data keygen.LocalPartySaveData - err = json.Unmarshal(keyData, &data) - if err != nil { - return errors.Wrap(err, "Failed to unmarshal wallet data") - } - - s.party = signing.NewLocalParty(tx, params, data, s.outCh, s.endCh) - s.data = &data - s.tx = tx - logger.Info("Initialized sigining session successfully!") - return nil -} - -func (s *SigningSession) Sign(onSuccess func(data []byte)) { - logger.Info("Starting signing", "walletID", s.walletID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - - select { - case msg := <-s.outCh: - s.handleTssMessage(msg) - case sig := <-s.endCh: - publicKey := *s.data.ECDSAPub - pk := ecdsa.PublicKey{ - Curve: publicKey.Curve(), - X: publicKey.X(), - Y: publicKey.Y(), - } - - ok := ecdsa.Verify(&pk, s.tx.Bytes(), new(big.Int).SetBytes(sig.R), new(big.Int).SetBytes(sig.S)) - if !ok { - s.ErrCh <- errors.New("Failed to verify signature") - return - } - - r := event.SigningResultEvent{ - ResultType: event.SigningResultTypeSuccess, - NetworkInternalCode: s.networkInternalCode, - WalletID: s.walletID, - TxID: s.txID, - R: sig.R, - S: sig.S, - SignatureRecovery: sig.SignatureRecovery, - } - - bytes, err := json.Marshal(r) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Failed to marshal raw signature") - return - } - - err = s.resultQueue.Enqueue(event.SigningResultCompleteTopic, bytes, &messaging.EnqueueOptions{ - IdempotententKey: s.txID, - }) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Failed to publish sign success message") - - return - } - - logger.Info("[SIGN] Sign successfully", "walletID", s.walletID) - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - - onSuccess(bytes) - return - } - - } -} diff --git a/pkg/mpc/eddsa_keygen_session.go b/pkg/mpc/eddsa_keygen_session.go deleted file mode 100644 index 7a1b325..0000000 --- a/pkg/mpc/eddsa_keygen_session.go +++ /dev/null @@ -1,137 +0,0 @@ -package mpc - -import ( - "encoding/json" - "fmt" - - "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/decred/dcrd/dcrec/edwards/v2" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" -) - -type EDDSAKeygenSession struct { - Session - endCh chan *keygen.LocalPartySaveData -} - -type EDDSAKeygenSuccessEvent struct { - WalletID string `json:"wallet_id"` - PubKey []byte `json:"pub_key"` -} - -func NewEDDSAKeygenSession( - walletID string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - partyIDs []*tss.PartyID, - threshold int, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *EDDSAKeygenSession { - return &EDDSAKeygenSession{Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: threshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: partyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("keygen:broadcast:eddsa:%s", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:eddsa:%s:%s", nodeID, walletID) - }, - }, - composeKey: func(waleltID string) string { - return fmt.Sprintf("eddsa:%s", waleltID) - }, - getRoundFunc: GetEddsaMsgRound, - resultQueue: resultQueue, - sessionType: SessionTypeEddsa, - identityStore: identityStore, - }, - endCh: make(chan *keygen.LocalPartySaveData), - } -} - -func (s *EDDSAKeygenSession) Init() { - logger.Infof("Initializing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - ctx := tss.NewPeerContext(s.partyIDs) - params := tss.NewParameters(tss.Edwards(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - s.party = keygen.NewLocalParty(params, s.outCh, s.endCh) - logger.Infof("[INITIALIZED] Initialized session successfully partyID: %s, peerIDs %s, walletID %s, threshold = %d", s.selfPartyID, s.partyIDs, s.walletID, s.threshold) -} - -func (s *EDDSAKeygenSession) GenerateKey(done func()) { - logger.Info("Starting to generate key", "walletID", s.walletID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - select { - case msg := <-s.outCh: - s.handleTssMessage(msg) - case saveData := <-s.endCh: - keyBytes, err := json.Marshal(saveData) - if err != nil { - s.ErrCh <- err - return - } - - err = s.kvstore.Put(s.composeKey(s.walletID), keyBytes) - if err != nil { - logger.Error("Failed to save key", err, "walletID", s.walletID) - s.ErrCh <- err - return - } - - keyInfo := keyinfo.KeyInfo{ - ParticipantPeerIDs: s.participantPeerIDs, - Threshold: s.threshold, - } - - err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) - if err != nil { - logger.Error("Failed to save keyinfo", err, "walletID", s.walletID) - s.ErrCh <- err - return - } - - publicKey := saveData.EDDSAPub - pkX, pkY := publicKey.X(), publicKey.Y() - pk := edwards.PublicKey{ - Curve: tss.Edwards(), - X: pkX, - Y: pkY, - } - - pubKeyBytes := pk.SerializeCompressed() - s.pubkeyBytes = pubKeyBytes - - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - done() - return - } - } -} diff --git a/pkg/mpc/eddsa_rounds.go b/pkg/mpc/eddsa_rounds.go deleted file mode 100644 index 01519d0..0000000 --- a/pkg/mpc/eddsa_rounds.go +++ /dev/null @@ -1,74 +0,0 @@ -package mpc - -import ( - "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" - "github.com/bnb-chain/tss-lib/v2/eddsa/signing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" -) - -type GetRoundFunc func(msg []byte, partyID *tss.PartyID, isBroadcast bool) (RoundInfo, error) - -type RoundInfo struct { - Index int - RoundMsg string - MsgIdentifier string -} - -const ( - EDDSA_KEYGEN1 = "KGRound1Message" - EDDSA_KEYGEN2aUnicast = "KGRound2Message1" - EDDSA_KEYGEN2b = "KGRound2Message2" - EDDSA_KEYSIGN1 = "SignRound1Message" - EDDSA_KEYSIGN2 = "SignRound2Message" - EDDSA_KEYSIGN3 = "SignRound3Message" - EDDSA_TSSKEYGENROUNDS = 3 - EDDSA_TSSKEYSIGNROUNDS = 3 -) - -func GetEddsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (RoundInfo, error) { - parsedMsg, err := tss.ParseWireMessage(msg, partyID, isBroadcast) - if err != nil { - return RoundInfo{}, err - } - switch parsedMsg.Content().(type) { - case *keygen.KGRound1Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_KEYGEN1, - }, nil - - case *keygen.KGRound2Message1: - return RoundInfo{ - Index: 1, - RoundMsg: EDDSA_KEYGEN2aUnicast, - }, nil - - case *keygen.KGRound2Message2: - return RoundInfo{ - Index: 2, - RoundMsg: EDDSA_KEYGEN2b, - }, nil - - case *signing.SignRound1Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_KEYSIGN1, - }, nil - - case *signing.SignRound2Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_KEYSIGN2, - }, nil - - case *signing.SignRound3Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_KEYSIGN3, - }, nil - - default: - return RoundInfo{}, errors.New("unknown round") - } -} diff --git a/pkg/mpc/eddsa_signing_session.go b/pkg/mpc/eddsa_signing_session.go deleted file mode 100644 index c421839..0000000 --- a/pkg/mpc/eddsa_signing_session.go +++ /dev/null @@ -1,188 +0,0 @@ -package mpc - -import ( - "encoding/json" - "fmt" - "math/big" - - "github.com/bnb-chain/tss-lib/v2/common" - "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" - "github.com/bnb-chain/tss-lib/v2/eddsa/signing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" - "github.com/fystack/mpcium/pkg/event" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/decred/dcrd/dcrec/edwards/v2" - "github.com/samber/lo" -) - -type EDDSASigningSession struct { - Session - endCh chan *common.SignatureData - data *keygen.LocalPartySaveData - tx *big.Int - txID string - networkInternalCode string -} - -func NewEDDSASigningSession( - walletID string, - txID string, - networkInternalCode string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - partyIDs []*tss.PartyID, - threshold int, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *EDDSASigningSession { - return &EDDSASigningSession{ - Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: threshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: partyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - // preParams: preParams, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("sign:eddsa:broadcast:%s:%s", walletID, txID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("sign:eddsa:direct:%s:%s", nodeID, txID) - }, - }, - composeKey: func(waleltID string) string { - return fmt.Sprintf("eddsa:%s", waleltID) - }, - getRoundFunc: GetEddsaMsgRound, - resultQueue: resultQueue, - identityStore: identityStore, - }, - endCh: make(chan *common.SignatureData), - txID: txID, - networkInternalCode: networkInternalCode, - } -} - -func (s *EDDSASigningSession) Init(tx *big.Int) error { - logger.Infof("Initializing signing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - ctx := tss.NewPeerContext(s.partyIDs) - params := tss.NewParameters(tss.Edwards(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - - keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get wallet data from KVStore") - } - - keyInfo, err := s.keyinfoStore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get key info data") - } - - if len(s.participantPeerIDs) < keyInfo.Threshold+1 { - logger.Warn("Not enough participants to sign, expected %d, got %d", keyInfo.Threshold+1, len(s.participantPeerIDs)) - return ErrNotEnoughParticipants - } - - // check if t+1 participants are present - result := lo.Intersect(s.participantPeerIDs, keyInfo.ParticipantPeerIDs) - if len(result) < keyInfo.Threshold+1 { - return fmt.Errorf( - "Incompatible peerIDs to participate in signing. Current participants: %v, expected participants: %v", - s.participantPeerIDs, - keyInfo.ParticipantPeerIDs, - ) - } - - logger.Info("Have enough participants to sign", "participants", s.participantPeerIDs) - // Check if all the participants of the key are present - var data keygen.LocalPartySaveData - err = json.Unmarshal(keyData, &data) - if err != nil { - return errors.Wrap(err, "Failed to unmarshal wallet data") - } - - s.party = signing.NewLocalParty(tx, params, data, s.outCh, s.endCh) - s.data = &data - s.tx = tx - logger.Info("Initialized sigining session successfully!") - return nil -} - -func (s *EDDSASigningSession) Sign(onSuccess func(data []byte)) { - logger.Info("Starting signing", "walletID", s.walletID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - - select { - case msg := <-s.outCh: - s.handleTssMessage(msg) - case sig := <-s.endCh: - publicKey := *s.data.EDDSAPub - pk := edwards.PublicKey{ - Curve: tss.Edwards(), - X: publicKey.X(), - Y: publicKey.Y(), - } - - ok := edwards.Verify(&pk, s.tx.Bytes(), new(big.Int).SetBytes(sig.R), new(big.Int).SetBytes(sig.S)) - if !ok { - s.ErrCh <- errors.New("Failed to verify signature") - return - } - - r := event.SigningResultEvent{ - ResultType: event.SigningResultTypeSuccess, - NetworkInternalCode: s.networkInternalCode, - WalletID: s.walletID, - TxID: s.txID, - Signature: sig.Signature, - } - - bytes, err := json.Marshal(r) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Failed to marshal raw signature") - return - } - - err = s.resultQueue.Enqueue(event.SigningResultCompleteTopic, bytes, &messaging.EnqueueOptions{ - IdempotententKey: s.txID, - }) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Failed to publish sign success message") - return - } - - logger.Info("[SIGN] Sign successfully", "walletID", s.walletID) - - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - - onSuccess(bytes) - return - } - - } -} diff --git a/pkg/mpc/key_type.go b/pkg/mpc/key_type.go deleted file mode 100644 index 756efa8..0000000 --- a/pkg/mpc/key_type.go +++ /dev/null @@ -1,8 +0,0 @@ -package mpc - -type KeyType string - -const ( - KeyTypeSecp256k1 KeyType = "secp256k1" - KeyTypeEd25519 = "ed25519" -) diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go deleted file mode 100644 index 6c105f4..0000000 --- a/pkg/mpc/node.go +++ /dev/null @@ -1,219 +0,0 @@ -package mpc - -import ( - "bytes" - "fmt" - "math/big" - "time" - - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/google/uuid" -) - -const ( - PurposeKeygen string = "keygen" - PurposeSign string = "sign" -) - -type ID string - -type Node struct { - nodeID string - peerIDs []string - - pubSub messaging.PubSub - direct messaging.DirectMessaging - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - ecdsaPreParams *keygen.LocalPreParams - identityStore identity.Store - - peerRegistry PeerRegistry -} - -func CreatePartyID(nodeID string, label string) *tss.PartyID { - partyID := uuid.NewString() - key := big.NewInt(0).SetBytes([]byte(nodeID)) - return tss.NewPartyID(partyID, label, key) -} - -func PartyIDToNodeID(partyID *tss.PartyID) string { - return string(partyID.KeyInt().Bytes()) -} - -func ComparePartyIDs(x, y *tss.PartyID) bool { - return bytes.Equal(x.KeyInt().Bytes(), y.KeyInt().Bytes()) -} - -func ComposeReadyKey(nodeID string) string { - return fmt.Sprintf("ready/%s", nodeID) -} - -func NewNode( - nodeID string, - peerIDs []string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - peerRegistry PeerRegistry, - identityStore identity.Store, -) *Node { - preParams, err := keygen.GeneratePreParams(5 * time.Minute) - if err != nil { - logger.Fatal("Generate pre params failed", err) - } - logger.Info("Starting new node, preparams is generated successfully!") - - go peerRegistry.WatchPeersReady() - - return &Node{ - nodeID: nodeID, - peerIDs: peerIDs, - pubSub: pubSub, - direct: direct, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - ecdsaPreParams: preParams, - peerRegistry: peerRegistry, - identityStore: identityStore, - } -} - -func (p *Node) ID() string { - return p.nodeID -} - -func composeReadyTopic(nodeID string) string { - return fmt.Sprintf("%s-%s", nodeID, "ready") -} - -func (p *Node) CreateKeyGenSession(walletID string, threshold int, successQueue messaging.MessageQueue) (*KeygenSession, error) { - if p.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { - return nil, fmt.Errorf("Not enough peers to create gen session! Expected %d, got %d", threshold+1, p.peerRegistry.GetReadyPeersCount()) - } - - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) - session := NewKeygenSession( - walletID, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - allPartyIDs, - threshold, - p.ecdsaPreParams, - p.kvstore, - p.keyinfoStore, - successQueue, - p.identityStore, - ) - return session, nil -} - -func (p *Node) CreateEDDSAKeyGenSession(walletID string, threshold int, successQueue messaging.MessageQueue) (*EDDSAKeygenSession, error) { - if p.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { - return nil, fmt.Errorf("Not enough peers to create gen session! Expected %d, got %d", threshold+1, p.peerRegistry.GetReadyPeersCount()) - } - - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) - session := NewEDDSAKeygenSession( - walletID, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - allPartyIDs, - threshold, - p.kvstore, - p.keyinfoStore, - successQueue, - p.identityStore, - ) - return session, nil -} - -func (p *Node) CreateSigningSession( - walletID string, - txID string, - networkInternalCode string, - threshold int, - resultQueue messaging.MessageQueue, -) (*SigningSession, error) { - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) - session := NewSigningSession( - walletID, - txID, - networkInternalCode, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - allPartyIDs, - threshold, - p.ecdsaPreParams, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.identityStore, - ) - return session, nil -} - -func (p *Node) CreateEDDSASigningSession( - walletID string, - txID string, - networkInternalCode string, - threshold int, - resultQueue messaging.MessageQueue, -) (*EDDSASigningSession, error) { - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) - session := NewEDDSASigningSession( - walletID, - txID, - networkInternalCode, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - allPartyIDs, - threshold, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.identityStore, - ) - return session, nil -} - -func (p *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { - var selfPartyID *tss.PartyID - partyIDs := make([]*tss.PartyID, len(readyPeerIDs)) - for i, peerID := range readyPeerIDs { - if peerID == p.nodeID { - selfPartyID = CreatePartyID(peerID, purpose) - partyIDs[i] = selfPartyID - } else { - partyIDs[i] = CreatePartyID(peerID, purpose) - } - } - allPartyIDs := tss.SortPartyIDs(partyIDs, 0) - return selfPartyID, allPartyIDs -} - -func (p *Node) Close() { - err := p.peerRegistry.Resign() - if err != nil { - logger.Error("Resign failed", err) - } -} diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go new file mode 100644 index 0000000..6917a7f --- /dev/null +++ b/pkg/mpc/node/node.go @@ -0,0 +1,355 @@ +package node + +import ( + "encoding/json" + "fmt" + "math/big" + "strconv" + "time" + + "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/infra" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/session" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" +) + +// DefaultVersion is the default version for keygen and resharing +const DefaultVersion = 1 + +type Node struct { + nodeID string + peerIDs []string + + pubSub messaging.PubSub + direct messaging.DirectMessaging + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + identityStore identity.Store + + peerRegistry *registry + consulKV infra.ConsulKV +} + +func NewNode( + nodeID string, + peerIDs []string, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + identityStore identity.Store, + peerRegistry *registry, + consulKV infra.ConsulKV, +) *Node { + go peerRegistry.WatchPeersReady() + + return &Node{ + nodeID: nodeID, + peerIDs: peerIDs, + pubSub: pubSub, + direct: direct, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + identityStore: identityStore, + peerRegistry: peerRegistry, + consulKV: consulKV, + } +} + +func (n *Node) ID() string { + return n.nodeID +} + +func (n *Node) CreateKeygenSession(keyType types.KeyType, walletID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { + if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { + return nil, fmt.Errorf("not enough peers to create gen session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) + } + + readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() + selfPartyID, allPartyIDs := n.generatePartyIDs(session.PurposeKeygen, readyPeerIDs, DefaultVersion) + switch keyType { + case types.KeyTypeSecp256k1: + preparams, err := n.getECDSAPreParams(false) + if err != nil { + return nil, fmt.Errorf("failed to get preparams: %w", err) + } + ecdsaSession := session.NewECDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + *preparams, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + n.consulKV, + ) + + return ecdsaSession, nil + case types.KeyTypeEd25519: + eddsaSession := session.NewEDDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + n.consulKV, + ) + return eddsaSession, nil + default: + return nil, fmt.Errorf("invalid key type: %s", keyType) + } +} + +func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID string, partyVersion int, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { + if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { + return nil, fmt.Errorf("not enough peers to create gen session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) + } + + readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() + selfPartyID, allPartyIDs := n.generatePartyIDs(session.PurposeSign, readyPeerIDs, partyVersion) + switch keyType { + case types.KeyTypeSecp256k1: + ecdsaSession := session.NewECDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + keygen.LocalPreParams{}, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + n.consulKV, + ) + saveData, err := ecdsaSession.GetSaveData(partyVersion) + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + + ecdsaSession.SetSaveData(saveData) + + return ecdsaSession, nil + case types.KeyTypeEd25519: + eddsaSession := session.NewEDDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + n.consulKV, + ) + + saveData, err := eddsaSession.GetSaveData(partyVersion) + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + + eddsaSession.SetSaveData(saveData) + + return eddsaSession, nil + default: + return nil, fmt.Errorf("invalid key type: %s", keyType) + } +} + +func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, walletID string, threshold int, partyVersion int, successQueue messaging.MessageQueue) (session.Session, error) { + if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { + return nil, fmt.Errorf("not enough peers to create resharing session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) + } + readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() + var selfPartyID *tss.PartyID + var partyIDs []*tss.PartyID + if isOldParty { + selfPartyID, partyIDs = n.generatePartyIDs(session.PurposeKeygen, readyPeerIDs, partyVersion) + } else { + selfPartyID, partyIDs = n.generatePartyIDs(session.PurposeReshare, readyPeerIDs, partyVersion+1) // Increment version for new parties + } + + switch keyType { + case types.KeyTypeSecp256k1: + preparams, err := n.getECDSAPreParams(isOldParty) + if err != nil { + return nil, fmt.Errorf("failed to get preparams: %w", err) + } + ecdsaSession := session.NewECDSASession( + walletID, + selfPartyID, + partyIDs, + threshold, + *preparams, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + n.consulKV, + ) + if isOldParty { + saveData, err := ecdsaSession.GetSaveData(partyVersion) + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + ecdsaSession.SetSaveData(saveData) + } else { + // Initialize new save data for new parties + // Reduce the loading time by pre-allocating the save data + saveData := keygen.NewLocalPartySaveData(len(partyIDs)) + saveData.LocalPreParams = *preparams + saveDataBytes, err := json.Marshal(saveData) + if err != nil { + return nil, fmt.Errorf("failed to marshal save data: %w", err) + } + ecdsaSession.SetSaveData(saveDataBytes) + } + return ecdsaSession, nil + case types.KeyTypeEd25519: + eddsaSession := session.NewEDDSASession( + walletID, + selfPartyID, + partyIDs, + threshold, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + n.consulKV, + ) + saveData, err := eddsaSession.GetSaveData(partyVersion) + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + eddsaSession.SetSaveData(saveData) + return eddsaSession, nil + default: + return nil, fmt.Errorf("invalid key type: %s", keyType) + } +} + +func (p *Node) Close() { + err := p.peerRegistry.Resign() + if err != nil { + logger.Error("Resign failed", err) + } +} + +func (n *Node) GetReadyPeersIncludeSelf() []string { + return n.peerRegistry.GetReadyPeersIncludeSelf() +} + +func (n *Node) GetKeyInfoVersion(keyType types.KeyType, walletID string) (int, error) { + var walletKey string + switch keyType { + case types.KeyTypeSecp256k1: + walletKey = fmt.Sprintf("ecdsa:%s", walletID) + case types.KeyTypeEd25519: + walletKey = fmt.Sprintf("eddsa:%s", walletID) + default: + return 0, fmt.Errorf("invalid key type: %s", keyType) + } + keyInfo, err := n.keyinfoStore.Get(walletKey) + if err != nil { + return 0, err + } + return int(keyInfo.Version), nil +} + +// PreloadPreParams preloads the preparams for the first time +func (n *Node) PreloadPreParams() { + _, err := n.getECDSAPreParams(false) + if err != nil { + logger.Error("Failed to get preparams", err) + } + _, err = n.getECDSAPreParams(true) + if err != nil { + logger.Error("Failed to get preparams", err) + } +} + +// For ecdsa, we need to generate preparams for each party +// Load preparams from kvstore if exists, otherwise generate and save to kvstore +func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error) { + var path string + if isOldParty { + path = fmt.Sprintf("preparams.old.%s", n.nodeID) + } else { + path = fmt.Sprintf("preparams.%s", n.nodeID) + } + + preparamsBytes, _ := n.kvstore.Get(path) + if preparamsBytes == nil { + logger.Info("Generating preparams", "isOldParty", isOldParty) + preparams, err := keygen.GeneratePreParams(5 * time.Minute) + if err != nil { + return nil, err + } + preparamsBytes, err = json.Marshal(preparams) + if err != nil { + return nil, err + } + n.kvstore.Put(path, preparamsBytes) + return preparams, nil + } + + var preparams keygen.LocalPreParams + if err := json.Unmarshal(preparamsBytes, &preparams); err != nil { + return nil, err + } + return &preparams, nil +} + +// generatePartyIDs generates the party IDs for the given purpose and version +// It returns the self party ID and all party IDs +// It also sorts the party IDs in place +func (n *Node) generatePartyIDs(purpose session.Purpose, readyPeerIDs []string, version int) (self *tss.PartyID, all []*tss.PartyID) { + // Pre-allocate slice with exact size needed + partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) + + // Create all party IDs in one pass + for _, peerID := range readyPeerIDs { + partyID := createPartyID(peerID, string(purpose), version) + if peerID == n.nodeID { + self = partyID + } + partyIDs = append(partyIDs, partyID) + } + + // Sort party IDs in place + all = tss.SortPartyIDs(partyIDs, 0) + return +} + +// createPartyID creates a new party ID for the given node ID, label and version +// It returns the party ID: random string +// Moniker: for routing messages +// Key: for mpc internal use (need persistent storage) +func createPartyID(nodeID string, label string, version int) *tss.PartyID { + partyID := uuid.NewString() + moniker := nodeID + ":" + label + var key *big.Int + if version == 0 { + key = big.NewInt(0).SetBytes([]byte(nodeID)) + } else { + key = big.NewInt(0).SetBytes([]byte(nodeID + ":" + strconv.Itoa(version))) + } + return tss.NewPartyID(partyID, moniker, key) +} diff --git a/pkg/mpc/registry.go b/pkg/mpc/node/registry.go similarity index 99% rename from pkg/mpc/registry.go rename to pkg/mpc/node/registry.go index dba9b55..98c65fa 100644 --- a/pkg/mpc/registry.go +++ b/pkg/mpc/node/registry.go @@ -1,4 +1,4 @@ -package mpc +package node import ( "fmt" diff --git a/pkg/mpc/node_test.go b/pkg/mpc/node_test.go deleted file mode 100644 index dafacda..0000000 --- a/pkg/mpc/node_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package mpc - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -// func TestCreateKeyGenSession(t *testing.T) { -// nodeID := uuid.NewString() - -// peerIDs := []string{ -// nodeID, -// uuid.NewString(), -// uuid.NewString(), -// } -// ctrl := gomock.NewController(t) -// defer ctrl.Finish() -// pubsub := mock.NewMockPubSub(ctrl) -// direct := mock.NewMockDirectMessaging(ctrl) - -// node := NewNode(nodeID, peerIDs, pubsub, direct) - -// session, err := node.CreateKeyGenSession() - -// assert.NoError(t, err) -// assert.Len(t, session.PartyIDs(), 3, "Length of partyIDs should be equal") -// assert.NotNil(t, session.PartyID()) - -// for i, partyID := range session.PartyIDs() { -// // check sortedID -// assert.Equal(t, partyID.Index, i, "Index should be equal") -// } - -// } - -func TestPartyIDToNodeID(t *testing.T) { - partyID := CreatePartyID("4d8cb873-dc86-4776-b6f6-cf5c668f6468", "keygen") - nodeID := PartyIDToNodeID(partyID) - assert.Equal(t, nodeID, "4d8cb873-dc86-4776-b6f6-cf5c668f6468", "NodeID should be equal") -} diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go new file mode 100644 index 0000000..ccb5094 --- /dev/null +++ b/pkg/mpc/party/base.go @@ -0,0 +1,159 @@ +package party + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/big" + "sync" + + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/types" +) + +type Party interface { + StartKeygen(ctx context.Context, send func(tss.Message), onComplete func([]byte)) + StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), onComplete func([]byte)) + StartResharing( + ctx context.Context, + oldPartyIDs, + newPartyIDs []*tss.PartyID, + oldThreshold, + newThreshold int, + send func(tss.Message), + onComplete func([]byte), + ) + + WalletID() string + PartyID() *tss.PartyID + PartyIDs() []*tss.PartyID + GetSaveData() []byte + SetSaveData(saveData []byte) + ClassifyMsg(msgBytes []byte) (uint8, bool, error) + InCh() chan types.TssMessage + OutCh() chan tss.Message + ErrCh() chan error +} + +type party struct { + walletID string + threshold int + partyID *tss.PartyID + partyIDs []*tss.PartyID + inCh chan types.TssMessage + outCh chan tss.Message + errCh chan error + + ctx context.Context + cancel context.CancelFunc + closeOnce sync.Once +} + +func NewParty( + walletID string, + partyID *tss.PartyID, + partyIDs []*tss.PartyID, + threshold int, + errCh chan error, +) *party { + return &party{ + walletID: walletID, + threshold: threshold, + partyID: partyID, + partyIDs: partyIDs, + inCh: make(chan types.TssMessage, 1000), + outCh: make(chan tss.Message, 1000), + errCh: errCh, + } +} + +func (p *party) WalletID() string { + return p.walletID +} + +func (p *party) PartyID() *tss.PartyID { + return p.partyID +} + +func (p *party) PartyIDs() []*tss.PartyID { + return p.partyIDs +} + +func (p *party) InCh() chan types.TssMessage { return p.inCh } +func (p *party) OutCh() chan tss.Message { return p.outCh } +func (p *party) ErrCh() chan error { return p.errCh } + +// runParty handles the common party execution loop +// startPartyLoop runs a TSS party, handling messages, errors, and completion. +func runParty[T any]( + s Party, + ctx context.Context, + party tss.Party, + send func(tss.Message), + endCh <-chan T, + onComplete func([]byte), +) { + // safe error reporter + safeErr := func(err error) { + select { + case s.ErrCh() <- err: + case <-ctx.Done(): + } + } + + // start the tss party logic + go func() { + defer func() { + if r := recover(); r != nil { + safeErr(fmt.Errorf("panic in party.Start: %v", r)) + } + }() + if err := party.Start(); err != nil { + safeErr(err) + } + }() + + // main handling loop + for { + select { + case <-ctx.Done(): + if ctx.Err() != context.Canceled { + safeErr(fmt.Errorf("party timed out: %w", ctx.Err())) + } + return + + case inMsg, ok := <-s.InCh(): + if !ok { + return + } + ok2, err := party.UpdateFromBytes(inMsg.MsgBytes, inMsg.From, inMsg.IsBroadcast) + if err != nil || !ok2 { + safeErr(errors.New("UpdateFromBytes failed")) + return + } + + case outMsg, ok := <-s.OutCh(): + if !ok { + return + } + // respect cancellation before invoking callback + if ctx.Err() != nil { + return + } + send(outMsg) + + case result, ok := <-endCh: + if !ok { + return + } + bts, err := json.Marshal(result) + if err != nil { + safeErr(err) + return + } + onComplete(bts) + return + } + } +} diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go new file mode 100644 index 0000000..6f2881c --- /dev/null +++ b/pkg/mpc/party/ecdsa.go @@ -0,0 +1,139 @@ +package party + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/big" + "time" + + "github.com/bnb-chain/tss-lib/v2/common" + "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" + "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/monitoring" + "github.com/golang/protobuf/ptypes/any" + "google.golang.org/protobuf/proto" +) + +type ECDSAParty struct { + party + preParams keygen.LocalPreParams + saveData *keygen.LocalPartySaveData + KeygenStart time.Time + KeygenCompletion time.Time +} + +func NewECDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, + preParams keygen.LocalPreParams, errCh chan error) *ECDSAParty { + return &ECDSAParty{ + party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), + preParams: preParams, + } +} + +func (s *ECDSAParty) GetSaveData() []byte { + saveData, err := json.Marshal(s.saveData) + if err != nil { + s.ErrCh() <- fmt.Errorf("failed serializing shares: %w", err) + return nil + } + return saveData +} + +func (s *ECDSAParty) SetSaveData(saveData []byte) { + var localSaveData keygen.LocalPartySaveData + err := json.Unmarshal(saveData, &localSaveData) + if err != nil { + s.ErrCh() <- fmt.Errorf("failed deserializing shares: %w", err) + return + } + s.saveData = &localSaveData +} + +func (s *ECDSAParty) ClassifyMsg(msgBytes []byte) (uint8, bool, error) { + msg := &any.Any{} + if err := proto.Unmarshal(msgBytes, msg); err != nil { + return 0, false, err + } + + _, isBroadcast := ecdsaBroadcastMessages[msg.TypeUrl] + // logger.Info("ClassifyMsg", "typeUrl", msg.TypeUrl, "isBroadcast", isBroadcast) + + round := ecdsaMsgURL2Round[msg.TypeUrl] + if round > 4 { + round = round - 4 + } + return round, isBroadcast, nil +} + +func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { + end := make(chan *keygen.LocalPartySaveData, 1) + // Time the initialization of TSS parameters and party + s.KeygenStart = time.Now() + params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + party := keygen.NewLocalParty(params, s.outCh, end, s.preParams) + initElapsed := time.Since(s.KeygenStart) + logger.Info("[Starting ECDSA] key generation", + "walletID", s.walletID, + "initElapsed", initElapsed.Milliseconds(), + "startTime", s.KeygenStart.Format(time.RFC3339), + ) + + // Time the runParty execution + runStart := time.Now() + runParty(s, ctx, party, send, end, finish) + s.KeygenCompletion = time.Now() + runElapsed := time.Since(runStart) + logger.Info("[Finished ECDSA] key generation run", + "walletID", s.walletID, + "runElapsed", runElapsed.Milliseconds(), + "completionTime", s.KeygenCompletion.Format(time.RFC3339), + ) + + // Record the completion event + monitoring.RecordKeygenCompletion(monitoring.KeygenTimestamps{ + WalletID: s.walletID, + NodeID: s.partyID.Id, + KeyType: "ECDSA", + StartTime: s.KeygenStart, + CompletionTime: s.KeygenCompletion, + InitDurationMs: initElapsed.Milliseconds(), + RunDurationMs: runElapsed.Milliseconds(), + }) +} + +func (s *ECDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { + if s.saveData == nil { + s.ErrCh() <- errors.New("save data is nil") + return + } + end := make(chan *common.SignatureData, 1) + params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + party := signing.NewLocalParty(msg, params, *s.saveData, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} + +func (s *ECDSAParty) StartResharing(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, + oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { + if s.saveData == nil { + s.ErrCh() <- errors.New("save data is nil") + return + } + end := make(chan *keygen.LocalPartySaveData, 1) + params := tss.NewReSharingParameters( + tss.S256(), + tss.NewPeerContext(oldPartyIDs), + tss.NewPeerContext(newPartyIDs), + s.partyID, + len(oldPartyIDs), + oldThreshold, + len(newPartyIDs), + newThreshold, + ) + party := resharing.NewLocalParty(params, *s.saveData, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} diff --git a/pkg/mpc/party/ecdsa_round.go b/pkg/mpc/party/ecdsa_round.go new file mode 100644 index 0000000..0f38c05 --- /dev/null +++ b/pkg/mpc/party/ecdsa_round.go @@ -0,0 +1,56 @@ +package party + +var ( + ecdsaMsgURL2Round = map[string]uint8{ + // DKG + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound1Message": 1, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound2Message1": 2, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound2Message2": 3, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound3Message": 4, + + // Signing + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound1Message1": 5, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound1Message2": 6, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound2Message": 7, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound3Message": 8, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound4Message": 9, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound5Message": 10, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound6Message": 11, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound7Message": 12, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound8Message": 13, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound9Message": 14, + + // Resharing + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound1Message": 15, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound2Message1": 16, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound2Message2": 17, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound3Message1": 18, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound3Message2": 19, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound4Message1": 20, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound4Message2": 21, + } + + ecdsaBroadcastMessages = map[string]struct{}{ + // DKG + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound1Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound2Message2": {}, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound3Message": {}, + + // Signing + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound1Message2": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound3Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound4Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound5Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound6Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound7Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound8Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound9Message": {}, + + // Resharing + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound1Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound2Message1": {}, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound2Message2": {}, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound3Message1": {}, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound4Message2": {}, + } +) diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go new file mode 100644 index 0000000..5d086b1 --- /dev/null +++ b/pkg/mpc/party/eddsa.go @@ -0,0 +1,142 @@ +package party + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/big" + "time" + + "github.com/bnb-chain/tss-lib/v2/common" + "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" + "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" + "github.com/bnb-chain/tss-lib/v2/eddsa/signing" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/monitoring" + "github.com/golang/protobuf/ptypes/any" + "google.golang.org/protobuf/proto" +) + +type EDDSAParty struct { + party + reshareParams *tss.ReSharingParameters + saveData *keygen.LocalPartySaveData + KeygenStart time.Time + KeygenCompletion time.Time +} + +func NewEDDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, + reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData, errCh chan error) *EDDSAParty { + return &EDDSAParty{ + party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), + reshareParams: reshareParams, + saveData: saveData, + } +} + +func (s *EDDSAParty) GetSaveData() []byte { + saveData, err := json.Marshal(s.saveData) + if err != nil { + s.ErrCh() <- fmt.Errorf("failed serializing shares: %w", err) + return nil + } + return saveData +} + +func (s *EDDSAParty) SetSaveData(shareData []byte) { + var localSaveData keygen.LocalPartySaveData + err := json.Unmarshal(shareData, &localSaveData) + if err != nil { + s.ErrCh() <- fmt.Errorf("failed deserializing shares: %w", err) + return + } + s.saveData = &localSaveData +} + +func (s *EDDSAParty) ClassifyMsg(msgBytes []byte) (uint8, bool, error) { + msg := &any.Any{} + if err := proto.Unmarshal(msgBytes, msg); err != nil { + return 0, false, err + } + + _, isBroadcast := eddsaBroadcastMessages[msg.TypeUrl] + + round := eddsaMsgURL2Round[msg.TypeUrl] + if round > 4 { + round = round - 4 + } + return round, isBroadcast, nil +} + +func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { + end := make(chan *keygen.LocalPartySaveData, 1) + + // Measure time to initialize the party + s.KeygenStart = time.Now() + params := tss.NewParameters(tss.Edwards(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + party := keygen.NewLocalParty(params, s.outCh, end) + initElapsed := time.Since(s.KeygenStart) + + logger.Info("[Starting EDDSA] key generation", + "walletID", s.walletID, + "initElapsed", initElapsed.Milliseconds(), + "startTime", s.KeygenStart.Format(time.RFC3339), + ) + + // Measure time to run the party + runStart := time.Now() + runParty(s, ctx, party, send, end, finish) + s.KeygenCompletion = time.Now() + runElapsed := time.Since(runStart) + + logger.Info("[Finished EDDSA] key generation run", + "walletID", s.walletID, + "runElapsed", runElapsed.Milliseconds(), + "completionTime", s.KeygenCompletion.Format(time.RFC3339), + ) + + // Record the completion event + monitoring.RecordKeygenCompletion(monitoring.KeygenTimestamps{ + WalletID: s.walletID, + NodeID: s.partyID.Id, + KeyType: "EDDSA", + StartTime: s.KeygenStart, + CompletionTime: s.KeygenCompletion, + InitDurationMs: initElapsed.Milliseconds(), + RunDurationMs: runElapsed.Milliseconds(), + }) +} + +func (s *EDDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { + if s.saveData == nil { + s.ErrCh() <- errors.New("save data is nil") + return + } + end := make(chan *common.SignatureData, 1) + params := tss.NewParameters(tss.Edwards(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + party := signing.NewLocalParty(msg, params, *s.saveData, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} + +func (s *EDDSAParty) StartResharing(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, + oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { + if s.saveData == nil { + s.ErrCh() <- errors.New("save data is nil") + return + } + end := make(chan *keygen.LocalPartySaveData, 1) + params := tss.NewReSharingParameters( + tss.Edwards(), + tss.NewPeerContext(oldPartyIDs), + tss.NewPeerContext(newPartyIDs), + s.partyID, + len(oldPartyIDs), + oldThreshold, + len(newPartyIDs), + newThreshold, + ) + party := resharing.NewLocalParty(params, *s.saveData, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} diff --git a/pkg/mpc/party/eddsa_round.go b/pkg/mpc/party/eddsa_round.go new file mode 100644 index 0000000..7f0c874 --- /dev/null +++ b/pkg/mpc/party/eddsa_round.go @@ -0,0 +1,38 @@ +package party + +var ( + eddsaMsgURL2Round = map[string]uint8{ + // DKG + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound1Message": 1, + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound2Message1": 2, + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound2Message2": 3, + + // Signing + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound1Message": 4, + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound2Message": 5, + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound3Message": 6, + + // Resharing + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound1Message": 7, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound2Message": 8, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound3Message1": 9, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound3Message2": 10, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound4Message": 11, + } + + eddsaBroadcastMessages = map[string]struct{}{ + // DKG + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound1Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound2Message2": {}, + + // Signing + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound1Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound2Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound3Message": {}, + + // Resharing + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound1Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound2Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound4Message": {}, + } +) diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go deleted file mode 100644 index f8204f6..0000000 --- a/pkg/mpc/session.go +++ /dev/null @@ -1,218 +0,0 @@ -package mpc - -import ( - "fmt" - "strings" - "sync" - - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/types" - "github.com/nats-io/nats.go" -) - -var ( - ErrNotEnoughParticipants = errors.New("Not enough participants to sign") -) - -type TopicComposer struct { - ComposeBroadcastTopic func() string - ComposeDirectTopic func(nodeID string) string -} - -type KeyComposerFn func(id string) string - -type SessionType string - -const ( - SessionTypeEcdsa SessionType = "session_ecdsa" - SessionTypeEddsa SessionType = "session_eddsa" -) - -type Session struct { - walletID string - pubSub messaging.PubSub - direct messaging.DirectMessaging - threshold int - participantPeerIDs []string - selfPartyID *tss.PartyID - // IDs of all parties in the session including self - partyIDs []*tss.PartyID - outCh chan tss.Message - ErrCh chan error - party tss.Party - - // preParams is nil for EDDSA session - preParams *keygen.LocalPreParams - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - broadcastSub messaging.Subscription - directSub messaging.Subscription - resultQueue messaging.MessageQueue - identityStore identity.Store - - topicComposer *TopicComposer - composeKey KeyComposerFn - getRoundFunc GetRoundFunc - mu sync.Mutex - // After the session is done, the key will be stored pubkeyBytes - pubkeyBytes []byte - sessionType SessionType -} - -func (s *Session) PartyID() *tss.PartyID { - return s.selfPartyID -} - -func (s *Session) PartyIDs() []*tss.PartyID { - return s.partyIDs -} - -func (s *Session) PartyCount() int { - return len(s.partyIDs) -} - -func (s *Session) handleTssMessage(keyshare tss.Message) { - data, routing, err := keyshare.WireBytes() - if err != nil { - s.ErrCh <- err - return - } - - tssMsg := types.NewTssMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To) - signature, err := s.identityStore.SignMessage(&tssMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to sign message: %w", err) - return - } - tssMsg.Signature = signature - msg, err := types.MarshalTssMessage(&tssMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to marshal tss message: %w", err) - return - } - - if routing.IsBroadcast && len(routing.To) == 0 { - err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msg) - if err != nil { - s.ErrCh <- err - return - } - } else { - for _, to := range routing.To { - nodeID := PartyIDToNodeID(to) - topic := s.topicComposer.ComposeDirectTopic(nodeID) - err := s.direct.Send(topic, msg) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to send direct message to %s: %w", topic, err) - } - - } - - } -} - -func (s *Session) receiveTssMessage(rawMsg []byte) { - msg, err := types.UnmarshalTssMessage(rawMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to unmarshal message: %w", err) - return - } - err = s.identityStore.VerifyMessage(msg) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to verify message: %w, tampered message", err) - return - } - - toIDs := make([]string, len(msg.To)) - for i, id := range msg.To { - toIDs[i] = id.String() - } - - round, err := s.getRoundFunc(msg.MsgBytes, s.selfPartyID, msg.IsBroadcast) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Broken TSS Share") - return - } - - logger.Debug(fmt.Sprintf("%s Received message", s.sessionType), "from", msg.From.String(), "to", strings.Join(toIDs, ","), "isBroadcast", msg.IsBroadcast, "round", round.RoundMsg) - isBroadcast := msg.IsBroadcast && len(msg.To) == 0 - isToSelf := len(msg.To) == 1 && ComparePartyIDs(msg.To[0], s.selfPartyID) - - if isBroadcast || isToSelf { - s.mu.Lock() - defer s.mu.Unlock() - ok, err := s.party.UpdateFromBytes(msg.MsgBytes, msg.From, msg.IsBroadcast) - if !ok || err != nil { - logger.Error("Failed to update party", err, "walletID", s.walletID) - return - } - - } -} - -func (s *Session) SendReplySignSuccess(natMsg *nats.Msg) { - msg := natMsg.Data - s.mu.Lock() - defer s.mu.Unlock() - - err := s.pubSub.Publish(natMsg.Reply, msg) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to reply sign sucess message: %w", err) - return - } - logger.Info("Sent reply sign sucess message", "reply", natMsg.Reply) -} - -func (s *Session) ListenToIncomingMessageAsync() { - go func() { - sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { - msg := natMsg.Data - s.receiveTssMessage(msg) - }) - - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) - return - } - - s.broadcastSub = sub - }() - - nodeID := PartyIDToNodeID(s.selfPartyID) - targetID := s.topicComposer.ComposeDirectTopic(nodeID) - sub, err := s.direct.Listen(targetID, func(msg []byte) { - go s.receiveTssMessage(msg) // async for avoid timeout - }) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", targetID, err) - } - s.directSub = sub - -} - -func (s *Session) Close() error { - err := s.broadcastSub.Unsubscribe() - if err != nil { - return err - } - err = s.directSub.Unsubscribe() - if err != nil { - return err - } - return nil -} - -func (s *Session) GetPubKeyResult() []byte { - return s.pubkeyBytes -} - -func (s *Session) ErrChan() <-chan error { - return s.ErrCh -} diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go new file mode 100644 index 0000000..4fad2ca --- /dev/null +++ b/pkg/mpc/session/base.go @@ -0,0 +1,386 @@ +package session + +import ( + "context" + "fmt" + "math/big" + "slices" + "sync" + "time" + + "github.com/bnb-chain/tss-lib/v2/common" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/infra" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/party" + "github.com/fystack/mpcium/pkg/types" + "github.com/hashicorp/consul/api" + "github.com/nats-io/nats.go" +) + +type Curve string + +type Purpose string + +const ( + CurveSecp256k1 Curve = "secp256k1" + CurveEd25519 Curve = "ed25519" + + PurposeKeygen Purpose = "keygen" + PurposeSign Purpose = "sign" + PurposeReshare Purpose = "reshare" +) + +type TopicComposer struct { + ComposeBroadcastTopic func() string + ComposeDirectTopic func(nodeID string) string +} + +type KeyComposerFn func(id string) string + +type Session interface { + StartKeygen(ctx context.Context, send func(tss.Message), onComplete func([]byte)) + StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), onComplete func([]byte)) + StartResharing( + ctx context.Context, + oldPartyIDs []*tss.PartyID, + newPartyIDs []*tss.PartyID, + oldThreshold int, + newThreshold int, + send func(tss.Message), + onComplete func([]byte), + ) + + GetSaveData(version int) ([]byte, error) + GetPublicKey(data []byte) ([]byte, error) + VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) + + PartyIDs() []*tss.PartyID + Send(msg tss.Message) + Listen(ctx context.Context) + SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) + WaitForReady(ctx context.Context, sessionID string) error + ErrCh() chan error + Close() +} + +type session struct { + walletID string + party party.Party + + broadcastSub messaging.Subscription + directSub messaging.Subscription + pubSub messaging.PubSub + direct messaging.DirectMessaging + + identityStore identity.Store + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + + msgBuffer chan []byte + workerCount int + + topicComposer *TopicComposer + composeKey KeyComposerFn + consulKV infra.ConsulKV + + mu sync.Mutex + errCh chan error +} + +func NewSession( + purpose Purpose, + walletID string, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + identityStore identity.Store, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + consulKV infra.ConsulKV, +) *session { + errCh := make(chan error, 1000) + return &session{ + walletID: walletID, + pubSub: pubSub, + direct: direct, + identityStore: identityStore, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + errCh: errCh, + consulKV: consulKV, + msgBuffer: make(chan []byte, 100), // Buffer for 100 messages + } +} + +func (s *session) PartyIDs() []*tss.PartyID { + return s.party.PartyIDs() +} + +func (s *session) ErrCh() chan error { + return s.errCh +} + +func (s *session) WaitForReady(ctx context.Context, sessionID string) error { + // build our Consul prefix + prefix := fmt.Sprintf("tss-ready/%s/%s/", s.walletID, sessionID) + + // 1) publish our ready flag + myKey := prefix + s.party.PartyID().String() + if _, err := s.consulKV.Put(&api.KVPair{ + Key: myKey, + Value: []byte("true"), + }, nil); err != nil { + return fmt.Errorf("failed to write ready flag: %w", err) + } + + // 2) poll until we see everyone + total := len(s.party.PartyIDs()) + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + pairs, _, err := s.consulKV.List(prefix, nil) + if err != nil { + logger.Error("error listing readiness keys", err) + continue + } + if len(pairs) >= total { + logger.Debug("[READY] peers ready", "have", len(pairs), "need", total, "walletID", s.walletID) + return nil + } + logger.Debug("[READY] Waiting for peers ready", "wallet", s.walletID, "have", len(pairs), "need", total) + } + } +} + +// Send is a wrapper around the party's Send method +// It signs the message and sends it to the remote party +func (s *session) Send(msg tss.Message) { + data, routing, err := msg.WireBytes() + if err != nil { + s.errCh <- fmt.Errorf("failed to wire bytes: %w", err) + return + } + + tssMsg := types.NewTssMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To) + signature, err := s.identityStore.SignMessage(&tssMsg) + if err != nil { + s.errCh <- fmt.Errorf("failed to sign message: %w", err) + return + } + tssMsg.Signature = signature + msgBytes, err := types.MarshalTssMessage(&tssMsg) + if err != nil { + s.errCh <- fmt.Errorf("failed to marshal message: %w", err) + return + } + round, _, err := s.party.ClassifyMsg(data) + if err != nil { + s.errCh <- fmt.Errorf("failed to classify message: %w", err) + return + } + toNodeIDs := make([]string, len(routing.To)) + for i, to := range routing.To { + toNodeIDs[i] = getRoutingFromPartyID(to) + } + logger.Debug("Sending message", "from", routing.From.Moniker, "to", toNodeIDs, "isBroadcast", routing.IsBroadcast, "round", round) + + if routing.IsBroadcast && len(routing.To) == 0 { + err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msgBytes) + if err != nil { + s.errCh <- fmt.Errorf("failed to publish message: %w", err) + return + } + } else { + for _, to := range routing.To { + nodeID := getRoutingFromPartyID(to) + topic := s.topicComposer.ComposeDirectTopic(nodeID) + err := s.direct.Send(topic, msgBytes) + if err != nil { + s.errCh <- fmt.Errorf("failed to send message: %w", err) + return + } + } + } +} + +// Listen is a wrapper around the party's Listen method +// It subscribes to the broadcast and self direct topics +func (s *session) Listen(ctx context.Context) { + go s.startIncomingMessageWorker(ctx) + + var wg sync.WaitGroup + wg.Add(2) + + selfDirectTopic := s.topicComposer.ComposeDirectTopic(getRoutingFromPartyID(s.party.PartyID())) + broadcastTopic := s.topicComposer.ComposeBroadcastTopic() + + broadcast := func() { + defer wg.Done() + sub, err := s.pubSub.Subscribe(broadcastTopic, func(natMsg *nats.Msg) { + msg := natMsg.Data + select { + + case <-ctx.Done(): + return + default: + s.msgBuffer <- msg + } + }) + + if err != nil { + s.errCh <- fmt.Errorf("failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) + return + } + + s.broadcastSub = sub + } + + direct := func() { + defer wg.Done() + sub, err := s.direct.Listen(selfDirectTopic, func(msg []byte) { + select { + case <-ctx.Done(): + return + default: + s.msgBuffer <- msg + } + }) + + if err != nil { + s.errCh <- fmt.Errorf("failed to subscribe to direct topic %s: %w", s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), err) + return + } + + s.directSub = sub + } + + go broadcast() + go direct() + wg.Wait() +} + +func (s *session) startIncomingMessageWorker(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-s.msgBuffer: + if !ok { + return + } + s.receive(msg) + } + } +} + +// SaveKey saves the key to the keyinfo store and the kvstore +func (s *session) SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) { + keyInfo := keyinfo.KeyInfo{ + ParticipantPeerIDs: participantPeerIDs, + Threshold: threshold, + Version: version, + } + composeKey := s.composeKey(s.walletID) + err = s.keyinfoStore.Save(composeKey, &keyInfo) + if err != nil { + s.errCh <- fmt.Errorf("failed to save keyinfo: %w", err) + return + } + + err = s.kvstore.Put(fmt.Sprintf("%s-%d", composeKey, version), data) + if err != nil { + s.errCh <- fmt.Errorf("failed to save key: %w", err) + return + } + return +} + +func (s *session) SetSaveData(saveBytes []byte) { + s.party.SetSaveData(saveBytes) +} + +// GetSaveData gets the key from the kvstore +func (s *session) GetSaveData(version int) ([]byte, error) { + var key string + composeKey := s.composeKey(s.walletID) + if version == 0 { + key = composeKey + } else { + key = fmt.Sprintf("%s-%d", composeKey, version) + } + data, err := s.kvstore.Get(key) + if err != nil { + return nil, fmt.Errorf("failed to get key: %w", err) + } + return data, nil +} + +func (s *session) Close() { + // Close subscriptions first + if s.broadcastSub != nil { + s.broadcastSub.Unsubscribe() + } + if s.directSub != nil { + s.directSub.Unsubscribe() + } + + if s.msgBuffer != nil { + close(s.msgBuffer) + } + + // Close error channel last + select { + case <-s.errCh: + // Channel already closed + default: + close(s.errCh) + } +} + +// receive is a helper function that receives a message from the party +func (s *session) receive(rawMsg []byte) { + msg, err := types.UnmarshalTssMessage(rawMsg) + if err != nil { + s.errCh <- fmt.Errorf("failed to unmarshal message: %w", err) + return + } + + err = s.identityStore.VerifyMessage(msg) + if err != nil { + s.errCh <- fmt.Errorf("failed to verify message: %w", err) + return + } + + // Skip messages from self + if msg.From.String() == s.party.PartyID().String() { + return + } + + toIDs := make([]string, len(msg.To)) + for i, id := range msg.To { + toIDs[i] = id.String() + } + + isBroadcast := msg.IsBroadcast && len(msg.To) == 0 + isToSelf := slices.Contains(toIDs, s.party.PartyID().String()) + + if isBroadcast || isToSelf { + round, _, err := s.party.ClassifyMsg(msg.MsgBytes) + if err != nil { + s.errCh <- fmt.Errorf("failed to classify message: %w", err) + return + } + logger.Debug("Received message", "from", msg.From.Moniker, "round", round, "isBroadcast", msg.IsBroadcast, "isToSelf", isToSelf) + s.mu.Lock() + defer s.mu.Unlock() + s.party.InCh() <- *msg + } +} diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go new file mode 100644 index 0000000..a03882e --- /dev/null +++ b/pkg/mpc/session/ecdsa.go @@ -0,0 +1,129 @@ +package session + +import ( + "context" + "crypto/ecdsa" + "encoding/json" + "errors" + "fmt" + "math/big" + + "github.com/bnb-chain/tss-lib/v2/common" + "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/infra" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/party" +) + +type ECDSASession struct { + *session +} + +func NewECDSASession( + walletID string, + partyID *tss.PartyID, + partyIDs []*tss.PartyID, + threshold int, + preParams keygen.LocalPreParams, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + identityStore identity.Store, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + consulKV infra.ConsulKV, +) *ECDSASession { + s := NewSession(PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore, consulKV) + s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, preParams, s.errCh) + s.topicComposer = &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("broadcast:ecdsa:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("direct:ecdsa:%s:%s", nodeID, walletID) + }, + } + s.composeKey = func(walletID string) string { + return fmt.Sprintf("ecdsa:%s", walletID) + } + return &ECDSASession{ + session: s, + } +} + +func (s *ECDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { + s.party.StartKeygen(ctx, send, finish) +} + +func (s *ECDSASession) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { + s.party.StartSigning(ctx, msg, send, finish) +} + +func (s *ECDSASession) StartResharing(ctx context.Context, oldPartyIDs []*tss.PartyID, newPartyIDs []*tss.PartyID, oldThreshold int, newThreshold int, send func(tss.Message), finish func([]byte)) { + s.party.StartResharing(ctx, oldPartyIDs, newPartyIDs, oldThreshold, newThreshold, send, finish) +} + +func (s *ECDSASession) GetPublicKey(data []byte) ([]byte, error) { + saveData := &keygen.LocalPartySaveData{} + err := json.Unmarshal(data, saveData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + publicKey := saveData.ECDSAPub + pubKey := &ecdsa.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + return nil, fmt.Errorf("failed to encode public key: %w", err) + } + return pubKeyBytes, nil +} + +func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) { + signatureData := &common.SignatureData{} + err := json.Unmarshal(signature, signatureData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal signature data: %w", err) + } + + data := s.party.GetSaveData() + if data == nil { + return nil, errors.New("save data is nil") + } + + saveData := &keygen.LocalPartySaveData{} + err = json.Unmarshal(data, saveData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + + if saveData.ECDSAPub == nil { + return nil, errors.New("ECDSA public key is nil") + } + + publicKey := saveData.ECDSAPub + pk := &ecdsa.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + + // Convert signature components to big integers + r := new(big.Int).SetBytes(signatureData.R) + sigS := new(big.Int).SetBytes(signatureData.S) + + // Verify the signature + ok := ecdsa.Verify(pk, msg, r, sigS) + if !ok { + return nil, errors.New("signature verification failed") + } + + return signatureData, nil +} diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go new file mode 100644 index 0000000..65a162d --- /dev/null +++ b/pkg/mpc/session/eddsa.go @@ -0,0 +1,130 @@ +package session + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/big" + + "github.com/bnb-chain/tss-lib/v2/common" + "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/decred/dcrd/dcrec/edwards/v2" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/infra" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/party" +) + +type EDDSASession struct { + *session +} + +func NewEDDSASession( + walletID string, + partyID *tss.PartyID, + partyIDs []*tss.PartyID, + threshold int, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + identityStore identity.Store, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + consulKV infra.ConsulKV, +) *EDDSASession { + s := NewSession(PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore, consulKV) + s.party = party.NewEDDSAParty(walletID, partyID, partyIDs, threshold, nil, nil, s.errCh) + s.topicComposer = &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("broadcast:eddsa:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("direct:eddsa:%s:%s", nodeID, walletID) + }, + } + s.composeKey = func(walletID string) string { + return fmt.Sprintf("eddsa:%s", walletID) + } + return &EDDSASession{ + session: s, + } +} + +func (s *EDDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { + s.party.StartKeygen(ctx, send, finish) +} + +func (s *EDDSASession) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { + s.party.StartSigning(ctx, msg, send, finish) +} + +func (s *EDDSASession) StartResharing(ctx context.Context, oldPartyIDs []*tss.PartyID, newPartyIDs []*tss.PartyID, oldThreshold int, newThreshold int, send func(tss.Message), finish func([]byte)) { + s.party.StartResharing(ctx, oldPartyIDs, newPartyIDs, oldThreshold, newThreshold, send, finish) +} + +func (s *EDDSASession) GetPublicKey(data []byte) ([]byte, error) { + saveData := &keygen.LocalPartySaveData{} + err := json.Unmarshal(data, saveData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + + if saveData.EDDSAPub == nil { + return nil, errors.New("EDDSA public key is nil") + } + + publicKey := saveData.EDDSAPub + pubKey := &edwards.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + + pubKeyBytes := pubKey.SerializeCompressed() + return pubKeyBytes, nil +} + +func (s *EDDSASession) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) { + signatureData := &common.SignatureData{} + err := json.Unmarshal(signature, signatureData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal signature data: %w", err) + } + + data := s.party.GetSaveData() + if data == nil { + return nil, errors.New("save data is nil") + } + + saveData := &keygen.LocalPartySaveData{} + err = json.Unmarshal(data, saveData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + + if saveData.EDDSAPub == nil { + return nil, errors.New("EDDSA public key is nil") + } + + publicKey := saveData.EDDSAPub + pk := &edwards.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + + // Convert signature components to big integers + r := new(big.Int).SetBytes(signatureData.R) + sigS := new(big.Int).SetBytes(signatureData.S) + + // Verify the signature + ok := edwards.Verify(pk, msg, r, sigS) + if !ok { + return nil, errors.New("signature verification failed") + } + + return signatureData, nil +} diff --git a/pkg/mpc/session/utils.go b/pkg/mpc/session/utils.go new file mode 100644 index 0000000..581abc8 --- /dev/null +++ b/pkg/mpc/session/utils.go @@ -0,0 +1,8 @@ +package session + +import "github.com/bnb-chain/tss-lib/v2/tss" + +// Moniker saves the routing partyID to nodeID mapping +func getRoutingFromPartyID(partyID *tss.PartyID) string { + return partyID.Moniker +} diff --git a/pkg/tsslimiter/queue.go b/pkg/tsslimiter/queue.go new file mode 100644 index 0000000..7e12fa2 --- /dev/null +++ b/pkg/tsslimiter/queue.go @@ -0,0 +1,108 @@ +package tsslimiter + +import ( + "sync" + + "github.com/fystack/mpcium/pkg/logger" +) + +// SessionJob represents a queued job with type, execution logic, and optional error callback +// Run should return an error if execution fails. +type SessionJob struct { + Type SessionType + Run func() error + OnError func(error) + Name string +} + +// Queue defines the interface for a job queue that manages TSS session jobs. +type Queue interface { + // Enqueue adds a new session job to the queue for processing. + Enqueue(job SessionJob) + + // Stop gracefully shuts down the queue and waits for background workers to finish. + Stop() +} + +// WeightedQueue buffers and processes session jobs using the WeightedLimiter +type WeightedQueue struct { + queue chan SessionJob + limiter *WeightedLimiter + stopChan chan struct{} + wg sync.WaitGroup +} + +// NewWeightedQueue initializes a buffered job queue +func NewWeightedQueue(limiter *WeightedLimiter, bufferSize int) *WeightedQueue { + q := &WeightedQueue{ + queue: make(chan SessionJob, bufferSize), + limiter: limiter, + stopChan: make(chan struct{}), + } + + // Start the background worker to process queue + q.wg.Add(1) + go q.run() + return q +} + +// Enqueue adds a job to the queue +func (q *WeightedQueue) Enqueue(job SessionJob) { + q.queue <- job +} + +// run continuously processes jobs based on limiter capacity, logging counters +func (q *WeightedQueue) run() { + defer q.wg.Done() + + for { + select { + case job := <-q.queue: + // Log queue length and limiter state before acquire + usedBefore, max := q.limiter.Stats() + logger.Info("Before Acquire", "usedPoints", usedBefore, "maxPoints", max, "pendingJobs", len(q.queue)) + + // Block until we can acquire budget + q.limiter.Acquire(job.Type) + // if !ok { + // logger.Info("Failed to Acquire", "jobType", job.Type, "name", job.Name) + // // Notify via OnError callback if provided + // if job.OnError != nil { + // job.OnError(fmt.Errorf("tsslimiter: failed to acquire budget for job type %v, job %s", job.Type, job.Name)) + // } + // continue + // } + + // Log limiter state after acquire + usedAfter, _ := q.limiter.Stats() + logger.Info("After Acquire", "usedPoints", usedAfter, "jobType", job.Type, "name", job.Name) + + // Launch job + q.wg.Add(1) + go func(j SessionJob) { + defer q.wg.Done() + defer q.limiter.Release(j.Type) + + usedExec, _ := q.limiter.Stats() + logger.Info("Executing Job", "usedPoints", usedExec, "jobType", j.Type, "name", job.Name) + + err := j.Run() + if err != nil && j.OnError != nil { + // Call the error handler for this job + j.OnError(err) + } + + logger.Info("Pending Jobs", "num", len(q.queue)) + }(job) + + case <-q.stopChan: + return + } + } +} + +// Stop shuts down the queue processing loop and waits for running jobs +func (q *WeightedQueue) Stop() { + close(q.stopChan) + q.wg.Wait() +} diff --git a/pkg/tsslimiter/queue_test.go b/pkg/tsslimiter/queue_test.go new file mode 100644 index 0000000..77ed427 --- /dev/null +++ b/pkg/tsslimiter/queue_test.go @@ -0,0 +1,119 @@ +package tsslimiter_test + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/fystack/mpcium/pkg/tsslimiter" + "github.com/stretchr/testify/assert" +) + +func TestWeightedQueue_SingleJobExecution(t *testing.T) { + limiter := tsslimiter.NewWeightedLimiter(2) // 2 cores = 200 points + queue := tsslimiter.NewWeightedQueue(limiter, 10) + defer queue.Stop() + + var executed int32 = 0 + + job := tsslimiter.SessionJob{ + Type: tsslimiter.SessionSignECDSA, + Run: func() { + atomic.AddInt32(&executed, 1) + }, + } + + queue.Enqueue(job) + time.Sleep(200 * time.Millisecond) // Give time to process + + assert.Equal(t, int32(1), executed, "Expected job to execute") +} + +func TestWeightedQueue_RespectsConcurrency(t *testing.T) { + limiter := tsslimiter.NewWeightedLimiter(1) // 1 core = 100 points + queue := tsslimiter.NewWeightedQueue(limiter, 10) + defer queue.Stop() + + var executing int32 = 0 + var completed int32 = 0 + + // 3 jobs each costing 100 (keygen) → only 1 should run at a time + for i := 0; i < 3; i++ { + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionKeygenECDSA, + Run: func() { + current := atomic.AddInt32(&executing, 1) + assert.LessOrEqual(t, current, int32(1), "Too many concurrent jobs running") + time.Sleep(100 * time.Millisecond) + atomic.AddInt32(&executing, -1) + atomic.AddInt32(&completed, 1) + }, + }) + } + + time.Sleep(500 * time.Millisecond) + assert.Equal(t, int32(3), completed, "All jobs should complete sequentially") +} + +func TestWeightedQueue_MixedSessions(t *testing.T) { + limiter := tsslimiter.NewWeightedLimiter(2) // 2 cores = 200 points + queue := tsslimiter.NewWeightedQueue(limiter, 10) + defer queue.Stop() + + var completed int32 = 0 + + // Sign (40) + Keygen (100) + Sign (40) = 180 total, fits under 200 + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionSignECDSA, + Run: func() { + time.Sleep(50 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }, + }) + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionKeygenECDSA, + Run: func() { + time.Sleep(50 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }, + }) + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionSignECDSA, + Run: func() { + time.Sleep(50 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }, + }) + + time.Sleep(300 * time.Millisecond) + assert.Equal(t, int32(3), completed, "All mixed jobs should run within capacity") +} + +func TestWeightedQueue_BackpressureBuffering(t *testing.T) { + limiter := tsslimiter.NewWeightedLimiter(1) // 1 core = 100 + queue := tsslimiter.NewWeightedQueue(limiter, 10) + defer queue.Stop() + + var completed int32 = 0 + + // First job blocks the CPU + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionKeygenECDSA, + Run: func() { + time.Sleep(150 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }, + }) + + // Second job should wait in the queue + queue.Enqueue(tsslimiter.SessionJob{ + + Type: tsslimiter.SessionSignECDSA, + Run: func() { + atomic.AddInt32(&completed, 1) + }, + }) + + time.Sleep(400 * time.Millisecond) + assert.Equal(t, int32(2), completed, "Both jobs should run in sequence due to backpressure") +} diff --git a/pkg/tsslimiter/tsslimiter.go b/pkg/tsslimiter/tsslimiter.go new file mode 100644 index 0000000..2f34b25 --- /dev/null +++ b/pkg/tsslimiter/tsslimiter.go @@ -0,0 +1,122 @@ +package tsslimiter + +import ( + "sync" + + "github.com/fystack/mpcium/pkg/logger" +) + +type SessionType int + +const ( + SessionKeygenECDSA SessionType = iota + SessionReshareECDSA + SessionSignECDSA + SessionKeygenEDDSA + SessionReshareEDDSA + SessionSignEDDSA + SessionKeygenCombined +) + +// sessionCosts defines the estimated CPU cost (in points) of each session type. +// The values are based on practical benchmarks using tss-lib (ECDSA over secp256k1), +// where 100 points = 100% of a physical CPU core. +// +// These costs allow us to model CPU pressure and prevent overload by setting +// a total max budget equal to the number of physical cores × 100 points. +// +// For example, on a 4-core CPU: +// +// - maxPoints = 400 +// +// - You could run 1 keygen (100) + 10 sign sessions (30 × 10 = 300) +// +// - Or 4 resharing sessions (80 × 4 = 320) + 2 sign sessions (30 × 2 = 60) +// +// Note: These values are conservative to maintain low latency and avoid timeouts. +var sessionCosts = map[SessionType]int{ + SessionKeygenECDSA: 75, // Full core + SessionReshareECDSA: 70, + SessionSignECDSA: 40, + SessionKeygenEDDSA: 25, // ~25% of core + SessionReshareEDDSA: 20, + SessionSignEDDSA: 15, + SessionKeygenCombined: 100, // ECDSA (100) + EDDSA (25) +} + +type Limiter interface { + // TryAcquire attempts to acquire resources for the given session type. + // Returns true if successful, false otherwise. + TryAcquire(t SessionType) bool + + // Acquire blocks until it successfully acquires resources for the session type. + Acquire(t SessionType) + + // Release frees the resources for the given session type. + Release(t SessionType) + Stats() (int, int) +} + +type WeightedLimiter struct { + mu sync.Mutex + usedPoints int + maxPoints int + cond *sync.Cond +} + +// NewWeightedLimiter creates a limiter with maxPoints = maxSessionsAllowed * 100 +func NewWeightedLimiter(maxSessions int) *WeightedLimiter { + l := &WeightedLimiter{ + maxPoints: maxSessions * 100, + } + l.cond = sync.NewCond(&l.mu) + return l +} + +func (l *WeightedLimiter) TryAcquire(t SessionType) bool { + l.mu.Lock() + defer l.mu.Unlock() + + logger.Info("TryAcquire....", "sessionType", t, "usedPoints", l.usedPoints, "maxPoints", l.maxPoints) + cost := sessionCosts[t] + if l.usedPoints+cost > l.maxPoints { + return false + } + + logger.Info("DOneACQUIRE") + l.usedPoints += cost + return true +} + +func (l *WeightedLimiter) Acquire(t SessionType) { + cost := sessionCosts[t] + + l.mu.Lock() + defer l.mu.Unlock() + + for l.usedPoints+cost > l.maxPoints { + l.cond.Wait() + } + + l.usedPoints += cost +} + +func (l *WeightedLimiter) Release(t SessionType) { + l.mu.Lock() + defer l.mu.Unlock() + + cost := sessionCosts[t] + l.usedPoints -= cost + if l.usedPoints < 0 { + l.usedPoints = 0 + } + + logger.Info("Release", "sessionType", t, "usedPoints", l.usedPoints, "maxPoints", l.maxPoints) + l.cond.Broadcast() // Wake up waiting goroutines +} + +func (l *WeightedLimiter) Stats() (int, int) { + l.mu.Lock() + defer l.mu.Unlock() + return l.usedPoints, l.maxPoints +} diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index edd0bf4..b014b1a 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -6,7 +6,7 @@ type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" - KeyTypeEd25519 = "ed25519" + KeyTypeEd25519 KeyType = "ed25519" ) // InitiatorMessage is anything that carries a payload to verify and its signature. @@ -33,6 +33,36 @@ type SignTxMessage struct { Signature []byte `json:"signature"` } +type ResharingMessage struct { + WalletID string `json:"wallet_id"` + NewThreshold int `json:"new_threshold"` + Signature []byte `json:"signature"` + KeyType KeyType `json:"key_type"` +} + +// InitiatorID implements InitiatorMessage. +func (r *ResharingMessage) InitiatorID() string { + return r.WalletID +} + +// Raw implements InitiatorMessage. +func (r *ResharingMessage) Raw() ([]byte, error) { + // Create a struct with only the fields that should be signed + payload := struct { + WalletID string `json:"wallet_id"` + NewThreshold int `json:"new_threshold"` + }{ + WalletID: r.WalletID, + NewThreshold: r.NewThreshold, + } + return json.Marshal(payload) +} + +// Sig implements InitiatorMessage. +func (r *ResharingMessage) Sig() []byte { + return r.Signature +} + func (m *SignTxMessage) Raw() ([]byte, error) { // omit the Signature field itself when computing the signed‐over data payload := struct { diff --git a/pkg/types/tss.go b/pkg/types/tss.go index 6d61e9c..94559b5 100644 --- a/pkg/types/tss.go +++ b/pkg/types/tss.go @@ -41,6 +41,28 @@ func NewTssMessage( return tssMsg } +func NewTssResharingMessage( + walletID string, + msgBytes []byte, + isBroadcast bool, + from *tss.PartyID, + to []*tss.PartyID, + isToOldCommittee bool, + isToOldAndNewCommittees bool, +) TssMessage { + tssMsg := TssMessage{ + WalletID: walletID, + IsBroadcast: isBroadcast, + MsgBytes: msgBytes, + From: from, + To: to, + IsToOldCommittee: isToOldCommittee, + IsToOldAndNewCommittees: isToOldAndNewCommittees, + } + + return tssMsg +} + func MarshalTssMessage(tssMsg *TssMessage) ([]byte, error) { msgBytes, err := json.Marshal(tssMsg) if err != nil { diff --git a/scripts/migration/add-key-type/main.go b/scripts/migration/add-key-type/main.go index 9891243..a2da004 100644 --- a/scripts/migration/add-key-type/main.go +++ b/scripts/migration/add-key-type/main.go @@ -11,7 +11,7 @@ import ( ) func main() { - logger.Init("production") + logger.Init("production", true) nodeName := flag.String("name", "", "Provide node name") flag.Parse() if *nodeName == "" { diff --git a/scripts/migration/update-keyinfo/main.go b/scripts/migration/update-keyinfo/main.go index 704ff5a..e68134b 100644 --- a/scripts/migration/update-keyinfo/main.go +++ b/scripts/migration/update-keyinfo/main.go @@ -13,7 +13,7 @@ import ( // script to add key type prefix ecdsa for existing keys func main() { config.InitViperConfig() - logger.Init("production") + logger.Init("production", true) appConfig := config.LoadConfig() logger.Info("App config", "config", appConfig) diff --git a/setup_identities.sh b/setup_identities.sh new file mode 100755 index 0000000..53ed0a3 --- /dev/null +++ b/setup_identities.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Number of nodes to create (default is 3) +NUM_NODES=3 + +echo "🚀 Setting up Node Identities..." + +# Create node directories and copy config files +echo "📁 Creating node directories..." +for i in $(seq 0 $((NUM_NODES-1))); do + mkdir -p "node$i/identity" + if [ ! -f "node$i/config.yaml" ]; then + cp config.yaml "node$i/" + fi + if [ ! -f "node$i/peers.json" ]; then + cp peers.json "node$i/" + fi +done + +# Generate identity for each node +echo "🔑 Generating identities for each node..." +for i in $(seq 0 $((NUM_NODES-1))); do + echo "📝 Generating identity for node$i..." + cd "node$i" + mpcium-cli generate-identity --node "node$i" + cd .. +done + +# Distribute identity files to all nodes +echo "🔄 Distributing identity files across nodes..." +for i in $(seq 0 $((NUM_NODES-1))); do + for j in $(seq 0 $((NUM_NODES-1))); do + if [ $i != $j ]; then + echo "📋 Copying node${i}_identity.json to node$j..." + cp "node$i/identity/node${i}_identity.json" "node$j/identity/" + fi + done +done + +echo "✨ Node identities setup complete!" +echo +echo "📂 Created folder structure:" +echo "├── node0" +echo "│ ├── config.yaml" +echo "│ ├── identity/" +echo "│ └── peers.json" +echo "├── node1" +echo "│ ├── config.yaml" +echo "│ ├── identity/" +echo "│ └── peers.json" +echo "└── node2" +echo " ├── config.yaml" +echo " ├── identity/" +echo " └── peers.json" +echo +echo "✅ You can now start your nodes with:" +echo "cd node0 && mpcium start -n node0" +echo "cd node1 && mpcium start -n node1" +echo "cd node2 && mpcium start -n node2" \ No newline at end of file diff --git a/setup_initiator.sh b/setup_initiator.sh new file mode 100755 index 0000000..de37e07 --- /dev/null +++ b/setup_initiator.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +echo "🚀 Setting up Event Initiator..." + +# Generate the event initiator +echo "📝 Generating event initiator..." +mpcium-cli generate-initiator + +# Extract the public key from the generated file +if [ -f "event_initiator.identity.json" ]; then + PUBLIC_KEY=$(grep -o '"public_key": *"[^"]*"' event_initiator.identity.json | cut -d'"' -f4) + + if [ -n "$PUBLIC_KEY" ]; then + echo "🔑 Found public key: $PUBLIC_KEY" + + # Update config.yaml + if [ -f "config.yaml" ]; then + echo "📝 Updating config.yaml..." + # Check if event_initiator_pubkey already exists + if grep -q "event_initiator_pubkey:" config.yaml; then + # Replace existing line + sed -i "s/event_initiator_pubkey: .*/event_initiator_pubkey: \"$PUBLIC_KEY\"/" config.yaml + else + # Add new line + echo "event_initiator_pubkey: \"$PUBLIC_KEY\"" >> config.yaml + fi + echo "✅ Successfully updated config.yaml" + else + echo "❌ Error: config.yaml not found. Please create it first." + exit 1 + fi + else + echo "❌ Error: Could not extract public key from event_initiator.identity.json" + exit 1 + fi +else + echo "❌ Error: event_initiator.identity.json not found" + exit 1 +fi + +echo "✨ Event Initiator setup complete!" \ No newline at end of file