diff --git a/cmd/generate-bindings/evm.go b/cmd/generate-bindings/evm.go new file mode 100644 index 00000000..6764c70d --- /dev/null +++ b/cmd/generate-bindings/evm.go @@ -0,0 +1,270 @@ +package generatebindings + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/spf13/viper" + + "github.com/smartcontractkit/cre-cli/cmd/creinit" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" + "github.com/smartcontractkit/cre-cli/internal/validation" +) + +func resolveEvmInputs(args []string, v *viper.Viper) (EvmInputs, error) { + // Get current working directory as default project root + currentDir, err := os.Getwd() + if err != nil { + return EvmInputs{}, fmt.Errorf("failed to get current working directory: %w", err) + } + + // Resolve project root with fallback to current directory + projectRoot := v.GetString("project-root") + if projectRoot == "" { + projectRoot = currentDir + } + + contractsPath := filepath.Join(projectRoot, "contracts") + if _, err := os.Stat(contractsPath); err != nil { + return EvmInputs{}, fmt.Errorf("contracts folder not found in project root: %s", contractsPath) + } + + // Chain family is now a positional argument + chainFamily := args[0] + + // Language defaults are handled by StringP + language := v.GetString("language") + + // Resolve ABI path with fallback to contracts/{chainFamily}/src/abi/ + abiPath := v.GetString("abi") + if abiPath == "" { + abiPath = filepath.Join(projectRoot, "contracts", chainFamily, "src", "abi") + } + + // Package name defaults are handled by StringP + pkgName := v.GetString("pkg") + + // Output path is contracts/{chainFamily}/src/generated/ under projectRoot + outPath := filepath.Join(projectRoot, "contracts", chainFamily, "src", "generated") + + return EvmInputs{ + ProjectRoot: projectRoot, + ChainFamily: chainFamily, + Language: language, + AbiPath: abiPath, + PkgName: pkgName, + OutPath: outPath, + }, nil +} + +func validateEvmInputs(inputs EvmInputs) error { + validate, err := validation.NewValidator() + if err != nil { + return fmt.Errorf("failed to initialize validator: %w", err) + } + + if err = validate.Struct(inputs); err != nil { + return validate.ParseValidationErrors(err) + } + + // Additional validation for ABI path + if _, err := os.Stat(inputs.AbiPath); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("ABI path does not exist: %s", inputs.AbiPath) + } + return fmt.Errorf("failed to access ABI path: %w", err) + } + + // Validate that if AbiPath is a directory, it contains .abi files + if info, err := os.Stat(inputs.AbiPath); err == nil && info.IsDir() { + files, err := filepath.Glob(filepath.Join(inputs.AbiPath, "*.abi")) + if err != nil { + return fmt.Errorf("failed to check for ABI files in directory: %w", err) + } + if len(files) == 0 { + return fmt.Errorf("no .abi files found in directory: %s", inputs.AbiPath) + } + } + + return nil +} + +// contractNameToPackage converts contract names to valid Go package names +// Examples: IERC20 -> ierc20, ReserveManager -> reserve_manager, IReserveManager -> ireserve_manager +func contractNameToPackage(contractName string) string { + if contractName == "" { + return "" + } + + var result []rune + runes := []rune(contractName) + + for i, r := range runes { + // Convert to lowercase + if r >= 'A' && r <= 'Z' { + lower := r - 'A' + 'a' + + // Add underscore before uppercase letters, but not: + // - At the beginning (i == 0) + // - If the previous character was also uppercase and this is followed by lowercase (e.g., "ERC" in "ERC20") + // - If this is part of a sequence of uppercase letters at the beginning (e.g., "IERC20" -> "ierc20") + if i > 0 { + prevIsUpper := runes[i-1] >= 'A' && runes[i-1] <= 'Z' + nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z' + + // Add underscore if: + // - Previous char was lowercase (CamelCase boundary) + // - Previous char was uppercase but this char is followed by lowercase (end of acronym) + if !prevIsUpper || (prevIsUpper && nextIsLower && i > 1) { + result = append(result, '_') + } + } + + result = append(result, lower) + } else { + result = append(result, r) + } + } + + return string(result) +} + +func processEvmAbiDirectory(inputs EvmInputs) error { + // Read all .abi files in the directory + files, err := filepath.Glob(filepath.Join(inputs.AbiPath, "*.abi")) + if err != nil { + return fmt.Errorf("failed to find ABI files: %w", err) + } + + if len(files) == 0 { + return fmt.Errorf("no .abi files found in directory: %s", inputs.AbiPath) + } + + packageNames := make(map[string]bool) + for _, abiFile := range files { + contractName := filepath.Base(abiFile) + contractName = contractName[:len(contractName)-4] + packageName := contractNameToPackage(contractName) + if _, exists := packageNames[packageName]; exists { + return fmt.Errorf("package name collision: multiple contracts would generate the same package name '%s' (contracts are converted to snake_case for package names). Please rename one of your contract files to avoid this conflict", packageName) + } + packageNames[packageName] = true + } + + // Process each ABI file + for _, abiFile := range files { + // Extract contract name from filename (remove .abi extension) + contractName := filepath.Base(abiFile) + contractName = contractName[:len(contractName)-4] // Remove .abi extension + + // Convert contract name to package name + packageName := contractNameToPackage(contractName) + + // Create per-contract output directory + contractOutDir := filepath.Join(inputs.OutPath, packageName) + if err := os.MkdirAll(contractOutDir, 0o755); err != nil { + return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) + } + + // Create output file path in contract-specific directory + outputFile := filepath.Join(contractOutDir, contractName+".go") + + fmt.Printf("Processing ABI file: %s, contract: %s, package: %s, output: %s\n", abiFile, contractName, packageName, outputFile) + + err = evm.GenerateBindings( + "", // combinedJSONPath - empty for now + abiFile, + packageName, // Use contract-specific package name + contractName, // Use contract name as type name + outputFile, + ) + if err != nil { + return fmt.Errorf("failed to generate bindings for %s: %w", contractName, err) + } + } + + return nil +} + +func processEvmSingleAbi(inputs EvmInputs) error { + // Extract contract name from ABI file path + contractName := filepath.Base(inputs.AbiPath) + if filepath.Ext(contractName) == ".abi" { + contractName = contractName[:len(contractName)-4] // Remove .abi extension + } + + // Convert contract name to package name + packageName := contractNameToPackage(contractName) + + // Create per-contract output directory + contractOutDir := filepath.Join(inputs.OutPath, packageName) + if err := os.MkdirAll(contractOutDir, 0o755); err != nil { + return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) + } + + // Create output file path in contract-specific directory + outputFile := filepath.Join(contractOutDir, contractName+".go") + + fmt.Printf("Processing single ABI file: %s, contract: %s, package: %s, output: %s\n", inputs.AbiPath, contractName, packageName, outputFile) + + return evm.GenerateBindings( + "", // combinedJSONPath - empty for now + inputs.AbiPath, + packageName, // Use contract-specific package name + contractName, // Use contract name as type name + outputFile, + ) +} + +func executeEvm(inputs EvmInputs) error { + fmt.Printf("GenerateBindings would be called here: projectRoot=%s, chainFamily=%s, language=%s, abiPath=%s, pkgName=%s, outPath=%s\n", inputs.ProjectRoot, inputs.ChainFamily, inputs.Language, inputs.AbiPath, inputs.PkgName, inputs.OutPath) + + // Validate language + switch inputs.Language { + case "go": + // Language supported, continue + default: + return fmt.Errorf("unsupported language: %s", inputs.Language) + } + + // Validate chain family and handle accordingly + switch inputs.ChainFamily { + case "evm": + // Create output directory if it doesn't exist + if err := os.MkdirAll(inputs.OutPath, 0o755); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + // Check if ABI path is a directory or file + info, err := os.Stat(inputs.AbiPath) + if err != nil { + return fmt.Errorf("failed to access ABI path: %w", err) + } + + if info.IsDir() { + if err := processEvmAbiDirectory(inputs); err != nil { + return err + } + } else { + if err := processEvmSingleAbi(inputs); err != nil { + return err + } + } + + err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go@"+creinit.SdkVersion) + if err != nil { + return err + } + err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm@"+creinit.EVMCapabilitiesVersion) + if err != nil { + return err + } + if err = runCommand(inputs.ProjectRoot, "go", "mod", "tidy"); err != nil { + return err + } + return nil + default: + return fmt.Errorf("unsupported chain family: %s", inputs.ChainFamily) + } +} diff --git a/cmd/generate-bindings/evm/README.md b/cmd/generate-bindings/evm/README.md new file mode 100644 index 00000000..da346939 --- /dev/null +++ b/cmd/generate-bindings/evm/README.md @@ -0,0 +1,49 @@ +## License + +This repository contains two separate license regimes: + +1. **LGPL-3.0-or-later** for all code in `./abigen` (the forked go-ethereum abigen). + See the full text in `LICENSE` under “GNU LESSER…” +2. **MIT** for everything else in this repo. + See the full text in `LICENSE` under “MIT License”. + + +# CRE Generated Bindings (MVP) + +This project utilizes a forked version of `abigen` (from go-ethereum) +that lets you generate Go bindings for your smart contracts using a custom template. + +## Prerequisites + +1. **Go** + Install Go 1.18 or later: + ```bash + brew install go # macOS (Homebrew) + sudo apt install golang # Ubuntu/Debian + ``` +2. **Solidity compiler** + Install `solc` to compile or verify your contracts: + ```bash + npm install -g solc # via npm + brew install solidity # macOS (Homebrew) + ``` + +## Usage +### Programmatic API + +```go +import "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" + +func main() { + err := bindings.GenerateBindings( + "./pkg/bindings/build/MyContract_combined.json", // or "" if using abiPath + "./pkg/bindings/MyContract.abi", // or "" for combined-json mode + "bindings", // Go package name + "MyContract", // typeName (single-ABI only) + "./pkg/bindings/build/bindings.go", // output file + ) + if err != nil { + log.Fatalf("generate bindings: %v", err) + } +} +``` \ No newline at end of file diff --git a/cmd/generate-bindings/bindings/abigen/FORK_METADATA.md b/cmd/generate-bindings/evm/abigen/FORK_METADATA.md similarity index 100% rename from cmd/generate-bindings/bindings/abigen/FORK_METADATA.md rename to cmd/generate-bindings/evm/abigen/FORK_METADATA.md diff --git a/cmd/generate-bindings/bindings/abigen/bind.go b/cmd/generate-bindings/evm/abigen/bind.go similarity index 100% rename from cmd/generate-bindings/bindings/abigen/bind.go rename to cmd/generate-bindings/evm/abigen/bind.go diff --git a/cmd/generate-bindings/bindings/abigen/bindv2.go b/cmd/generate-bindings/evm/abigen/bindv2.go similarity index 100% rename from cmd/generate-bindings/bindings/abigen/bindv2.go rename to cmd/generate-bindings/evm/abigen/bindv2.go diff --git a/cmd/generate-bindings/bindings/abigen/source.go.tpl b/cmd/generate-bindings/evm/abigen/source.go.tpl similarity index 100% rename from cmd/generate-bindings/bindings/abigen/source.go.tpl rename to cmd/generate-bindings/evm/abigen/source.go.tpl diff --git a/cmd/generate-bindings/bindings/abigen/source2.go.tpl b/cmd/generate-bindings/evm/abigen/source2.go.tpl similarity index 100% rename from cmd/generate-bindings/bindings/abigen/source2.go.tpl rename to cmd/generate-bindings/evm/abigen/source2.go.tpl diff --git a/cmd/generate-bindings/bindings/abigen/template.go b/cmd/generate-bindings/evm/abigen/template.go similarity index 100% rename from cmd/generate-bindings/bindings/abigen/template.go rename to cmd/generate-bindings/evm/abigen/template.go diff --git a/cmd/generate-bindings/bindings/bindgen.go b/cmd/generate-bindings/evm/bindgen.go similarity index 96% rename from cmd/generate-bindings/bindings/bindgen.go rename to cmd/generate-bindings/evm/bindgen.go index 593ed6dc..6ca40f63 100644 --- a/cmd/generate-bindings/bindings/bindgen.go +++ b/cmd/generate-bindings/evm/bindgen.go @@ -1,4 +1,4 @@ -package bindings +package evm import ( _ "embed" @@ -11,7 +11,7 @@ import ( "github.com/ethereum/go-ethereum/common/compiler" "github.com/ethereum/go-ethereum/crypto" - "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings/abigen" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm/abigen" ) //go:embed sourcecre.go.tpl diff --git a/cmd/generate-bindings/bindings/bindings_test.go b/cmd/generate-bindings/evm/bindings_test.go similarity index 99% rename from cmd/generate-bindings/bindings/bindings_test.go rename to cmd/generate-bindings/evm/bindings_test.go index de225b36..ab558459 100644 --- a/cmd/generate-bindings/bindings/bindings_test.go +++ b/cmd/generate-bindings/evm/bindings_test.go @@ -1,4 +1,4 @@ -package bindings_test +package evm_test import ( "context" @@ -20,7 +20,7 @@ import ( "github.com/smartcontractkit/cre-sdk-go/cre/testutils" consensusmock "github.com/smartcontractkit/cre-sdk-go/internal_testing/capabilities/consensus/mock" - datastorage "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings/testdata" + datastorage "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm/testdata" ) const anyChainSelector = uint64(1337) diff --git a/cmd/generate-bindings/bindings/gen.go b/cmd/generate-bindings/evm/gen.go similarity index 67% rename from cmd/generate-bindings/bindings/gen.go rename to cmd/generate-bindings/evm/gen.go index febfa9dc..34410db7 100644 --- a/cmd/generate-bindings/bindings/gen.go +++ b/cmd/generate-bindings/evm/gen.go @@ -1,2 +1,2 @@ //go:generate go run ./testdata/gen -package bindings +package evm diff --git a/cmd/generate-bindings/evm/gen_test.go b/cmd/generate-bindings/evm/gen_test.go new file mode 100644 index 00000000..730875f1 --- /dev/null +++ b/cmd/generate-bindings/evm/gen_test.go @@ -0,0 +1,31 @@ +package evm_test + +import ( + "testing" + + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" +) + +func TestGenerateBindings(t *testing.T) { + if err := evm.GenerateBindings( + "./testdata/DataStorage_combined.json", + "", + "bindings", + "", + "./testdata/bindings.go", + ); err != nil { + t.Fatal(err) + } +} + +func TestGenerateBindingsOld(t *testing.T) { + if err := evm.GenerateBindings( + "./testdata/DataStorage_combined.json", + "", + "bindingsold", + "", + "./testdata/bindingsold.go", + ); err != nil { + t.Fatal(err) + } +} diff --git a/cmd/generate-bindings/bindings/mockcontract.go.tpl b/cmd/generate-bindings/evm/mockcontract.go.tpl similarity index 100% rename from cmd/generate-bindings/bindings/mockcontract.go.tpl rename to cmd/generate-bindings/evm/mockcontract.go.tpl diff --git a/cmd/generate-bindings/bindings/sourcecre.go.tpl b/cmd/generate-bindings/evm/sourcecre.go.tpl similarity index 100% rename from cmd/generate-bindings/bindings/sourcecre.go.tpl rename to cmd/generate-bindings/evm/sourcecre.go.tpl diff --git a/cmd/generate-bindings/bindings/testdata/DataStorage.sol b/cmd/generate-bindings/evm/testdata/DataStorage.sol similarity index 100% rename from cmd/generate-bindings/bindings/testdata/DataStorage.sol rename to cmd/generate-bindings/evm/testdata/DataStorage.sol diff --git a/cmd/generate-bindings/bindings/testdata/DataStorage_combined.json b/cmd/generate-bindings/evm/testdata/DataStorage_combined.json similarity index 100% rename from cmd/generate-bindings/bindings/testdata/DataStorage_combined.json rename to cmd/generate-bindings/evm/testdata/DataStorage_combined.json diff --git a/cmd/generate-bindings/bindings/testdata/bindings.go b/cmd/generate-bindings/evm/testdata/bindings.go similarity index 100% rename from cmd/generate-bindings/bindings/testdata/bindings.go rename to cmd/generate-bindings/evm/testdata/bindings.go diff --git a/cmd/generate-bindings/bindings/testdata/bindings_mock.go b/cmd/generate-bindings/evm/testdata/bindings_mock.go similarity index 100% rename from cmd/generate-bindings/bindings/testdata/bindings_mock.go rename to cmd/generate-bindings/evm/testdata/bindings_mock.go diff --git a/cmd/generate-bindings/evm/testdata/bindingsold.go b/cmd/generate-bindings/evm/testdata/bindingsold.go new file mode 100644 index 00000000..5755e867 --- /dev/null +++ b/cmd/generate-bindings/evm/testdata/bindingsold.go @@ -0,0 +1,1249 @@ +// Code generated — DO NOT EDIT. + +package bindingsold + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "math/big" + "reflect" + "strings" + + ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/rpc" + "google.golang.org/protobuf/types/known/emptypb" + + pb2 "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + "github.com/smartcontractkit/chainlink-protos/cre/go/values/pb" + "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm" + "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm/bindings" + "github.com/smartcontractkit/cre-sdk-go/cre" +) + +var ( + _ = bytes.Equal + _ = errors.New + _ = fmt.Sprintf + _ = big.NewInt + _ = strings.NewReader + _ = ethereum.NotFound + _ = bind.Bind + _ = common.Big1 + _ = types.BloomLookup + _ = event.NewSubscription + _ = abi.ConvertType + _ = emptypb.Empty{} + _ = pb.NewBigIntFromInt + _ = pb2.AggregationType_AGGREGATION_TYPE_COMMON_PREFIX + _ = bindings.FilterOptions{} + _ = evm.FilterLogTriggerRequest{} + _ = cre.ResponseBufferTooSmall + _ = rpc.API{} + _ = json.Unmarshal + _ = reflect.Bool +) + +var DataStorageMetaData = &bind.MetaData{ + ABI: "[{\"inputs\":[{\"internalType\":\"address\",\"name\":\"requester\",\"type\":\"address\"},{\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"},{\"internalType\":\"string\",\"name\":\"reason\",\"type\":\"string\"}],\"name\":\"DataNotFound\",\"type\":\"error\"},{\"inputs\":[{\"internalType\":\"address\",\"name\":\"requester\",\"type\":\"address\"},{\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"},{\"internalType\":\"string\",\"name\":\"reason\",\"type\":\"string\"}],\"name\":\"DataNotFound2\",\"type\":\"error\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":true,\"internalType\":\"address\",\"name\":\"caller\",\"type\":\"address\"},{\"indexed\":false,\"internalType\":\"string\",\"name\":\"message\",\"type\":\"string\"}],\"name\":\"AccessLogged\",\"type\":\"event\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":true,\"internalType\":\"address\",\"name\":\"sender\",\"type\":\"address\"},{\"indexed\":false,\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"},{\"indexed\":false,\"internalType\":\"string\",\"name\":\"value\",\"type\":\"string\"}],\"name\":\"DataStored\",\"type\":\"event\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":false,\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"},{\"components\":[{\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"},{\"internalType\":\"string\",\"name\":\"value\",\"type\":\"string\"}],\"indexed\":true,\"internalType\":\"structDataStorage.UserData\",\"name\":\"userData\",\"type\":\"tuple\"},{\"indexed\":false,\"internalType\":\"string\",\"name\":\"sender\",\"type\":\"string\"},{\"indexed\":true,\"internalType\":\"bytes\",\"name\":\"metadata\",\"type\":\"bytes\"},{\"indexed\":true,\"internalType\":\"bytes[]\",\"name\":\"metadataArray\",\"type\":\"bytes[]\"}],\"name\":\"DynamicEvent\",\"type\":\"event\"},{\"anonymous\":false,\"inputs\":[],\"name\":\"NoFields\",\"type\":\"event\"},{\"inputs\":[],\"name\":\"getMultipleReserves\",\"outputs\":[{\"components\":[{\"internalType\":\"uint256\",\"name\":\"totalMinted\",\"type\":\"uint256\"},{\"internalType\":\"uint256\",\"name\":\"totalReserve\",\"type\":\"uint256\"}],\"internalType\":\"structDataStorage.UpdateReserves[]\",\"name\":\"reserves\",\"type\":\"tuple[]\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[],\"name\":\"getReserves\",\"outputs\":[{\"components\":[{\"internalType\":\"uint256\",\"name\":\"totalMinted\",\"type\":\"uint256\"},{\"internalType\":\"uint256\",\"name\":\"totalReserve\",\"type\":\"uint256\"}],\"internalType\":\"structDataStorage.UpdateReserves\",\"name\":\"\",\"type\":\"tuple\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[],\"name\":\"getTupleReserves\",\"outputs\":[{\"internalType\":\"uint256\",\"name\":\"totalMinted\",\"type\":\"uint256\"},{\"internalType\":\"uint256\",\"name\":\"totalReserve\",\"type\":\"uint256\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[],\"name\":\"getValue\",\"outputs\":[{\"internalType\":\"string\",\"name\":\"\",\"type\":\"string\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"string\",\"name\":\"message\",\"type\":\"string\"}],\"name\":\"logAccess\",\"outputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"bytes\",\"name\":\"metadata\",\"type\":\"bytes\"},{\"internalType\":\"bytes\",\"name\":\"payload\",\"type\":\"bytes\"}],\"name\":\"onReport\",\"outputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"address\",\"name\":\"user\",\"type\":\"address\"},{\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"}],\"name\":\"readData\",\"outputs\":[{\"internalType\":\"string\",\"name\":\"\",\"type\":\"string\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"},{\"internalType\":\"string\",\"name\":\"value\",\"type\":\"string\"}],\"name\":\"storeData\",\"outputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"components\":[{\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"},{\"internalType\":\"string\",\"name\":\"value\",\"type\":\"string\"}],\"internalType\":\"structDataStorage.UserData\",\"name\":\"userData\",\"type\":\"tuple\"}],\"name\":\"storeUserData\",\"outputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"string\",\"name\":\"key\",\"type\":\"string\"},{\"internalType\":\"string\",\"name\":\"newValue\",\"type\":\"string\"}],\"name\":\"updateData\",\"outputs\":[{\"internalType\":\"string\",\"name\":\"oldValue\",\"type\":\"string\"}],\"stateMutability\":\"nonpayable\",\"type\":\"function\"}]", + Bin: "0x6080604052348015600e575f5ffd5b50610dfe8061001c5f395ff3fe608060405234801561000f575f5ffd5b506004361061009b575f3560e01c806398458c5d1161006357806398458c5d14610145578063b765cb7c14610158578063bddbb0231461016d578063ccf1582714610180578063f5bfa81514610193575f5ffd5b80630902f1ac1461009f57806320965255146100df578063255e0caf146101085780634ece5b4c1461011d578063805f213214610132575b5f5ffd5b6040805180820182525f80825260209182015281518083018352606480825260c89183019182528351908152905191810191909152015b60405180910390f35b6040805180820190915260048152631d195cdd60e21b60208201525b6040516100d691906106b0565b604080516064815260c86020820152016100d6565b61013061012b36600461070d565b6101a6565b005b61013061014036600461070d565b610232565b610130610153366004610777565b6102c9565b61016061036c565b6040516100d691906107ad565b6100fb61017b36600461070d565b610418565b61013061018e366004610804565b610549565b6100fb6101a1366004610842565b610590565b335f90815260208190526040908190209051839183916101c9908890889061089d565b908152602001604051809103902091826101e4929190610944565b50336001600160a01b03167fc95c7d5d3ac582f659cd004afbea77723e1315567b6557f3c059e8eb9586518f858585856040516102249493929190610a25565b60405180910390a250505050565b5f61023f82840184610adf565b602080820151335f90815291829052604091829020835192519394509092909161026891610b90565b908152602001604051809103902090816102829190610ba6565b508051602082015160405133927fc95c7d5d3ac582f659cd004afbea77723e1315567b6557f3c059e8eb9586518f926102ba92610c60565b60405180910390a25050505050565b6102d66020820182610c8d565b335f9081526020819052604090206102ee8480610c8d565b6040516102fc92919061089d565b90815260200160405180910390209182610317929190610944565b50337fc95c7d5d3ac582f659cd004afbea77723e1315567b6557f3c059e8eb9586518f6103448380610c8d565b6103516020860186610c8d565b6040516103619493929190610a25565b60405180910390a250565b6040805160028082526060828101909352816020015b604080518082019091525f808252602082015281526020019060019003908161038257905050905060405180604001604052806064815260200160c8815250815f815181106103d3576103d3610ccf565b6020026020010181905250604051806040016040528061012c81526020016101908152508160018151811061040a5761040a610ccf565b602002602001018190525090565b335f908152602081905260409081902090516060919061043b908790879061089d565b90815260200160405180910390208054610454906108c0565b80601f0160208091040260200160405190810160405280929190818152602001828054610480906108c0565b80156104cb5780601f106104a2576101008083540402835291602001916104cb565b820191905f5260205f20905b8154815290600101906020018083116104ae57829003601f168201915b5050505050905080515f036105025733858560405163f1e5020960e01b81526004016104f993929190610ce3565b60405180910390fd5b335f9081526020819052604090819020905184918491610525908990899061089d565b90815260200160405180910390209182610540929190610944565b50949350505050565b336001600160a01b03167fe2ab1536af9681ad9e5927bca61830526c4cd932e970162eef77328af1fdcfb58383604051610584929190610d46565b60405180910390a25050565b6001600160a01b0383165f90815260208190526040808220905160609291906105bc908690869061089d565b908152602001604051809103902080546105d5906108c0565b80601f0160208091040260200160405190810160405280929190818152602001828054610601906108c0565b801561064c5780601f106106235761010080835404028352916020019161064c565b820191905f5260205f20905b81548152906001019060200180831161062f57829003601f168201915b5050505050905080515f0361067a5784848460405163f1e5020960e01b81526004016104f993929190610d59565b949350505050565b5f81518084528060208401602086015e5f602082860101526020601f19601f83011685010191505092915050565b602081525f6106c26020830184610682565b9392505050565b5f5f83601f8401126106d9575f5ffd5b5081356001600160401b038111156106ef575f5ffd5b602083019150836020828501011115610706575f5ffd5b9250929050565b5f5f5f5f60408587031215610720575f5ffd5b84356001600160401b03811115610735575f5ffd5b610741878288016106c9565b90955093505060208501356001600160401b0381111561075f575f5ffd5b61076b878288016106c9565b95989497509550505050565b5f60208284031215610787575f5ffd5b81356001600160401b0381111561079c575f5ffd5b8201604081850312156106c2575f5ffd5b602080825282518282018190525f918401906040840190835b818110156107f9576107e383855180518252602090810151910152565b60209390930192604092909201916001016107c6565b509095945050505050565b5f5f60208385031215610815575f5ffd5b82356001600160401b0381111561082a575f5ffd5b610836858286016106c9565b90969095509350505050565b5f5f5f60408486031215610854575f5ffd5b83356001600160a01b038116811461086a575f5ffd5b925060208401356001600160401b03811115610884575f5ffd5b610890868287016106c9565b9497909650939450505050565b818382375f9101908152919050565b634e487b7160e01b5f52604160045260245ffd5b600181811c908216806108d457607f821691505b6020821081036108f257634e487b7160e01b5f52602260045260245ffd5b50919050565b601f82111561093f57805f5260205f20601f840160051c8101602085101561091d5750805b601f840160051c820191505b8181101561093c575f8155600101610929565b50505b505050565b6001600160401b0383111561095b5761095b6108ac565b61096f8361096983546108c0565b836108f8565b5f601f8411600181146109a0575f85156109895750838201355b5f19600387901b1c1916600186901b17835561093c565b5f83815260208120601f198716915b828110156109cf57868501358255602094850194600190920191016109af565b50868210156109eb575f1960f88860031b161c19848701351681555b505060018560011b0183555050505050565b81835281816020850137505f828201602090810191909152601f909101601f19169091010190565b604081525f610a386040830186886109fd565b8281036020840152610a4b8185876109fd565b979650505050505050565b5f82601f830112610a65575f5ffd5b81356001600160401b03811115610a7e57610a7e6108ac565b604051601f8201601f19908116603f011681016001600160401b0381118282101715610aac57610aac6108ac565b604052818152838201602001851015610ac3575f5ffd5b816020850160208301375f918101602001919091529392505050565b5f60208284031215610aef575f5ffd5b81356001600160401b03811115610b04575f5ffd5b820160408185031215610b15575f5ffd5b604080519081016001600160401b0381118282101715610b3757610b376108ac565b60405281356001600160401b03811115610b4f575f5ffd5b610b5b86828501610a56565b82525060208201356001600160401b03811115610b76575f5ffd5b610b8286828501610a56565b602083015250949350505050565b5f82518060208501845e5f920191825250919050565b81516001600160401b03811115610bbf57610bbf6108ac565b610bd381610bcd84546108c0565b846108f8565b6020601f821160018114610c05575f8315610bee5750848201515b5f19600385901b1c1916600184901b17845561093c565b5f84815260208120601f198516915b82811015610c345787850151825560209485019460019092019101610c14565b5084821015610c5157868401515f19600387901b60f8161c191681555b50505050600190811b01905550565b604081525f610c726040830185610682565b8281036020840152610c848185610682565b95945050505050565b5f5f8335601e19843603018112610ca2575f5ffd5b8301803591506001600160401b03821115610cbb575f5ffd5b602001915036819003821315610706575f5ffd5b634e487b7160e01b5f52603260045260245ffd5b6001600160a01b03841681526060602082018190525f90610d0790830184866109fd565b828103604093840152601a81527f4e6f206578697374696e67206461746120746f20757064617465000000000000602082015291909101949350505050565b602081525f61067a6020830184866109fd565b6001600160a01b03841681526060602082018190525f90610d7d90830184866109fd565b8281036040840152602181527f4e6f2064617461206173736f63696174656420776974682074686973206b65796020820152601760f91b60408201526060810191505094935050505056fea26469706673582212206938aba98b7e3f58f5746f55b3e72f9e14ee3ad529e8ff28f32829e2cc0303c364736f6c634300081e0033", +} + +// Structs +type UpdateReserves struct { + TotalMinted *big.Int + TotalReserve *big.Int +} + +type UserData struct { + Key string + Value string +} + +// Contract Method Inputs +type LogAccessInput struct { + Message string +} + +type OnReportInput struct { + Metadata []byte + Payload []byte +} + +type ReadDataInput struct { + User common.Address + Key string +} + +type StoreDataInput struct { + Key string + Value string +} + +type StoreUserDataInput struct { + UserData UserData +} + +type UpdateDataInput struct { + Key string + NewValue string +} + +// Contract Method Outputs +type GetTupleReservesOutput struct { + TotalMinted *big.Int + TotalReserve *big.Int +} + +// Errors +type DataNotFound struct { + Requester common.Address + Key string + Reason string +} + +type DataNotFound2 struct { + Requester common.Address + Key string + Reason string +} + +// Events +// The Topics struct should be used as a filter (for log triggers). +// Note: It is only possible to filter on indexed fields. +// Indexed (string and bytes) fields will be of type common.Hash. +// They need to he (crypto.Keccak256) hashed and passed in. +// Indexed (tuple/slice/array) fields can be passed in as is, the EncodeTopics function will handle the hashing. +// +// The Decoded struct will be the result of calling decode (Adapt) on the log trigger result. +// Indexed dynamic type fields will be of type common.Hash. + +type AccessLoggedTopics struct { + Caller common.Address +} + +type AccessLoggedDecoded struct { + Caller common.Address + Message string +} + +type DataStoredTopics struct { + Sender common.Address +} + +type DataStoredDecoded struct { + Sender common.Address + Key string + Value string +} + +type DynamicEventTopics struct { + UserData UserData + Metadata common.Hash + MetadataArray [][]byte +} + +type DynamicEventDecoded struct { + Key string + UserData common.Hash + Sender string + Metadata common.Hash + MetadataArray common.Hash +} + +type NoFieldsTopics struct { +} + +type NoFieldsDecoded struct { +} + +// Main Binding Type for DataStorage +type DataStorage struct { + Address common.Address + Options *bindings.ContractInitOptions + ABI *abi.ABI + client *evm.Client + Codec DataStorageCodec +} + +type DataStorageCodec interface { + EncodeGetMultipleReservesMethodCall() ([]byte, error) + DecodeGetMultipleReservesMethodOutput(data []byte) ([]UpdateReserves, error) + EncodeGetReservesMethodCall() ([]byte, error) + DecodeGetReservesMethodOutput(data []byte) (UpdateReserves, error) + EncodeGetTupleReservesMethodCall() ([]byte, error) + DecodeGetTupleReservesMethodOutput(data []byte) (GetTupleReservesOutput, error) + EncodeGetValueMethodCall() ([]byte, error) + DecodeGetValueMethodOutput(data []byte) (string, error) + EncodeLogAccessMethodCall(in LogAccessInput) ([]byte, error) + EncodeOnReportMethodCall(in OnReportInput) ([]byte, error) + EncodeReadDataMethodCall(in ReadDataInput) ([]byte, error) + DecodeReadDataMethodOutput(data []byte) (string, error) + EncodeStoreDataMethodCall(in StoreDataInput) ([]byte, error) + EncodeStoreUserDataMethodCall(in StoreUserDataInput) ([]byte, error) + EncodeUpdateDataMethodCall(in UpdateDataInput) ([]byte, error) + DecodeUpdateDataMethodOutput(data []byte) (string, error) + EncodeUpdateReservesStruct(in UpdateReserves) ([]byte, error) + EncodeUserDataStruct(in UserData) ([]byte, error) + AccessLoggedLogHash() []byte + EncodeAccessLoggedTopics(evt abi.Event, values []AccessLoggedTopics) ([]*evm.TopicValues, error) + DecodeAccessLogged(log *evm.Log) (*AccessLoggedDecoded, error) + DataStoredLogHash() []byte + EncodeDataStoredTopics(evt abi.Event, values []DataStoredTopics) ([]*evm.TopicValues, error) + DecodeDataStored(log *evm.Log) (*DataStoredDecoded, error) + DynamicEventLogHash() []byte + EncodeDynamicEventTopics(evt abi.Event, values []DynamicEventTopics) ([]*evm.TopicValues, error) + DecodeDynamicEvent(log *evm.Log) (*DynamicEventDecoded, error) + NoFieldsLogHash() []byte + EncodeNoFieldsTopics(evt abi.Event, values []NoFieldsTopics) ([]*evm.TopicValues, error) + DecodeNoFields(log *evm.Log) (*NoFieldsDecoded, error) +} + +func NewDataStorage( + client *evm.Client, + address common.Address, + options *bindings.ContractInitOptions, +) (*DataStorage, error) { + parsed, err := abi.JSON(strings.NewReader(DataStorageMetaData.ABI)) + if err != nil { + return nil, err + } + codec, err := NewCodec() + if err != nil { + return nil, err + } + return &DataStorage{ + Address: address, + Options: options, + ABI: &parsed, + client: client, + Codec: codec, + }, nil +} + +type Codec struct { + abi *abi.ABI +} + +func NewCodec() (DataStorageCodec, error) { + parsed, err := abi.JSON(strings.NewReader(DataStorageMetaData.ABI)) + if err != nil { + return nil, err + } + return &Codec{abi: &parsed}, nil +} + +func (c *Codec) EncodeGetMultipleReservesMethodCall() ([]byte, error) { + return c.abi.Pack("getMultipleReserves") +} + +func (c *Codec) DecodeGetMultipleReservesMethodOutput(data []byte) ([]UpdateReserves, error) { + vals, err := c.abi.Methods["getMultipleReserves"].Outputs.Unpack(data) + if err != nil { + return *new([]UpdateReserves), err + } + jsonData, err := json.Marshal(vals[0]) + if err != nil { + return *new([]UpdateReserves), fmt.Errorf("failed to marshal ABI result: %w", err) + } + + var result []UpdateReserves + if err := json.Unmarshal(jsonData, &result); err != nil { + return *new([]UpdateReserves), fmt.Errorf("failed to unmarshal to []UpdateReserves: %w", err) + } + + return result, nil +} + +func (c *Codec) EncodeGetReservesMethodCall() ([]byte, error) { + return c.abi.Pack("getReserves") +} + +func (c *Codec) DecodeGetReservesMethodOutput(data []byte) (UpdateReserves, error) { + vals, err := c.abi.Methods["getReserves"].Outputs.Unpack(data) + if err != nil { + return *new(UpdateReserves), err + } + jsonData, err := json.Marshal(vals[0]) + if err != nil { + return *new(UpdateReserves), fmt.Errorf("failed to marshal ABI result: %w", err) + } + + var result UpdateReserves + if err := json.Unmarshal(jsonData, &result); err != nil { + return *new(UpdateReserves), fmt.Errorf("failed to unmarshal to UpdateReserves: %w", err) + } + + return result, nil +} + +func (c *Codec) EncodeGetTupleReservesMethodCall() ([]byte, error) { + return c.abi.Pack("getTupleReserves") +} + +func (c *Codec) DecodeGetTupleReservesMethodOutput(data []byte) (GetTupleReservesOutput, error) { + vals, err := c.abi.Methods["getTupleReserves"].Outputs.Unpack(data) + if err != nil { + return GetTupleReservesOutput{}, err + } + if len(vals) != 2 { + return GetTupleReservesOutput{}, fmt.Errorf("expected 2 values, got %d", len(vals)) + } + jsonData0, err := json.Marshal(vals[0]) + if err != nil { + return GetTupleReservesOutput{}, fmt.Errorf("failed to marshal ABI result 0: %w", err) + } + + var result0 *big.Int + if err := json.Unmarshal(jsonData0, &result0); err != nil { + return GetTupleReservesOutput{}, fmt.Errorf("failed to unmarshal to *big.Int: %w", err) + } + jsonData1, err := json.Marshal(vals[1]) + if err != nil { + return GetTupleReservesOutput{}, fmt.Errorf("failed to marshal ABI result 1: %w", err) + } + + var result1 *big.Int + if err := json.Unmarshal(jsonData1, &result1); err != nil { + return GetTupleReservesOutput{}, fmt.Errorf("failed to unmarshal to *big.Int: %w", err) + } + + return GetTupleReservesOutput{ + TotalMinted: result0, + TotalReserve: result1, + }, nil +} + +func (c *Codec) EncodeGetValueMethodCall() ([]byte, error) { + return c.abi.Pack("getValue") +} + +func (c *Codec) DecodeGetValueMethodOutput(data []byte) (string, error) { + vals, err := c.abi.Methods["getValue"].Outputs.Unpack(data) + if err != nil { + return *new(string), err + } + jsonData, err := json.Marshal(vals[0]) + if err != nil { + return *new(string), fmt.Errorf("failed to marshal ABI result: %w", err) + } + + var result string + if err := json.Unmarshal(jsonData, &result); err != nil { + return *new(string), fmt.Errorf("failed to unmarshal to string: %w", err) + } + + return result, nil +} + +func (c *Codec) EncodeLogAccessMethodCall(in LogAccessInput) ([]byte, error) { + return c.abi.Pack("logAccess", in.Message) +} + +func (c *Codec) EncodeOnReportMethodCall(in OnReportInput) ([]byte, error) { + return c.abi.Pack("onReport", in.Metadata, in.Payload) +} + +func (c *Codec) EncodeReadDataMethodCall(in ReadDataInput) ([]byte, error) { + return c.abi.Pack("readData", in.User, in.Key) +} + +func (c *Codec) DecodeReadDataMethodOutput(data []byte) (string, error) { + vals, err := c.abi.Methods["readData"].Outputs.Unpack(data) + if err != nil { + return *new(string), err + } + jsonData, err := json.Marshal(vals[0]) + if err != nil { + return *new(string), fmt.Errorf("failed to marshal ABI result: %w", err) + } + + var result string + if err := json.Unmarshal(jsonData, &result); err != nil { + return *new(string), fmt.Errorf("failed to unmarshal to string: %w", err) + } + + return result, nil +} + +func (c *Codec) EncodeStoreDataMethodCall(in StoreDataInput) ([]byte, error) { + return c.abi.Pack("storeData", in.Key, in.Value) +} + +func (c *Codec) EncodeStoreUserDataMethodCall(in StoreUserDataInput) ([]byte, error) { + return c.abi.Pack("storeUserData", in.UserData) +} + +func (c *Codec) EncodeUpdateDataMethodCall(in UpdateDataInput) ([]byte, error) { + return c.abi.Pack("updateData", in.Key, in.NewValue) +} + +func (c *Codec) DecodeUpdateDataMethodOutput(data []byte) (string, error) { + vals, err := c.abi.Methods["updateData"].Outputs.Unpack(data) + if err != nil { + return *new(string), err + } + jsonData, err := json.Marshal(vals[0]) + if err != nil { + return *new(string), fmt.Errorf("failed to marshal ABI result: %w", err) + } + + var result string + if err := json.Unmarshal(jsonData, &result); err != nil { + return *new(string), fmt.Errorf("failed to unmarshal to string: %w", err) + } + + return result, nil +} + +func (c *Codec) EncodeUpdateReservesStruct(in UpdateReserves) ([]byte, error) { + tupleType, err := abi.NewType( + "tuple", "", + []abi.ArgumentMarshaling{ + {Name: "totalMinted", Type: "uint256"}, + {Name: "totalReserve", Type: "uint256"}, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to create tuple type for UpdateReserves: %w", err) + } + args := abi.Arguments{ + {Name: "updateReserves", Type: tupleType}, + } + + return args.Pack(in) +} +func (c *Codec) EncodeUserDataStruct(in UserData) ([]byte, error) { + tupleType, err := abi.NewType( + "tuple", "", + []abi.ArgumentMarshaling{ + {Name: "key", Type: "string"}, + {Name: "value", Type: "string"}, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to create tuple type for UserData: %w", err) + } + args := abi.Arguments{ + {Name: "userData", Type: tupleType}, + } + + return args.Pack(in) +} + +func (c *Codec) AccessLoggedLogHash() []byte { + return c.abi.Events["AccessLogged"].ID.Bytes() +} + +func (c *Codec) EncodeAccessLoggedTopics( + evt abi.Event, + values []AccessLoggedTopics, +) ([]*evm.TopicValues, error) { + var callerRule []interface{} + for _, v := range values { + if reflect.ValueOf(v.Caller).IsZero() { + callerRule = append(callerRule, common.Hash{}) + continue + } + fieldVal, err := bindings.PrepareTopicArg(evt.Inputs[0], v.Caller) + if err != nil { + return nil, err + } + callerRule = append(callerRule, fieldVal) + } + + rawTopics, err := abi.MakeTopics( + callerRule, + ) + if err != nil { + return nil, err + } + + return bindings.PrepareTopics(rawTopics, evt.ID.Bytes()), nil +} + +// DecodeAccessLogged decodes a log into a AccessLogged struct. +func (c *Codec) DecodeAccessLogged(log *evm.Log) (*AccessLoggedDecoded, error) { + event := new(AccessLoggedDecoded) + if err := c.abi.UnpackIntoInterface(event, "AccessLogged", log.Data); err != nil { + return nil, err + } + var indexed abi.Arguments + for _, arg := range c.abi.Events["AccessLogged"].Inputs { + if arg.Indexed { + if arg.Type.T == abi.TupleTy { + // abigen throws on tuple, so converting to bytes to + // receive back the common.Hash as is instead of error + arg.Type.T = abi.BytesTy + } + indexed = append(indexed, arg) + } + } + // Convert [][]byte → []common.Hash + topics := make([]common.Hash, len(log.Topics)) + for i, t := range log.Topics { + topics[i] = common.BytesToHash(t) + } + + if err := abi.ParseTopics(event, indexed, topics[1:]); err != nil { + return nil, err + } + return event, nil +} + +func (c *Codec) DataStoredLogHash() []byte { + return c.abi.Events["DataStored"].ID.Bytes() +} + +func (c *Codec) EncodeDataStoredTopics( + evt abi.Event, + values []DataStoredTopics, +) ([]*evm.TopicValues, error) { + var senderRule []interface{} + for _, v := range values { + if reflect.ValueOf(v.Sender).IsZero() { + senderRule = append(senderRule, common.Hash{}) + continue + } + fieldVal, err := bindings.PrepareTopicArg(evt.Inputs[0], v.Sender) + if err != nil { + return nil, err + } + senderRule = append(senderRule, fieldVal) + } + + rawTopics, err := abi.MakeTopics( + senderRule, + ) + if err != nil { + return nil, err + } + + return bindings.PrepareTopics(rawTopics, evt.ID.Bytes()), nil +} + +// DecodeDataStored decodes a log into a DataStored struct. +func (c *Codec) DecodeDataStored(log *evm.Log) (*DataStoredDecoded, error) { + event := new(DataStoredDecoded) + if err := c.abi.UnpackIntoInterface(event, "DataStored", log.Data); err != nil { + return nil, err + } + var indexed abi.Arguments + for _, arg := range c.abi.Events["DataStored"].Inputs { + if arg.Indexed { + if arg.Type.T == abi.TupleTy { + // abigen throws on tuple, so converting to bytes to + // receive back the common.Hash as is instead of error + arg.Type.T = abi.BytesTy + } + indexed = append(indexed, arg) + } + } + // Convert [][]byte → []common.Hash + topics := make([]common.Hash, len(log.Topics)) + for i, t := range log.Topics { + topics[i] = common.BytesToHash(t) + } + + if err := abi.ParseTopics(event, indexed, topics[1:]); err != nil { + return nil, err + } + return event, nil +} + +func (c *Codec) DynamicEventLogHash() []byte { + return c.abi.Events["DynamicEvent"].ID.Bytes() +} + +func (c *Codec) EncodeDynamicEventTopics( + evt abi.Event, + values []DynamicEventTopics, +) ([]*evm.TopicValues, error) { + var userDataRule []interface{} + for _, v := range values { + if reflect.ValueOf(v.UserData).IsZero() { + userDataRule = append(userDataRule, common.Hash{}) + continue + } + fieldVal, err := bindings.PrepareTopicArg(evt.Inputs[1], v.UserData) + if err != nil { + return nil, err + } + userDataRule = append(userDataRule, fieldVal) + } + var metadataRule []interface{} + for _, v := range values { + if reflect.ValueOf(v.Metadata).IsZero() { + metadataRule = append(metadataRule, common.Hash{}) + continue + } + fieldVal, err := bindings.PrepareTopicArg(evt.Inputs[3], v.Metadata) + if err != nil { + return nil, err + } + metadataRule = append(metadataRule, fieldVal) + } + var metadataArrayRule []interface{} + for _, v := range values { + if reflect.ValueOf(v.MetadataArray).IsZero() { + metadataArrayRule = append(metadataArrayRule, common.Hash{}) + continue + } + fieldVal, err := bindings.PrepareTopicArg(evt.Inputs[4], v.MetadataArray) + if err != nil { + return nil, err + } + metadataArrayRule = append(metadataArrayRule, fieldVal) + } + + rawTopics, err := abi.MakeTopics( + userDataRule, + metadataRule, + metadataArrayRule, + ) + if err != nil { + return nil, err + } + + return bindings.PrepareTopics(rawTopics, evt.ID.Bytes()), nil +} + +// DecodeDynamicEvent decodes a log into a DynamicEvent struct. +func (c *Codec) DecodeDynamicEvent(log *evm.Log) (*DynamicEventDecoded, error) { + event := new(DynamicEventDecoded) + if err := c.abi.UnpackIntoInterface(event, "DynamicEvent", log.Data); err != nil { + return nil, err + } + var indexed abi.Arguments + for _, arg := range c.abi.Events["DynamicEvent"].Inputs { + if arg.Indexed { + if arg.Type.T == abi.TupleTy { + // abigen throws on tuple, so converting to bytes to + // receive back the common.Hash as is instead of error + arg.Type.T = abi.BytesTy + } + indexed = append(indexed, arg) + } + } + // Convert [][]byte → []common.Hash + topics := make([]common.Hash, len(log.Topics)) + for i, t := range log.Topics { + topics[i] = common.BytesToHash(t) + } + + if err := abi.ParseTopics(event, indexed, topics[1:]); err != nil { + return nil, err + } + return event, nil +} + +func (c *Codec) NoFieldsLogHash() []byte { + return c.abi.Events["NoFields"].ID.Bytes() +} + +func (c *Codec) EncodeNoFieldsTopics( + evt abi.Event, + values []NoFieldsTopics, +) ([]*evm.TopicValues, error) { + + rawTopics, err := abi.MakeTopics() + if err != nil { + return nil, err + } + + return bindings.PrepareTopics(rawTopics, evt.ID.Bytes()), nil +} + +// DecodeNoFields decodes a log into a NoFields struct. +func (c *Codec) DecodeNoFields(log *evm.Log) (*NoFieldsDecoded, error) { + event := new(NoFieldsDecoded) + if err := c.abi.UnpackIntoInterface(event, "NoFields", log.Data); err != nil { + return nil, err + } + var indexed abi.Arguments + for _, arg := range c.abi.Events["NoFields"].Inputs { + if arg.Indexed { + if arg.Type.T == abi.TupleTy { + // abigen throws on tuple, so converting to bytes to + // receive back the common.Hash as is instead of error + arg.Type.T = abi.BytesTy + } + indexed = append(indexed, arg) + } + } + // Convert [][]byte → []common.Hash + topics := make([]common.Hash, len(log.Topics)) + for i, t := range log.Topics { + topics[i] = common.BytesToHash(t) + } + + if err := abi.ParseTopics(event, indexed, topics[1:]); err != nil { + return nil, err + } + return event, nil +} + +func (c DataStorage) GetMultipleReserves( + runtime cre.Runtime, + blockNumber *big.Int, +) cre.Promise[[]UpdateReserves] { + calldata, err := c.Codec.EncodeGetMultipleReservesMethodCall() + if err != nil { + return cre.PromiseFromResult[[]UpdateReserves](*new([]UpdateReserves), err) + } + + var bn cre.Promise[*pb.BigInt] + if blockNumber == nil { + promise := c.client.HeaderByNumber(runtime, &evm.HeaderByNumberRequest{ + BlockNumber: bindings.FinalizedBlockNumber, + }) + + bn = cre.Then(promise, func(finalizedBlock *evm.HeaderByNumberReply) (*pb.BigInt, error) { + if finalizedBlock == nil || finalizedBlock.Header == nil { + return nil, errors.New("failed to get finalized block header") + } + return finalizedBlock.Header.BlockNumber, nil + }) + } else { + bn = cre.PromiseFromResult(pb.NewBigIntFromInt(blockNumber), nil) + } + + promise := cre.ThenPromise(bn, func(bn *pb.BigInt) cre.Promise[*evm.CallContractReply] { + return c.client.CallContract(runtime, &evm.CallContractRequest{ + Call: &evm.CallMsg{To: c.Address.Bytes(), Data: calldata}, + BlockNumber: bn, + }) + }) + return cre.Then(promise, func(response *evm.CallContractReply) ([]UpdateReserves, error) { + return c.Codec.DecodeGetMultipleReservesMethodOutput(response.Data) + }) + +} + +func (c DataStorage) GetReserves( + runtime cre.Runtime, + blockNumber *big.Int, +) cre.Promise[UpdateReserves] { + calldata, err := c.Codec.EncodeGetReservesMethodCall() + if err != nil { + return cre.PromiseFromResult[UpdateReserves](*new(UpdateReserves), err) + } + + var bn cre.Promise[*pb.BigInt] + if blockNumber == nil { + promise := c.client.HeaderByNumber(runtime, &evm.HeaderByNumberRequest{ + BlockNumber: bindings.FinalizedBlockNumber, + }) + + bn = cre.Then(promise, func(finalizedBlock *evm.HeaderByNumberReply) (*pb.BigInt, error) { + if finalizedBlock == nil || finalizedBlock.Header == nil { + return nil, errors.New("failed to get finalized block header") + } + return finalizedBlock.Header.BlockNumber, nil + }) + } else { + bn = cre.PromiseFromResult(pb.NewBigIntFromInt(blockNumber), nil) + } + + promise := cre.ThenPromise(bn, func(bn *pb.BigInt) cre.Promise[*evm.CallContractReply] { + return c.client.CallContract(runtime, &evm.CallContractRequest{ + Call: &evm.CallMsg{To: c.Address.Bytes(), Data: calldata}, + BlockNumber: bn, + }) + }) + return cre.Then(promise, func(response *evm.CallContractReply) (UpdateReserves, error) { + return c.Codec.DecodeGetReservesMethodOutput(response.Data) + }) + +} + +func (c DataStorage) GetTupleReserves( + runtime cre.Runtime, + blockNumber *big.Int, +) cre.Promise[GetTupleReservesOutput] { + calldata, err := c.Codec.EncodeGetTupleReservesMethodCall() + if err != nil { + return cre.PromiseFromResult[GetTupleReservesOutput](GetTupleReservesOutput{}, err) + } + + var bn cre.Promise[*pb.BigInt] + if blockNumber == nil { + promise := c.client.HeaderByNumber(runtime, &evm.HeaderByNumberRequest{ + BlockNumber: bindings.FinalizedBlockNumber, + }) + + bn = cre.Then(promise, func(finalizedBlock *evm.HeaderByNumberReply) (*pb.BigInt, error) { + if finalizedBlock == nil || finalizedBlock.Header == nil { + return nil, errors.New("failed to get finalized block header") + } + return finalizedBlock.Header.BlockNumber, nil + }) + } else { + bn = cre.PromiseFromResult(pb.NewBigIntFromInt(blockNumber), nil) + } + + promise := cre.ThenPromise(bn, func(bn *pb.BigInt) cre.Promise[*evm.CallContractReply] { + return c.client.CallContract(runtime, &evm.CallContractRequest{ + Call: &evm.CallMsg{To: c.Address.Bytes(), Data: calldata}, + BlockNumber: bn, + }) + }) + return cre.Then(promise, func(response *evm.CallContractReply) (GetTupleReservesOutput, error) { + return c.Codec.DecodeGetTupleReservesMethodOutput(response.Data) + }) + +} + +func (c DataStorage) GetValue( + runtime cre.Runtime, + blockNumber *big.Int, +) cre.Promise[string] { + calldata, err := c.Codec.EncodeGetValueMethodCall() + if err != nil { + return cre.PromiseFromResult[string](*new(string), err) + } + + var bn cre.Promise[*pb.BigInt] + if blockNumber == nil { + promise := c.client.HeaderByNumber(runtime, &evm.HeaderByNumberRequest{ + BlockNumber: bindings.FinalizedBlockNumber, + }) + + bn = cre.Then(promise, func(finalizedBlock *evm.HeaderByNumberReply) (*pb.BigInt, error) { + if finalizedBlock == nil || finalizedBlock.Header == nil { + return nil, errors.New("failed to get finalized block header") + } + return finalizedBlock.Header.BlockNumber, nil + }) + } else { + bn = cre.PromiseFromResult(pb.NewBigIntFromInt(blockNumber), nil) + } + + promise := cre.ThenPromise(bn, func(bn *pb.BigInt) cre.Promise[*evm.CallContractReply] { + return c.client.CallContract(runtime, &evm.CallContractRequest{ + Call: &evm.CallMsg{To: c.Address.Bytes(), Data: calldata}, + BlockNumber: bn, + }) + }) + return cre.Then(promise, func(response *evm.CallContractReply) (string, error) { + return c.Codec.DecodeGetValueMethodOutput(response.Data) + }) + +} + +func (c DataStorage) ReadData( + runtime cre.Runtime, + args ReadDataInput, + blockNumber *big.Int, +) cre.Promise[string] { + calldata, err := c.Codec.EncodeReadDataMethodCall(args) + if err != nil { + return cre.PromiseFromResult[string](*new(string), err) + } + + var bn cre.Promise[*pb.BigInt] + if blockNumber == nil { + promise := c.client.HeaderByNumber(runtime, &evm.HeaderByNumberRequest{ + BlockNumber: bindings.FinalizedBlockNumber, + }) + + bn = cre.Then(promise, func(finalizedBlock *evm.HeaderByNumberReply) (*pb.BigInt, error) { + if finalizedBlock == nil || finalizedBlock.Header == nil { + return nil, errors.New("failed to get finalized block header") + } + return finalizedBlock.Header.BlockNumber, nil + }) + } else { + bn = cre.PromiseFromResult(pb.NewBigIntFromInt(blockNumber), nil) + } + + promise := cre.ThenPromise(bn, func(bn *pb.BigInt) cre.Promise[*evm.CallContractReply] { + return c.client.CallContract(runtime, &evm.CallContractRequest{ + Call: &evm.CallMsg{To: c.Address.Bytes(), Data: calldata}, + BlockNumber: bn, + }) + }) + return cre.Then(promise, func(response *evm.CallContractReply) (string, error) { + return c.Codec.DecodeReadDataMethodOutput(response.Data) + }) + +} + +func (c DataStorage) WriteReportFromUpdateReserves( + runtime cre.Runtime, + input UpdateReserves, + gasConfig *evm.GasConfig, +) cre.Promise[*evm.WriteReportReply] { + encoded, err := c.Codec.EncodeUpdateReservesStruct(input) + if err != nil { + return cre.PromiseFromResult[*evm.WriteReportReply](nil, err) + } + promise := runtime.GenerateReport(&pb2.ReportRequest{ + EncodedPayload: encoded, + EncoderName: "evm", + SigningAlgo: "ecdsa", + HashingAlgo: "keccak256", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*evm.WriteReportReply] { + return c.client.WriteReport(runtime, &evm.WriteCreReportRequest{ + Receiver: c.Address.Bytes(), + Report: report, + GasConfig: gasConfig, + }) + }) +} + +func (c DataStorage) WriteReportFromUserData( + runtime cre.Runtime, + input UserData, + gasConfig *evm.GasConfig, +) cre.Promise[*evm.WriteReportReply] { + encoded, err := c.Codec.EncodeUserDataStruct(input) + if err != nil { + return cre.PromiseFromResult[*evm.WriteReportReply](nil, err) + } + promise := runtime.GenerateReport(&pb2.ReportRequest{ + EncodedPayload: encoded, + EncoderName: "evm", + SigningAlgo: "ecdsa", + HashingAlgo: "keccak256", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*evm.WriteReportReply] { + return c.client.WriteReport(runtime, &evm.WriteCreReportRequest{ + Receiver: c.Address.Bytes(), + Report: report, + GasConfig: gasConfig, + }) + }) +} + +func (c DataStorage) WriteReport( + runtime cre.Runtime, + report *cre.Report, + gasConfig *evm.GasConfig, +) cre.Promise[*evm.WriteReportReply] { + return c.client.WriteReport(runtime, &evm.WriteCreReportRequest{ + Receiver: c.Address.Bytes(), + Report: report, + GasConfig: gasConfig, + }) +} + +// DecodeDataNotFoundError decodes a DataNotFound error from revert data. +func (c *DataStorage) DecodeDataNotFoundError(data []byte) (*DataNotFound, error) { + args := c.ABI.Errors["DataNotFound"].Inputs + values, err := args.Unpack(data[4:]) + if err != nil { + return nil, fmt.Errorf("failed to unpack error: %w", err) + } + if len(values) != 3 { + return nil, fmt.Errorf("expected 3 values, got %d", len(values)) + } + + requester, ok0 := values[0].(common.Address) + if !ok0 { + return nil, fmt.Errorf("unexpected type for requester in DataNotFound error") + } + + key, ok1 := values[1].(string) + if !ok1 { + return nil, fmt.Errorf("unexpected type for key in DataNotFound error") + } + + reason, ok2 := values[2].(string) + if !ok2 { + return nil, fmt.Errorf("unexpected type for reason in DataNotFound error") + } + + return &DataNotFound{ + Requester: requester, + Key: key, + Reason: reason, + }, nil +} + +// Error implements the error interface for DataNotFound. +func (e *DataNotFound) Error() string { + return fmt.Sprintf("DataNotFound error: requester=%v; key=%v; reason=%v;", e.Requester, e.Key, e.Reason) +} + +// DecodeDataNotFound2Error decodes a DataNotFound2 error from revert data. +func (c *DataStorage) DecodeDataNotFound2Error(data []byte) (*DataNotFound2, error) { + args := c.ABI.Errors["DataNotFound2"].Inputs + values, err := args.Unpack(data[4:]) + if err != nil { + return nil, fmt.Errorf("failed to unpack error: %w", err) + } + if len(values) != 3 { + return nil, fmt.Errorf("expected 3 values, got %d", len(values)) + } + + requester, ok0 := values[0].(common.Address) + if !ok0 { + return nil, fmt.Errorf("unexpected type for requester in DataNotFound2 error") + } + + key, ok1 := values[1].(string) + if !ok1 { + return nil, fmt.Errorf("unexpected type for key in DataNotFound2 error") + } + + reason, ok2 := values[2].(string) + if !ok2 { + return nil, fmt.Errorf("unexpected type for reason in DataNotFound2 error") + } + + return &DataNotFound2{ + Requester: requester, + Key: key, + Reason: reason, + }, nil +} + +// Error implements the error interface for DataNotFound2. +func (e *DataNotFound2) Error() string { + return fmt.Sprintf("DataNotFound2 error: requester=%v; key=%v; reason=%v;", e.Requester, e.Key, e.Reason) +} + +func (c *DataStorage) UnpackError(data []byte) (any, error) { + switch common.Bytes2Hex(data[:4]) { + case common.Bytes2Hex(c.ABI.Errors["DataNotFound"].ID.Bytes()[:4]): + return c.DecodeDataNotFoundError(data) + case common.Bytes2Hex(c.ABI.Errors["DataNotFound2"].ID.Bytes()[:4]): + return c.DecodeDataNotFound2Error(data) + default: + return nil, errors.New("unknown error selector") + } +} + +// AccessLoggedTrigger wraps the raw log trigger and provides decoded AccessLoggedDecoded data +type AccessLoggedTrigger struct { + cre.Trigger[*evm.Log, *evm.Log] // Embed the raw trigger + contract *DataStorage // Keep reference for decoding +} + +// Adapt method that decodes the log into AccessLogged data +func (t *AccessLoggedTrigger) Adapt(l *evm.Log) (*bindings.DecodedLog[AccessLoggedDecoded], error) { + // Decode the log using the contract's codec + decoded, err := t.contract.Codec.DecodeAccessLogged(l) + if err != nil { + return nil, fmt.Errorf("failed to decode AccessLogged log: %w", err) + } + + return &bindings.DecodedLog[AccessLoggedDecoded]{ + Log: l, // Original log + Data: *decoded, // Decoded data + }, nil +} + +func (c *DataStorage) LogTriggerAccessLoggedLog(chainSelector uint64, confidence evm.ConfidenceLevel, filters []AccessLoggedTopics) (cre.Trigger[*evm.Log, *bindings.DecodedLog[AccessLoggedDecoded]], error) { + event := c.ABI.Events["AccessLogged"] + topics, err := c.Codec.EncodeAccessLoggedTopics(event, filters) + if err != nil { + return nil, fmt.Errorf("failed to encode topics for AccessLogged: %w", err) + } + + rawTrigger := evm.LogTrigger(chainSelector, &evm.FilterLogTriggerRequest{ + Addresses: [][]byte{c.Address.Bytes()}, + Topics: topics, + Confidence: confidence, + }) + + return &AccessLoggedTrigger{ + Trigger: rawTrigger, + contract: c, + }, nil +} + +func (c *DataStorage) FilterLogsAccessLogged(runtime cre.Runtime, options *bindings.FilterOptions) (cre.Promise[*evm.FilterLogsReply], error) { + if options == nil { + return nil, errors.New("FilterLogs options are required.") + } + return c.client.FilterLogs(runtime, &evm.FilterLogsRequest{ + FilterQuery: &evm.FilterQuery{ + Addresses: [][]byte{c.Address.Bytes()}, + Topics: []*evm.Topics{ + {Topic: [][]byte{c.Codec.AccessLoggedLogHash()}}, + }, + BlockHash: options.BlockHash, + FromBlock: pb.NewBigIntFromInt(options.FromBlock), + ToBlock: pb.NewBigIntFromInt(options.ToBlock), + }, + }), nil +} + +// DataStoredTrigger wraps the raw log trigger and provides decoded DataStoredDecoded data +type DataStoredTrigger struct { + cre.Trigger[*evm.Log, *evm.Log] // Embed the raw trigger + contract *DataStorage // Keep reference for decoding +} + +// Adapt method that decodes the log into DataStored data +func (t *DataStoredTrigger) Adapt(l *evm.Log) (*bindings.DecodedLog[DataStoredDecoded], error) { + // Decode the log using the contract's codec + decoded, err := t.contract.Codec.DecodeDataStored(l) + if err != nil { + return nil, fmt.Errorf("failed to decode DataStored log: %w", err) + } + + return &bindings.DecodedLog[DataStoredDecoded]{ + Log: l, // Original log + Data: *decoded, // Decoded data + }, nil +} + +func (c *DataStorage) LogTriggerDataStoredLog(chainSelector uint64, confidence evm.ConfidenceLevel, filters []DataStoredTopics) (cre.Trigger[*evm.Log, *bindings.DecodedLog[DataStoredDecoded]], error) { + event := c.ABI.Events["DataStored"] + topics, err := c.Codec.EncodeDataStoredTopics(event, filters) + if err != nil { + return nil, fmt.Errorf("failed to encode topics for DataStored: %w", err) + } + + rawTrigger := evm.LogTrigger(chainSelector, &evm.FilterLogTriggerRequest{ + Addresses: [][]byte{c.Address.Bytes()}, + Topics: topics, + Confidence: confidence, + }) + + return &DataStoredTrigger{ + Trigger: rawTrigger, + contract: c, + }, nil +} + +func (c *DataStorage) FilterLogsDataStored(runtime cre.Runtime, options *bindings.FilterOptions) (cre.Promise[*evm.FilterLogsReply], error) { + if options == nil { + return nil, errors.New("FilterLogs options are required.") + } + return c.client.FilterLogs(runtime, &evm.FilterLogsRequest{ + FilterQuery: &evm.FilterQuery{ + Addresses: [][]byte{c.Address.Bytes()}, + Topics: []*evm.Topics{ + {Topic: [][]byte{c.Codec.DataStoredLogHash()}}, + }, + BlockHash: options.BlockHash, + FromBlock: pb.NewBigIntFromInt(options.FromBlock), + ToBlock: pb.NewBigIntFromInt(options.ToBlock), + }, + }), nil +} + +// DynamicEventTrigger wraps the raw log trigger and provides decoded DynamicEventDecoded data +type DynamicEventTrigger struct { + cre.Trigger[*evm.Log, *evm.Log] // Embed the raw trigger + contract *DataStorage // Keep reference for decoding +} + +// Adapt method that decodes the log into DynamicEvent data +func (t *DynamicEventTrigger) Adapt(l *evm.Log) (*bindings.DecodedLog[DynamicEventDecoded], error) { + // Decode the log using the contract's codec + decoded, err := t.contract.Codec.DecodeDynamicEvent(l) + if err != nil { + return nil, fmt.Errorf("failed to decode DynamicEvent log: %w", err) + } + + return &bindings.DecodedLog[DynamicEventDecoded]{ + Log: l, // Original log + Data: *decoded, // Decoded data + }, nil +} + +func (c *DataStorage) LogTriggerDynamicEventLog(chainSelector uint64, confidence evm.ConfidenceLevel, filters []DynamicEventTopics) (cre.Trigger[*evm.Log, *bindings.DecodedLog[DynamicEventDecoded]], error) { + event := c.ABI.Events["DynamicEvent"] + topics, err := c.Codec.EncodeDynamicEventTopics(event, filters) + if err != nil { + return nil, fmt.Errorf("failed to encode topics for DynamicEvent: %w", err) + } + + rawTrigger := evm.LogTrigger(chainSelector, &evm.FilterLogTriggerRequest{ + Addresses: [][]byte{c.Address.Bytes()}, + Topics: topics, + Confidence: confidence, + }) + + return &DynamicEventTrigger{ + Trigger: rawTrigger, + contract: c, + }, nil +} + +func (c *DataStorage) FilterLogsDynamicEvent(runtime cre.Runtime, options *bindings.FilterOptions) (cre.Promise[*evm.FilterLogsReply], error) { + if options == nil { + return nil, errors.New("FilterLogs options are required.") + } + return c.client.FilterLogs(runtime, &evm.FilterLogsRequest{ + FilterQuery: &evm.FilterQuery{ + Addresses: [][]byte{c.Address.Bytes()}, + Topics: []*evm.Topics{ + {Topic: [][]byte{c.Codec.DynamicEventLogHash()}}, + }, + BlockHash: options.BlockHash, + FromBlock: pb.NewBigIntFromInt(options.FromBlock), + ToBlock: pb.NewBigIntFromInt(options.ToBlock), + }, + }), nil +} + +// NoFieldsTrigger wraps the raw log trigger and provides decoded NoFieldsDecoded data +type NoFieldsTrigger struct { + cre.Trigger[*evm.Log, *evm.Log] // Embed the raw trigger + contract *DataStorage // Keep reference for decoding +} + +// Adapt method that decodes the log into NoFields data +func (t *NoFieldsTrigger) Adapt(l *evm.Log) (*bindings.DecodedLog[NoFieldsDecoded], error) { + // Decode the log using the contract's codec + decoded, err := t.contract.Codec.DecodeNoFields(l) + if err != nil { + return nil, fmt.Errorf("failed to decode NoFields log: %w", err) + } + + return &bindings.DecodedLog[NoFieldsDecoded]{ + Log: l, // Original log + Data: *decoded, // Decoded data + }, nil +} + +func (c *DataStorage) LogTriggerNoFieldsLog(chainSelector uint64, confidence evm.ConfidenceLevel, filters []NoFieldsTopics) (cre.Trigger[*evm.Log, *bindings.DecodedLog[NoFieldsDecoded]], error) { + event := c.ABI.Events["NoFields"] + topics, err := c.Codec.EncodeNoFieldsTopics(event, filters) + if err != nil { + return nil, fmt.Errorf("failed to encode topics for NoFields: %w", err) + } + + rawTrigger := evm.LogTrigger(chainSelector, &evm.FilterLogTriggerRequest{ + Addresses: [][]byte{c.Address.Bytes()}, + Topics: topics, + Confidence: confidence, + }) + + return &NoFieldsTrigger{ + Trigger: rawTrigger, + contract: c, + }, nil +} + +func (c *DataStorage) FilterLogsNoFields(runtime cre.Runtime, options *bindings.FilterOptions) (cre.Promise[*evm.FilterLogsReply], error) { + if options == nil { + return nil, errors.New("FilterLogs options are required.") + } + return c.client.FilterLogs(runtime, &evm.FilterLogsRequest{ + FilterQuery: &evm.FilterQuery{ + Addresses: [][]byte{c.Address.Bytes()}, + Topics: []*evm.Topics{ + {Topic: [][]byte{c.Codec.NoFieldsLogHash()}}, + }, + BlockHash: options.BlockHash, + FromBlock: pb.NewBigIntFromInt(options.FromBlock), + ToBlock: pb.NewBigIntFromInt(options.ToBlock), + }, + }), nil +} diff --git a/cmd/generate-bindings/evm/testdata/bindingsold_mock.go b/cmd/generate-bindings/evm/testdata/bindingsold_mock.go new file mode 100644 index 00000000..0ed39bcc --- /dev/null +++ b/cmd/generate-bindings/evm/testdata/bindingsold_mock.go @@ -0,0 +1,117 @@ +// Code generated — DO NOT EDIT. + +//go:build !wasip1 + +package bindingsold + +import ( + "errors" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + evmmock "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm/mock" +) + +var ( + _ = errors.New + _ = fmt.Errorf + _ = big.NewInt + _ = common.Big1 +) + +// DataStorageMock is a mock implementation of DataStorage for testing. +type DataStorageMock struct { + GetMultipleReserves func() ([]UpdateReserves, error) + GetReserves func() (UpdateReserves, error) + GetTupleReserves func() (GetTupleReservesOutput, error) + GetValue func() (string, error) + ReadData func(ReadDataInput) (string, error) +} + +// NewDataStorageMock creates a new DataStorageMock for testing. +func NewDataStorageMock(address common.Address, clientMock *evmmock.ClientCapability) *DataStorageMock { + mock := &DataStorageMock{} + + codec, err := NewCodec() + if err != nil { + panic("failed to create codec for mock: " + err.Error()) + } + + abi := codec.(*Codec).abi + _ = abi + + funcMap := map[string]func([]byte) ([]byte, error){ + string(abi.Methods["getMultipleReserves"].ID[:4]): func(payload []byte) ([]byte, error) { + if mock.GetMultipleReserves == nil { + return nil, errors.New("getMultipleReserves method not mocked") + } + result, err := mock.GetMultipleReserves() + if err != nil { + return nil, err + } + return abi.Methods["getMultipleReserves"].Outputs.Pack(result) + }, + string(abi.Methods["getReserves"].ID[:4]): func(payload []byte) ([]byte, error) { + if mock.GetReserves == nil { + return nil, errors.New("getReserves method not mocked") + } + result, err := mock.GetReserves() + if err != nil { + return nil, err + } + return abi.Methods["getReserves"].Outputs.Pack(result) + }, + string(abi.Methods["getTupleReserves"].ID[:4]): func(payload []byte) ([]byte, error) { + if mock.GetTupleReserves == nil { + return nil, errors.New("getTupleReserves method not mocked") + } + result, err := mock.GetTupleReserves() + if err != nil { + return nil, err + } + return abi.Methods["getTupleReserves"].Outputs.Pack( + result.TotalMinted, + result.TotalReserve, + ) + }, + string(abi.Methods["getValue"].ID[:4]): func(payload []byte) ([]byte, error) { + if mock.GetValue == nil { + return nil, errors.New("getValue method not mocked") + } + result, err := mock.GetValue() + if err != nil { + return nil, err + } + return abi.Methods["getValue"].Outputs.Pack(result) + }, + string(abi.Methods["readData"].ID[:4]): func(payload []byte) ([]byte, error) { + if mock.ReadData == nil { + return nil, errors.New("readData method not mocked") + } + inputs := abi.Methods["readData"].Inputs + + values, err := inputs.Unpack(payload) + if err != nil { + return nil, errors.New("Failed to unpack payload") + } + if len(values) != 2 { + return nil, errors.New("expected 2 input values") + } + + args := ReadDataInput{ + User: values[0].(common.Address), + Key: values[1].(string), + } + + result, err := mock.ReadData(args) + if err != nil { + return nil, err + } + return abi.Methods["readData"].Outputs.Pack(result) + }, + } + + evmmock.AddContractMock(address, clientMock, funcMap, nil) + return mock +} diff --git a/cmd/generate-bindings/bindings/testdata/emptybindings/EmptyContract.sol b/cmd/generate-bindings/evm/testdata/emptybindings/EmptyContract.sol similarity index 100% rename from cmd/generate-bindings/bindings/testdata/emptybindings/EmptyContract.sol rename to cmd/generate-bindings/evm/testdata/emptybindings/EmptyContract.sol diff --git a/cmd/generate-bindings/bindings/testdata/emptybindings/EmptyContract_combined.json b/cmd/generate-bindings/evm/testdata/emptybindings/EmptyContract_combined.json similarity index 100% rename from cmd/generate-bindings/bindings/testdata/emptybindings/EmptyContract_combined.json rename to cmd/generate-bindings/evm/testdata/emptybindings/EmptyContract_combined.json diff --git a/cmd/generate-bindings/bindings/testdata/emptybindings/emptybindings.go b/cmd/generate-bindings/evm/testdata/emptybindings/emptybindings.go similarity index 100% rename from cmd/generate-bindings/bindings/testdata/emptybindings/emptybindings.go rename to cmd/generate-bindings/evm/testdata/emptybindings/emptybindings.go diff --git a/cmd/generate-bindings/bindings/testdata/emptybindings/emptybindings_mock.go b/cmd/generate-bindings/evm/testdata/emptybindings/emptybindings_mock.go similarity index 100% rename from cmd/generate-bindings/bindings/testdata/emptybindings/emptybindings_mock.go rename to cmd/generate-bindings/evm/testdata/emptybindings/emptybindings_mock.go diff --git a/cmd/generate-bindings/bindings/testdata/gen/main.go b/cmd/generate-bindings/evm/testdata/gen/main.go similarity index 70% rename from cmd/generate-bindings/bindings/testdata/gen/main.go rename to cmd/generate-bindings/evm/testdata/gen/main.go index 2eda5a71..44836dd4 100644 --- a/cmd/generate-bindings/bindings/testdata/gen/main.go +++ b/cmd/generate-bindings/evm/testdata/gen/main.go @@ -1,11 +1,11 @@ package main import ( - "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" ) func main() { - if err := bindings.GenerateBindings( + if err := evm.GenerateBindings( "./testdata/DataStorage_combined.json", "", "bindings", @@ -15,7 +15,7 @@ func main() { panic(err) } - if err := bindings.GenerateBindings( + if err := evm.GenerateBindings( "./testdata/emptybindings/EmptyContract_combined.json", "", "emptybindings", diff --git a/cmd/generate-bindings/generate-bindings.go b/cmd/generate-bindings/generate-bindings.go index 7da55c94..2e934940 100644 --- a/cmd/generate-bindings/generate-bindings.go +++ b/cmd/generate-bindings/generate-bindings.go @@ -4,28 +4,35 @@ import ( "fmt" "os" "os/exec" - "path/filepath" - "github.com/rs/zerolog" - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "github.com/smartcontractkit/cre-cli/cmd/creinit" - "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings" "github.com/smartcontractkit/cre-cli/internal/runtime" - "github.com/smartcontractkit/cre-cli/internal/validation" + "github.com/spf13/cobra" ) -type Inputs struct { +// runCommand executes a command in a specified directory +func runCommand(dir string, command string, args ...string) error { + cmd := exec.Command(command, args...) + cmd.Dir = dir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to run %s: %w", command, err) + } + + return nil +} + +type EvmInputs struct { ProjectRoot string `validate:"required,dir" cli:"--project-root"` ChainFamily string `validate:"required,oneof=evm" cli:"--chain-family"` Language string `validate:"required,oneof=go" cli:"--language"` - AbiPath string `validate:"required,path_read" cli:"--abi"` - PkgName string `validate:"required" cli:"--pkg"` - OutPath string `validate:"required" cli:"--out"` + // just keeping it simple for now + AbiPath string `validate:"required,path_read" cli:"--abi"` + PkgName string `validate:"required" cli:"--pkg"` + OutPath string `validate:"required" cli:"--out"` } -func New(runtimeContext *runtime.Context) *cobra.Command { +func NewEvmBindings(runtimeContext *runtime.Context) *cobra.Command { generateBindingsCmd := &cobra.Command{ Use: "generate-bindings ", Short: "Generate bindings from contract ABI", @@ -36,17 +43,14 @@ For example, IERC20.abi generates bindings in generated/ierc20/ package.`, Example: " cre generate-bindings evm", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - handler := newHandler(runtimeContext) - - inputs, err := handler.ResolveInputs(args, runtimeContext.Viper) + inputs, err := resolveEvmInputs(args, runtimeContext.Viper) if err != nil { return err } - err = handler.ValidateInputs(inputs) - if err != nil { + if err := validateEvmInputs(inputs); err != nil { return err } - return handler.Execute(inputs) + return executeEvm(inputs) }, } @@ -58,284 +62,40 @@ For example, IERC20.abi generates bindings in generated/ierc20/ package.`, return generateBindingsCmd } -type handler struct { - log *zerolog.Logger - validated bool -} - -func newHandler(ctx *runtime.Context) *handler { - return &handler{ - log: ctx.Logger, - validated: false, - } -} - -func (h *handler) ResolveInputs(args []string, v *viper.Viper) (Inputs, error) { - // Get current working directory as default project root - currentDir, err := os.Getwd() - if err != nil { - return Inputs{}, fmt.Errorf("failed to get current working directory: %w", err) - } - - // Resolve project root with fallback to current directory - projectRoot := v.GetString("project-root") - if projectRoot == "" { - projectRoot = currentDir - } - - contractsPath := filepath.Join(projectRoot, "contracts") - if _, err := os.Stat(contractsPath); err != nil { - return Inputs{}, fmt.Errorf("contracts folder not found in project root: %s", contractsPath) - } - - // Chain family is now a positional argument - chainFamily := args[0] - - // Language defaults are handled by StringP - language := v.GetString("language") - - // Resolve ABI path with fallback to contracts/{chainFamily}/src/abi/ - abiPath := v.GetString("abi") - if abiPath == "" { - abiPath = filepath.Join(projectRoot, "contracts", chainFamily, "src", "abi") - } - - // Package name defaults are handled by StringP - pkgName := v.GetString("pkg") - - // Output path is contracts/{chainFamily}/src/generated/ under projectRoot - outPath := filepath.Join(projectRoot, "contracts", chainFamily, "src", "generated") - - return Inputs{ - ProjectRoot: projectRoot, - ChainFamily: chainFamily, - Language: language, - AbiPath: abiPath, - PkgName: pkgName, - OutPath: outPath, - }, nil -} - -func (h *handler) ValidateInputs(inputs Inputs) error { - validate, err := validation.NewValidator() - if err != nil { - return fmt.Errorf("failed to initialize validator: %w", err) - } - - if err = validate.Struct(inputs); err != nil { - return validate.ParseValidationErrors(err) - } - - // Additional validation for ABI path - if _, err := os.Stat(inputs.AbiPath); err != nil { - if os.IsNotExist(err) { - return fmt.Errorf("ABI path does not exist: %s", inputs.AbiPath) - } - return fmt.Errorf("failed to access ABI path: %w", err) - } - - // Validate that if AbiPath is a directory, it contains .abi files - if info, err := os.Stat(inputs.AbiPath); err == nil && info.IsDir() { - files, err := filepath.Glob(filepath.Join(inputs.AbiPath, "*.abi")) - if err != nil { - return fmt.Errorf("failed to check for ABI files in directory: %w", err) - } - if len(files) == 0 { - return fmt.Errorf("no .abi files found in directory: %s", inputs.AbiPath) - } - } - - h.validated = true - return nil -} - -// contractNameToPackage converts contract names to valid Go package names -// Examples: IERC20 -> ierc20, ReserveManager -> reserve_manager, IReserveManager -> ireserve_manager -func contractNameToPackage(contractName string) string { - if contractName == "" { - return "" - } - - var result []rune - runes := []rune(contractName) - - for i, r := range runes { - // Convert to lowercase - if r >= 'A' && r <= 'Z' { - lower := r - 'A' + 'a' - - // Add underscore before uppercase letters, but not: - // - At the beginning (i == 0) - // - If the previous character was also uppercase and this is followed by lowercase (e.g., "ERC" in "ERC20") - // - If this is part of a sequence of uppercase letters at the beginning (e.g., "IERC20" -> "ierc20") - if i > 0 { - prevIsUpper := runes[i-1] >= 'A' && runes[i-1] <= 'Z' - nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z' - - // Add underscore if: - // - Previous char was lowercase (CamelCase boundary) - // - Previous char was uppercase but this char is followed by lowercase (end of acronym) - if !prevIsUpper || (prevIsUpper && nextIsLower && i > 1) { - result = append(result, '_') - } - } - - result = append(result, lower) - } else { - result = append(result, r) - } - } - - return string(result) -} - -func (h *handler) processAbiDirectory(inputs Inputs) error { - // Read all .abi files in the directory - files, err := filepath.Glob(filepath.Join(inputs.AbiPath, "*.abi")) - if err != nil { - return fmt.Errorf("failed to find ABI files: %w", err) - } - - if len(files) == 0 { - return fmt.Errorf("no .abi files found in directory: %s", inputs.AbiPath) - } - - packageNames := make(map[string]bool) - for _, abiFile := range files { - contractName := filepath.Base(abiFile) - contractName = contractName[:len(contractName)-4] - packageName := contractNameToPackage(contractName) - if _, exists := packageNames[packageName]; exists { - return fmt.Errorf("package name collision: multiple contracts would generate the same package name '%s' (contracts are converted to snake_case for package names). Please rename one of your contract files to avoid this conflict", packageName) - } - packageNames[packageName] = true - } - - // Process each ABI file - for _, abiFile := range files { - // Extract contract name from filename (remove .abi extension) - contractName := filepath.Base(abiFile) - contractName = contractName[:len(contractName)-4] // Remove .abi extension - - // Convert contract name to package name - packageName := contractNameToPackage(contractName) - - // Create per-contract output directory - contractOutDir := filepath.Join(inputs.OutPath, packageName) - if err := os.MkdirAll(contractOutDir, 0o755); err != nil { - return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) - } - - // Create output file path in contract-specific directory - outputFile := filepath.Join(contractOutDir, contractName+".go") - - fmt.Printf("Processing ABI file: %s, contract: %s, package: %s, output: %s\n", abiFile, contractName, packageName, outputFile) - - err = bindings.GenerateBindings( - "", // combinedJSONPath - empty for now - abiFile, - packageName, // Use contract-specific package name - contractName, // Use contract name as type name - outputFile, - ) - if err != nil { - return fmt.Errorf("failed to generate bindings for %s: %w", contractName, err) - } - } - - return nil -} - -func (h *handler) processSingleAbi(inputs Inputs) error { - // Extract contract name from ABI file path - contractName := filepath.Base(inputs.AbiPath) - if filepath.Ext(contractName) == ".abi" { - contractName = contractName[:len(contractName)-4] // Remove .abi extension - } - - // Convert contract name to package name - packageName := contractNameToPackage(contractName) - - // Create per-contract output directory - contractOutDir := filepath.Join(inputs.OutPath, packageName) - if err := os.MkdirAll(contractOutDir, 0o755); err != nil { - return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) - } - - // Create output file path in contract-specific directory - outputFile := filepath.Join(contractOutDir, contractName+".go") - - fmt.Printf("Processing single ABI file: %s, contract: %s, package: %s, output: %s\n", inputs.AbiPath, contractName, packageName, outputFile) - - return bindings.GenerateBindings( - "", // combinedJSONPath - empty for now - inputs.AbiPath, - packageName, // Use contract-specific package name - contractName, // Use contract name as type name - outputFile, - ) +type SolanaInputs struct { + ProjectRoot string `validate:"required,dir" cli:"--project-root"` + Language string `validate:"required,oneof=go" cli:"--language"` + // just keeping it simple for now + IdlPath string `validate:"required,path_read" cli:"--idl"` + // PkgName string `validate:"required" cli:"--pkg"` + OutPath string `validate:"required" cli:"--out"` } -func (h *handler) Execute(inputs Inputs) error { - fmt.Printf("GenerateBindings would be called here: projectRoot=%s, chainFamily=%s, language=%s, abiPath=%s, pkgName=%s, outPath=%s\n", inputs.ProjectRoot, inputs.ChainFamily, inputs.Language, inputs.AbiPath, inputs.PkgName, inputs.OutPath) - - // Validate language - switch inputs.Language { - case "go": - // Language supported, continue - default: - return fmt.Errorf("unsupported language: %s", inputs.Language) - } - - // Validate chain family and handle accordingly - switch inputs.ChainFamily { - case "evm": - // Create output directory if it doesn't exist - if err := os.MkdirAll(inputs.OutPath, 0o755); err != nil { - return fmt.Errorf("failed to create output directory: %w", err) - } - - // Check if ABI path is a directory or file - info, err := os.Stat(inputs.AbiPath) - if err != nil { - return fmt.Errorf("failed to access ABI path: %w", err) - } - - if info.IsDir() { - if err := h.processAbiDirectory(inputs); err != nil { +func NewSolanaBindings(runtimeContext *runtime.Context) *cobra.Command { + var generateBindingsCmd = &cobra.Command{ + Use: "generate-bindings-solana", + Short: "Generate bindings from contract IDL", + Long: `This command generates bindings from contract IDL files. +Supports Solana chain family and Go language. +Each contract gets its own package subdirectory to avoid naming conflicts. +For example, data_storage.json generates bindings in generated/data_storage/ package.`, + Example: " cre generate-bindings-solana", + RunE: func(cmd *cobra.Command, args []string) error { + inputs, err := resolveSolanaInputs(args, runtimeContext.Viper) + if err != nil { return err } - } else { - if err := h.processSingleAbi(inputs); err != nil { + if err := validateSolanaInputs(inputs); err != nil { return err } - } - - err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go@"+creinit.SdkVersion) - if err != nil { - return err - } - err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm@"+creinit.EVMCapabilitiesVersion) - if err != nil { - return err - } - if err = runCommand(inputs.ProjectRoot, "go", "mod", "tidy"); err != nil { - return err - } - return nil - default: - return fmt.Errorf("unsupported chain family: %s", inputs.ChainFamily) + return executeSolana(inputs) + }, } -} -func runCommand(dir string, command string, args ...string) error { - cmd := exec.Command(command, args...) - cmd.Dir = dir - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to run %s: %w", command, err) - } + generateBindingsCmd.Flags().StringP("project-root", "p", "", "Path to project root directory (defaults to current directory)") + generateBindingsCmd.Flags().StringP("language", "l", "go", "Target language (go)") + generateBindingsCmd.Flags().StringP("abi", "a", "", "Path to ABI directory (defaults to contracts/{chain-family}/src/abi/)") + generateBindingsCmd.Flags().StringP("pkg", "k", "bindings", "Base package name (each contract gets its own subdirectory)") - return nil + return generateBindingsCmd } diff --git a/cmd/generate-bindings/generate-bindings_test.go b/cmd/generate-bindings/generate-bindings_test.go index 140df93c..2f1c762b 100644 --- a/cmd/generate-bindings/generate-bindings_test.go +++ b/cmd/generate-bindings/generate-bindings_test.go @@ -6,13 +6,11 @@ import ( "path/filepath" "testing" - "github.com/rs/zerolog" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings" - "github.com/smartcontractkit/cre-cli/internal/runtime" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" ) func TestContractNameToPackage(t *testing.T) { @@ -42,7 +40,7 @@ func TestContractNameToPackage(t *testing.T) { } } -func TestResolveInputs_DefaultFallbacks(t *testing.T) { +func TestResolveEvmInputs_DefaultFallbacks(t *testing.T) { // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) @@ -69,15 +67,12 @@ func TestResolveInputs_DefaultFallbacks(t *testing.T) { err = os.Chdir(tempDir) require.NoError(t, err) - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - // Test with minimal input (only chain-family) v := viper.New() v.Set("language", "go") // Default from StringP v.Set("pkg", "bindings") // Default from StringP - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := resolveEvmInputs([]string{"evm"}, v) require.NoError(t, err) // Use filepath.EvalSymlinks to handle macOS /var vs /private/var symlink issues @@ -96,22 +91,19 @@ func TestResolveInputs_DefaultFallbacks(t *testing.T) { } // command should run in projectRoot which contains contracts directory -func TestResolveInputs_CustomProjectRoot(t *testing.T) { +func TestResolveEvmInputs_CustomProjectRoot(t *testing.T) { // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - // Test with custom project root v := viper.New() v.Set("project-root", tempDir) v.Set("language", "go") // Default from StringP v.Set("pkg", "bindings") // Default from StringP - _, err = handler.ResolveInputs([]string{"evm"}, v) + _, err = resolveEvmInputs([]string{"evm"}, v) require.Error(t, err) expectedErrMsg := fmt.Sprintf("contracts folder not found in project root: %s", tempDir) @@ -119,7 +111,7 @@ func TestResolveInputs_CustomProjectRoot(t *testing.T) { } // Empty project root should default to current directory, and this should contain contracts and go.mod -func TestResolveInputs_EmptyProjectRoot(t *testing.T) { +func TestResolveEvmInputs_EmptyProjectRoot(t *testing.T) { // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) @@ -146,16 +138,13 @@ func TestResolveInputs_EmptyProjectRoot(t *testing.T) { err = os.Chdir(tempDir) require.NoError(t, err) - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - // Test with empty project root (should use current directory) v := viper.New() v.Set("project-root", "") v.Set("language", "go") // Default from StringP v.Set("pkg", "bindings") // Default from StringP - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := resolveEvmInputs([]string{"evm"}, v) require.NoError(t, err) // Use filepath.EvalSymlinks to handle macOS /var vs /private/var symlink issues @@ -173,12 +162,9 @@ func TestResolveInputs_EmptyProjectRoot(t *testing.T) { assert.Equal(t, expectedOut, actualOut) } -func TestValidateInputs_RequiredChainFamily(t *testing.T) { - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - +func TestValidateEvmInputs_RequiredChainFamily(t *testing.T) { // Test validation with missing chain family - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: "/tmp", ChainFamily: "", // Missing required field Language: "go", @@ -187,12 +173,12 @@ func TestValidateInputs_RequiredChainFamily(t *testing.T) { OutPath: "/tmp/out", } - err := handler.ValidateInputs(inputs) + err := validateEvmInputs(inputs) require.Error(t, err) assert.Contains(t, err.Error(), "chain-family") } -func TestValidateInputs_ValidInputs(t *testing.T) { +func TestValidateEvmInputs_ValidEvmInputs(t *testing.T) { // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) @@ -204,11 +190,8 @@ func TestValidateInputs_ValidInputs(t *testing.T) { err = os.WriteFile(abiFile, []byte(abiContent), 0600) require.NoError(t, err) - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - // Test validation with valid inputs (using single file) - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: tempDir, ChainFamily: "evm", Language: "go", @@ -217,9 +200,8 @@ func TestValidateInputs_ValidInputs(t *testing.T) { OutPath: tempDir, } - err = handler.ValidateInputs(inputs) + err = validateEvmInputs(inputs) require.NoError(t, err) - assert.True(t, handler.validated) // Test validation with directory containing .abi files abiDir := filepath.Join(tempDir, "abi") @@ -229,22 +211,18 @@ func TestValidateInputs_ValidInputs(t *testing.T) { require.NoError(t, err) inputs.AbiPath = abiDir - err = handler.ValidateInputs(inputs) + err = validateEvmInputs(inputs) require.NoError(t, err) - assert.True(t, handler.validated) } -func TestValidateInputs_InvalidChainFamily(t *testing.T) { +func TestValidateEvmInputs_InvalidChainFamily(t *testing.T) { // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - // Test validation with invalid chain family - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: tempDir, ChainFamily: "solana", // No longer supported Language: "go", @@ -253,22 +231,19 @@ func TestValidateInputs_InvalidChainFamily(t *testing.T) { OutPath: tempDir, } - err = handler.ValidateInputs(inputs) + err = validateEvmInputs(inputs) require.Error(t, err) assert.Contains(t, err.Error(), "chain-family") } -func TestValidateInputs_InvalidLanguage(t *testing.T) { +func TestValidateEvmInputs_InvalidLanguage(t *testing.T) { // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - // Test validation with invalid language - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: tempDir, ChainFamily: "evm", Language: "typescript", // No longer supported @@ -277,17 +252,14 @@ func TestValidateInputs_InvalidLanguage(t *testing.T) { OutPath: tempDir, } - err = handler.ValidateInputs(inputs) + err = validateEvmInputs(inputs) require.Error(t, err) assert.Contains(t, err.Error(), "language") } -func TestValidateInputs_NonExistentDirectory(t *testing.T) { - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - +func TestValidateEvmInputs_NonExistentDirectory(t *testing.T) { // Test validation with non-existent directory - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: "/non/existent/path", ChainFamily: "evm", Language: "go", @@ -296,7 +268,7 @@ func TestValidateInputs_NonExistentDirectory(t *testing.T) { OutPath: "/non/existent/out", } - err := handler.ValidateInputs(inputs) + err := validateEvmInputs(inputs) require.Error(t, err) assert.Contains(t, err.Error(), "project-root") } @@ -320,14 +292,7 @@ func TestProcessAbiDirectory_MultipleFiles(t *testing.T) { err = os.WriteFile(filepath.Join(abiDir, "Contract2.abi"), []byte(abiContent), 0600) require.NoError(t, err) - // Create a mock logger to prevent nil pointer dereference - logger := zerolog.New(os.Stderr).With().Timestamp().Logger() - runtimeCtx := &runtime.Context{ - Logger: &logger, - } - handler := newHandler(runtimeCtx) - - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: tempDir, ChainFamily: "evm", Language: "go", @@ -338,7 +303,7 @@ func TestProcessAbiDirectory_MultipleFiles(t *testing.T) { // This test will fail because it tries to call the actual bindings.GenerateBindings // but it tests the directory processing logic - err = handler.processAbiDirectory(inputs) + err = processEvmAbiDirectory(inputs) // We expect an error because the bindings package requires actual ABI format // but we can check that it created the expected directory structure if err == nil { @@ -382,14 +347,7 @@ func TestProcessAbiDirectory_CreatesPerContractDirectories(t *testing.T) { require.NoError(t, err) } - // Create a mock logger - logger := zerolog.New(os.Stderr).With().Timestamp().Logger() - runtimeCtx := &runtime.Context{ - Logger: &logger, - } - handler := newHandler(runtimeCtx) - - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: tempDir, ChainFamily: "evm", Language: "go", @@ -399,7 +357,7 @@ func TestProcessAbiDirectory_CreatesPerContractDirectories(t *testing.T) { } // Try to process - the mock ABI content might actually work - err = handler.processAbiDirectory(inputs) + err = processEvmAbiDirectory(inputs) if err != nil { t.Logf("Expected error occurred: %v", err) } @@ -423,13 +381,7 @@ func TestProcessAbiDirectory_NoAbiFiles(t *testing.T) { err = os.MkdirAll(abiDir, 0755) require.NoError(t, err) - logger := zerolog.New(os.Stderr).With().Timestamp().Logger() - runtimeCtx := &runtime.Context{ - Logger: &logger, - } - handler := newHandler(runtimeCtx) - - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: tempDir, ChainFamily: "evm", Language: "go", @@ -438,7 +390,7 @@ func TestProcessAbiDirectory_NoAbiFiles(t *testing.T) { OutPath: outDir, } - err = handler.processAbiDirectory(inputs) + err = processEvmAbiDirectory(inputs) require.Error(t, err) assert.Contains(t, err.Error(), "no .abi files found") } @@ -463,13 +415,7 @@ func TestProcessAbiDirectory_PackageNameCollision(t *testing.T) { err = os.WriteFile(filepath.Join(abiDir, "test_contract.abi"), []byte(abiContent), 0600) require.NoError(t, err) - logger := zerolog.New(os.Stderr).With().Timestamp().Logger() - runtimeCtx := &runtime.Context{ - Logger: &logger, - } - handler := newHandler(runtimeCtx) - - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: tempDir, ChainFamily: "evm", Language: "go", @@ -478,20 +424,14 @@ func TestProcessAbiDirectory_PackageNameCollision(t *testing.T) { OutPath: outDir, } - err = handler.processAbiDirectory(inputs) + err = processEvmAbiDirectory(inputs) fmt.Println(err.Error()) require.Error(t, err) require.Equal(t, err.Error(), "package name collision: multiple contracts would generate the same package name 'test_contract' (contracts are converted to snake_case for package names). Please rename one of your contract files to avoid this conflict") } func TestProcessAbiDirectory_NonExistentDirectory(t *testing.T) { - logger := zerolog.New(os.Stderr).With().Timestamp().Logger() - runtimeCtx := &runtime.Context{ - Logger: &logger, - } - handler := newHandler(runtimeCtx) - - inputs := Inputs{ + inputs := EvmInputs{ ProjectRoot: "/tmp", ChainFamily: "evm", Language: "go", @@ -500,7 +440,7 @@ func TestProcessAbiDirectory_NonExistentDirectory(t *testing.T) { OutPath: "/tmp/out", } - err := handler.processAbiDirectory(inputs) + err := processEvmAbiDirectory(inputs) require.Error(t, err) // For non-existent directory, filepath.Glob returns empty slice, so we get the "no .abi files found" error assert.Contains(t, err.Error(), "no .abi files found") @@ -616,7 +556,7 @@ func TestGenerateBindings_UnconventionalNaming(t *testing.T) { require.NoError(t, err) outFile := filepath.Join(tempDir, "bindings.go") - err = bindings.GenerateBindings("", abiFile, tc.pkgName, tc.typeName, outFile) + err = evm.GenerateBindings("", abiFile, tc.pkgName, tc.typeName, outFile) if tc.shouldFail { require.Error(t, err, "Expected binding generation to fail for %s", tc.name) diff --git a/cmd/generate-bindings/solana.go b/cmd/generate-bindings/solana.go new file mode 100644 index 00000000..408d2f9f --- /dev/null +++ b/cmd/generate-bindings/solana.go @@ -0,0 +1,191 @@ +package generatebindings + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/spf13/viper" + + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana" + "github.com/smartcontractkit/cre-cli/internal/validation" +) + +func resolveSolanaInputs(args []string, v *viper.Viper) (SolanaInputs, error) { + // Get current working directory as default project root + currentDir, err := os.Getwd() + if err != nil { + return SolanaInputs{}, fmt.Errorf("failed to get current working directory: %w", err) + } + + // Resolve project root with fallback to current directory + projectRoot := v.GetString("project-root") + if projectRoot == "" { + projectRoot = currentDir + } + + contractsPath := filepath.Join(projectRoot, "contracts") + if _, err := os.Stat(contractsPath); err != nil { + return SolanaInputs{}, fmt.Errorf("contracts folder not found in project root: %s", contractsPath) + } + + // Language defaults are handled by StringP + language := v.GetString("language") + + // Resolve ABI path with fallback to contracts/{chainFamily}/src/abi/ + idlPath := v.GetString("idl") + if idlPath == "" { + idlPath = filepath.Join(projectRoot, "contracts", "solana", "src", "idl") + } + + // Output path is contracts/{chainFamily}/src/generated/ under projectRoot + outPath := filepath.Join(projectRoot, "contracts", "solana", "src", "generated") + + return SolanaInputs{ + ProjectRoot: projectRoot, + Language: language, + IdlPath: idlPath, + OutPath: outPath, + }, nil +} + +func validateSolanaInputs(inputs SolanaInputs) error { + validate, err := validation.NewValidator() + if err != nil { + return fmt.Errorf("failed to initialize validator: %w", err) + } + + if err = validate.Struct(inputs); err != nil { + return validate.ParseValidationErrors(err) + } + + // Additional validation for Idl path + if _, err := os.Stat(inputs.IdlPath); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("IDL path does not exist: %s", inputs.IdlPath) + } + return fmt.Errorf("failed to access IDL path: %w", err) + } + + // Validate that if IdlPath is a directory, it contains .json files + if info, err := os.Stat(inputs.IdlPath); err == nil && info.IsDir() { + files, err := filepath.Glob(filepath.Join(inputs.IdlPath, "*.json")) + if err != nil { + return fmt.Errorf("failed to check for ABI files in directory: %w", err) + } + if len(files) == 0 { + return fmt.Errorf("no .json files found in directory: %s", inputs.IdlPath) + } + } + + return nil +} + +func processSolanaIdlDirectory(inputs SolanaInputs) error { + // Read all .json files in the directory + files, err := filepath.Glob(filepath.Join(inputs.IdlPath, "*.json")) + if err != nil { + return fmt.Errorf("failed to find IDL files: %w", err) + } + + if len(files) == 0 { + return fmt.Errorf("no .json files found in directory: %s", inputs.IdlPath) + } + + // Process each IDL file + for _, idlFile := range files { + // Extract contract name from filename (remove .json extension) + contractName := filepath.Base(idlFile) + contractName = contractName[:len(contractName)-5] // Remove .json extension + + // Create per-contract output directory + contractOutDir := filepath.Join(inputs.OutPath, contractName) + if err := os.MkdirAll(contractOutDir, 0755); err != nil { + return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) + } + + // Create output file path in contract-specific directory + outputFile := filepath.Join(contractOutDir, contractName+".go") + + fmt.Printf("Processing IDL file: %s, contract: %s, output: %s\n", idlFile, contractName, outputFile) + + err = solana.GenerateBindings( + idlFile, + contractName, + contractOutDir, + ) + if err != nil { + return fmt.Errorf("failed to generate bindings for %s: %w", idlFile, err) + } + } + + return nil +} + +func processSolanaSingleIdl(inputs SolanaInputs) error { + // Extract contract name from IDL file path + contractName := filepath.Base(inputs.IdlPath) + if filepath.Ext(contractName) == ".json" { + contractName = contractName[:len(contractName)-4] // Remove .json extension + } + + // Create per-contract output directory + contractOutDir := filepath.Join(inputs.OutPath, contractName) + if err := os.MkdirAll(contractOutDir, 0755); err != nil { + return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) + } + + fmt.Printf("Processing single IDL file: %s, contract: %s, output: %s\n", inputs.IdlPath, contractName, contractOutDir) + + return solana.GenerateBindings( + inputs.IdlPath, + contractName, + contractOutDir, + ) +} + +func executeSolana(inputs SolanaInputs) error { + // Validate language + switch inputs.Language { + case "go": + // Language supported, continue + default: + return fmt.Errorf("unsupported language: %s", inputs.Language) + } + + // Create output directory if it doesn't exist + if err := os.MkdirAll(inputs.OutPath, 0755); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + // Check if IDL path is a directory or file + info, err := os.Stat(inputs.IdlPath) + if err != nil { + return fmt.Errorf("failed to access IDL path: %w", err) + } + + if info.IsDir() { + if err := processSolanaIdlDirectory(inputs); err != nil { + return err + } + } else { + if err := processSolanaSingleIdl(inputs); err != nil { + return err + } + } + + // TODO: Add Solana-specific SDK dependencies when available + // err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go@"+creinit.SdkVersion) + // if err != nil { + // return err + // } + // err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana@"+creinit.SolanaCapabilitiesVersion) + // if err != nil { + // return err + // } + // if err = runCommand(inputs.ProjectRoot, "go", "mod", "tidy"); err != nil { + // return err + // } + + return nil +} diff --git a/cmd/generate-bindings/bindings/README.md b/cmd/generate-bindings/solana/README.md similarity index 100% rename from cmd/generate-bindings/bindings/README.md rename to cmd/generate-bindings/solana/README.md diff --git a/cmd/generate-bindings/solana/anchor-go/anchor_test.go b/cmd/generate-bindings/solana/anchor-go/anchor_test.go new file mode 100644 index 00000000..0f7fe914 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/anchor_test.go @@ -0,0 +1,180 @@ +package main_test + +import ( + "fmt" + "go/token" + "log/slog" + "os" + "os/exec" + "path" + "testing" + + "github.com/gagliardetto/anchor-go/generator" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" + bin "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" +) + +const defaultProgramName = "myprogram" + +func TestAnchorGo(t *testing.T) { + + var outputDir = "/Users/yashvardhan/cre-cli/cmd/generate-bindings/solana_bindings/testdata/my_anchor_project" + var programName = "my_project" + var modPath = "" + var pathToIdl = "/Users/yashvardhan/cre-client-program/my-project/target/idl/my_project.json" + var programIDOverride solana.PublicKey + programIDOverride = solana.MustPublicKeyFromBase58("2GvhVcTPPkHbGduj6efNowFoWBQjE77Xab1uBKCYJvNN") + + modPath = path.Join("github.com", "gagliardetto", "anchor-go", "generated") + slog.Info("Using default module path", "modPath", modPath) + if err := os.MkdirAll(outputDir, 0o777); err != nil { + panic(fmt.Errorf("Failed to create output directory: %w", err)) + } + slog.Info("Starting code generation", + "outputDir", outputDir, + "modPath", modPath, + "pathToIdl", pathToIdl, + "programID", func() string { + if programIDOverride.IsZero() { + return "not provided" + } + return programIDOverride.String() + }(), + ) + + options := generator.GeneratorOptions{ + OutputDir: outputDir, + Package: programName, + ProgramName: programName, + ModPath: modPath, + SkipGoMod: true, + } + if !programIDOverride.IsZero() { + options.ProgramId = &programIDOverride + slog.Info("Using provided program ID", "programID", programIDOverride.String()) + } + parsedIdl, err := idl.ParseFromFilepath(pathToIdl) + if err != nil { + panic(err) + } + if parsedIdl == nil { + panic("Parsed IDL is nil, please check the IDL file path and format.") + } + if err := parsedIdl.Validate(); err != nil { + panic(fmt.Errorf("Invalid IDL: %w", err)) + } + { + { + if parsedIdl.Address != nil && !parsedIdl.Address.IsZero() && options.ProgramId == nil { + // If the IDL has an address, use it as the program ID: + slog.Info("Using IDL address as program ID", "address", parsedIdl.Address.String()) + options.ProgramId = parsedIdl.Address + } + } + parsedIdl.Metadata.Name = bin.ToSnakeForSighash(parsedIdl.Metadata.Name) + { + // check that the name is not a reserved keyword: + if parsedIdl.Metadata.Name != "" { + if tools.IsReservedKeyword(parsedIdl.Metadata.Name) { + slog.Warn("The IDL metadata.name is a reserved Go keyword: adding a suffix to avoid conflicts.", + "name", parsedIdl.Metadata.Name, + "reservedKeyword", token.Lookup(parsedIdl.Metadata.Name).String(), + ) + // Add a suffix to the name to avoid conflicts with Go reserved keywords: + parsedIdl.Metadata.Name += "_program" + } + if !tools.IsValidIdent(parsedIdl.Metadata.Name) { + // add a prefix to the name to avoid conflicts with Go reserved keywords: + parsedIdl.Metadata.Name = "my_" + parsedIdl.Metadata.Name + } + } + // if begins with + } + if programName == "" && parsedIdl.Metadata.Name != "" { + panic("Please provide a package name using the -name flag, or ensure the IDL has a valid metadata.name field.") + } + if programName == defaultProgramName && parsedIdl.Metadata.Name != "" { + cleanedName := bin.ToSnakeForSighash(parsedIdl.Metadata.Name) + options.Package = cleanedName + options.ProgramName = cleanedName + slog.Info("Using IDL metadata.name as package name", "packageName", cleanedName) + } + + slog.Info("Parsed IDL successfully", + "version", parsedIdl.Metadata.Version, + "name", parsedIdl.Metadata.Name, + "address", parsedIdl.Address, + "programId", func() string { + if parsedIdl.Address.IsZero() { + return "not provided" + } + return parsedIdl.Address.String() + }(), + "instructionsCount", len(parsedIdl.Instructions), + "accountsCount", len(parsedIdl.Accounts), + "eventsCount", len(parsedIdl.Events), + "typesCount", len(parsedIdl.Types), + "constantsCount", len(parsedIdl.Constants), + "errorsCount", len(parsedIdl.Errors), + ) + } + gen := generator.NewGenerator(parsedIdl, &options) + generatedFiles, err := gen.Generate() + if err != nil { + panic(err) + } + + { + for _, file := range generatedFiles.Files { + { + // Save assets: + assetFilename := file.Name + assetFilepath := path.Join(options.OutputDir, assetFilename) + + // Create file: + goFile, err := os.Create(assetFilepath) + if err != nil { + panic(err) + } + defer goFile.Close() + + slog.Info("Writing file", + "filepath", assetFilepath, + "name", file.Name, + "modPath", options.ModPath, + ) + err = file.File.Render(goFile) + if err != nil { + panic(err) + } + } + } + // executeCmd(outputDir, "go", "mod", "tidy") + // executeCmd(outputDir, "go", "fmt") + // executeCmd(outputDir, "go", "build", "-o", "/dev/null") // Just to ensure everything compiles. + slog.Info("Generation completed successfully", + "outputDir", options.OutputDir, + "modPath", options.ModPath, + "package", options.Package, + "programName", options.ProgramName, + ) + } +} + +func executeCmd(dir string, name string, arg ...string) { + cmd := exec.Command(name, arg...) + cmd.Dir = dir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Run() + if err != nil { + panic(err) + } +} + +func hasCommand(name string) bool { + _, err := exec.LookPath(name) + return err == nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/accounts.go b/cmd/generate-bindings/solana/anchor-go/generator/accounts.go new file mode 100644 index 00000000..542d4020 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/accounts.go @@ -0,0 +1,407 @@ +package generator + +import ( + "fmt" + "strconv" + + . "github.com/dave/jennifer/jen" + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/anchor-go/tools" +) + +func (g *Generator) genfile_accounts() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains parsers for the accounts defined in the IDL.") + file.HeaderComment("Code generated by https://github.com/smartcontractkit/cre-cli. DO NOT EDIT.") + + names := []string{} + { + for _, acc := range g.idl.Accounts { + names = append(names, tools.ToCamelUpper(acc.Name)) + } + } + { + code, err := g.gen_accountParser(names) + if err != nil { + return nil, fmt.Errorf("error generating account parser: %w", err) + } + file.Add(code) + } + + return &OutputFile{ + Name: "accounts.go", + File: file, + }, nil +} + +func (g *Generator) gen_accountParser(accountNames []string) (Code, error) { + code := Empty() + { + code.Func().Id("ParseAnyAccount"). + Params(Id("accountData").Index().Byte()). + Params(Any(), Error()). + BlockFunc(func(block *Group) { + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("accountData")) + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to peek account discriminator: %w"), Err()), + ), + ) + + block.Switch(Id("discriminator")).BlockFunc(func(switchBlock *Group) { + for _, name := range accountNames { + switchBlock.Case(Id(FormatAccountDiscriminatorName(name))).Block( + Id("value").Op(":=").New(Id(name)), + Err().Op(":=").Id("value").Dot("UnmarshalWithDecoder").Call(Id("decoder")), + If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to unmarshal account as "+name+": %w"), Err()), + ), + ), + Return(Id("value"), Nil()), + ) + } + switchBlock.Default().Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("unknown discriminator: %s"), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + }) + }) + } + { + code.Line().Line() + // for each account, generate a function to parse it: + for _, name := range accountNames { + discriminatorName := FormatAccountDiscriminatorName(name) + + code.Func().Id("ParseAccount_"+name). + Params(Id("accountData").Index().Byte()). + Params(Op("*").Id(name), Error()). + BlockFunc(func(block *Group) { + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("accountData")) + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to peek discriminator: %w"), Err()), + ), + ) + + block.If(Id("discriminator").Op("!=").Id(discriminatorName)).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("expected discriminator %v, got %s"), Id(discriminatorName), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + + block.Id("acc").Op(":=").New(Id(name)) + block.Err().Op("=").Id("acc").Dot("UnmarshalWithDecoder").Call(Id("decoder")) + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to unmarshal account of type "+name+": %w"), Err()), + ), + ) + + block.Return(Id("acc"), Nil()) + }) + code.Line().Line() + + // DecodeAccount method for the codec + code.Add(creDecodeAccountFn(name)) + code.Line().Line() + + // Read account data onchain using read capability + // code.Add(creReadAccountFn(name, g)) + code.Line().Line() + } + } + return code, nil +} + +func (g *Generator) gen_IDLTypeDefTyStruct( + name string, + docs []string, + typ idl.IdlTypeDefTyStruct, + withDiscriminator bool, +) (Code, error) { + st := newStatement() + + exportedAccountName := tools.ToCamelUpper(name) + { + // Declare the struct: + code := Empty() + addComments(code, docs) + code.Type().Id(exportedAccountName).StructFunc(func(fieldsGroup *Group) { + switch fields := typ.Fields.(type) { + case idl.IdlDefinedFieldsNamed: + // Generate unique field names to handle duplicates + uniqueFieldNames := generateUniqueFieldNames(fields) + + for fieldIndex, field := range fields { + + // Add docs for the field: + for docIndex, doc := range field.Docs { + if docIndex == 0 && fieldIndex > 0 { + fieldsGroup.Line() + } + fieldsGroup.Comment(doc) + } + // fieldsGroup.Line() + optionality := IsOption(field.Ty) || IsCOption(field.Ty) + + // TODO: optionality for complex enums is a nil interface. + uniqueFieldName := uniqueFieldNames[field.Name] + fieldsGroup.Add(genFieldWithName(field, uniqueFieldName, optionality)). + Add(func() Code { + tagMap := map[string]string{} + if IsOption(field.Ty) { + tagMap["bin"] = "optional" + } + if IsCOption(field.Ty) { + tagMap["bin"] = "coption" + } + // add json tag: use original field name to avoid duplicates + tagMap["json"] = field.Name + func() string { + if optionality { + return ",omitempty" + } + return "" + }() + return Tag(tagMap) + }()) + } + case idl.IdlDefinedFieldsTuple: + // panic(fmt.Errorf("tuple fields not supported: %s", spew.Sdump(fields))) + for fieldIndex, field := range fields { + + fieldsGroup.Line() + optionality := IsOption(field) || IsCOption(field) + + fieldsGroup.Add(genFieldNamed( + FormatTupleItemName(fieldIndex), + field, + optionality, + )). + Add(func() Code { + tagMap := map[string]string{} + if IsOption(field) { + tagMap["bin"] = "optional" + } + if IsCOption(field) { + tagMap["bin"] = "coption" + } + // add json tag: + tagMap["json"] = tools.ToCamelLower(FormatTupleItemName(fieldIndex)) + func() string { + if optionality { + return ",omitempty" + } + return "" + }() + return Tag(tagMap) + }()) + } + + case nil: + // No fields, just an empty struct. + // TODO: should we panic here? + default: + panic(fmt.Errorf("unknown fields type: %T", typ.Fields)) + } + }) + st.Add(code.Line()) + } + { + // Declare the decoder/encoder methods: + code := Empty() + + { + discriminatorName := FormatAccountDiscriminatorName(exportedAccountName) + + // Declare MarshalWithEncoder: + // TODO: + code.Line().Line().Add( + gen_MarshalWithEncoder_struct( + g.idl, + withDiscriminator, + exportedAccountName, + discriminatorName, + typ.Fields, + true, + )) + + // Declare UnmarshalWithDecoder + code.Line().Line().Add( + gen_UnmarshalWithDecoder_struct( + g.idl, + withDiscriminator, + exportedAccountName, + discriminatorName, + typ.Fields, + )) + } + st.Add(code.Line().Line()) + } + { + code := Empty() + code.Add(creGenerateCodecEncoderForTypes(exportedAccountName)) + st.Add(code.Line().Line()) + } + { + // Declare the WriteReportFrom methods: + // TODO: should i exclude events here ? currently it does accounts/structs/events + code := Empty() + code.Add(creWriteReportFromStructs(exportedAccountName, g)) + st.Add(code.Line().Line()) + } + return st, nil +} + +// generateUniqueFieldNames creates unique Go field names from IDL field names, +// handling cases where multiple IDL fields would map to the same Go field name +func generateUniqueFieldNames(fields []idl.IdlField) map[string]string { + fieldNameMap := make(map[string]string) + usedNames := make(map[string]int) + + for _, field := range fields { + baseName := tools.ToCamelUpper(field.Name) + finalName := baseName + + // Check if this name has been used before + if count, exists := usedNames[baseName]; exists { + // Add a numeric suffix to make it unique + finalName = baseName + fmt.Sprintf("%d", count+1) + usedNames[baseName] = count + 1 + } else { + usedNames[baseName] = 0 + } + + fieldNameMap[field.Name] = finalName + } + + return fieldNameMap +} + +func genField(field idl.IdlField, pointer bool) Code { + return genFieldNamed(field.Name, field.Ty, pointer) +} + +// genFieldWithName generates a field with a custom field name (for handling duplicates) +func genFieldWithName(field idl.IdlField, fieldName string, pointer bool) Code { + return genFieldNamed(fieldName, field.Ty, pointer) +} + +func genFieldNamed(name string, typ idltype.IdlType, pointer bool) Code { + st := newStatement() + st.Id(tools.ToCamelUpper(name)). + Add(func() Code { + if isComplexEnum(typ) { + return nil + } + if pointer { + return Op("*") + } + return nil + }()). + Add(genTypeName(typ)) + return st +} + +func genTypeName(idlTypeEnv idltype.IdlType) Code { + st := newStatement() + switch { + case IsIDLTypeKind(idlTypeEnv): + { + str := idlTypeEnv + st.Add(IDLTypeKind_ToTypeDeclCode(str)) + } + case IsOption(idlTypeEnv): + { + opt := idlTypeEnv.(*idltype.Option) + // TODO: optional = pointer? or that's determined upstream? + st.Add(genTypeName(opt.Option)) + } + case IsCOption(idlTypeEnv): + { + copt := idlTypeEnv.(*idltype.COption) + st.Add(genTypeName(copt.COption)) + } + case IsVec(idlTypeEnv): + { + vec := idlTypeEnv.(*idltype.Vec) + st.Index().Add(genTypeName(vec.Vec)) + } + case IsDefined(idlTypeEnv): + { + def := idlTypeEnv.(*idltype.Defined) + st.Add(Id(tools.ToCamelUpper(def.Name))) + } + case IsArray(idlTypeEnv): + { + arr := idlTypeEnv.(*idltype.Array) + { + switch size := arr.Size.(type) { + case *idltype.IdlArrayLenGeneric: + panic(fmt.Sprintf("generic array length not supported: %s", spew.Sdump(size))) + case *idltype.IdlArrayLenValue: + if size.Value < 0 { + panic(fmt.Sprintf("expected positive integer, got %d", size.Value)) + } + st.Index(Id(strconv.Itoa(int(size.Value)))).Add(genTypeName(arr.Type)) + } + } + } + default: + panic("unhandled type: " + spew.Sdump(idlTypeEnv)) + } + return st +} + +func IDLTypeKind_ToTypeDeclCode(ts idltype.IdlType) *Statement { + stat := newStatement() + switch ts.(type) { + case *idltype.Bool: + stat.Bool() + case *idltype.U8: + stat.Uint8() + case *idltype.I8: + stat.Int8() + case *idltype.U16: + // TODO: some types have their implementation in github.com/gagliardetto/binary + stat.Uint16() + case *idltype.I16: + stat.Int16() + case *idltype.U32: + stat.Uint32() + case *idltype.I32: + stat.Int32() + case *idltype.F32: + stat.Float32() + case *idltype.U64: + stat.Uint64() + case *idltype.I64: + stat.Int64() + case *idltype.F64: + stat.Float64() + case *idltype.U128: + stat.Qual(PkgBinary, "Uint128") + case *idltype.I128: + stat.Qual(PkgBinary, "Int128") + case *idltype.Bytes: + stat.Index().Byte() + case *idltype.String: + stat.String() + case *idltype.Pubkey: + stat.Qual(PkgSolanaGo, "PublicKey") + + default: + panic(fmt.Sprintf("unhandled type: %s", spew.Sdump(ts))) + } + + return stat +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go b/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go new file mode 100644 index 00000000..224dafa0 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go @@ -0,0 +1,37 @@ +package generator + +import ( + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" +) + +// typeRegistryComplexEnum contains all types that are a complex enum (and thus implemented as an interface). +var typeRegistryComplexEnum = make(map[string]struct{}) + +func isComplexEnum(envel idltype.IdlType) bool { + switch vv := envel.(type) { + case *idltype.Defined: + _, ok := typeRegistryComplexEnum[vv.Name] + return ok + } + return false +} + +func register_TypeName_as_ComplexEnum(name string) { + typeRegistryComplexEnum[name] = struct{}{} +} + +func registerComplexEnums(def idl.IdlTypeDef) { + switch vv := def.Ty.(type) { + case *idl.IdlTypeDefTyEnum: + enumTypeName := def.Name + if !vv.IsAllSimple() { + register_TypeName_as_ComplexEnum(enumTypeName) + } + case idl.IdlTypeDefTyEnum: + enumTypeName := def.Name + if !vv.IsAllSimple() { + register_TypeName_as_ComplexEnum(enumTypeName) + } + } +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/constants.go b/cmd/generate-bindings/solana/anchor-go/generator/constants.go new file mode 100644 index 00000000..e544cb21 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/constants.go @@ -0,0 +1,351 @@ +package generator + +import ( + "encoding/json" + "fmt" + "math/big" + "strconv" + "strings" + + . "github.com/dave/jennifer/jen" + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/solana-go" +) + +func (g *Generator) gen_constants() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains constants.") + { + if len(g.idl.Constants) > 0 { + file.Comment("Constants defined in the IDL:") + file.Line() + } + code := Empty() + for coi, co := range g.idl.Constants { + if co.Name == "" { + continue // Skip constants without a name. + } + if len(co.Value) == 0 { + continue // Skip constants without a value. + } + + addComments(code, co.Docs) + + switch ty := co.Ty.(type) { + case *idltype.String: + _ = ty + // "value":"\"organism\"" + v, err := strconv.Unquote(co.Value) + if err != nil { + return nil, fmt.Errorf("failed to unquote string constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(v) + code.Line() + case *idltype.Bytes: + _ = ty + // "value":"[102, 101, 101, 95, 118, 97, 117, 108, 116]" + var b []byte + err := json.Unmarshal([]byte(co.Value), &b) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal bytes constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Var().Id(co.Name).Op("=").Index().Byte().Op("{").ListFunc(func(byteGroup *Group) { + for _, byteVal := range b[:] { + byteGroup.Lit(int(byteVal)) + } + }).Op("}") + code.Line() + case *idltype.Pubkey: + _ = ty + // "value":"MiNTdCbWwAu3boEeTL6HzS5VgLb89mhf8VhMLtMrmWL" + pk, err := solana.PublicKeyFromBase58(co.Value) + if err != nil { + return nil, fmt.Errorf("failed to parse pubkey constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Var().Id(co.Name).Op("=").Qual(PkgSolanaGo, "MustPublicKeyFromBase58").Call(Lit(pk.String())) + code.Line() + case *idltype.Bool: + _ = ty + // "value":"true" + v, err := strconv.ParseBool(co.Value) + if err != nil { + return nil, fmt.Errorf("failed to parse bool constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Var().Id(co.Name).Op("=").Lit(v) + code.Line() + case *idltype.U8: + _ = ty + // "value":"42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 8) + if err != nil { + return nil, fmt.Errorf("failed to parse u8 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint8(v)) + code.Line() + case *idltype.I8: + _ = ty + // "value":"-42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 8) + if err != nil { + return nil, fmt.Errorf("failed to parse i8 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int8(v)) + code.Line() + case *idltype.U16: + _ = ty + // "value":"42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse u16 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint16(v)) + code.Line() + case *idltype.I16: + _ = ty + // "value":"-42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse i16 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int16(v)) + code.Line() + case *idltype.U32: + _ = ty + // "value":"42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse u32 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint32(v)) + code.Line() + case *idltype.I32: + _ = ty + // "value":"-42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse i32 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int32(v)) + code.Line() + case *idltype.U64: + _ = ty + // "value":"42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse u64 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint64(v)) + code.Line() + case *idltype.I64: + _ = ty + // "value":"-42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse i64 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int64(v)) + code.Line() + case *idltype.U128: + _ = ty + // "value":"100_000_000" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + bigInt := new(big.Int) + _, ok := bigInt.SetString(cleanValue, 10) + if !ok { + return nil, fmt.Errorf("failed to parse u128 constants[%d] %s: invalid format", coi, spew.Sdump(co)) + } + // Generate code that creates a big.Int from string + code.Var().Id(co.Name).Op("=").Func().Params().Op("*").Qual("math/big", "Int").Block( + Id("val").Op(",").Id("ok").Op(":=").New(Qual("math/big", "Int")).Dot("SetString").Call(Lit(cleanValue), Lit(10)), + If(Op("!").Id("ok")).Block( + Panic(Lit(fmt.Sprintf("invalid u128 constant %s", co.Name))), + ), + Return(Id("val")), + ).Call() + code.Line() + case *idltype.I128: + _ = ty + // "value":"-100_000_000" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + bigInt := new(big.Int) + _, ok := bigInt.SetString(cleanValue, 10) + if !ok { + return nil, fmt.Errorf("failed to parse i128 constants[%d] %s: invalid format", coi, spew.Sdump(co)) + } + // Generate code that creates a big.Int from string + code.Var().Id(co.Name).Op("=").Func().Params().Op("*").Qual("math/big", "Int").Block( + Id("val").Op(",").Id("ok").Op(":=").New(Qual("math/big", "Int")).Dot("SetString").Call(Lit(cleanValue), Lit(10)), + If(Op("!").Id("ok")).Block( + Panic(Lit(fmt.Sprintf("invalid i128 constant %s", co.Name))), + ), + Return(Id("val")), + ).Call() + code.Line() + case *idltype.F32: + _ = ty + // "value":"3.14" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseFloat(cleanValue, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse f32 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(float32(v)) + code.Line() + case *idltype.F64: + _ = ty + // "value":"3.14" + // "value":"4e-6" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseFloat(cleanValue, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse f64 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(v) + code.Line() + case *idltype.Array: + _ = ty + // "type":{"array":["u8",23]},"value":"[115, 101, 110, 100, 95, 119, 105, 116, 104, 95, 115, 119, 97, 112, 95, 100, 101, 108, 101, 103, 97, 116, 101]" + var b []any + err := json.Unmarshal([]byte(co.Value), &b) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal array constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + size, ok := ty.Size.(*idltype.IdlArrayLenValue) + if !ok { + return nil, fmt.Errorf("expected IdlArrayLenValue for constants[%d] %s, got %T", coi, spew.Sdump(co), ty.Size) + } + if len(b) != size.Value { + return nil, fmt.Errorf("expected %d elements in array constants[%d] %s, got %d", ty.Size, coi, spew.Sdump(co), len(b)) + } + code.Var().Id(co.Name).Op("=").Index(Lit(size.Value)).Do(func(index *Statement) { + switch ty.Type.(type) { + case *idltype.U8: + index.Byte() + case *idltype.I8: + index.Int8() + case *idltype.U16: + index.Uint16() + case *idltype.I16: + index.Int16() + case *idltype.U32: + index.Uint32() + case *idltype.I32: + index.Int32() + case *idltype.U64: + index.Uint64() + case *idltype.I64: + index.Int64() + case *idltype.F32: + index.Float32() + case *idltype.F64: + index.Float64() + case *idltype.String: + index.String() + case *idltype.Bool: + index.Bool() + default: + panic(fmt.Errorf("unsupported array type for constants[%d] %s: %T", coi, spew.Sdump(co), ty.Type)) + } + }).Op("{").ListFunc(func(byteGroup *Group) { + for _, val := range b[:] { + switch ty.Type.(type) { + case *idltype.U8: + byteGroup.Lit(byte(val.(float64))) + case *idltype.I8: + byteGroup.Lit(int8(val.(float64))) + case *idltype.U16: + byteGroup.Lit(uint16(val.(float64))) + case *idltype.I16: + byteGroup.Lit(int16(val.(float64))) + case *idltype.U32: + byteGroup.Lit(uint32(val.(float64))) + case *idltype.I32: + byteGroup.Lit(int32(val.(float64))) + case *idltype.U64: + byteGroup.Lit(uint64(val.(float64))) + case *idltype.I64: + byteGroup.Lit(int64(val.(float64))) + case *idltype.F32: + // TODO: is this correct? Are they encoded as strings? + v, err := strconv.ParseFloat(val.(string), 32) + if err != nil { + panic(fmt.Errorf("failed to parse f32 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(float32(v)) + case *idltype.F64: + // TODO: is this correct? Are they encoded as strings? + v, err := strconv.ParseFloat(val.(string), 64) + if err != nil { + panic(fmt.Errorf("failed to parse f64 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(v) + case *idltype.String: + v, err := strconv.Unquote(val.(string)) + if err != nil { + panic(fmt.Errorf("failed to unquote string in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(v) + case *idltype.Bool: + v, err := strconv.ParseBool(val.(string)) + if err != nil { + panic(fmt.Errorf("failed to parse bool in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(v) + default: + panic(fmt.Errorf("unsupported array type for constants[%d] %s: %T", coi, spew.Sdump(co), ty.Type)) + } + } + }).Op("}") + code.Line() + + case *idltype.Defined: + _ = ty + // Handle user-defined types like usize, isize, etc. + switch ty.Name { + case "usize": + // usize is typically a pointer-sized unsigned integer + // In most cases, this is equivalent to u64 on 64-bit systems + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse usize constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint64(v)) + code.Line() + case "isize": + // isize is typically a pointer-sized signed integer + // In most cases, this is equivalent to i64 on 64-bit systems + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse isize constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int64(v)) + code.Line() + default: + // For other defined types, we could try to resolve them, + // but for now, we'll return an error with more specific information + return nil, fmt.Errorf("unsupported defined type '%s' for constants[%d] %s: %T", ty.Name, coi, spew.Sdump(co), ty) + } + + default: + return nil, fmt.Errorf("unsupported constant type for constants[%d] %s: %T", coi, spew.Sdump(co), ty) + } + } + file.Add(code) + } + return &OutputFile{ + Name: "constants.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/constants_test.go b/cmd/generate-bindings/solana/anchor-go/generator/constants_test.go new file mode 100644 index 00000000..080fb8e6 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/constants_test.go @@ -0,0 +1,808 @@ +package generator + +import ( + "fmt" + "strings" + "testing" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenConstants(t *testing.T) { + tests := []struct { + name string + constants []idl.IdlConst + expectError bool + expectCode []string // 期望在生成的代码中找到的字符串 + }{ + { + name: "String constant", + constants: []idl.IdlConst{ + { + Name: "TEST_STRING", + Ty: &idltype.String{}, + Value: `"hello world"`, + }, + }, + expectCode: []string{ + "const TEST_STRING = \"hello world\"", + }, + }, + { + name: "Boolean constants", + constants: []idl.IdlConst{ + { + Name: "IS_ENABLED", + Ty: &idltype.Bool{}, + Value: "true", + }, + { + Name: "IS_DISABLED", + Ty: &idltype.Bool{}, + Value: "false", + }, + }, + expectCode: []string{ + "var IS_ENABLED = true", + "var IS_DISABLED = false", + }, + }, + { + name: "Unsigned integer constants", + constants: []idl.IdlConst{ + { + Name: "MAX_U8", + Ty: &idltype.U8{}, + Value: "255", + }, + { + Name: "MAX_U16", + Ty: &idltype.U16{}, + Value: "65535", + }, + { + Name: "MAX_U32", + Ty: &idltype.U32{}, + Value: "4294967295", + }, + { + Name: "MAX_U64", + Ty: &idltype.U64{}, + Value: "18446744073709551615", + }, + }, + expectCode: []string{ + "const MAX_U8 = uint8(0xff)", + "const MAX_U16 = uint16(0xffff)", + "const MAX_U32 = uint32(0xffffffff)", + "const MAX_U64 = uint64(0xffffffffffffffff)", + }, + }, + { + name: "Signed integer constants", + constants: []idl.IdlConst{ + { + Name: "MIN_I8", + Ty: &idltype.I8{}, + Value: "-128", + }, + { + Name: "MIN_I16", + Ty: &idltype.I16{}, + Value: "-32768", + }, + { + Name: "MIN_I32", + Ty: &idltype.I32{}, + Value: "-2147483648", + }, + { + Name: "MIN_I64", + Ty: &idltype.I64{}, + Value: "-9223372036854775808", + }, + }, + expectCode: []string{ + "const MIN_I8 = int8(-128)", + "const MIN_I16 = int16(-32768)", + "const MIN_I32 = int32(-2147483648)", + "const MIN_I64 = int64(-9223372036854775808)", + }, + }, + { + name: "Float constants", + constants: []idl.IdlConst{ + { + Name: "PI_F32", + Ty: &idltype.F32{}, + Value: "3.14159", + }, + { + Name: "E_F64", + Ty: &idltype.F64{}, + Value: "2.718281828459045", + }, + }, + expectCode: []string{ + "const PI_F32 = float32(3.14159)", + "const E_F64 = 2.718281828459045", + }, + }, + { + name: "Numbers with underscores", + constants: []idl.IdlConst{ + { + Name: "LARGE_NUMBER", + Ty: &idltype.U64{}, + Value: "100_000_000", + }, + { + Name: "NEGATIVE_NUMBER", + Ty: &idltype.I32{}, + Value: "-1_000_000", + }, + }, + expectCode: []string{ + "const LARGE_NUMBER = uint64(0x5f5e100)", + "const NEGATIVE_NUMBER = int32(-1000000)", + }, + }, + { + name: "usize constant", + constants: []idl.IdlConst{ + { + Name: "MAX_BIN_PER_ARRAY", + Ty: &idltype.Defined{ + Name: "usize", + }, + Value: "70", + }, + }, + expectCode: []string{ + "const MAX_BIN_PER_ARRAY = uint64(0x46)", + }, + }, + { + name: "isize constant", + constants: []idl.IdlConst{ + { + Name: "MIN_BIN_ID", + Ty: &idltype.Defined{ + Name: "isize", + }, + Value: "-443636", + }, + }, + expectCode: []string{ + "const MIN_BIN_ID = int64(-443636)", + }, + }, + { + name: "u128 constant", + constants: []idl.IdlConst{ + { + Name: "MAX_BASE_FEE", + Ty: &idltype.U128{}, + Value: "100_000_000", + }, + }, + expectCode: []string{ + "var MAX_BASE_FEE = func() *big.Int", + ".SetString(\"100000000\", 10)", + }, + }, + { + name: "i128 constant", + constants: []idl.IdlConst{ + { + Name: "MIN_BALANCE", + Ty: &idltype.I128{}, + Value: "-1_000_000_000_000", + }, + }, + expectCode: []string{ + "var MIN_BALANCE = func() *big.Int", + ".SetString(\"-1000000000000\", 10)", + }, + }, + { + name: "Bytes constant", + constants: []idl.IdlConst{ + { + Name: "SEED_BYTES", + Ty: &idltype.Bytes{}, + Value: "[102, 101, 101, 95, 118, 97, 117, 108, 116]", + }, + }, + expectCode: []string{ + "var SEED_BYTES = []byte{102, 101, 101, 95, 118, 97, 117, 108, 116}", + }, + }, + { + name: "Pubkey constant", + constants: []idl.IdlConst{ + { + Name: "PROGRAM_ID", + Ty: &idltype.Pubkey{}, + Value: "11111111111111111111111111111112", // System Program ID + }, + }, + expectCode: []string{ + "var PROGRAM_ID = solanago.MustPublicKeyFromBase58(\"11111111111111111111111111111112\")", + }, + }, + { + name: "Empty name - should be skipped", + constants: []idl.IdlConst{ + { + Name: "", + Ty: &idltype.U8{}, + Value: "42", + }, + { + Name: "VALID_CONST", + Ty: &idltype.U8{}, + Value: "42", + }, + }, + expectCode: []string{ + "const VALID_CONST = uint8(0x2a)", + }, + }, + { + name: "Empty value - should be skipped", + constants: []idl.IdlConst{ + { + Name: "EMPTY_VALUE", + Ty: &idltype.U8{}, + Value: "", + }, + { + Name: "VALID_CONST", + Ty: &idltype.U8{}, + Value: "42", + }, + }, + expectCode: []string{ + "const VALID_CONST = uint8(0x2a)", + }, + }, + { + name: "Unsupported defined type", + constants: []idl.IdlConst{ + { + Name: "CUSTOM_TYPE", + Ty: &idltype.Defined{ + Name: "CustomType", + }, + Value: "42", + }, + }, + expectError: true, + }, + { + name: "Invalid string format", + constants: []idl.IdlConst{ + { + Name: "INVALID_STRING", + Ty: &idltype.String{}, + Value: "invalid string format", // 应该有引号 + }, + }, + expectError: true, + }, + { + name: "Invalid number format", + constants: []idl.IdlConst{ + { + Name: "INVALID_NUMBER", + Ty: &idltype.U8{}, + Value: "not_a_number", + }, + }, + expectError: true, + }, + { + name: "Invalid u128 format", + constants: []idl.IdlConst{ + { + Name: "INVALID_U128", + Ty: &idltype.U128{}, + Value: "not_a_number", + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 创建一个最小的 IDL 结构 + idlData := &idl.Idl{ + Constants: tt.constants, + } + + // 创建生成器 + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + // 生成常量 + outputFile, err := gen.gen_constants() + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + require.NotNil(t, outputFile) + + // 获取生成的代码 + generatedCode := outputFile.File.GoString() + + // 检查期望的代码片段是否存在 + for _, expectedCode := range tt.expectCode { + assert.Contains(t, generatedCode, expectedCode, + "Expected code snippet not found: %s\nGenerated code:\n%s", + expectedCode, generatedCode) + } + + // 基本的结构检查 + assert.Contains(t, generatedCode, "package test") + assert.Contains(t, generatedCode, "Code generated by https://github.com/gagliardetto/anchor-go") + assert.Contains(t, generatedCode, "This file contains constants") + }) + } +} + +func TestGenConstantsWithArrays(t *testing.T) { + // 测试数组常量 + constants := []idl.IdlConst{ + { + Name: "BYTE_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.U8{}, + Size: &idltype.IdlArrayLenValue{Value: 3}, + }, + Value: "[1, 2, 3]", + }, + } + + idlData := &idl.Idl{ + Constants: constants, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var BYTE_ARRAY = [3]byte{uint8(0x1), uint8(0x2), uint8(0x3)}") +} + +func TestGenConstantsEdgeCases(t *testing.T) { + t.Run("No constants", func(t *testing.T) { + idlData := &idl.Idl{ + Constants: []idl.IdlConst{}, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "package test") + // 不应该包含 "Constants defined in the IDL:" 注释 + assert.NotContains(t, generatedCode, "Constants defined in the IDL:") + }) + + t.Run("Underscore cleaning", func(t *testing.T) { + // 测试下划线清理功能 + testCases := []struct { + value string + expected string + }{ + {"1_000", "1000"}, + {"1_000_000", "1000000"}, + {"1_2_3_4", "1234"}, + {"100", "100"}, // 没有下划线 + } + + for _, tc := range testCases { + constants := []idl.IdlConst{ + { + Name: "TEST_VALUE", + Ty: &idltype.U64{}, + Value: tc.value, + }, + } + + idlData := &idl.Idl{ + Constants: constants, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err, "Failed for value: %s", tc.value) + + generatedCode := outputFile.File.GoString() + + // 验证生成的代码不包含原始的下划线值 + if strings.Contains(tc.value, "_") { + assert.NotContains(t, generatedCode, tc.value) + } + } + }) +} + +func TestGenConstantsPerformance(t *testing.T) { + // 测试大量常量的性能 + constants := make([]idl.IdlConst, 1000) + for i := 0; i < 1000; i++ { + constants[i] = idl.IdlConst{ + Name: fmt.Sprintf("CONST_%d", i), + Ty: &idltype.U32{}, + Value: fmt.Sprintf("%d", i), + } + } + + idlData := &idl.Idl{ + Constants: constants, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + // 测试性能(应该在合理时间内完成) + outputFile, err := gen.gen_constants() + require.NoError(t, err) + require.NotNil(t, outputFile) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "CONST_0") + assert.Contains(t, generatedCode, "CONST_999") +} + +// TestGenConstantsSpecialCases 测试特殊情况 +func TestGenConstantsSpecialCases(t *testing.T) { + t.Run("Zero values", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "ZERO_U8", + Ty: &idltype.U8{}, + Value: "0", + }, + { + Name: "ZERO_I32", + Ty: &idltype.I32{}, + Value: "0", + }, + { + Name: "ZERO_F64", + Ty: &idltype.F64{}, + Value: "0.0", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "const ZERO_U8 = uint8(0x0)") + assert.Contains(t, generatedCode, "const ZERO_I32 = int32(0)") + assert.Contains(t, generatedCode, "const ZERO_F64 = 0") + }) + + t.Run("Maximum values", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "MAX_U8_VALUE", + Ty: &idltype.U8{}, + Value: "255", + }, + { + Name: "MAX_I8_VALUE", + Ty: &idltype.I8{}, + Value: "127", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "const MAX_U8_VALUE = uint8(0xff)") + assert.Contains(t, generatedCode, "const MAX_I8_VALUE = int8(127)") + }) + + t.Run("Complex underscores", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "COMPLEX_NUMBER", + Ty: &idltype.U64{}, + Value: "1_000_000_000_000_000_000", + }, + { + Name: "HEX_LIKE_NUMBER", + Ty: &idltype.U32{}, + Value: "0_x_F_F_F_F", // 这不是真正的十六进制,只是包含下划线的数字 + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + // 第二个应该失败,因为它不是有效的数字 + outputFile, err := gen.gen_constants() + assert.Error(t, err) // 应该失败 + _ = outputFile + }) + + t.Run("Scientific notation", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "SCIENTIFIC_F32", + Ty: &idltype.F32{}, + Value: "1.23e-4", + }, + { + Name: "SCIENTIFIC_F64", + Ty: &idltype.F64{}, + Value: "1.23456789e10", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "const SCIENTIFIC_F32 = float32(0.000123)") + assert.Contains(t, generatedCode, "const SCIENTIFIC_F64 = 1.23456789e+10") + }) + + t.Run("Empty bytes array", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "EMPTY_BYTES", + Ty: &idltype.Bytes{}, + Value: "[]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var EMPTY_BYTES = []byte{}") + }) + + t.Run("With docs", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "DOCUMENTED_CONST", + Docs: []string{"This is a test constant", "With multiple lines of documentation"}, + Ty: &idltype.U32{}, + Value: "42", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "// This is a test constant") + assert.Contains(t, generatedCode, "// With multiple lines of documentation") + assert.Contains(t, generatedCode, "const DOCUMENTED_CONST = uint32(0x2a)") + }) +} + +// TestGenConstantsErrorCases 测试各种错误情况 +func TestGenConstantsErrorCases(t *testing.T) { + t.Run("Invalid pubkey", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "INVALID_PUBKEY", + Ty: &idltype.Pubkey{}, + Value: "invalid_pubkey_format", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + _, err := gen.gen_constants() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse pubkey") + }) + + t.Run("Invalid bytes format", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "INVALID_BYTES", + Ty: &idltype.Bytes{}, + Value: "[1, 2, invalid]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + _, err := gen.gen_constants() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal bytes") + }) + + t.Run("Invalid array format", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "INVALID_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.U8{}, + Size: &idltype.IdlArrayLenValue{Value: 3}, + }, + Value: "[1, 2]", // 只有2个元素,但期望3个 + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + _, err := gen.gen_constants() + assert.Error(t, err) + assert.Contains(t, err.Error(), "got 2") + }) + + t.Run("Number overflow", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "OVERFLOW_U8", + Ty: &idltype.U8{}, + Value: "256", // 超出 u8 范围 + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + _, err := gen.gen_constants() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse u8") + }) +} + +// TestGenConstantsRealWorldExamples 测试真实世界的例子 +func TestGenConstantsRealWorldExamples(t *testing.T) { + t.Run("Solana program constants", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "LAMPORTS_PER_SOL", + Ty: &idltype.U64{}, + Value: "1_000_000_000", + }, + { + Name: "SEED_PREFIX", + Ty: &idltype.String{}, + Value: `"anchor"`, + }, + { + Name: "MAX_SEED_LEN", + Ty: &idltype.U32{}, + Value: "32", + }, + { + Name: "SYSTEM_PROGRAM_ID", + Ty: &idltype.Pubkey{}, + Value: "11111111111111111111111111111112", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "myprogram"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "package myprogram") + assert.Contains(t, generatedCode, "const LAMPORTS_PER_SOL = uint64(0x3b9aca00)") + assert.Contains(t, generatedCode, "const SEED_PREFIX = \"anchor\"") + assert.Contains(t, generatedCode, "const MAX_SEED_LEN = uint32(0x20)") + assert.Contains(t, generatedCode, "var SYSTEM_PROGRAM_ID = solanago.MustPublicKeyFromBase58") + }) + + t.Run("Mixed types with all supported features", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "FEATURE_ENABLED", + Docs: []string{"Feature flag for new functionality"}, + Ty: &idltype.Bool{}, + Value: "true", + }, + { + Name: "MAX_BIN_COUNT", + Docs: []string{"Maximum number of bins per array"}, + Ty: &idltype.Defined{ + Name: "usize", + }, + Value: "70", + }, + { + Name: "PROTOCOL_FEE", + Docs: []string{"Protocol fee in basis points"}, + Ty: &idltype.U128{}, + Value: "10_000_000_000_000_000_000", + }, + { + Name: "SIGNATURE_SEED", + Ty: &idltype.Array{ + Type: &idltype.U8{}, + Size: &idltype.IdlArrayLenValue{Value: 8}, + }, + Value: "[115, 105, 103, 110, 97, 116, 117, 114]", // "signatur" in ASCII + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + + // 检查注释 + assert.Contains(t, generatedCode, "// Feature flag for new functionality") + assert.Contains(t, generatedCode, "// Maximum number of bins per array") + assert.Contains(t, generatedCode, "// Protocol fee in basis points") + + // 检查生成的常量 + assert.Contains(t, generatedCode, "var FEATURE_ENABLED = true") + assert.Contains(t, generatedCode, "const MAX_BIN_COUNT = uint64(0x46)") + assert.Contains(t, generatedCode, "var PROTOCOL_FEE = func() *big.Int") + assert.Contains(t, generatedCode, "var SIGNATURE_SEED = [8]byte{uint8(0x73), uint8(0x69), uint8(0x67), uint8(0x6e), uint8(0x61), uint8(0x74), uint8(0x75), uint8(0x72)}") + }) +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/cre.go b/cmd/generate-bindings/solana/anchor-go/generator/cre.go new file mode 100644 index 00000000..6ce6b662 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/cre.go @@ -0,0 +1,634 @@ +package generator + +import ( + "encoding/json" + "fmt" + + "github.com/dave/jennifer/jen" + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/tools" +) + +// func (c *Codec) Decode(data []byte) (*, error) { +func creDecodeAccountFn(name string) Code { + return Func(). + Params(Id("c").Op("*").Id("Codec")). + Id("Decode"+name). + Params(Id("data").Index().Byte()). + Params(Op("*").Id(name), Error()). + Block(Return(Id("ParseAccount_" + name).Call(Id("data")))) +} + +// func (c *Codec) EncodeStruct(in ) ([]byte, error) { +// return in.Marshal() +// } +func creGenerateCodecEncoderForTypes(exportedAccountName string) Code { + return Func(). + Params(Id("c").Op("*").Id("Codec")). + Id("Encode"+exportedAccountName+"Struct"). + Params(Id("in").Id(exportedAccountName)). + Params(Index().Byte(), Error()). + Block(Return(Id("in").Dot("Marshal").Call())) +} + +// func (c *DataStorage) ReadAccount_DataAccount(runtime cre.Runtime, accountAddress solanago.PublicKey, blockNumber *big.Int) cre.Promise[*] { +func creReadAccountFn(name string, g *Generator) Code { + code := Func(). + Params(Id("c").Op("*").Id(tools.ToCamelUpper(g.options.Package))). // method receiver + Id("ReadAccount_" + name). + Params( + ListMultiline( + func(paramsCode *Group) { + paramsCode.Id("runtime").Qual(PkgCRE, "Runtime") + paramsCode.Id("accountAddress").Qual(PkgSolanaGo, "PublicKey") + paramsCode.Id("blockNumber").Op("*").Qual(PkgBig, "Int") + }, + ), + ). + Params( + Qual(PkgCRE, "Promise").Types(Op("*").Id(name)), + ). + BlockFunc(func(block *Group) { + block.Comment("cre account read") + // bn := cre.PromiseFromResult(uint64(blockNumber.Int64()), nil) + block.Id("bn").Op(":=").Qual(PkgCRE, "PromiseFromResult").Call( + Id("uint64").Call(Id("blockNumber").Dot("Int64").Call()), + Nil(), + ) + // promise := cre.ThenPromise(bn, func(bn uint64) cre.Promise[*solana.GetAccountInfoReply] { + // return c.client.GetAccountInfoWithOpts(runtime, &solana.GetAccountInfoRequest{ + // Account: types.PublicKey(accountAddress), + // Opts: &solana.GetAccountInfoOpts{MinContextSlot: &bn}, + // }) + // }) + block.Id("promise").Op(":=").Qual(PkgCRE, "ThenPromise").Call( + Id("bn"), + getAccountInfoLambda(), + ) + // return cre.Then(promise, func(response *solana.GetAccountInfoReply) (*DataAccount, error) { + // return ParseAccount_DataAccount(response.Value.Data.AsDecodedBinary) + // }) + block.Return( + Qual(PkgCRE, "Then").Call( + Id("promise"), + parseAccountLambda(name), + ), + ) + }) + return code +} + +// if err block +// +// return cre.PromiseFromResult[*](nil, err) +// } +func creWriteReportErrorBlock() Code { + code := Empty() + code.If(Id("err").Op("!=").Nil()).Block( + Return( + Qual(PkgCRE, "PromiseFromResult").Types(Op("*").Qual(PkgRealSolanaCre, "WriteReportReply")).Call( + Nil(), Id("err"), + ))) + code.Line().Line() + return code +} + +// func (c *DataStorage) WriteReportFrom(runtime cre.Runtime, input , accountList []solanago.PublicKey) cre.Promise[*solana.WriteReportReply] { +func creWriteReportFromStructs(exportedAccountName string, g *Generator) Code { + code := Empty() + declarerName := newWriteReportFromInstructionFuncName(exportedAccountName) + code.Func(). + Params(Id("c").Op("*").Id(tools.ToCamelUpper(g.options.Package))). // method receiver + Id(declarerName). + // params + Params( + ListMultiline(func(p *Group) { + p.Id("runtime").Qual(PkgCRE, "Runtime") + p.Id("input").Id(exportedAccountName) + p.Id("remainingAccounts").Index().Op("*").Qual(PkgRealSolanaCre, "AccountMeta") + }), + ). + // return type + Params(Qual(PkgCRE, "Promise").Types(Op("*").Qual(PkgRealSolanaCre, "WriteReportReply"))). + BlockFunc(func(block *Group) { + // encoded, err := c.Codec.EncodeStruct(input) + block.List(Id("encodedInput"), Id("err")).Op(":="). + Id("c").Dot("Codec").Dot("Encode" + exportedAccountName + "Struct").Call(Id("input")) + + // if err block + block.Add(creWriteReportErrorBlock()) + + // encodedAccountList, err := EncodeAccountList(accountList) + block.Id("encodedAccountList").Op(":="). + Qual(PkgBindings, "CalculateAccountsHash").Call(Id("remainingAccounts")).Line() + + // fwdReport := ForwarderReport{Payload: encodedInput, AccountHash: encodedAccountList} + block.Id("fwdReport").Op(":=").Qual(PkgBindings, "ForwarderReport").Values(Dict{ + Id("Payload"): Id("encodedInput"), + Id("AccountHash"): Id("encodedAccountList"), + }) + + // encodedFwdReport, err := fwdReport.Marshal() + block.List(Id("encodedFwdReport"), Id("err")).Op(":=").Id("fwdReport").Dot("Marshal").Call() + + // if err block + block.Add(creWriteReportErrorBlock()) + + // promise := runtime.GenerateReport(&pb2.ReportRequest{ ... }) + block.Id("promise").Op(":=").Id("runtime").Dot("GenerateReport").Call( + Op("&").Qual(PkgPb2, "ReportRequest").Values(Dict{ + Id("EncodedPayload"): Id("encodedFwdReport"), + Id("EncoderName"): Lit("solana"), + Id("SigningAlgo"): Lit("ed25519"), + Id("HashingAlgo"): Lit("sha256"), + }), + ).Line() + + // typedAccountList := make([]solana.PublicKey, len(accountList)) + // block.Id("typedAccountList").Op(":="). + // Id("make").Call( + // Index().Qual(PkgSolanaCre, "PublicKey"), + // Id("len").Call(Id("accountList")), + // ) + + // // for i, account := range accountList { + // // typedAccountList[i] = solana.PublicKey(account) + // // } + // block.For( + // List(Id("i"), Id("account")).Op(":=").Range().Id("accountList"), + // ).Block( + // Id("typedAccountList").Index(Id("i")).Op("="). + // Qual(PkgSolanaCre, "PublicKey").Call(Id("account")), + // ).Line() + + //return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + // return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + // AccountList: typedAccountList, + // Receiver: ProgramID.Bytes(), + // Report: report, + // }) + // }) + block.Return( + Qual(PkgCRE, "ThenPromise").Call( + Id("promise"), + creWriteReportFromStructsLambda(), + ), + ) + }) + return code +} + +func creWriteReportFromStructsLambda() *Statement { + // func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + // return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + // AccountList: typedAccountList, + // Receiver: ProgramID.Bytes(), + // Report: report, + // }) + // } + return Func(). + Params(Id("report").Op("*").Qual(PkgCRE, "Report")). + Qual(PkgCRE, "Promise").Types(Op("*").Qual(PkgRealSolanaCre, "WriteReportReply")). + Block( + Return( + Id("c").Dot("client").Dot("WriteReport").Call( + Id("runtime"), + Op("&").Qual(PkgRealSolanaCre, "WriteCreReportRequest").Values(jen.Dict{ + Id("Receiver"): Id("ProgramID").Dot("Bytes").Call(), + Id("Report"): Id("report"), + Id("RemainingAccounts"): Id("remainingAccounts"), + }), + ), + ), + ) +} + +// func (c *Codec) Decode(event *solana.Log) (*, error) { +func creDecodeEventFn(name string) Code { + return Func(). + Params(Id("c").Op("*").Id("Codec")). // method receiver + Id("Decode"+name). + Params(Id("event").Op("*").Qual(PkgSolanaCre, "Log")). + Params(Op("*").Id(name), Error()). + BlockFunc(func(block *Group) { + // res, err := ParseEvent_(event.Data) + block.List(Id("res"), Id("err")).Op(":=").Id("ParseEvent_" + name).Call( + Id("event").Dot("Data"), + ) + block.Add(nilErrBlock()) + block.Return(Id("res"), Nil()) + }).Line().Line() +} + +// type Trigger struct { +// cre.Trigger[*solana.Log, *solana.Log] +// contract * +// } +func creTriggerType(name string, g *Generator) Code { + return Type().Id(name+"Trigger"). + Struct( + Qual(PkgCRE, "Trigger"). // embedded generic type + Types( + Op("*").Qual(PkgSolanaCre, "Log"), + Op("*").Qual(PkgSolanaCre, "Log"), + ), + Id("contract").Op("*").Id(tools.ToCamelUpper(g.options.Package)), + ).Line().Line() +} + +// func (t *AccessLoggedTrigger) Adapt(l *solana.Log) (*bindings.DecodedLog[AccessLogged], error) { +func creLogTriggerAdaptFn(name string) Code { + return Func(). + Params(Id("t").Op("*").Id(name+"Trigger")). // receiver (*DataUpdatedTrigger) + Id("Adapt"). + Params(Id("l").Op("*").Qual(PkgSolanaCre, "Log")). + Params( + Op("*").Qual(PkgBindings, "DecodedLog").Types(Id(name)), // return type + Error(), + ). + Block( + // decoded, err := t.contract.Codec.Decode(l) + List(Id("decoded"), Id("err")).Op(":=").Id("t").Dot("contract").Dot("Codec").Dot("Decode"+name).Call(Id("l")), + // if err != nil { return nil, err } + Add(nilErrBlock()), + // return &bindings.DecodedLog{ Log: l, Data: *decoded } + Return( + Op("&").Qual(PkgBindings, "DecodedLog").Types(Id(name)).Values(Dict{ + Id("Log"): Id("l"), + Id("Data"): Op("*").Id("decoded"), + }), + Nil(), + ), + ).Line().Line() +} + +// func (c *pkgName) LogTrigger_(chainSelector uint64, subKeyPathAndValue []solana.SubKeyPathAndFilter) (cre.Trigger[*solana.Log, *bindings.DecodedLog[]], error) { +func creLogTriggerFunc(name string, g *Generator) Code { + return Func(). + Params(Id("c").Op("*").Id(tools.ToCamelUpper(g.options.Package))). // method receiver + Id("LogTrigger_"+name). + Params( + Id("chainSelector").Uint64(), + Id("subKeyPathAndValue").Index().Qual(PkgSolanaCre, "SubKeyPathAndFilter"), + ). + Params( + Qual(PkgCRE, "Trigger").Types( + Op("*").Qual(PkgSolanaCre, "Log"), + Op("*").Qual(PkgBindings, "DecodedLog").Types(Id(name)), + ), + Error(), + ). + BlockFunc(func(b *jen.Group) { + // eventIdl := types.GetIdlEvent(c.IdlTypes, "") + b.List(Id("eventIdl"), Id("err")).Op(":=").Qual(PkgAnchorIdlCodec, "GetIdlEvent").Call( + Id("c").Dot("IdlTypes"), + Lit(name), + ) + b.Add(nilErrBlock()) + + // if len(subKeyPathAndValue) > 4 { return nil, fmt.Errorf(...) } + b.If(Len(Id("subKeyPathAndValue")).Op(">").Lit(4)).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call( + Lit("too many subkey path and value pairs: %d"), + Len(Id("subKeyPathAndValue")), + ), + ), + ) + + // subKeyPaths, subKeyFilters, err := bindings.ValidateSubKeyPathAndValue[](subKeyPathAndValue) + b.List( + Id("subKeyPaths"), + Id("subKeyFilters"), + Id("err"), + ).Op(":=").Qual(PkgBindings, "ValidateSubKeyPathAndValue"). + Types(Id(name)). + Call(Id("subKeyPathAndValue")) + + b.If(Id("err").Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call( + Lit("failed to validate subkey path and value: %w"), + Id("err"), + ), + ), + ) + + // rawTrigger := solana.LogTrigger(chainSelector, &solana.FilterLogTriggerRequest{ ... }) + b.Id("rawTrigger").Op(":=").Qual(PkgSolanaCre, "LogTrigger").Call( + Id("chainSelector"), + Op("&").Qual(PkgSolanaCre, "FilterLogTriggerRequest").Values(jen.Dict{ + Id("Address"): Id("ProgramID").Dot("Bytes").Call(), + Id("EventName"): Lit(name), + Id("EventSig"): Id("Event_" + name), + Id("EventIdl"): Id("eventIdl"), + Id("SubkeyPaths"): Id("subKeyPaths"), + Id("SubkeyFilters"): Id("subKeyFilters"), + }), + ) + + // return &Trigger{ Trigger: rawTrigger }, nil + b.Return( + Op("&").Id(name+"Trigger").Values(jen.Dict{ + Id("Trigger"): Id("rawTrigger"), + Id("contract"): Id("c"), + }), + Nil(), + ) + }).Line().Line() +} + +func nilErrBlock() Code { + return If(Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("err")), + ) +} + +func creEventFuncs(name string, g *Generator) Code { + code := Empty() + // event decode func + code.Add(creDecodeEventFn(name)) + + // trigger type + code.Add(creTriggerType(name, g)) + + // Adapt func + code.Add(creLogTriggerAdaptFn(name)) + + // Log trigger func + code.Add(creLogTriggerFunc(name, g)) + + return code +} + +// genfile_constructor generates the file `constructor.go`. +func (g *Generator) genfile_constructor() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains the constructor for the program.") + + { + // idl string + code := newStatement() + idlData, err := json.Marshal(g.idl) + if err != nil { + return nil, fmt.Errorf("error reading IDL file: %w", err) + } + code.Var().Id("IDL").Op("=").Lit(string(idlData)) + file.Add(code) + code.Line() + + // contract type + code = newStatement() + code.Type().Id(tools.ToCamelUpper(g.options.Package)).Struct( + Id("IdlTypes").Op("*").Qual(PkgAnchorIdlCodec, "IdlTypeDefSlice"), + Id("client").Op("*").Qual(PkgRealSolanaCre, "Client"), + Id("Codec").Id(tools.ToCamelUpper(g.options.Package)+"Codec"), + ) + code.Line() + file.Add(code) + code.Line() + + // codec type + code = newStatement() + code.Type().Id("Codec").Struct() + code.Line() + file.Add(code) + + // new constructor + code = newStatement() + code.Func(). + Id("New"+tools.ToCamelUpper(g.options.Package)). + Params( + Id("client").Op("*").Qual(PkgRealSolanaCre, "Client"), + ). + Params( + Op("*").Id(tools.ToCamelUpper(g.options.Package)), Error(), + ). + Block( + // type idlTypesStruct struct { anchorcodec.IdlTypeDefSlice `json:"types"` } + Type().Id("idlTypesStruct").Struct( + Qual(PkgAnchorIdlCodec, "IdlTypeDefSlice"). + Tag(map[string]string{"json": "types"}), + ), + + // var idlTypes idlTypesStruct + Var().Id("idlTypes").Id("idlTypesStruct"), + + // err := json.Unmarshal([]byte(IDL), &idlTypes) + Id("err").Op(":=").Qual(PkgJson, "Unmarshal").Call( + Index().Byte().Parens(Id("IDL")), + Op("&").Id("idlTypes"), + ), + + // if err != nil { return nil, err } + If(Err().Op("!=").Nil()).Block( + Return(Nil(), Err()), + ), + + // return &DataStorage{ Codec: &Codec{}, IdlTypes: &idlTypes.IdlTypeDefSlice, client: client }, nil + Return( + Op("&").Id(tools.ToCamelUpper(g.options.Package)).Values(Dict{ + Id("Codec"): Op("&").Id("Codec").Values(), + Id("IdlTypes"): Op("&").Id("idlTypes").Dot("IdlTypeDefSlice"), + Id("client"): Id("client"), + }), + Nil(), + ), + ) + file.Add(code) + code.Line() + + methods, err := g.generateCodecMethods() + if err != nil { + return nil, err + } + + // Codec interface + code = newStatement() + code.Type().Id(tools.ToCamelUpper(g.options.Package) + "Codec").Interface(methods...) + file.Add(code) + code.Line() + } + + return &OutputFile{ + Name: "constructor.go", + File: file, + }, nil +} + +func getAccountInfoLambda() *Statement { + // func(bn uint64) cre.Promise[*solana.GetAccountInfoWithOptsReply] { + // return c.client.GetAccountInfoWithOpts(runtime, &solana.GetAccountInfoWithOptsRequest{ + // Account: types.PublicKey(accountAddress), + // Opts: &solana.GetAccountInfoOpts{MinContextSlot: &bn}, + // }) + // } + return Func(). + Params(Id("bn").Uint64()). + Qual(PkgCRE, "Promise").Types(Op("*").Qual(PkgRealSolanaCre, "GetAccountInfoWithOptsReply")). + Block( + Return( + Id("c").Dot("client").Dot("GetAccountInfoWithOpts").Call( + Id("runtime"), + Op("&").Qual(PkgRealSolanaCre, "GetAccountInfoWithOptsRequest").Values(Dict{ + Id("Account"): Id("accountAddress").Dot("Bytes").Call(), + Id("Opts"): Op("&").Qual(PkgRealSolanaCre, "GetAccountInfoOpts").Values(Dict{ + Id("MinContextSlot"): Id("bn"), + }), + }), + ), + ), + ) +} + +func parseAccountLambda(name string) *Statement { + // func(response *solana.GetAccountInfoWithOptsReply) (*DataAccount, error) { + // return ParseAccount_DataAccount(response.Value.Data.AsDecodedBinary) + // } + return Func(). + Params(Id("response").Op("*").Qual(PkgRealSolanaCre, "GetAccountInfoWithOptsReply")). + Params(Op("*").Id(name), Error()). + Block( + Return( + Id("ParseAccount_" + name).Call( + Id("response").Dot("Value").Dot("Data").Dot("GetRaw").Call(), + ), + ), + ) +} + +func dummyForwarderCode(file *File) { + code := newStatement() + code.Type().Id("ForwarderReport").Struct( + Id("AccountHash").Index(Lit(32)).Byte().Tag(map[string]string{"json": "account_hash"}), + Id("Payload").Index().Byte().Tag(map[string]string{"json": "payload"}), + ) + file.Add(code) + code.Line() + + code = newStatement() + code.Func(). + Params(Id("c").Op("*").Id("Codec")). + Id("EncodeForwarderReportStruct"). + Params(Id("in").Id("ForwarderReport")). + Params(Index().Byte(), Error()). + Block( + Return(Id("in").Dot("Marshal").Call()), + ) + file.Add(code) + code.Line() + + code = newStatement() + code.Func(). + Params(Id("obj").Id("ForwarderReport")). + Id("MarshalWithEncoder"). + Params(Id("encoder").Op("*").Qual(PkgBinary, "Encoder")). + Params(Id("err").Error()). + Block( + Comment("Serialize `AccountHash`:"), + Id("err").Op("=").Id("encoder").Dot("Encode").Call(Id("obj").Dot("AccountHash")), + If(Id("err").Op("!=").Nil()).Block( + Return(Qual(PkgAnchorGoErrors, "NewField").Call(Lit("AccountHash"), Id("err"))), + ), + Comment("Serialize `Payload`:"), + Id("err").Op("=").Id("encoder").Dot("Encode").Call(Id("obj").Dot("Payload")), + If(Id("err").Op("!=").Nil()).Block( + Return(Qual(PkgAnchorGoErrors, "NewField").Call(Lit("Payload"), Id("err"))), + ), + Return(Nil()), + ) + file.Add(code) + code.Line() + + code = newStatement() + code.Func(). + Params(Id("obj").Id("ForwarderReport")). + Id("Marshal"). + Params(). + Params(Index().Byte(), Error()). + Block( + Id("buf").Op(":=").Qual("bytes", "NewBuffer").Call(Nil()), + Id("encoder").Op(":=").Qual(PkgBinary, "NewBorshEncoder").Call(Id("buf")), + Id("err").Op(":=").Id("obj").Dot("MarshalWithEncoder").Call(Id("encoder")), + If(Id("err").Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("error while encoding ForwarderReport: %w"), Id("err")), + ), + ), + Return(Id("buf").Dot("Bytes").Call(), Nil()), + ) + file.Add(code) + code.Line() +} + +func (g *Generator) generateCodecAccountMethods() ([]Code, error) { + accountMethods := make([]Code, 0, len(g.idl.Accounts)) + for _, acc := range g.idl.Accounts { + methodName := "Decode" + acc.Name + m := Id(methodName). + Params(Id("data").Index().Byte()). // ([]byte) + Params( + Op("*").Id(acc.Name), // (*DataAccount) + Error(), // error + ) + + accountMethods = append(accountMethods, m) + } + + return accountMethods, nil +} + +func (g *Generator) generateCodecEventMethods() ([]Code, error) { + eventMethods := make([]Code, 0, len(g.idl.Events)) + for _, event := range g.idl.Events { + methodName := "Decode" + event.Name + m := Id(methodName). + Params(Id("log").Op("*").Qual(PkgSolanaCre, "Log")). + Params( + Op("*").Id(event.Name), + Error(), + ) + + eventMethods = append(eventMethods, m) + } + + return eventMethods, nil +} + +func (g *Generator) generateCodecStructMethod() ([]Code, error) { + structMethods := make([]Code, 0, len(g.idl.Types)) + for _, typ := range g.idl.Types { + methodName := "Encode" + typ.Name + "Struct" + m := Id(methodName). + Params( + Id("in").Id(typ.Name), // e.g., AccessLogged / DataAccount / ... + ). + Params( + Index().Byte(), // []byte + Error(), // error + ) + structMethods = append(structMethods, m) + } + return structMethods, nil +} + +func (g *Generator) generateCodecMethods() ([]Code, error) { + accountMethods, err := g.generateCodecAccountMethods() + if err != nil { + return nil, err + } + // eventMethods, err := g.generateCodecEventMethods() + // if err != nil { + // return nil, err + // } + structMethods, err := g.generateCodecStructMethod() + if err != nil { + return nil, err + } + // return append(append(accountMethods, eventMethods...), structMethods...), nil + return append(accountMethods, structMethods...), nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/discriminator.go b/cmd/generate-bindings/solana/anchor-go/generator/discriminator.go new file mode 100644 index 00000000..585095c9 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/discriminator.go @@ -0,0 +1,115 @@ +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" +) + +func (g *Generator) gen_discriminators() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains the discriminators for accounts and events defined in the IDL.") + + { + accountDiscriminatorsCodes := Empty() + accountDiscriminatorsCodes.Comment("Account discriminators") + accountDiscriminatorsCodes.Line() + accountDiscriminatorsCodes.Var().Parens( + DoGroup(func(code *Group) { + for _, account := range g.idl.Accounts { + if account.Discriminator == nil { + continue + } + + discriminator := account.Discriminator + if len(discriminator) != 8 { + panic(fmt.Errorf("discriminator for account %s must be exactly 8 bytes long, got %d bytes", account.Name, len(discriminator))) + } + + discriminatorName := FormatAccountDiscriminatorName(account.Name) + { + code.Id(discriminatorName).Op("=").Index(Lit(8)).Byte().Op("{").ListFunc(func(byteGroup *Group) { + for _, byteVal := range discriminator[:] { + byteGroup.Lit(int(byteVal)) + } + }).Op("}") + } + code.Line() + } + }), + ) + file.Add(accountDiscriminatorsCodes) + file.Line() + } + { + // Generate the discriminators for events. + eventDiscriminatorsCodes := Empty() + eventDiscriminatorsCodes.Comment("Event discriminators") + eventDiscriminatorsCodes.Line() + eventDiscriminatorsCodes.Var().Parens( + DoGroup(func(code *Group) { + for _, event := range g.idl.Events { + if event.Discriminator == nil { + continue + } + + discriminator := event.Discriminator + if len(discriminator) != 8 { + panic(fmt.Errorf("discriminator for event %s must be exactly 8 bytes long", event.Name)) + } + + discriminatorName := FormatEventDiscriminatorName(event.Name) + { + code.Id(discriminatorName).Op("=").Index(Lit(8)).Byte().Op("{").ListFunc(func(byteGroup *Group) { + for _, byteVal := range discriminator[:] { + byteGroup.Lit(int(byteVal)) + } + }).Op("}") + } + code.Line() + } + }), + ) + file.Add(eventDiscriminatorsCodes) + file.Line() + } + { + // Generate the discriminators for instructions. + instructionDiscriminatorsCodes := Empty() + instructionDiscriminatorsCodes.Comment("Instruction discriminators") + instructionDiscriminatorsCodes.Line() + instructionDiscriminatorsCodes.Var().Parens( + DoGroup( + func(code *Group) { + for _, instruction := range g.idl.Instructions { + if instruction.Discriminator == nil { + continue + } + + discriminator := instruction.Discriminator + if len(discriminator) != 8 { + panic(fmt.Errorf("discriminator for instruction %s must be exactly 8 bytes long", instruction.Name)) + } + + discriminatorName := FormatInstructionDiscriminatorName(instruction.Name) + { + code.Id(discriminatorName).Op("=").Index(Lit(8)).Byte().Op("{").ListFunc(func(byteGroup *Group) { + for _, byteVal := range discriminator[:] { + byteGroup.Lit(int(byteVal)) + } + }).Op("}") + } + code.Line() + } + }, + ), + ) + file.Add(instructionDiscriminatorsCodes) + file.Line() + } + return &OutputFile{ + Name: "discriminators.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/doc.go b/cmd/generate-bindings/solana/anchor-go/generator/doc.go new file mode 100644 index 00000000..330dfb46 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/doc.go @@ -0,0 +1,33 @@ +package generator + +import ( + . "github.com/dave/jennifer/jen" +) + +func (g *Generator) genfile_doc() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains documentation and example usage for the generated code.") + + // TODO: + // - example usage + // - documentation + + file.Line().Line() + + if len(g.idl.Docs) == 0 { + file.Comment("No documentation available from the IDL.") + file.Comment("Please refer to the IDL source or the program documentation for more information.") + file.Line() + } else { + file.Comment("Documentation from the IDL:") + for _, comment := range g.idl.Docs { + file.Comment(comment) + } + } + + return &OutputFile{ + Name: "doc.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/errors.go b/cmd/generate-bindings/solana/anchor-go/generator/errors.go new file mode 100644 index 00000000..cb76a155 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/errors.go @@ -0,0 +1,100 @@ +package generator + +import ( + "encoding/json" + "errors" + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/solana-go/rpc/jsonrpc" +) + +func (g *Generator) gen_errors() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains errors.") + { + code := Empty() + for _, e := range g.idl.Errors { + _ = e + // spew.Dump(e) + + // type IdlErrorCode struct { + // Code uint32 `json:"code"` + // Name string `json:"name"` + // // #[serde(skip_serializing_if = "is_default")] + // // pub msg: Option, + // Msg Option[string] `json:"msg,omitzero"` + // } + } + file.Add(code) + } + return &OutputFile{ + Name: "errors.go", + File: file, + }, nil +} + +type CustomError interface { + Code() int + Name() string + Error() string +} +type customErrorDef struct { + code int + name string + msg string +} + +func (e *customErrorDef) Code() int { + return e.code +} + +func (e *customErrorDef) Name() string { + return e.name +} + +func (e *customErrorDef) Error() string { + return fmt.Sprintf("%s(%d): %s", e.name, e.code, e.msg) +} + +var Errors = map[int]CustomError{} + +func DecodeCustomError(rpcErr error) (err error, ok bool) { + if errCode, o := decodeErrorCode(rpcErr); o { + if customErr, o := Errors[errCode]; o { + err = customErr + ok = true + return + } + } + return +} + +func decodeErrorCode(rpcErr error) (errorCode int, ok bool) { + var jErr *jsonrpc.RPCError + if errors.As(rpcErr, &jErr) && jErr.Data != nil { + if root, o := jErr.Data.(map[string]any); o { + if rootErr, o := root["err"].(map[string]any); o { + if rootErrInstructionError, o := rootErr["InstructionError"]; o { + if rootErrInstructionErrorItems, o := rootErrInstructionError.([]any); o { + if len(rootErrInstructionErrorItems) == 2 { + if v, o := rootErrInstructionErrorItems[1].(map[string]any); o { + if v2, o := v["Custom"].(json.Number); o { + if code, err := v2.Int64(); err == nil { + ok = true + errorCode = int(code) + } + } else if v2, o := v["Custom"].(float64); o { + ok = true + errorCode = int(v2) + } + } + } + } + } + } + } + } + return +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/events.go b/cmd/generate-bindings/solana/anchor-go/generator/events.go new file mode 100644 index 00000000..425112ec --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/events.go @@ -0,0 +1,135 @@ +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/tools" +) + +func (g *Generator) genfile_events() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains parsers for the events defined in the IDL.") + + names := []string{} + { + for _, event := range g.idl.Events { + names = append(names, tools.ToCamelUpper(event.Name)) + } + } + { + code, err := g.gen_eventParser(names) + if err != nil { + return nil, fmt.Errorf("error generating event parser: %w", err) + } + file.Add(code) + } + + return &OutputFile{ + Name: "events.go", + File: file, + }, nil +} + +func (g *Generator) gen_eventParser(eventNames []string) (Code, error) { + code := Empty() + { + code.Func().Id("ParseAnyEvent"). + Params(Id("eventData").Index().Byte()). + Params(Any(), Error()). + BlockFunc(func(block *Group) { + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("eventData")) + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to peek event discriminator: %w"), Err()), + ), + ) + + block.Switch(Id("discriminator")).BlockFunc(func(switchBlock *Group) { + for _, name := range eventNames { + switchBlock.Case(Id(FormatEventDiscriminatorName(name))).Block( + Id("value").Op(":=").New(Id(name)), + Err().Op(":=").Id("value").Dot("UnmarshalWithDecoder").Call(Id("decoder")), + If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to unmarshal event as "+name+": %w"), Err()), + ), + ), + Return(Id("value"), Nil()), + ) + } + switchBlock.Default().Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("unknown discriminator: %s"), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + }) + }) + } + { + code.Line().Line() + // for each event, generate a function to parse it: + for _, name := range eventNames { + discriminatorName := FormatEventDiscriminatorName(name) + + code.Func().Id("ParseEvent_"+name). + Params(Id("eventData").Index().Byte()). + Params(Op("*").Id(name), Error()). + BlockFunc(func(block *Group) { + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("eventData")) + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to peek discriminator: %w"), Err()), + ), + ) + + block.If(Id("discriminator").Op("!=").Id(discriminatorName)).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("expected discriminator %v, got %s"), Id(discriminatorName), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + + block.Id("event").Op(":=").New(Id(name)) + block.Err().Op("=").Id("event").Dot("UnmarshalWithDecoder").Call(Id("decoder")) + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to unmarshal event of type "+name+": %w"), Err()), + ), + ) + + block.Return(Id("event"), Nil()) + }) + code.Line().Line() + + // code.Add(creEventFuncs(name, g)) + } + } + return code, nil +} + +/* +type LogTriggerConfig struct { + Name string + Address lptypes.PublicKey + EventName string + EventSig lptypes.EventSignature + StartingBlock int64 + EventIdl lptypes.EventIdl + SubkeyPaths [][]string + Retention time.Duration + MaxLogsKept int64 + SubkeyFilters []SubkeyFilterCriteria +} + +type SubkeyFilterCriteria struct { + SubkeyIndex uint64 + Comparers []primitives.ValueComparator +} + +*/ diff --git a/cmd/generate-bindings/solana/anchor-go/generator/fetchers.go b/cmd/generate-bindings/solana/anchor-go/generator/fetchers.go new file mode 100644 index 00000000..0e7484bb --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/fetchers.go @@ -0,0 +1,17 @@ +package generator + +import ( + . "github.com/dave/jennifer/jen" +) + +func (g *Generator) gen_fetchers() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains fetcher functions.") + { + } + return &OutputFile{ + Name: "fetchers.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/generator.go b/cmd/generate-bindings/solana/anchor-go/generator/generator.go new file mode 100644 index 00000000..98a35353 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/generator.go @@ -0,0 +1,167 @@ +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/solana-go" +) + +var Debug = false // Set to true to enable debug logging. + +type Generator struct { + options *GeneratorOptions + idl *idl.Idl +} + +type GeneratorOptions struct { + OutputDir string // Directory to write the generated code to. + Package string // Package name for the generated code. + ModPath string // Module path for the generated code. E.g. "github.com/gagliardetto/mysolana-program-go" + ProgramId *solana.PublicKey // Program ID to use in the generated code. + ProgramName string // Name of the program for the generated code. + SkipGoMod bool // If true, skip generating the go.mod file. +} + +func NewGenerator(idl *idl.Idl, options *GeneratorOptions) *Generator { + return &Generator{ + idl: idl, + options: options, + } +} + +type OutputFile struct { + Name string // Name of the output file. + File *File +} + +type Output struct { + Files []*OutputFile // List of output files to be generated. + GoMod []byte // Go module file content. +} + +func (g *Generator) Generate() (*Output, error) { + if g.idl == nil { + return nil, fmt.Errorf("IDL is nil, cannot generate code") + } + if g.options == nil { + g.options = &GeneratorOptions{ + OutputDir: "generated", + Package: "idlclient", + ModPath: "github.com/gagliardetto/anchor-go/idlclient", + ProgramId: nil, + ProgramName: "myprogram", + } + } + if err := g.idl.Validate(); err != nil { + return nil, fmt.Errorf("invalid IDL: %w", err) + } + output := &Output{ + Files: make([]*OutputFile, 0), + } + + { + // Register complex enums. + { + // register complex enums: + // TODO: .types is the only place where we can find complex enums? (or enums in general?) + for _, typ := range g.idl.Types { + registerComplexEnums(typ) + } + } + if len(g.idl.Docs) > 0 { + file, err := g.genfile_doc() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + if len(g.idl.Accounts) > 0 { + file, err := g.genfile_accounts() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + if len(g.idl.Events) > 0 { + file, err := g.genfile_events() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.genfile_constructor() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.genfile_types() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_discriminators() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_fetchers() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_errors() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_constants() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_tests() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_instructions() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + if g.options.ProgramId != nil { + file, err := g.genfile_programID(*g.options.ProgramId) + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + if !g.options.SkipGoMod { + goMod, err := g.gen_gomod() + if err != nil { + return nil, err + } + output.GoMod = goMod + } + } + + return output, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/gomod.go b/cmd/generate-bindings/solana/anchor-go/generator/gomod.go new file mode 100644 index 00000000..633c1d5d --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/gomod.go @@ -0,0 +1,26 @@ +package generator + +import ( + "golang.org/x/mod/modfile" +) + +// gen_gomod generates a `go.mod` file for the generated code, and writes +// it to the destination directory. +func (g *Generator) gen_gomod() ([]byte, error) { + mdf := &modfile.File{} + mdf.AddModuleStmt(g.options.ModPath) + + mdf.AddNewRequire("github.com/gagliardetto/solana-go", "v1.12.0", false) + mdf.AddNewRequire("github.com/gagliardetto/anchor-go", "v0.3.2", false) + mdf.AddNewRequire("github.com/gagliardetto/binary", "v0.8.0", false) + mdf.AddNewRequire("github.com/gagliardetto/treeout", "v0.1.4", false) + mdf.AddNewRequire("github.com/gagliardetto/gofuzz", "v1.2.2", false) + mdf.AddNewRequire("github.com/stretchr/testify", "v1.10.0", false) + mdf.AddNewRequire("github.com/davecgh/go-spew", "v1.1.1", false) + + // add replacement for "github.com/gagliardetto/anchor-go/errors" to ../../demo-anchor-go/errors + // mdf.AddReplace("github.com/gagliardetto/anchor-go", "", "../../demo-anchor-go", "") + mdf.Cleanup() + + return mdf.Format() +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/id.go b/cmd/generate-bindings/solana/anchor-go/generator/id.go new file mode 100644 index 00000000..5415f537 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/id.go @@ -0,0 +1,25 @@ +package generator + +import ( + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/solana-go" +) + +// TODO: +// - generate program IDs for mainnet, devnet, testnet, and localnet. + +func (g *Generator) genfile_programID(id solana.PublicKey) (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains the program ID.") + + { + idDecl := Var().Id("ProgramID").Op("=").Qual(PkgSolanaGo, "MustPublicKeyFromBase58").Call(Lit(id.String())) + file.Add(idDecl) + } + + return &OutputFile{ + Name: "program_id.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/instructions.go b/cmd/generate-bindings/solana/anchor-go/generator/instructions.go new file mode 100644 index 00000000..ed3c6cac --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/instructions.go @@ -0,0 +1,793 @@ +package generator + +import ( + "fmt" + "strings" + + . "github.com/dave/jennifer/jen" + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" +) + +func (g *Generator) gen_instructions() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains instructions and instruction parsers.") + { + for _, instruction := range g.idl.Instructions { + ixCode := Empty() + { + declarerName := newInstructionFuncName(instruction.Name) + ixCode.Commentf("Builds a %q instruction.", instruction.Name) + { + if len(instruction.Docs) > 0 { + ixCode.Line() + // Add documentation comments for the instruction. + for _, doc := range instruction.Docs { + ixCode.Comment(doc) + } + } + } + ixCode.Line() + ixCode.Func().Id(declarerName). + Params( + DoGroup( + func(g *Group) { + addCommentSections := len(instruction.Args) > 0 && len(instruction.Accounts) > 0 + if addCommentSections { + g.Line().Comment("Params:") + } + g.Add( + ListMultiline( + func(paramsCode *Group) { + for _, param := range instruction.Args { + paramType := genTypeName(param.Ty) + if IsOption(param.Ty) || IsCOption(param.Ty) { + paramType = Op("*").Add(paramType) + } + paramsCode.Id(formatParamName(param.Name)).Add(paramType) + } + }, + ), + ) + if addCommentSections { + g.Line().Comment("Accounts:") + } + g.Add( + ListMultiline( + func(accountsCode *Group) { + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + accountsCode.Id(formatAccountNameParam(acc.Name)).Qual(PkgSolanaGo, "PublicKey") + } + // TODO: for accounts: + // - Optional? + // - PDA? + // - Address? + // - Relations? + case *idl.IdlInstructionAccounts: + { + panic(fmt.Errorf("Accounts groups are not supported yet: %s", acc.Name)) + // accs := acc.Accounts + // // add comment for the accounts + // if len(accs) > 0 { + // accountsCode.Commentf("Accounts group %q:", acc.Name) + // } + // for _, acc := range accs { + // // If the account has a name, use it as the parameter name. + // // Otherwise, use a generic name. + // acc := acc.(*idl.IdlInstructionAccount) + // accountName := formatAccountNameParam(acc.Name) + // accountsCode.Id(accountName).Qual(PkgSolanaGo, "PublicKey") + // } + } + default: + panic("unknown account type: " + spew.Sdump(account)) + } + } + }, + ), + ) + }, + ), + ). + ParamsFunc(func(returnsCode *Group) { + returnsCode.Qual(PkgSolanaGo, "Instruction") + returnsCode.Error() + }).BlockFunc(func(body *Group) { + if len(instruction.Args) > 0 { + body.Id("buf__").Op(":=").New(Qual("bytes", "Buffer")) + body.Id("enc__").Op(":=").Qual(PkgBinary, "NewBorshEncoder").Call(Id("buf__")) + + { + // write the discriminator + body.Line().Comment("Encode the instruction discriminator.") + discriminatorName := FormatInstructionDiscriminatorName(instruction.Name) + body.Err().Op(":=").Id("enc__").Dot("WriteBytes").Call(Id(discriminatorName).Index(Op(":")), False()) + body.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to write instruction discriminator: %w"), Err()), + ), + ) + } + // for _, param := range instruction.Args { + // paramName := formatParamName(param.Name) + // isComplexEnum(param.Ty) + + // body.Line().Commentf("Encode the parameter: %s", paramName) + // body.Block( + // Err().Op(":=").Id("enc__").Dot("Encode").Call(Id(paramName)), + // If(Err().Op("!=").Nil()).Block( + // Return( + // Nil(), + // Qual(PkgAnchorGoErrors, "NewField").Call( + // Lit(paramName), + // Err(), + // ), + // ), + // ), + // ) + // } + checkNil := true + body.BlockFunc(func(g *Group) { + gen_marshal_DefinedFieldsNamed( + g, + instruction.Args, + checkNil, + func(param idl.IdlField) *Statement { + return Id(formatParamName(param.Name)) + }, + "enc__", + true, // returnNilErr + func(param idl.IdlField) string { + return formatParamName(param.Name) + }, + ) + }) + } + body.Id("accounts__").Op(":=").Qual(PkgSolanaGo, "AccountMetaSlice").Block() + if len(instruction.Accounts) > 0 { + body.Line().Comment("Add the accounts to the instruction.") + + body.Block( + DoGroup(func(body *Group) { + for ai, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + if ai > 0 { + body.Line() + } + body.Comment(formatAccountCommentDocs(ai, acc)) + body.Line() + { + // add comment for the account + if len(acc.Docs) > 0 { + for _, doc := range acc.Docs { + body.Comment(doc).Line() + } + } + } + accountName := formatAccountNameParam(acc.Name) + body.Id("accounts__").Dot("Append").Call( + Qual(PkgSolanaGo, "NewAccountMeta").Call( + Id(accountName), + Lit(acc.Writable), + Lit(acc.Signer), + ), + ) + } + + case *idl.IdlInstructionAccounts: + { + panic(fmt.Errorf("Accounts groups are not supported yet: %s", acc.Name)) + // if ai > 0 { + // body.Line() + // } + // body.Commentf("Accounts group: %s", acc.Name) + // body.Line() + // accs := acc.Accounts + // for acci, acc := range accs { + // acc := acc.(*idl.IdlInstructionAccount) + // body.Comment(formatAccountCommentDocs(acci, acc)) + // body.Line() + // accountName := formatAccountNameParam(acc.Name) + // body.Id("accounts__").Dot("Append").Call( + // Qual(PkgSolanaGo, "NewAccountMeta").Call( + // Id(accountName), + // Lit(acc.Writable), + // Lit(acc.Signer), + // ), + // ) + // } + } + default: + panic("unknown account type: " + spew.Sdump(account)) + } + } + }), + ) + } + + // create the return instruction + body.Line().Comment("Create the instruction.") + body.Return( + Qual(PkgSolanaGo, "NewInstruction").CallFunc( + func(g *Group) { + g.Add( + ListMultiline(func(gg *Group) { + gg.Id("ProgramID") + gg.Id("accounts__") + if len(instruction.Args) > 0 { + gg.Id("buf__").Dot("Bytes").Call() + } else { + gg.Nil() // No arguments to encode. + } + }), + ) + }, + ), + Nil(), // No error + ) + }) + } + file.Add(ixCode) + } + } + + // Add instruction types and parsers + { + typeNames := []string{} + discriminatorNames := []string{} + for _, instruction := range g.idl.Instructions { + // Check if the instruction name already ends with "instruction" (case-insensitive) + instructionNameLower := strings.ToLower(instruction.Name) + if strings.HasSuffix(instructionNameLower, "instruction") { + // Already has "instruction" suffix, don't add it again + typeNames = append(typeNames, tools.ToCamelUpper(instruction.Name)) + } else { + // Add "Instruction" suffix + typeNames = append(typeNames, tools.ToCamelUpper(instruction.Name)+"Instruction") + } + discriminatorNames = append(discriminatorNames, tools.ToCamelUpper(instruction.Name)) + } + + // Generate instruction struct types + { + for _, instruction := range g.idl.Instructions { + typeCode, err := g.gen_instructionType(instruction) + if err != nil { + return nil, fmt.Errorf("error generating instruction type for %s: %w", instruction.Name, err) + } + file.Add(typeCode) + } + } + + // Generate instruction parsers + { + code, err := g.gen_instructionParser(typeNames, discriminatorNames) + if err != nil { + return nil, fmt.Errorf("error generating instruction parser: %w", err) + } + file.Add(code) + } + } + + return &OutputFile{ + Name: "instructions.go", + File: file, + }, nil +} + +func formatAccountNameParam(accountName string) string { + accountName = accountName + "Account" + if tools.IsReservedKeyword(accountName) { + return accountName + "_" + } + if !tools.IsValidIdent(accountName) { + return "a_" + tools.ToCamelUpper(accountName) + } + return tools.ToCamelLower(accountName) +} + +func formatParamName(paramName string) string { + paramName = paramName + "Param" + if tools.IsReservedKeyword(paramName) { + return paramName + "_" + } + if !tools.IsValidIdent(paramName) { + return "p_" + tools.ToCamelUpper(paramName) + } + return tools.ToCamelLower(paramName) +} + +func newInstructionFuncName(instructionName string) string { + // Check if the instruction name already ends with "instruction" (case-insensitive) + instructionNameLower := strings.ToLower(instructionName) + if strings.HasSuffix(instructionNameLower, "instruction") { + // Already has "instruction" suffix, don't add it again + return "New" + tools.ToCamelUpper(instructionName) + } else { + // Add "Instruction" suffix + return "New" + tools.ToCamelUpper(instructionName) + "Instruction" + } +} + +func newWriteReportFromInstructionFuncName(instructionName string) string { + return "WriteReportFrom" + tools.ToCamelUpper(instructionName) +} + +func formatAccountCommentDocs(index int, account *idl.IdlInstructionAccount) string { + buf := new(strings.Builder) + buf.WriteString(fmt.Sprintf("Account %d %q", index, account.Name)) + buf.WriteString(": ") + if account.Writable { + buf.WriteString("Writable") + } else { + buf.WriteString("Read-only") + } + if account.Signer { + buf.WriteString(", Signer") + } else { + buf.WriteString(", Non-signer") + } + if account.Optional { + buf.WriteString(", Optional") + } else { + buf.WriteString(", Required") + } + if account.Address.IsSome() && !account.Address.Unwrap().IsZero() { + buf.WriteString(fmt.Sprintf(", Address: %s", account.Address.Unwrap().String())) + } + // TODO: Handle PDA and Relations + return buf.String() +} + +func (g *Generator) gen_instructionParser(typeNames []string, discriminatorNames []string) (Code, error) { + code := Empty() + + // Generate Instruction interface + code.Line().Line() + code.Comment("Instruction interface defines common methods for all instruction types") + code.Line() + code.Type().Id("Instruction").Interface( + Id("GetDiscriminator").Params().Params(Index().Byte()), + Line(), + Id("UnmarshalWithDecoder").Params(Id("decoder").Op("*").Qual(PkgBinary, "Decoder")).Params(Error()), + Line(), + Id("UnmarshalAccountIndices").Params(Id("buf").Index().Byte()).Params(Index().Uint8(), Error()), + Line(), + Id("PopulateFromAccountIndices").Params(Id("indices").Index().Uint8(), Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey")).Params(Error()), + Line(), + Id("GetAccountKeys").Params().Params(Index().Qual(PkgSolanaGo, "PublicKey")), + ) + + // Single unified ParseInstruction function with optional accounts + code.Line().Line() + code.Comment("ParseInstruction parses instruction data and optionally populates accounts").Line() + code.Comment("If accountIndicesData is nil or empty, accounts will not be populated") + code.Line() + code.Func().Id("ParseInstruction"). + Params( + Id("instructionData").Index().Byte(), + Id("accountIndicesData").Index().Byte(), + Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey"), + ). + Params(Id("Instruction"), Error()). + BlockFunc(func(block *Group) { + block.Comment("Validate inputs") + block.If(Len(Id("instructionData")).Op("<").Lit(8)).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("instruction data too short: expected at least 8 bytes, got %d"), Len(Id("instructionData")))), + ) + + block.Comment("Extract discriminator") + block.Id("discriminator").Op(":=").Index(Lit(8)).Byte().Values() + block.Copy(Id("discriminator").Index(Op(":")), Id("instructionData").Index(Lit(0), Lit(8))) + + block.Comment("Parse based on discriminator") + block.Switch(Id("discriminator")).BlockFunc(func(switchBlock *Group) { + // This for loop runs during code generation, not at runtime + for i, typeName := range typeNames { + discriminatorName := discriminatorNames[i] + switchBlock.Case(Id(FormatInstructionDiscriminatorName(discriminatorName))).Block( + Id("instruction").Op(":=").New(Id(typeName)), + Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("instructionData")), + Id("err").Op(":=").Id("instruction").Dot("UnmarshalWithDecoder").Call(Id("decoder")), + If(Id("err").Op("!=").Nil()).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("failed to unmarshal instruction as "+typeName+": %w"), Id("err"))), + ), + If(Id("accountIndicesData").Op("!=").Nil().Op("&&").Len(Id("accountIndicesData")).Op(">").Lit(0)).Block( + Id("indices").Op(",").Id("err").Op(":=").Id("instruction").Dot("UnmarshalAccountIndices").Call(Id("accountIndicesData")), + If(Id("err").Op("!=").Nil()).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("failed to unmarshal account indices: %w"), Id("err"))), + ), + Id("err").Op("=").Id("instruction").Dot("PopulateFromAccountIndices").Call(Id("indices"), Id("accountKeys")), + If(Id("err").Op("!=").Nil()).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("failed to populate accounts: %w"), Id("err"))), + ), + ), + Return(Id("instruction"), Nil()), + ) + } + switchBlock.Default().Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("unknown instruction discriminator: %s"), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + }) + }) + + // Generic ParseInstructionTyped function for type-safe parsing + code.Line().Line() + code.Comment("ParseInstructionTyped parses instruction data and returns a specific instruction type") + code.Comment("T must implement the Instruction interface") + code.Line() + code.Func().Id("ParseInstructionTyped"). + Types(Id("T").Id("Instruction")). + Params( + Id("instructionData").Index().Byte(), + Id("accountIndicesData").Index().Byte(), + Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey"), + ). + Params(Id("T"), Error()). + BlockFunc(func(block *Group) { + block.Id("instruction").Op(",").Id("err").Op(":=").Id("ParseInstruction").Call(Id("instructionData"), Id("accountIndicesData"), Id("accountKeys")) + block.If(Id("err").Op("!=").Nil()).Block( + Return(Op("*").New(Id("T")), Id("err")), + ) + block.Id("typed").Op(",").Id("ok").Op(":=").Id("instruction").Assert(Id("T")) + block.If(Op("!").Id("ok")).Block( + Return(Op("*").New(Id("T")), Qual("fmt", "Errorf").Call(Lit("instruction is not of expected type"))), + ) + block.Return(Id("typed"), Nil()) + }) + + // Convenience function for parsing without accounts + code.Line().Line() + code.Comment("ParseInstructionWithoutAccounts parses instruction data without account information") + code.Line() + code.Func().Id("ParseInstructionWithoutAccounts"). + Params(Id("instructionData").Index().Byte()). + Params(Id("Instruction"), Error()). + Block( + Return(Id("ParseInstruction").Call(Id("instructionData"), Nil(), Index().Qual(PkgSolanaGo, "PublicKey").Op("{}"))), + ) + + // Convenience function for parsing with accounts + code.Line().Line() + code.Comment("ParseInstructionWithAccounts parses instruction data with account information") + code.Line() + code.Func().Id("ParseInstructionWithAccounts"). + Params( + Id("instructionData").Index().Byte(), + Id("accountIndicesData").Index().Byte(), + Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey"), + ). + Params(Id("Instruction"), Error()). + Block( + Return(Id("ParseInstruction").Call(Id("instructionData"), Id("accountIndicesData"), Id("accountKeys"))), + ) + + return code, nil +} + +func (g *Generator) gen_instructionType(instruction idl.IdlInstruction) (Code, error) { + code := Empty() + + // Check if the instruction name already ends with "instruction" (case-insensitive) + instructionNameLower := strings.ToLower(instruction.Name) + var typeName string + if strings.HasSuffix(instructionNameLower, "instruction") { + // Already has "instruction" suffix, don't add it again + typeName = tools.ToCamelUpper(instruction.Name) + } else { + // Add "Instruction" suffix + typeName = tools.ToCamelUpper(instruction.Name) + "Instruction" + } + + // Generate the instruction struct type + code.Type().Id(typeName).StructFunc(func(structGroup *Group) { + // Add fields for each instruction argument + for _, arg := range instruction.Args { + fieldType := genTypeName(arg.Ty) + if IsOption(arg.Ty) || IsCOption(arg.Ty) { + fieldType = Op("*").Add(fieldType) + } + structGroup.Id(tools.ToCamelUpper(arg.Name)).Add(fieldType).Tag(map[string]string{ + "json": arg.Name, + }) + } + + // Add fields for each instruction account + if len(instruction.Accounts) > 0 { + structGroup.Line().Comment("Accounts:") + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + // Add account field with metadata + fieldName := tools.ToCamelUpper(acc.Name) + structGroup.Id(fieldName).Qual(PkgSolanaGo, "PublicKey").Tag(map[string]string{ + "json": acc.Name, + }) + + // Add account metadata fields + if acc.Writable { + structGroup.Id(fieldName + "Writable").Bool().Tag(map[string]string{ + "json": acc.Name + "_writable", + }) + } + if acc.Signer { + structGroup.Id(fieldName + "Signer").Bool().Tag(map[string]string{ + "json": acc.Name + "_signer", + }) + } + if acc.Optional { + structGroup.Id(fieldName + "Optional").Bool().Tag(map[string]string{ + "json": acc.Name + "_optional", + }) + } + } + case *idl.IdlInstructionAccounts: + { + // Handle account groups (not fully implemented yet) + structGroup.Commentf("Account group: %s (not fully supported)", acc.Name) + } + } + } + } + }) + + // Generate GetDiscriminator method (required by Instruction interface) + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("GetDiscriminator"). + Params(). + Params(Index().Byte()). + Block( + Return(Id(FormatInstructionDiscriminatorName(tools.ToCamelUpper(instruction.Name))).Index(Op(":"))), + ) + + // Generate UnmarshalWithDecoder method + code.Line().Line() + code.Commentf("UnmarshalWithDecoder unmarshals the %s from Borsh-encoded bytes prefixed with its discriminator.", typeName).Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("UnmarshalWithDecoder"). + Params(Id("decoder").Op("*").Qual(PkgBinary, "Decoder")). + Params(Error()). + BlockFunc(func(block *Group) { + // Note: discriminator has already been read and validated by the parser + // Read instruction arguments + if len(instruction.Args) > 0 { + block.Var().Id("err").Error() + } + { + // Read the discriminator and check it against the expected value + block.Comment("Read the discriminator and check it against the expected value:") + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + block.If(Err().Op("!=").Nil()).Block( + Return(Qual("fmt", "Errorf").Call(Lit("failed to read instruction discriminator for %s: %w"), Lit(typeName), Err())), + ) + block.If(Id("discriminator").Op("!=").Id(FormatInstructionDiscriminatorName(tools.ToCamelUpper(instruction.Name)))).Block( + Return( + Qual("fmt", "Errorf").Call( + Lit("instruction discriminator mismatch for %s: expected %s, got %s"), + Lit(typeName), + Id(FormatInstructionDiscriminatorName(tools.ToCamelUpper(instruction.Name))), + Id("discriminator"), + ), + ), + ) + } + for _, arg := range instruction.Args { + fieldName := tools.ToCamelUpper(arg.Name) + block.Commentf("Deserialize `%s`:", fieldName) + + if IsOption(arg.Ty) || IsCOption(arg.Ty) { + var optionalityReaderName string + switch { + case IsOption(arg.Ty): + optionalityReaderName = "ReadOption" + case IsCOption(arg.Ty): + optionalityReaderName = "ReadCOption" + } + + block.BlockFunc(func(optGroup *Group) { + optGroup.List(Id("ok"), Err()).Op(":=").Id("decoder").Dot(optionalityReaderName).Call() + optGroup.If(Err().Op("!=").Nil()).Block( + Return(Err()), + ) + optGroup.If(Id("ok")).Block( + List(Err()).Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(fieldName)), + If(Err().Op("!=").Nil()).Block( + Return(Err()), + ), + ) + }) + } else { + block.List(Err()).Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(fieldName)) + block.If(Err().Op("!=").Nil()).Block( + Return(Err()), + ) + } + } + + // Note: Accounts are not typically serialized in instruction data + // They are passed as part of the transaction's account metas + // This method only deserializes the instruction arguments + + block.Return(Nil()) + }) + + // Generate account-related methods if instruction has accounts + if len(instruction.Accounts) > 0 { + // Generate UnmarshalAccountIndices method + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("UnmarshalAccountIndices"). + Params(Id("buf").Index().Byte()). + Params(Index().Uint8(), Error()). + BlockFunc(func(block *Group) { + block.Comment("UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes") + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("buf")) + block.Id("indices").Op(":=").Make(Index().Uint8(), Lit(0)) + block.Id("index").Op(":=").Uint8().Call(Lit(0)) + block.Var().Id("err").Error() + + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + block.Commentf("Decode from %s account index", acc.Name) + block.Id("index").Op("=").Uint8().Call(Lit(0)) + block.List(Err()).Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("index")) + block.If(Err().Op("!=").Nil()).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("failed to decode %s account index: %w"), Lit(acc.Name), Err())), + ) + block.Id("indices").Op("=").Append(Id("indices"), Id("index")) + } + case *idl.IdlInstructionAccounts: + { + block.Commentf("Account group: %s (not fully supported)", acc.Name) + } + } + } + + block.Return(Id("indices"), Nil()) + }) + + // Generate PopulateFromAccountIndices method + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("PopulateFromAccountIndices"). + Params(Id("indices").Index().Uint8(), Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey")). + Params(Error()). + BlockFunc(func(block *Group) { + block.Comment("PopulateFromAccountIndices sets account public keys from indices and account keys array") + + // Count expected accounts + expectedAccountCount := 0 + for _, account := range instruction.Accounts { + switch account.(type) { + case *idl.IdlInstructionAccount: + expectedAccountCount++ + } + } + + block.If(Len(Id("indices")).Op("!=").Lit(expectedAccountCount)).Block( + Return(Qual("fmt", "Errorf").Call(Lit("mismatch between expected accounts (%d) and provided indices (%d)"), Lit(expectedAccountCount), Len(Id("indices")))), + ) + + block.Id("indexOffset").Op(":=").Lit(0) + + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + fieldName := tools.ToCamelUpper(acc.Name) + block.Commentf("Set %s account from index", acc.Name) + block.If(Id("indices").Index(Id("indexOffset")).Op(">=").Uint8().Call(Len(Id("accountKeys")))).Block( + Return(Qual("fmt", "Errorf").Call(Lit("account index %d for %s is out of bounds (max: %d)"), Id("indices").Index(Id("indexOffset")), Lit(acc.Name), Len(Id("accountKeys")).Op("-").Lit(1))), + ) + block.Id("obj").Dot(fieldName).Op("=").Id("accountKeys").Index(Id("indices").Index(Id("indexOffset"))) + block.Id("indexOffset").Op("++") + } + case *idl.IdlInstructionAccounts: + { + block.Commentf("Account group: %s (not fully supported)", acc.Name) + } + } + } + + block.Return(Nil()) + }) + + // Generate GetAccountKeys method + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("GetAccountKeys"). + Params(). + Params(Index().Qual(PkgSolanaGo, "PublicKey")). + BlockFunc(func(block *Group) { + block.Id("keys").Op(":=").Make(Index().Qual(PkgSolanaGo, "PublicKey"), Lit(0)) + + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + fieldName := tools.ToCamelUpper(acc.Name) + block.Id("keys").Op("=").Append(Id("keys"), Id("obj").Dot(fieldName)) + } + case *idl.IdlInstructionAccounts: + { + block.Commentf("Account group: %s (not fully supported)", acc.Name) + } + } + } + + block.Return(Id("keys")) + }) + } else { + // Generate empty implementations for instructions without accounts + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("UnmarshalAccountIndices"). + Params(Id("buf").Index().Byte()). + Params(Index().Uint8(), Error()). + Block( + Return(Index().Uint8().Op("{}"), Nil()), + ) + + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("PopulateFromAccountIndices"). + Params(Id("indices").Index().Uint8(), Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey")). + Params(Error()). + Block( + Return(Nil()), + ) + + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("GetAccountKeys"). + Params(). + Params(Index().Qual(PkgSolanaGo, "PublicKey")). + Block( + Return(Index().Qual(PkgSolanaGo, "PublicKey").Op("{}")), + ) + } + + // Generate Unmarshal method + code.Line().Line() + code.Commentf("Unmarshal unmarshals the %s from Borsh-encoded bytes prefixed with the discriminator.", typeName).Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("Unmarshal"). + Params(Id("buf").Index().Byte()). + Params(Error()). + BlockFunc(func(block *Group) { + block.Var().Id("err").Error() + block.List(Err()).Op("=").Id("obj").Dot("UnmarshalWithDecoder").Call( + Qual(PkgBinary, "NewBorshDecoder").Call(Id("buf")), + ) + block.If(Err().Op("!=").Nil()).Block( + Return( + Qual("fmt", "Errorf").Call( + Lit("error while unmarshaling "+typeName+": %w"), + Err(), + ), + ), + ) + block.Return(Nil()) + }) + + // Generate Unmarshal function + code.Line().Line() + code.Commentf("Unmarshal%s unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator.", typeName).Line() + code.Func().Id("Unmarshal"+typeName). + Params(Id("buf").Index().Byte()). + Params(Op("*").Id(typeName), Error()). + BlockFunc(func(block *Group) { + block.Id("obj").Op(":=").New(Id(typeName)) + block.Var().Id("err").Error() + block.List(Err()).Op("=").Id("obj").Dot("Unmarshal").Call(Id("buf")) + block.If(Err().Op("!=").Nil()).Block( + Return(Nil(), Err()), + ) + block.Return(Id("obj"), Nil()) + }) + + return code, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/is.go b/cmd/generate-bindings/solana/anchor-go/generator/is.go new file mode 100644 index 00000000..64d4cad7 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/is.go @@ -0,0 +1,100 @@ +package generator + +import "github.com/gagliardetto/anchor-go/idl/idltype" + +func IsOption(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Option: + return true + default: + return false + } +} + +func IsCOption(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.COption: + return true + default: + return false + } +} + +func IsDefined(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Defined: + return true + default: + return false + } +} + +func IsVec(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Vec: + return true + default: + return false + } +} + +func IsArray(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Array: + return true + default: + return false + } +} + +func IsIDLTypeKind(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Bool: + return true + case *idltype.U8: + return true + case *idltype.I8: + return true + case *idltype.U16: + return true + case *idltype.I16: + return true + case *idltype.U32: + return true + case *idltype.I32: + return true + case *idltype.F32: + return true + case *idltype.U64: + return true + case *idltype.I64: + return true + case *idltype.F64: + return true + case *idltype.U128: + return true + case *idltype.I128: + return true + case *idltype.U256: + return true + case *idltype.I256: + return true + case *idltype.Bytes: + return true + case *idltype.String: + return true + case *idltype.Pubkey: + return true + default: + return false + } +} + +func IsBool(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Bool: + return true + default: + return false + } +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/marshal.go b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go new file mode 100644 index 00000000..9a5a11e2 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go @@ -0,0 +1,381 @@ +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/anchor-go/tools" +) + +func gen_MarshalWithEncoder_struct( + idl_ *idl.Idl, + withDiscriminator bool, + receiverTypeName string, + discriminatorName string, + fields idl.IdlDefinedFields, + checkNil bool, +) Code { + code := Empty() + { + // Declare MarshalWithEncoder + code.Func().Params(Id("obj").Id(receiverTypeName)).Id("MarshalWithEncoder"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("encoder").Op("*").Qual(PkgBinary, "Encoder") + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Err().Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + if withDiscriminator && discriminatorName != "" { + body.Comment("Write account discriminator:") + body.Err().Op("=").Id("encoder").Dot("WriteBytes").Call(Id(discriminatorName).Index(Op(":")), False()) + body.If(Err().Op("!=").Nil()).Block( + Return(Err()), + ) + } + switch fields := fields.(type) { + case idl.IdlDefinedFieldsNamed: + gen_marshal_DefinedFieldsNamed( + body, + fields, + checkNil, + func(field idl.IdlField) *Statement { + return Id("obj").Dot(tools.ToCamelUpper(field.Name)) + }, + "encoder", + false, // returnNilErr + func(field idl.IdlField) string { + return tools.ToCamelUpper(field.Name) + }, + ) + case idl.IdlDefinedFieldsTuple: + convertedFields := tupleToFieldsNamed(fields) + gen_marshal_DefinedFieldsNamed( + body, + convertedFields, + checkNil, + func(field idl.IdlField) *Statement { + return Id("obj").Dot(tools.ToCamelUpper(field.Name)) + }, + "encoder", + false, // returnNilErr + func(field idl.IdlField) string { + return tools.ToCamelUpper(field.Name) + }, + ) + case nil: + // No fields, just an empty struct. + // TODO: should we panic here? + default: + panic(fmt.Sprintf("unexpected fields type: %T", fields)) + } + + body.Return(Nil()) + }) + } + { + code.Line().Line() + // also generate a + // func (obj ) Marshal() ([]byte, error) { + // return obj.MarshalWithEncoder(bin.NewBorshEncoder(buf)) + // } + // func (obj ) Marshal() ([]byte, error) { + // buf := new(bytes.Buffer) + // enc := bin.NewBorshEncoder(buf) + // err := enc.Encode(meta) + // if err != nil { + // return nil, err + // } + // return buf.Bytes(), nil + // } + code.Func().Params(Id("obj").Id(receiverTypeName)).Id("Marshal"). + Params( + ListFunc(func(results *Group) { + // no parameters + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Index().Byte() + results.Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + body.Id("buf").Op(":=").Qual("bytes", "NewBuffer").Call(Nil()) + body.Id("encoder").Op(":=").Qual(PkgBinary, "NewBorshEncoder").Call(Id("buf")) + body.Err().Op(":=").Id("obj").Dot("MarshalWithEncoder").Call(Id("encoder")) + body.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call( + Lit("error while encoding "+receiverTypeName+": %w"), + Err(), + ), + ), + ) + body.Return( + Id("buf").Dot("Bytes").Call(), + Nil(), + ) + }) + } + + return code +} + +func gen_marshal_DefinedFieldsNamed( + body *Group, + fields idl.IdlDefinedFieldsNamed, + checkNil bool, + nameFormatter func(field idl.IdlField) *Statement, + encoderVariableName string, + returnNilErr bool, + traceNameFormatter func(field idl.IdlField) string, +) { + for _, field := range fields { + exportedArgName := traceNameFormatter(field) + if IsOption(field.Ty) || IsCOption(field.Ty) { + body.Commentf("Serialize `%s` (optional):", exportedArgName) + } else { + body.Commentf("Serialize `%s`:", exportedArgName) + } + + if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) { + switch field.Ty.(type) { + case *idltype.Defined: + enumTypeName := field.Ty.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + argBody.Err().Op(":=").Id(formatEnumEncoderName(enumTypeName)).Call(Id(encoderVariableName), nameFormatter(field)) + argBody.If( + Err().Op("!=").Nil(), + ).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ) + }, + ), + ) + }) + case *idltype.Array: + enumTypeName := field.Ty.(*idltype.Array).Type.(*idltype.Defined).Name + // TODO: handle array length, which is defined in the type. + body.BlockFunc(func(argBody *Group) { + argBody.For( + Id("i").Op(":=").Lit(0), + Id("i").Op("<").Len(nameFormatter(field)), + Id("i").Op("++"), + ).BlockFunc(func(forBody *Group) { + forBody.Err().Op(":=").Id(formatEnumEncoderName(enumTypeName)).Call( + Id(encoderVariableName), + nameFormatter(field).Index(Id("i")), + ) + forBody.If( + Err().Op("!=").Nil(), + ).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual(PkgAnchorGoErrors, "NewIndex").Call( + Id("i"), + Err(), + ), + ) + }, + ), + ) + }) + }) + case *idltype.Vec: + enumTypeName := field.Ty.(*idltype.Vec).Vec.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + argBody.Err().Op(":=").Id(encoderVariableName).Dot("WriteLength").Call( + Len(nameFormatter(field)), + ) + argBody.If( + Err().Op("!=").Nil(), + ).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while writing vector length: %w"), + Err(), + ), + ) + }, + ), + ) + argBody.For( + Id("i").Op(":=").Lit(0), + Id("i").Op("<").Len(nameFormatter(field)), + Id("i").Op("++"), + ).BlockFunc(func(forBody *Group) { + forBody.Err().Op(":=").Id(formatEnumEncoderName(enumTypeName)).Call( + Id(encoderVariableName), + nameFormatter(field).Index(Id("i")), + ) + forBody.If( + Err().Op("!=").Nil(), + ).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual(PkgAnchorGoErrors, "NewIndex").Call( + Id("i"), + Err(), + ), + ) + }, + ), + ) + }) + }) + } + } else { + if IsOption(field.Ty) || IsCOption(field.Ty) { + var optionalityWriterName string + if IsOption(field.Ty) { + optionalityWriterName = "WriteOption" + } else { + optionalityWriterName = "WriteCOption" + } + if checkNil { + body.BlockFunc(func(optGroup *Group) { + // if nil: + optGroup.If(nameFormatter(field).Op("==").Nil()).Block( + Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(False()), + If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while encoding optionality: %w"), + Err(), + ), + ) + }, + ), + ), + ).Else().Block( + Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()), + If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while encoding optionality: %w"), + Err(), + ), + ) + }, + ), + ), + Err().Op("=").Id(encoderVariableName).Dot("Encode").Call(nameFormatter(field)), + If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ) + }, + ), + ), + ) + }) + } else { + body.BlockFunc(func(optGroup *Group) { + // TODO: make optional fields of accounts a pointer. + // Write as if not nil: + optGroup.Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()) + optGroup.If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while encoding optionality: %w"), + Err(), + ), + ) + }, + ), + ) + optGroup.Err().Op("=").Id(encoderVariableName).Dot("Encode").Call(nameFormatter(field)) + optGroup.If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ) + }, + ), + ) + }) + } + } else { + body.Err().Op("=").Id(encoderVariableName).Dot("Encode").Call(nameFormatter(field)) + body.If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ) + }, + ), + ) + } + } + } +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/tests.go b/cmd/generate-bindings/solana/anchor-go/generator/tests.go new file mode 100644 index 00000000..96fe02f9 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/tests.go @@ -0,0 +1,17 @@ +package generator + +import ( + . "github.com/dave/jennifer/jen" +) + +func (g *Generator) gen_tests() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains tests.") + { + } + return &OutputFile{ + Name: "tests_test.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/tools.go b/cmd/generate-bindings/solana/anchor-go/generator/tools.go new file mode 100644 index 00000000..15065ed3 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/tools.go @@ -0,0 +1,84 @@ +package generator + +import ( + "os" + "path" + + . "github.com/dave/jennifer/jen" +) + +const ( + PkgBinary = "github.com/gagliardetto/binary" + PkgCRE = "github.com/smartcontractkit/cre-sdk-go/cre" + PkgPb = "github.com/smartcontractkit/chainlink-protos/cre/go/values/pb" + PkgPb2 = "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + PkgSolanaGo = "github.com/gagliardetto/solana-go" + + PkgSolanaGoText = "github.com/gagliardetto/solana-go/text" + PkgAnchorGoErrors = "github.com/gagliardetto/anchor-go/errors" + PkgBig = "math/big" + // TODO: use or remove this: + PkgTreeout = "github.com/gagliardetto/treeout" + PkgFormat = "github.com/gagliardetto/solana-go/text/format" + PkgGoFuzz = "github.com/gagliardetto/gofuzz" + PkgTestifyRequire = "github.com/stretchr/testify/require" + PkgSolanaCre = "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/cre-sdk-go/capabilities/blockchain/solana" + PkgRealSolanaCre = "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" + PkgBindings = "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana/bindings" + PkgSolanaTypes = "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/cre-sdk-go/types" + PkgIdl = "github.com/gagliardetto/anchor-go/idl" + PkgAnchorIdlCodec = "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/cre-sdk-go/anchorcodec" + PkgJson = "encoding/json" +) + +func WriteFile(outDir string, assetFileName string, file *File) error { + // Save Go assets: + assetFilepath := path.Join(outDir, assetFileName) + + // Create file Golang file: + goFile, err := os.Create(assetFilepath) + if err != nil { + panic(err) + } + defer goFile.Close() + + // Write generated Golang to file: + return file.Render(goFile) +} + +func DoGroup(f func(*Group)) *Statement { + g := &Group{} + g.CustomFunc(Options{ + Multi: false, + }, f) + s := newStatement() + *s = append(*s, g) + return s +} + +func DoGroupMultiline(f func(*Group)) *Statement { + g := &Group{} + g.CustomFunc(Options{ + Multi: true, + }, f) + s := newStatement() + *s = append(*s, g) + return s +} + +func ListMultiline(f func(*Group)) *Statement { + g := &Group{} + g.CustomFunc(Options{ + Multi: true, + Separator: ",", + Open: "", + Close: " ", + }, f) + s := newStatement() + *s = append(*s, g) + return s +} + +func newStatement() *Statement { + return &Statement{} +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/types.go b/cmd/generate-bindings/solana/anchor-go/generator/types.go new file mode 100644 index 00000000..ba25514e --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/types.go @@ -0,0 +1,411 @@ +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" +) + +// genfile_types generates the file `types.go`. +func (g *Generator) genfile_types() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains parsers for the types defined in the IDL.") + + { + for index, typ := range g.idl.Types { + code, err := g.gen_IDLTypeDef(typ) + if err != nil { + return nil, fmt.Errorf("error generating type %d: %w", index, err) + } + file.Add(code) + } + } + + return &OutputFile{ + Name: "types.go", + File: file, + }, nil +} + +// `def.Type` is `IDLTypeDefTy` (which is an interface): +// either `IDLTypeDefTyEnum` or `IDLTypeDefTyStruct`. +func (g *Generator) gen_IDLTypeDef(def idl.IdlTypeDef) (Code, error) { + switch vv := def.Ty.(type) { + case *idl.IdlTypeDefTyStruct: + return g.gen_IDLTypeDefTyStruct(def.Name, def.Docs, *vv, false) + case *idl.IdlTypeDefTyEnum: + return g.gen_IDLTypeDefTyEnum(def.Name, def.Docs, *vv) + default: + panic(fmt.Errorf("unhandled type: %T", vv)) + } +} + +func (g *Generator) gen_IDLTypeDefTyEnum(name string, docs []string, typ idl.IdlTypeDefTyEnum) (Code, error) { + if typ.Variants.IsAllSimple() { + return g.gen_simpleEnum(name, docs, typ) + } + return g.gen_complexEnum(name, docs, typ) +} + +func (g *Generator) gen_simpleEnum(name string, docs []string, typ idl.IdlTypeDefTyEnum) (Code, error) { + st := newStatement() + + code := newStatement() + enumTypeName := tools.ToCamelUpper(name) + + addComments(code, docs) + { + code.Type().Id(enumTypeName).Qual(PkgBinary, "BorshEnum") + code.Line().Const().Parens(DoGroup(func(gr *Group) { + for variantIndex, variant := range typ.Variants { + // TODO: enum variants should have docs too. + // for docIndex, doc := range variant.Docs { + // if docIndex == 0 { + // gr.Line() + // } + // gr.Comment(doc).Line() + // } + + gr.Id(formatSimpleEnumVariantName(variant.Name, enumTypeName)).Add(func() Code { + if variantIndex == 0 { + return Id(enumTypeName).Op("=").Iota() + } + return nil + }()).Line() + } + // TODO: check for fields, etc. + })) + + // Generate stringer for the uint8 enum values: + code.Line().Line().Func().Params(Id("value").Id(enumTypeName)).Id("String"). + Params(). + Params(String()). + BlockFunc(func(body *Group) { + body.Switch(Id("value")).BlockFunc(func(switchBlock *Group) { + for _, variant := range typ.Variants { + switchBlock.Case(Id(formatSimpleEnumVariantName(variant.Name, enumTypeName))).Line().Return(Lit(variant.Name)) + } + switchBlock.Default().Line().Return(Lit("")) + }) + }) + st.Add(code.Line()) + } + return st, nil +} + +func addComments(code *Statement, docs []string) { + for _, doc := range docs { + code.Line() + code.Comment(doc) + } + if len(docs) > 0 { + code.Line() + } +} + +func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeDefTyEnum) (Code, error) { + st := newStatement() + + code := newStatement() + enumTypeName := tools.ToCamelUpper(name) + + // Add comments for the enum type: + addComments(code, docs) + { + register_TypeName_as_ComplexEnum(name) + containerName := formatEnumContainerName(enumTypeName) + interfaceMethodName := formatInterfaceMethodName(enumTypeName) + + // Declare the interface of the enum type: + code.Commentf("The %q interface for the %q complex enum.", interfaceMethodName, enumTypeName).Line() + code.Type().Id(enumTypeName).Interface( + Id(interfaceMethodName).Call(), + ).Line().Line() + + // Declare the enum variants container (non-exported, used internally) + code.Type().Id(containerName).StructFunc( + func(structGroup *Group) { + structGroup.Id("Enum").Qual(PkgBinary, "BorshEnum").Tag(map[string]string{ + "bin": "enum", + }) + + for _, variant := range typ.Variants { + structGroup.Id(tools.ToCamelUpper(variant.Name)).Id(formatComplexEnumVariantTypeName(enumTypeName, variant.Name)) + } + }, + ).Line().Line() + + // Declare parser function for the enum type: + code.Func().Id(formatEnumParserName(enumTypeName)).Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("decoder").Op("*").Qual(PkgBinary, "Decoder") + }), + ).Params( + ListFunc(func(results *Group) { + // Results: + results.Id(enumTypeName) + results.Error() + }), + ). + BlockFunc(func(body *Group) { + enumName := enumTypeName + body.BlockFunc(func(argBody *Group) { + argBody.List(Id("tmp")).Op(":=").New(Id(formatEnumContainerName(enumName))) + + argBody.Err().Op(":=").Id("decoder").Dot("Decode").Call(Id("tmp")) + + argBody.If( + Err().Op("!=").Nil(), + ).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed parsing "+enumTypeName+": %w"), Err()), + ), + ) + + argBody.Switch(Id("tmp").Dot("Enum")). + BlockFunc(func(switchGroup *Group) { + interfaceType := g.idl.Types.ByName(enumName) + + for variantIndex, variant := range interfaceType.Ty.(*idl.IdlTypeDefTyEnum).Variants { + variantTypeNameComplex := formatComplexEnumVariantTypeName(enumName, variant.Name) + + if variant.IsSimple() { + // TODO: the actual value is not important; + // what's important is the type. + switchGroup.Case(Lit(variantIndex)). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Parens(Op("*").Id(variantTypeNameComplex)). + Parens(Op("&").Id("tmp").Dot("Enum")), + Nil(), + ) + }) + } else { + switchGroup.Case(Lit(variantIndex)). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Op("&").Id("tmp").Dot(tools.ToCamelUpper(variant.Name)), + Nil(), + ) + }) + } + } + switchGroup.Default(). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit(enumTypeName+": unknown enum index: %v"), Id("tmp").Dot("Enum")), + ) + }) + }) + }) + }).Line().Line() + + // Declare the marshaler for the enum type: + code.Func().Id(formatEnumEncoderName(enumTypeName)).Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("encoder").Op("*").Qual(PkgBinary, "Encoder") + params.Id("value").Id(enumTypeName) + }), + ).Params( + ListFunc(func(results *Group) { + // Results: + results.Error() + }), + ). + BlockFunc(func(body *Group) { + body.BlockFunc(func(argBody *Group) { + argBody.List(Id("tmp")).Op(":=").Id(formatEnumContainerName(enumTypeName)).Block() + argBody.Switch(Id("realvalue").Op(":=").Id("value").Op(".").Parens(Type())). + BlockFunc(func(switchGroup *Group) { + // TODO: maybe it's from idl.Accounts ??? + interfaceType := g.idl.Types.ByName(enumTypeName) + for variantIndex, variant := range interfaceType.Ty.(*idl.IdlTypeDefTyEnum).Variants { + variantTypeNameStruct := formatComplexEnumVariantTypeName(enumTypeName, variant.Name) + + switchGroup.Case(Op("*").Id(variantTypeNameStruct)). + BlockFunc(func(caseGroup *Group) { + caseGroup.Id("tmp").Dot("Enum").Op("=").Lit(variantIndex) + caseGroup.Id("tmp").Dot(tools.ToCamelUpper(variant.Name)).Op("=").Op("*").Id("realvalue") + }) + } + }) + + argBody.Return(Id("encoder").Dot("Encode").Call(Id("tmp"))) + }) + }).Line().Line() + + for _, variant := range typ.Variants { + // Name of the variant type if the enum is a complex enum (i.e. enum variants are inline structs): + variantTypeNameComplex := formatComplexEnumVariantTypeName(enumTypeName, variant.Name) + + // Declare the enum variant types: + if variant.IsSimple() { + // TODO: make the name {variantTypeName}_{interface_name} ??? + code.Type().Id(variantTypeNameComplex).Uint8().Line().Line() + } else if variant.Fields.IsSome() { + code.Commentf("Variant %q of enum %q", variant.Name, enumTypeName).Line() + code.Type().Id(variantTypeNameComplex).StructFunc( + func(structGroup *Group) { + switch fields := variant.Fields.Unwrap().(type) { + case idl.IdlDefinedFieldsNamed: + for _, variantField := range fields { + optionality := IsOption(variantField.Ty) || IsCOption(variantField.Ty) + structGroup.Add(genField(variantField, optionality)). + Add(func() Code { + tagMap := map[string]string{} + if IsOption(variantField.Ty) { + tagMap["bin"] = "optional" + } + if IsCOption(variantField.Ty) { + tagMap["bin"] = "coption" + } + // add json tag: + tagMap["json"] = tools.ToCamelLower(variantField.Name) + func() string { + if optionality { + return ",omitempty" + } + return "" + }() + return Tag(tagMap) + }()) + } + case idl.IdlDefinedFieldsTuple: + for itemIndex, tupleItem := range fields { + optionality := IsOption(tupleItem) || IsCOption(tupleItem) + tupleItemName := FormatTupleItemName(itemIndex) + structGroup.Add(genFieldNamed(tupleItemName, tupleItem, optionality)). + Add(func() Code { + tagMap := map[string]string{} + if IsOption(tupleItem) { + tagMap["bin"] = "optional" + } + if IsCOption(tupleItem) { + tagMap["bin"] = "coption" + } + // add json tag: + tagMap["json"] = tools.ToCamelLower(tupleItemName) + func() string { + if optionality { + return ",omitempty" + } + return "" + }() + return Tag(tagMap) + }()) + } + default: + panic("not handled: " + spew.Sdump(variant.Fields)) + } + }, + ).Line().Line() + } + + if variant.IsSimple() { + // Declare MarshalWithEncoder + code.Line().Line().Func().Params(Id("obj").Id(variantTypeNameComplex)).Id("MarshalWithEncoder"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("encoder").Op("*").Qual(PkgBinary, "Encoder") + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Err().Error() + }), + ). + BlockFunc(func(body *Group) { + body.Return(Nil()) + }) + code.Line().Line() + + // Declare UnmarshalWithDecoder + code.Func().Params(Id("obj").Op("*").Id(variantTypeNameComplex)).Id("UnmarshalWithDecoder"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("decoder").Op("*").Qual(PkgBinary, "Decoder") + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Err().Error() + }), + ). + BlockFunc(func(body *Group) { + body.Return(Nil()) + }) + code.Line().Line() + } else if variant.Fields.IsSome() { + switch fields := variant.Fields.Unwrap().(type) { + case idl.IdlDefinedFieldsNamed: + // Declare MarshalWithEncoder: + code.Line().Line().Add( + gen_MarshalWithEncoder_struct( + g.idl, + false, + variantTypeNameComplex, + "", + fields, + true, + )) + + // Declare UnmarshalWithDecoder + code.Line().Line().Add( + gen_UnmarshalWithDecoder_struct( + g.idl, + false, + variantTypeNameComplex, + "", + fields, + )) + code.Line().Line() + case idl.IdlDefinedFieldsTuple: + // TODO: handle tuples + // Declare MarshalWithEncoder: + code.Line().Line().Add( + gen_MarshalWithEncoder_struct( + g.idl, + false, + variantTypeNameComplex, + "", + fields, + true, + )) + + // Declare UnmarshalWithDecoder + code.Line().Line().Add( + gen_UnmarshalWithDecoder_struct( + g.idl, + false, + variantTypeNameComplex, + "", + fields, + )) + code.Line().Line() + default: + panic("not handled: " + spew.Sdump(variant.Fields)) + } + } + + // Declare the method to implement the parent enum interface: + if variant.IsSimple() { + code.Func().Params(Id("_").Op("*").Id(variantTypeNameComplex)).Id(interfaceMethodName).Params().Block().Line().Line() + } else { + code.Func().Params(Id("_").Op("*").Id(variantTypeNameComplex)).Id(interfaceMethodName).Params().Block().Line().Line() + } + } + + st.Add(code.Line().Line()) + } + return st, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go new file mode 100644 index 00000000..1ee4f80d --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go @@ -0,0 +1,377 @@ +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/anchor-go/tools" +) + +func formatComplexEnumVariantTypeName(enumTypeName string, variantName string) string { + return fmt.Sprintf("%s_%s", tools.ToCamelUpper(enumTypeName), tools.ToCamelUpper(variantName)) +} + +func formatSimpleEnumVariantName(variantName string, enumTypeName string) string { + return fmt.Sprintf("%s_%s", tools.ToCamelUpper(enumTypeName), tools.ToCamelUpper(variantName)) +} + +func FormatTupleItemName(index int) string { + return tools.ToCamelUpper(fmt.Sprintf("V%d", index)) +} + +func formatEnumContainerName(enumTypeName string) string { + return tools.ToCamelLower(enumTypeName) + "EnumContainer" +} + +func formatInterfaceMethodName(enumTypeName string) string { + return "is" + tools.ToCamelUpper(enumTypeName) +} + +func formatDiscriminatorName(kind string, exportedAccountName string) string { + // trim prefix or suffix "Account" or "Event" from exportedAccountName + exportedAccountName = tools.ToCamelUpper(exportedAccountName) + + // // TODO: sometimes there's accounts/events like this: + // // - "Fund" + // // - "FundAccount" + // // This will create a name collision and fail to compile because + // // we remove the "Account" or "Event" suffix from the second one, + // // so there's a duplicate name "Fund". + // exportedAccountName = strings.TrimSuffix(exportedAccountName, "Account") + // exportedAccountName = strings.TrimSuffix(exportedAccountName, "Event") + // exportedAccountName = strings.TrimPrefix(exportedAccountName, "Account") + // exportedAccountName = strings.TrimPrefix(exportedAccountName, "Event") + + return kind + "_" + tools.ToCamelUpper(exportedAccountName) +} + +func FormatAccountDiscriminatorName(exportedAccountName string) string { + return formatDiscriminatorName("Account", exportedAccountName) +} + +func FormatEventDiscriminatorName(exportedEventName string) string { + return formatDiscriminatorName("Event", exportedEventName) +} + +func FormatInstructionDiscriminatorName(exportedInstructionName string) string { + return formatDiscriminatorName("Instruction", exportedInstructionName) +} + +func formatBuilderFuncName(insExportedName string) string { + return "New" + insExportedName + "InstructionBuilder" +} + +func formatEnumParserName(enumTypeName string) string { + return "Decode" + enumTypeName +} + +func formatEnumEncoderName(enumTypeName string) string { + return "Encode" + enumTypeName +} + +func gen_UnmarshalWithDecoder_struct( + idl_ *idl.Idl, + withDiscriminator bool, + receiverTypeName string, + discriminatorName string, + fields idl.IdlDefinedFields, +) Code { + code := Empty() + { + // Declare UnmarshalWithDecoder + code.Func().Params(Id("obj").Op("*").Id(receiverTypeName)).Id("UnmarshalWithDecoder"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("decoder").Op("*").Qual(PkgBinary, "Decoder") + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Err().Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + if withDiscriminator && discriminatorName != "" { + body.Comment("Read and check account discriminator:") + body.BlockFunc(func(discReadBody *Group) { + discReadBody.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + discReadBody.If(Err().Op("!=").Nil()).Block( + Return(Err()), + ) + discReadBody.If(Op("!").Id("discriminator").Dot("Equal").Call(Id(discriminatorName).Index(Op(":")))).Block( + Return( + Qual("fmt", "Errorf").Call( + Line().Lit("wrong discriminator: wanted %s, got %s"), + Line().Id(discriminatorName).Index(Op(":")), + Line().Qual("fmt", "Sprint").Call(Id("discriminator").Index(Op(":"))), + ), + ), + ) + }) + } + + switch fields := fields.(type) { + case idl.IdlDefinedFieldsNamed: + gen_unmarshal_DefinedFieldsNamed(body, fields) + case idl.IdlDefinedFieldsTuple: + convertedFields := tupleToFieldsNamed(fields) + gen_unmarshal_DefinedFieldsNamed(body, convertedFields) + case nil: + // No fields, just an empty struct. + // TODO: should we panic here? + default: + panic(fmt.Sprintf("unexpected fields type: %T", fields)) + } + + body.Return(Nil()) + }) + } + { + code.Line().Line() + // func (obj *) Unmarshal(buf []byte) (err error) { + // return obj.UnmarshalWithDecoder(bin.NewBorshDecoder(buf)) + // } + code.Func().Params(Id("obj").Op("*").Id(receiverTypeName)).Id("Unmarshal"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("buf").Index().Byte() + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + body.Err().Op(":=").Id("obj").Dot("UnmarshalWithDecoder").Call( + Qual(PkgBinary, "NewBorshDecoder").Call(Id("buf")), + ) + body.If(Err().Op("!=").Nil()).Block( + // If there was an error, return it. + Return( + Qual("fmt", "Errorf").Call( + Lit("error while unmarshaling "+receiverTypeName+": %w"), + Err(), + ), + ), + ) + body.Return( + Nil(), // No error. + ) + }) + } + { + code.Line().Line() + // func Unmarshal(buf []byte) (, error) { + // obj := new() + // err := obj.Unmarshal(buf) + // if err != nil { + // return nil, err + // } + // return obj, nil + // } + code.Func().Id("Unmarshal" + receiverTypeName). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("buf").Index().Byte() + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Op("*").Id(receiverTypeName) + results.Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + body.Id("obj").Op(":=").New(Id(receiverTypeName)) + body.Err().Op(":=").Id("obj").Dot("Unmarshal").Call(Id("buf")) + body.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Err(), + ), + ) + body.Return( + Id("obj"), + Nil(), // No error. + ) + }) + } + return code +} + +func tupleToFieldsNamed( + tuple idl.IdlDefinedFieldsTuple, +) idl.IdlDefinedFieldsNamed { + fields := make(idl.IdlDefinedFieldsNamed, len(tuple)) + for i, item := range tuple { + tupleItemName := FormatTupleItemName(i) + fields[i] = idl.IdlField{ + Name: tupleItemName, + Ty: item, + } + } + return fields +} + +func gen_unmarshal_DefinedFieldsNamed( + body *Group, + fields idl.IdlDefinedFieldsNamed, +) { + for _, field := range fields { + exportedArgName := tools.ToCamelUpper(field.Name) + if IsOption(field.Ty) || IsCOption(field.Ty) { + body.Commentf("Deserialize `%s` (optional):", exportedArgName) + } else { + body.Commentf("Deserialize `%s`:", exportedArgName) + } + + if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) { + // TODO: this assumes this cannot be an option; + // - check whether this is an option? + switch field.Ty.(type) { + case *idltype.Defined: + enumName := field.Ty.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + { + argBody.Var().Err().Error() + argBody.List( + Id("obj").Dot(exportedArgName), + Err(), + ).Op("=").Id(formatEnumParserName(enumName)).Call(Id("decoder")) + } + argBody.If( + Err().Op("!=").Nil(), + ).Block( + Return(Err()), + ) + }) + case *idltype.Array: + enumTypeName := field.Ty.(*idltype.Array).Type.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + // Read the array items: + argBody.For( + Id("i").Op(":=").Lit(0), + Id("i").Op("<").Len(Id("obj").Dot(exportedArgName)), + Id("i").Op("++"), + ).BlockFunc(func(forBody *Group) { + forBody.List( + Id("obj").Dot(exportedArgName).Index(Id("i")), + Err(), + ).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")) + forBody.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual(PkgAnchorGoErrors, "NewIndex").Call( + Id("i"), + Err(), + ), + ), + ), + ) + }) + }) + case *idltype.Vec: + enumTypeName := field.Ty.(*idltype.Vec).Vec.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + // Read the vector length: + argBody.List(Id("vecLen"), Err()).Op(":=").Id("decoder").Dot("ReadLength").Call() + argBody.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while reading vector length: %w"), + Err(), + ), + ), + ), + ) + // Create the vector: + argBody.Id("obj").Dot(exportedArgName).Op("=").Make(Index().Id(enumTypeName), Id("vecLen")) + // Read the vector items: + argBody.For( + Id("i").Op(":=").Lit(0), + Id("i").Op("<").Id("vecLen"), + Id("i").Op("++"), + ).BlockFunc(func(forBody *Group) { + forBody.List( + Id("obj").Dot(exportedArgName).Index(Id("i")), + Err(), + ).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")) + forBody.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual(PkgAnchorGoErrors, "NewIndex").Call( + Id("i"), + Err(), + ), + ), + ), + ) + }) + }) + } + } else { + if IsOption(field.Ty) || IsCOption(field.Ty) { + var optionalityReaderName string + switch { + case IsOption(field.Ty): + optionalityReaderName = "ReadOption" + case IsCOption(field.Ty): + optionalityReaderName = "ReadCOption" + } + + body.BlockFunc(func(optGroup *Group) { + // if nil: + optGroup.List(Id("ok"), Err()).Op(":=").Id("decoder").Dot(optionalityReaderName).Call() + optGroup.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while reading optionality: %w"), + Err(), + ), + ), + ), + ) + optGroup.If(Id("ok")).Block( + Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(exportedArgName)), + If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ), + ), + ), + ) + }) + } else { + body.Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(exportedArgName)) + body.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ), + ), + ) + } + } + } +} diff --git a/cmd/generate-bindings/solana/bindgen.go b/cmd/generate-bindings/solana/bindgen.go new file mode 100644 index 00000000..d3a481de --- /dev/null +++ b/cmd/generate-bindings/solana/bindgen.go @@ -0,0 +1,344 @@ +package solana + +import ( + "flag" + "fmt" + "go/token" + "log/slog" + "os" + "os/exec" + "path" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" + bin "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/anchor-go/generator" +) + +const defaultProgramName = "myprogram" + +func GenerateBindings( + pathToIdl string, + programName string, + outputDir string, +) error { + if pathToIdl == "" { + panic("Please provide the path to the IDL file using the -idl flag") + } + if outputDir == "" { + panic("Please provide the output directory using the -output flag") + } + if err := os.MkdirAll(outputDir, 0o777); err != nil { + panic(fmt.Errorf("Failed to create output directory: %w", err)) + } + + slog.Info("Starting code generation", + "outputDir", outputDir, + "pathToIdl", pathToIdl, + ) + options := generator.GeneratorOptions{ + OutputDir: outputDir, + Package: programName, + ProgramName: programName, + } + parsedIdl, err := idl.ParseFromFilepath(pathToIdl) + if err != nil { + panic(err) + } + if parsedIdl == nil { + panic("Parsed IDL is nil, please check the IDL file path and format.") + } + if err := parsedIdl.Validate(); err != nil { + panic(fmt.Errorf("Invalid IDL: %w", err)) + } + { + { + if parsedIdl.Address != nil && !parsedIdl.Address.IsZero() { + // If the IDL has an address, use it as the program ID: + slog.Info("Using IDL address as program ID", "address", parsedIdl.Address.String()) + options.ProgramId = parsedIdl.Address + } else { + panic("Please ensure the IDL has a valid metadata.address field.") + } + } + parsedIdl.Metadata.Name = bin.ToSnakeForSighash(parsedIdl.Metadata.Name) + { + // check that the name is not a reserved keyword: + if parsedIdl.Metadata.Name != "" { + if tools.IsReservedKeyword(parsedIdl.Metadata.Name) { + slog.Warn("The IDL metadata.name is a reserved Go keyword: adding a suffix to avoid conflicts.", + "name", parsedIdl.Metadata.Name, + "reservedKeyword", token.Lookup(parsedIdl.Metadata.Name).String(), + ) + // Add a suffix to the name to avoid conflicts with Go reserved keywords: + parsedIdl.Metadata.Name += "_program" + } + if !tools.IsValidIdent(parsedIdl.Metadata.Name) { + // add a prefix to the name to avoid conflicts with Go reserved keywords: + parsedIdl.Metadata.Name = "my_" + parsedIdl.Metadata.Name + } + } + // if begins with + } + if programName == "" && parsedIdl.Metadata.Name != "" { + panic("Please provide a package name using the -name flag, or ensure the IDL has a valid metadata.name field.") + } + if programName == defaultProgramName && parsedIdl.Metadata.Name != "" { + cleanedName := bin.ToSnakeForSighash(parsedIdl.Metadata.Name) + options.Package = cleanedName + options.ProgramName = cleanedName + slog.Info("Using IDL metadata.name as package name", "packageName", cleanedName) + } + + slog.Info("Parsed IDL successfully", + "version", parsedIdl.Metadata.Version, + "name", parsedIdl.Metadata.Name, + "address", parsedIdl.Address, + "programId", func() string { + if parsedIdl.Address.IsZero() { + return "not provided" + } + return parsedIdl.Address.String() + }(), + "instructionsCount", len(parsedIdl.Instructions), + "accountsCount", len(parsedIdl.Accounts), + "eventsCount", len(parsedIdl.Events), + "typesCount", len(parsedIdl.Types), + "constantsCount", len(parsedIdl.Constants), + "errorsCount", len(parsedIdl.Errors), + ) + } + gen := generator.NewGenerator(parsedIdl, &options) + generatedFiles, err := gen.Generate() + if err != nil { + panic(err) + } + + { + for _, file := range generatedFiles.Files { + { + // Save assets: + assetFilename := file.Name + assetFilepath := path.Join(options.OutputDir, assetFilename) + + // Create file: + goFile, err := os.Create(assetFilepath) + if err != nil { + panic(err) + } + defer goFile.Close() + + slog.Info("Writing file", + "filepath", assetFilepath, + "name", file.Name, + "modPath", options.ModPath, + ) + err = file.File.Render(goFile) + if err != nil { + panic(err) + } + } + } + // executeCmd(outputDir, "go", "mod", "tidy") + // executeCmd(outputDir, "go", "fmt") + // executeCmd(outputDir, "go", "build", "-o", "/dev/null") // Just to ensure everything compiles. + slog.Info("Generation completed successfully", + "outputDir", options.OutputDir, + "modPath", options.ModPath, + "package", options.Package, + "programName", options.ProgramName, + ) + } + return nil +} + +// ignore +func main() { + var outputDir string + var programName string + var modPath string + var pathToIdl string + var programIDOverride solana.PublicKey + flag.Var(&programIDOverride, "program-id", "Program ID to use in the generated code (optional; must be a valid Solana public key). If not provided, it will be derived from the IDL metadata.address field if available.") + flag.StringVar(&outputDir, "output", "", "Directory to write the generated code to") + flag.StringVar(&programName, "name", defaultProgramName, "Name of the program for the generated code (optional; example: myprogram). If not provided, it will be derived from the IDL metadata.name field.") + flag.StringVar(&modPath, "mod-path", "", "Module path for the generated code (optional; example: github.com/gagliardetto/mysolana-program-go)") + flag.StringVar(&pathToIdl, "idl", "", "Path to the IDL file (required)") + var skipGoMod bool + flag.BoolVar(&skipGoMod, "no-go-mod", false, "Skip generating the go.mod file (useful for testing)") + flag.Parse() + if pathToIdl == "" { + panic("Please provide the path to the IDL file using the -idl flag") + } + if outputDir == "" { + panic("Please provide the output directory using the -output flag") + } + + if modPath == "" { + modPath = path.Join("github.com", "gagliardetto", "anchor-go", "generated") + slog.Info("Using default module path", "modPath", modPath) + } else { + slog.Info("Using provided module path", "modPath", modPath) + } + if err := os.MkdirAll(outputDir, 0o777); err != nil { + panic(fmt.Errorf("Failed to create output directory: %w", err)) + } + slog.Info("Starting code generation", + "outputDir", outputDir, + "modPath", modPath, + "pathToIdl", pathToIdl, + "programID", func() string { + if programIDOverride.IsZero() { + return "not provided" + } + return programIDOverride.String() + }(), + ) + + options := generator.GeneratorOptions{ + OutputDir: outputDir, + Package: programName, + ProgramName: programName, + ModPath: modPath, + SkipGoMod: skipGoMod, + } + if !programIDOverride.IsZero() { + options.ProgramId = &programIDOverride + slog.Info("Using provided program ID", "programID", programIDOverride.String()) + } + parsedIdl, err := idl.ParseFromFilepath(pathToIdl) + if err != nil { + panic(err) + } + if parsedIdl == nil { + panic("Parsed IDL is nil, please check the IDL file path and format.") + } + if err := parsedIdl.Validate(); err != nil { + panic(fmt.Errorf("Invalid IDL: %w", err)) + } + { + { + if parsedIdl.Address != nil && !parsedIdl.Address.IsZero() && options.ProgramId == nil { + // If the IDL has an address, use it as the program ID: + slog.Info("Using IDL address as program ID", "address", parsedIdl.Address.String()) + options.ProgramId = parsedIdl.Address + } + } + parsedIdl.Metadata.Name = bin.ToSnakeForSighash(parsedIdl.Metadata.Name) + { + // check that the name is not a reserved keyword: + if parsedIdl.Metadata.Name != "" { + if tools.IsReservedKeyword(parsedIdl.Metadata.Name) { + slog.Warn("The IDL metadata.name is a reserved Go keyword: adding a suffix to avoid conflicts.", + "name", parsedIdl.Metadata.Name, + "reservedKeyword", token.Lookup(parsedIdl.Metadata.Name).String(), + ) + // Add a suffix to the name to avoid conflicts with Go reserved keywords: + parsedIdl.Metadata.Name += "_program" + } + if !tools.IsValidIdent(parsedIdl.Metadata.Name) { + // add a prefix to the name to avoid conflicts with Go reserved keywords: + parsedIdl.Metadata.Name = "my_" + parsedIdl.Metadata.Name + } + } + // if begins with + } + if programName == "" && parsedIdl.Metadata.Name != "" { + panic("Please provide a package name using the -name flag, or ensure the IDL has a valid metadata.name field.") + } + if programName == defaultProgramName && parsedIdl.Metadata.Name != "" { + cleanedName := bin.ToSnakeForSighash(parsedIdl.Metadata.Name) + options.Package = cleanedName + options.ProgramName = cleanedName + slog.Info("Using IDL metadata.name as package name", "packageName", cleanedName) + } + + slog.Info("Parsed IDL successfully", + "version", parsedIdl.Metadata.Version, + "name", parsedIdl.Metadata.Name, + "address", parsedIdl.Address, + "programId", func() string { + if parsedIdl.Address.IsZero() { + return "not provided" + } + return parsedIdl.Address.String() + }(), + "instructionsCount", len(parsedIdl.Instructions), + "accountsCount", len(parsedIdl.Accounts), + "eventsCount", len(parsedIdl.Events), + "typesCount", len(parsedIdl.Types), + "constantsCount", len(parsedIdl.Constants), + "errorsCount", len(parsedIdl.Errors), + ) + } + gen := generator.NewGenerator(parsedIdl, &options) + generatedFiles, err := gen.Generate() + if err != nil { + panic(err) + } + + if !skipGoMod { + goModFilepath := path.Join(options.OutputDir, "go.mod") + slog.Info("Writing go.mod file", + "filepath", goModFilepath, + "modPath", options.ModPath, + ) + + err = os.WriteFile(goModFilepath, []byte(generatedFiles.GoMod), 0o777) + if err != nil { + panic(err) + } + } + { + for _, file := range generatedFiles.Files { + { + // Save assets: + assetFilename := file.Name + assetFilepath := path.Join(options.OutputDir, assetFilename) + + // Create file: + goFile, err := os.Create(assetFilepath) + if err != nil { + panic(err) + } + defer goFile.Close() + + slog.Info("Writing file", + "filepath", assetFilepath, + "name", file.Name, + "modPath", options.ModPath, + ) + err = file.File.Render(goFile) + if err != nil { + panic(err) + } + } + } + // executeCmd(outputDir, "go", "mod", "tidy") + // executeCmd(outputDir, "go", "fmt") + // executeCmd(outputDir, "go", "build", "-o", "/dev/null") // Just to ensure everything compiles. + slog.Info("Generation completed successfully", + "outputDir", options.OutputDir, + "modPath", options.ModPath, + "package", options.Package, + "programName", options.ProgramName, + ) + } +} + +func executeCmd(dir string, name string, arg ...string) { + cmd := exec.Command(name, arg...) + cmd.Dir = dir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Run() + if err != nil { + panic(err) + } +} + +func hasCommand(name string) bool { + _, err := exec.LookPath(name) + return err == nil +} diff --git a/cmd/generate-bindings/solana/bindings_test.go b/cmd/generate-bindings/solana/bindings_test.go new file mode 100644 index 00000000..371fc1ee --- /dev/null +++ b/cmd/generate-bindings/solana/bindings_test.go @@ -0,0 +1,343 @@ +package solana_test + +import ( + "context" + "testing" + + "github.com/gagliardetto/solana-go" + "github.com/test-go/testify/require" + "google.golang.org/protobuf/proto" + + ocr3types "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + realSolana "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" + realSolanaMock "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana/mock" + + datastorage "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/testdata/data_storage" + "github.com/smartcontractkit/cre-sdk-go/cre/testutils" + consensusmock "github.com/smartcontractkit/cre-sdk-go/internal_testing/capabilities/consensus/mock" +) + +const anyChainSelector = uint64(1337) + +func TestGeneratedBindingsCodec(t *testing.T) { + codec := datastorage.Codec{} + + t.Run("encode functions", func(t *testing.T) { + // structs + userData := datastorage.UserData{ + Key: "testKey", + Value: "testValue", + } + _, err := codec.EncodeUserDataStruct(userData) + require.NoError(t, err) + + testPrivKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + testPubKey := testPrivKey.PublicKey() + + logAccess := datastorage.AccessLogged{ + Caller: testPubKey, + Message: "testMessage", + } + _, err = codec.EncodeAccessLoggedStruct(logAccess) + require.NoError(t, err) + + readData := datastorage.DataAccount{ + Sender: testPubKey.String(), + Key: "testKey", + Value: "testValue", + } + _, err = codec.EncodeDataAccountStruct(readData) + require.NoError(t, err) + + storeData := datastorage.DynamicEvent{ + Key: "testKey", + UserData: userData, + Sender: testPubKey.String(), + Metadata: []byte("testMetadata"), + MetadataArray: [][]byte{}, + } + _, err = codec.EncodeDynamicEventStruct(storeData) + require.NoError(t, err) + + storeUserData := datastorage.UpdateReserves{ + TotalMinted: 100, + TotalReserve: uint64(200), + } + _, err = codec.EncodeUpdateReservesStruct(storeUserData) + require.NoError(t, err) + + // onReport := datastorage.OnReportInput{ + // Metadata: []byte("testMetadata"), + // Payload: []byte("testPayload"), + // } + // _, err = codec.EncodeOnReportMethodCall(onReport) + // require.NoError(t, err) + }) +} + +func TestWriteReportMethods(t *testing.T) { + client := &realSolana.Client{ChainSelector: anyChainSelector} + ds, err := datastorage.NewDataStorage(client) + require.NoError(t, err, "Failed to create DataStorage instance") + + report := ocr3types.Metadata{ + Version: 1, + ExecutionID: "1234567890123456789012345678901234567890123456789012345678901234", + Timestamp: 1620000000, + DONID: 1, + DONConfigVersion: 1, + WorkflowID: "1234567890123456789012345678901234567890123456789012345678901234", + WorkflowName: "12", + WorkflowOwner: "1234567890123456789012345678901234567890", + ReportID: "1234", + } + + rawReport, err := report.Encode() + require.NoError(t, err) + + consensusCap, err := consensusmock.NewConsensusCapability(t) + require.NoError(t, err, "Failed to create Consensus capability") + consensusCap.Report = func(_ context.Context, input *sdk.ReportRequest) (*sdk.ReportResponse, error) { + return &sdk.ReportResponse{ + RawReport: rawReport, + }, nil + } + + solanaCap, err := realSolanaMock.NewClientCapability(anyChainSelector, t) + require.NoError(t, err, "Failed to create Solana client capability") + solanaCap.WriteReport = func(_ context.Context, req *realSolana.WriteReportRequest) (*realSolana.WriteReportReply, error) { + return &realSolana.WriteReportReply{ + TxStatus: realSolana.TxStatus_TX_STATUS_SUCCESS, + TxSignature: []byte{0x01, 0x02, 0x03, 0x04}, + }, nil + } + + runtime := testutils.NewRuntime(t, testutils.Secrets{}) + + reply := ds.WriteReportFromUserData(runtime, datastorage.UserData{ + Key: "testKey", + Value: "testValue", + }, nil) + require.NoError(t, err, "WriteReportDataStorageUserData should not return an error") + response, err := reply.Await() + require.NoError(t, err, "Awaiting WriteReportDataStorageUserData reply should not return an error") + require.NotNil(t, response, "Response from WriteReportDataStorageUserData should not be nil") + require.True(t, proto.Equal(&realSolana.WriteReportReply{ + TxStatus: realSolana.TxStatus_TX_STATUS_SUCCESS, + TxSignature: []byte{0x01, 0x02, 0x03, 0x04}, + }, response), "Response should match expected WriteReportReply") +} + +func TestEncodeStruct(t *testing.T) { + client := &realSolana.Client{ChainSelector: anyChainSelector} + ds, err := datastorage.NewDataStorage(client) + require.NoError(t, err, "Failed to create DataStorage instance") + + str := datastorage.DataAccount{ + Key: "testKey", + Value: "testValue", + Sender: "testSender", + } + + encoded, err := ds.Codec.EncodeDataAccountStruct(str) + require.NoError(t, err, "Encoding DataStorageDataAccount should not return an error") + require.NotNil(t, encoded, "Encoded data should not be nil") +} + +// func TestReadMethods(t *testing.T) { +// t.Run("single", func(t *testing.T) { +// client := &realSolana.Client{ChainSelector: anyChainSelector} +// ds, err := datastorage.NewDataStorage(client) +// require.NoError(t, err, "Failed to create DataStorage instance") + +// // Encode the expected string response +// dataAccount := datastorage.DataAccount{ +// Sender: "testSender", +// Key: "testKey", +// Value: "testValue", +// } +// encodedData, err := dataAccount.Marshal() +// require.NoError(t, err) + +// dataAccountDiscriminator := datastorage.Account_DataAccount +// dataAccountDiscriminatorBytes := dataAccountDiscriminator[:] +// encodedData = append(dataAccountDiscriminatorBytes, encodedData...) + +// solanaCap, err := realSolanaMock.NewClientCapability(anyChainSelector, t) +// require.NoError(t, err, "Failed to create EVM client capability") + +// solanaCap.GetAccountInfoWithOpts = func(_ context.Context, input *realSolana.GetAccountInfoWithOptsRequest) (*realSolana.GetAccountInfoWithOptsReply, error) { +// reply := &realSolana.GetAccountInfoWithOptsReply{ +// Value: &realSolana.Account{ +// Data: &realSolana.DataBytesOrJSON{ +// // AsDecodedBinary: encodedData, +// Body: &realSolana.DataBytesOrJSON_Raw{ +// Raw: encodedData, +// }, +// }, +// }, +// } +// return reply, nil +// } +// runtime := testutils.NewRuntime(t, testutils.Secrets{}) +// randomAddress, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// testBigInt := big.NewInt(123456) +// reply := ds.ReadAccount_DataAccount(runtime, randomAddress.PublicKey(), testBigInt) +// require.NotNil(t, reply, "ReadData should return a non-nil promise") + +// resp, err := reply.Await() +// require.NoError(t, err, "Awaiting ReadData reply should not return an error") +// require.Equal(t, dataAccount.Value, resp.Value, "Decoded value should match expected string") +// }) +// } + +// func TestDecodeEvents(t *testing.T) { +// t.Run("Success", func(t *testing.T) { +// client := &realSolana.Client{ChainSelector: anyChainSelector} +// ds, err := datastorage.NewDataStorage(client) +// require.NoError(t, err, "Failed to create DataStorage instance") + +// testPrivKey, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// testPubKey := testPrivKey.PublicKey() +// testLog := datastorage.AccessLogged{ +// Caller: testPubKey, +// Message: "testMessage", +// } + +// data, err := ds.Codec.EncodeAccessLoggedStruct(testLog) +// require.NoError(t, err) +// discriminator := datastorage.Event_AccessLogged + +// log := &solanasdk.Log{ +// Data: append(discriminator[:], data...), +// } + +// out, err := ds.Codec.DecodeAccessLogged(log) +// require.NoError(t, err) +// require.Equal(t, testPubKey, out.Caller) +// require.Equal(t, "testMessage", out.Message) + +// testLog2 := datastorage.DynamicEvent{ +// Key: "testKey", +// UserData: datastorage.UserData{ +// Key: "testKey", +// Value: "testValue", +// }, +// Sender: testPubKey.String(), +// Metadata: []byte("testMetadata"), +// MetadataArray: [][]byte{}, +// } +// data2, err := ds.Codec.EncodeDynamicEventStruct(testLog2) +// require.NoError(t, err) +// discriminator2 := datastorage.Event_DynamicEvent +// log2 := &solanasdk.Log{ +// Data: append(discriminator2[:], data2...), +// } +// out2, err := ds.Codec.DecodeDynamicEvent(log2) +// require.NoError(t, err) +// require.Equal(t, testPubKey.String(), out2.Sender) +// require.Equal(t, "testMetadata", string(out2.Metadata)) +// require.Equal(t, "testKey", out2.Key) +// require.Equal(t, "testValue", out2.UserData.Value) +// }) +// } + +// func TestLogTrigger(t *testing.T) { +// client := &realSolana.Client{ChainSelector: anyChainSelector} +// ds, err := datastorage.NewDataStorage(client) +// require.NoError(t, err, "Failed to create DataStorage instance") +// t.Run("simple event", func(t *testing.T) { +// testPrivKey, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// testPubKey := testPrivKey.PublicKey() +// events := []datastorage.AccessLogged{ +// { +// Caller: testPubKey, +// Message: "testMessage", +// }, +// } + +// encoded, err := ds.Codec.EncodeAccessLoggedStruct(events[0]) +// require.NoError(t, err, "Encoding AccessLogged should not return an error") +// discriminator := datastorage.Event_AccessLogged +// encoded = append(discriminator[:], encoded...) + +// trigger, err := ds.LogTrigger_AccessLogged(anyChainSelector, []solanasdk.SubKeyPathAndFilter{ +// { +// SubkeyPath: "Caller", +// Value: testPubKey, +// }, +// }) +// require.NotNil(t, trigger) +// require.NoError(t, err) + +// // Create a mock log that simulates what would be returned by the blockchain +// mockLog := &solanasdk.Log{ +// Address: solanatypes.PublicKey(datastorage.ProgramID), +// Data: encoded, +// } + +// // Call Adapt to decode the log +// decodedLog, err := trigger.Adapt(mockLog) +// require.NoError(t, err, "Adapt should not return an error") +// require.NotNil(t, decodedLog, "Decoded log should not be nil") +// require.Equal(t, events[0].Caller, decodedLog.Data.Caller, "Decoded caller should match") +// require.Equal(t, events[0].Message, decodedLog.Data.Message, "Decoded message should match") +// }) + +// t.Run("dynamic event", func(t *testing.T) { +// testPrivKey, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// testPubKey := testPrivKey.PublicKey() +// events := []datastorage.DynamicEvent{ +// { +// Key: "testKey", +// UserData: datastorage.UserData{ +// Key: "testKey", +// Value: "testValue", +// }, +// Sender: testPubKey.String(), +// Metadata: []byte("testMetadata"), +// MetadataArray: [][]byte{}, +// }, +// } + +// encoded, err := ds.Codec.EncodeDynamicEventStruct(events[0]) +// require.NoError(t, err, "Encoding DynamicEvent should not return an error") +// discriminator := datastorage.Event_DynamicEvent +// encoded = append(discriminator[:], encoded...) + +// trigger, err := ds.LogTrigger_DynamicEvent(anyChainSelector, []solanasdk.SubKeyPathAndFilter{ +// { +// SubkeyPath: "UserData.Key", +// Value: "testKey", +// }, +// { +// SubkeyPath: "Key", +// Value: "testKey", +// }, +// }) +// require.NotNil(t, trigger) +// require.NoError(t, err) + +// // Create a mock log that simulates what would be returned by the blockchain +// mockLog := &solanasdk.Log{ +// Address: solanatypes.PublicKey(datastorage.ProgramID), +// Data: encoded, +// } + +// // Call Adapt to decode the log +// decodedLog, err := trigger.Adapt(mockLog) +// require.NoError(t, err, "Adapt should not return an error") +// require.NotNil(t, decodedLog, "Decoded log should not be nil") +// require.Equal(t, events[0].Key, decodedLog.Data.Key, "Decoded key should match") +// require.Equal(t, events[0].UserData.Key, decodedLog.Data.UserData.Key, "Decoded user data key should match") +// require.Equal(t, events[0].UserData.Value, decodedLog.Data.UserData.Value, "Decoded user data value should match") +// require.Equal(t, events[0].Sender, decodedLog.Data.Sender, "Decoded sender should match") +// require.Equal(t, events[0].Metadata, decodedLog.Data.Metadata, "Decoded metadata should match") +// }) +// } diff --git a/cmd/generate-bindings/solana/bindings_test_temp.go b/cmd/generate-bindings/solana/bindings_test_temp.go new file mode 100644 index 00000000..ca9c2c99 --- /dev/null +++ b/cmd/generate-bindings/solana/bindings_test_temp.go @@ -0,0 +1,249 @@ +package solana + +// "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana_bindings/testdata/forwarder" +// "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana_bindings/testdata/receiver" + +const anyChainSelector = uint64(1337) + +/* + +// deploy +solana-test-validator -r \ + --upgradeable-program 5PdwLUj8VqLpA8RKGAUaReEies7kEebeQcZzcrB2R7ya /Users/yashvardhan/cre-client-program/my-project/target/deploy/forwarder.so Av3xZHYnFoW7wW4FEApAtHf8JeYauwaNm5cVLqk6MLfk \ + --upgradeable-program G5t6jDm3pmQFwW4y9KQn1iDkZrSEvC78H8cL7XaoTA3Q /Users/yashvardhan/cre-client-program/my-project/target/deploy/receiver.so Av3xZHYnFoW7wW4FEApAtHf8JeYauwaNm5cVLqk6MLfk + + +// update bindings +./anchor \ + --idl /Users/yashvardhan/cre-client-program/my-project/target/idl/forwarder.json \ + --output /Users/yashvardhan/cre-cli/cmd/generate-bindings/solana_bindings/testdata/forwarder \ + --program-id 5PdwLUj8VqLpA8RKGAUaReEies7kEebeQcZzcrB2R7ya \ + --no-go-mod + + ./anchor \ + --idl /Users/yashvardhan/cre-client-program/my-project/target/idl/receiver.json \ + --output /Users/yashvardhan/cre-cli/cmd/generate-bindings/solana_bindings/testdata/receiver \ + --program-id G5t6jDm3pmQFwW4y9KQn1iDkZrSEvC78H8cL7XaoTA3Q \ + --no-go-mod + +*/ + +// func TestSolanaBasic(t *testing.T) { +// solanaClient := rpc.New("http://localhost:8899") +// pk, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// common.FundAccounts(context.Background(), []solana.PrivateKey{pk}, solanaClient, t) +// // version, err := solanaClient.GetVersion(context.Background()) +// // require.NoError(t, err) +// // fmt.Println("version", version) +// // health, err := solanaClient.GetHealth(context.Background()) +// // require.NoError(t, err) +// // fmt.Println("health", health) +// // fmt.Println(forwarder.ProgramID.String()) +// // fmt.Println(receiver.ProgramID.String()) +// counterAccount, _, _ := solana.FindProgramAddress( +// [][]byte{[]byte("counter")}, +// forwarder.ProgramID, +// ) +// ix1, err := forwarder.NewReportInstruction(123456, receiver.Instruction_OnReport, counterAccount, receiver.ProgramID, solana.SystemProgramID) +// require.NoError(t, err) +// fmt.Println("ix11", ix1) +// // ix2, err := receiver.NewOnReportInstruction(123456) +// // require.NoError(t, err) +// // fmt.Println("ix2", ix2) +// res, err := common.SendAndConfirm(context.Background(), solanaClient, []solana.Instruction{ix1}, pk, rpc.CommitmentConfirmed) +// require.NoError(t, err) +// fmt.Println("res", res.Meta.LogMessages) +// } + +// func TestSolanaInit(t *testing.T) { +// solanaClient := rpc.New("http://localhost:8899") +// pk, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// common.FundAccounts(context.Background(), []solana.PrivateKey{pk}, solanaClient, t) + +// counterAccount, _, _ := solana.FindProgramAddress( +// [][]byte{[]byte("counter")}, +// forwarder.ProgramID, +// ) +// ix1, err := forwarder.NewInitializeInstruction(123456, pk.PublicKey(), counterAccount, solana.SystemProgramID) +// require.NoError(t, err) +// fmt.Println("ix1", ix1) +// res, err := common.SendAndConfirm(context.Background(), solanaClient, []solana.Instruction{ix1}, pk, rpc.CommitmentConfirmed) +// require.NoError(t, err) +// fmt.Println("res", res.Meta.LogMessages) +// } + +// func TestSolanaReadAccount(t *testing.T) { +// // create client +// solanaClient := rpc.New("http://localhost:8899") +// // find pda +// counterAccountAddress, _, _ := solana.FindProgramAddress( +// [][]byte{[]byte("counter")}, +// forwarder.ProgramID, +// ) +// resp, err := solanaClient.GetAccountInfoWithOpts( +// context.Background(), +// counterAccountAddress, +// &rpc.GetAccountInfoOpts{ +// Commitment: rpc.CommitmentConfirmed, +// DataSlice: nil, +// }, +// ) +// require.NoError(t, err) +// counter, err := forwarder.ParseAccount_Counter(resp.Value.Data.GetBinary()) +// require.NoError(t, err) +// fmt.Println("counter ", counter.Counter) +// } + +// func TestSolanaInit2(t *testing.T) { +// solanaClient := rpc.New("http://localhost:8899") +// pk, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// common.FundAccounts(context.Background(), []solana.PrivateKey{pk}, solanaClient, t) +// txId := uint64(3) +// txIdLE := common.Uint64ToLE(txId) +// executionStateAccount, _, _ := solana.FindProgramAddress( +// [][]byte{[]byte("execution_state"), txIdLE}, +// forwarder.ProgramID, +// ) +// ix1, err := forwarder.NewReport2Instruction( +// 123456, +// txId, +// receiver.Instruction_OnReport, +// pk.PublicKey(), +// executionStateAccount, +// receiver.ProgramID, +// solana.SystemProgramID, +// ) +// require.NoError(t, err) +// fmt.Println("ix1", ix1) +// res, err := common.SendAndConfirm(context.Background(), solanaClient, []solana.Instruction{ix1}, pk, rpc.CommitmentConfirmed) +// // require.NoError(t, err) +// fmt.Println("res error", err) +// fmt.Println("res", res) +// // fmt.Println("res", res.Meta.LogMessages) + +// resp, err := solanaClient.GetAccountInfoWithOpts( +// context.Background(), +// executionStateAccount, +// &rpc.GetAccountInfoOpts{ +// Commitment: rpc.CommitmentConfirmed, +// DataSlice: nil, +// }, +// ) +// require.NoError(t, err) +// executionState, err := forwarder.ParseAccount_ExecutionState(resp.Value.Data.GetBinary()) +// require.NoError(t, err) +// fmt.Println("executionState ", executionState.Success) +// fmt.Println("executionState ", executionState.Failure) +// fmt.Println("executionState ", executionState.TransmissionId) +// } + +// func TestSolanaInit(t *testing.T) { +// solanaClient := rpc.New("http://localhost:8899") +// pk, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// common.FundAccounts(context.Background(), []solana.PrivateKey{pk}, solanaClient, t) + +// // dataAccountAccount, _, err := solana.FindProgramAddress( +// // [][]byte{[]byte("test")}, +// // my_anchor_project.ProgramID, +// // ) +// // ix, err := my_anchor_project.NewInitializeInstruction( +// // "test-data", +// // dataAccountAccount, +// // pk.PublicKey(), +// // solana.SystemProgramID, +// // ) +// require.NoError(t, err) + +// res, err := common.SendAndConfirm( +// context.Background(), +// solanaClient, +// []solana.Instruction{}, +// pk, +// rpc.CommitmentConfirmed, +// common.AddSigners(pk), +// ) +// require.NoError(t, err) +// fmt.Println("res", res.Meta.LogMessages) + +// } + +// func TestSolanaGetData(t *testing.T) { +// solanaClient := rpc.New("http://localhost:8899") +// pk, err := solana.NewRandomPrivateKey() +// require.NoError(t, err) +// common.FundAccounts(context.Background(), []solana.PrivateKey{pk}, solanaClient, t) + +// // dataAccountAccount, _, err := solana.FindProgramAddress( +// // [][]byte{[]byte("test")}, +// // my_anchor_project.ProgramID, +// // ) + +// // ix3, err := my_anchor_project.NewGetInputDataInstruction("test-data") +// require.NoError(t, err) +// // ix4, err := my_anchor_project.NewGetInputDataFromAccountInstruction("test-data", dataAccountAccount) +// // require.NoError(t, err) +// // res, err := common.SendAndConfirm(context.Background(), solanaClient, []solana.Instruction{ix3, ix4}, pk, rpc.CommitmentConfirmed) +// res, err := common.SendAndConfirm(context.Background(), solanaClient, []solana.Instruction{}, pk, rpc.CommitmentConfirmed) + +// require.NoError(t, err) +// for _, log := range res.Meta.LogMessages { +// if strings.Contains(log, "Program log:") { +// fmt.Println("log", log) +// } +// } +// } + +// func TestSolanaWriteAccount(t *testing.T) { +// // solanaClient := rpc.New("http://localhost:8899") +// // pk, err := solana.NewRandomPrivateKey() +// // require.NoError(t, err) +// // common.FundAccounts(context.Background(), []solana.PrivateKey{pk}, solanaClient, t) + +// // // dataAccountAddress, _, err := solana.FindProgramAddress( +// // // [][]byte{[]byte("test")}, +// // // // my_anchor_project.ProgramID, +// // // ) +// // // ix, err := my_anchor_project.NewUpdateDataInstruction("test-data-new", dataAccountAddress) +// // require.NoError(t, err) + +// // // ix2, err := my_anchor_project.NewUpdateDataWithTypedReturnInstruction("test-data-new", dataAccountAddress) +// // // require.NoError(t, err) + +// // // res, err := common.SendAndConfirm(context.Background(), solanaClient, []solana.Instruction{ix, ix2}, pk, rpc.CommitmentConfirmed) +// // res, err := common.SendAndConfirm(context.Background(), solanaClient, []solana.Instruction{ix}, pk, rpc.CommitmentConfirmed) + +// // require.NoError(t, err) +// // fmt.Println("res", res.Meta.LogMessages) + +// // // output, err := common.ExtractTypedReturnValue(context.Background(), res.Meta.LogMessages, my_anchor_project.ProgramID.String(), func(b []byte) string { +// // // require.Len(t, b, int(binary.LittleEndian.Uint32(b[:4]))+4) // the first 4 bytes just encodes the length +// // // return string(b[4:]) +// // // }) +// // require.NoError(t, err) +// // fmt.Println("output", output) + +// // output2, err := common.ExtractAnchorTypedReturnValue[my_anchor_project.UpdateResponse](context.Background(), res.Meta.LogMessages, my_anchor_project.ProgramID.String()) +// // require.NoError(t, err) +// // fmt.Println("output2", output2) + +// // output3, err := my_anchor_project.SendUpdateDataInstruction("test-data-new", dataAccountAddress, solanaClient, pk, rpc.CommitmentConfirmed) +// // require.NoError(t, err) +// // fmt.Println("output3", output3) + +// // output4, err := my_anchor_project.SendUpdateDataWithTypedReturnInstruction("test-data-new", dataAccountAddress, solanaClient, pk, rpc.CommitmentConfirmed) +// // require.NoError(t, err) +// // fmt.Println("output4", output4.Data) +// } + +/* +anchor-go \ + --idl /Users/yashvardhan/cre-client-program/my-project/target/idl/data_storage.json \ + --output /Users/yashvardhan/cre-cli/cmd/generate-bindings/solana_bindings/testdata/data_storage \ + --program-id ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL \ + --no-go-mod + +*/ diff --git a/cmd/generate-bindings/solana/cre-sdk-go/anchorcodec/anchoridl.go b/cmd/generate-bindings/solana/cre-sdk-go/anchorcodec/anchoridl.go new file mode 100644 index 00000000..8c0b80e7 --- /dev/null +++ b/cmd/generate-bindings/solana/cre-sdk-go/anchorcodec/anchoridl.go @@ -0,0 +1,548 @@ +package anchorcodec + +/* +copied from https://github.com/gagliardetto/anchor-go where the IDL definition is not importable due to being defined +in the `main` package. +*/ + +import ( + "encoding/json" + "fmt" + + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/utilz" +) + +// https://github.com/project-serum/anchor/blob/97e9e03fb041b8b888a9876a7c0676d9bb4736f3/ts/src/idl.ts +type IDL struct { + Version string `json:"version"` + Name string `json:"name"` + Instructions []IdlInstruction `json:"instructions"` + Accounts IdlTypeDefSlice `json:"accounts,omitempty"` + Types IdlTypeDefSlice `json:"types,omitempty"` + Events []IdlEvent `json:"events,omitempty"` + Errors []IdlErrorCode `json:"errors,omitempty"` + Constants []IdlConstant `json:"constants,omitempty"` +} + +type IdlConstant struct { + Name string + Type IdlType + Value string +} + +type IdlTypeDefSlice []IdlTypeDef + +func (named IdlTypeDefSlice) GetByName(name string) *IdlTypeDef { + for i := range named { + v := named[i] + if v.Name == name { + return &v + } + } + return nil +} + +type IdlEvent struct { + Name string `json:"name"` + Fields []IdlEventField `json:"fields"` +} + +type IdlEventField struct { + Name string `json:"name"` + Type IdlType `json:"type"` + Index bool `json:"index"` +} + +type EventIDLTypes struct { + Event IdlEvent + Types IdlTypeDefSlice +} + +type IdlInstruction struct { + Name string `json:"name"` + Docs []string `json:"docs"` // @custom + Accounts IdlAccountItemSlice `json:"accounts"` + Args []IdlField `json:"args"` +} + +type IdlAccountItemSlice []IdlAccountItem + +func (slice IdlAccountItemSlice) NumAccounts() (count int) { + for _, item := range slice { + if item.IdlAccount != nil { + count++ + } + + if item.IdlAccounts != nil { + count += item.IdlAccounts.Accounts.NumAccounts() + } + } + + return count +} + +// type IdlAccountItem = IdlAccount | IdlAccounts; +type IdlAccountItem struct { + IdlAccount *IdlAccount + IdlAccounts *IdlAccounts +} + +func (env *IdlAccountItem) UnmarshalJSON(data []byte) error { + var temp interface{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if temp == nil { + return fmt.Errorf("envelope is nil: %v", env) + } + + switch v := temp.(type) { + case map[string]interface{}: + if len(v) == 0 { + return nil + } + + _, hasAccounts := v["accounts"] + _, hasIsMut := v["isMut"] + + if hasAccounts == hasIsMut { + return fmt.Errorf("invalid idl structure: expected exactly one of 'accounts' or 'isMut'") + } + + if hasAccounts { + return utilz.TranscodeJSON(temp, &env.IdlAccounts) + } + + return utilz.TranscodeJSON(temp, &env.IdlAccount) + default: + return fmt.Errorf("unknown kind: %s", spew.Sdump(temp)) + } +} + +func (env IdlAccountItem) MarshalJSON() ([]byte, error) { + if (env.IdlAccount == nil) == (env.IdlAccounts == nil) { + return nil, fmt.Errorf("invalid structure: expected either IdlAccount or IdlAccounts to be defined") + } + + visited := make(map[*IdlAccounts]struct{}) + if err := checkForIdlAccountsCycle(env.IdlAccounts, visited); err != nil { + return nil, err + } + + var result interface{} + if env.IdlAccounts != nil { + result = map[string]interface{}{ + "accounts": env.IdlAccounts, + } + } else { + result = env.IdlAccount + } + + return json.Marshal(result) +} + +func checkForIdlAccountsCycle(acc *IdlAccounts, visited map[*IdlAccounts]struct{}) error { + if acc == nil { + return nil + } + + if _, exists := visited[acc]; exists { + return fmt.Errorf("cycle detected in IdlAccounts named %q", acc.Name) + } + visited[acc] = struct{}{} + + for _, item := range acc.Accounts { + if (item.IdlAccount == nil) == (item.IdlAccounts == nil) { + return fmt.Errorf("invalid nested structure: expected either IdlAccount or IdlAccounts to be defined") + } + if item.IdlAccounts != nil { + if err := checkForIdlAccountsCycle(item.IdlAccounts, visited); err != nil { + return err + } + } + } + return nil +} + +type IdlAccount struct { + Docs []string `json:"docs"` // @custom + Name string `json:"name"` + IsMut bool `json:"isMut"` + IsSigner bool `json:"isSigner"` + Optional bool `json:"optional"` // @custom +} + +// A nested/recursive version of IdlAccount. +type IdlAccounts struct { + Name string `json:"name"` + Docs []string `json:"docs"` // @custom + Accounts IdlAccountItemSlice `json:"accounts"` +} + +type IdlField struct { + Name string `json:"name"` + Docs []string `json:"docs"` // @custom + Type IdlType `json:"type"` +} + +// PDA is a struct that does not correlate to an official IDL type +// It is needed to encode seeds to calculate the address for PDA account reads +type PDATypeDef struct { + Prefix []byte `json:"prefix,omitempty"` + Seeds []PDASeed `json:"seeds,omitempty"` +} + +type PDASeed struct { + Name string `json:"name"` + Type IdlType `json:"type"` +} + +type IdlTypeAsString string + +const ( + IdlTypeBool IdlTypeAsString = "bool" + IdlTypeU8 IdlTypeAsString = "u8" + IdlTypeI8 IdlTypeAsString = "i8" + IdlTypeU16 IdlTypeAsString = "u16" + IdlTypeI16 IdlTypeAsString = "i16" + IdlTypeU32 IdlTypeAsString = "u32" + IdlTypeI32 IdlTypeAsString = "i32" + IdlTypeU64 IdlTypeAsString = "u64" + IdlTypeI64 IdlTypeAsString = "i64" + IdlTypeU128 IdlTypeAsString = "u128" + IdlTypeI128 IdlTypeAsString = "i128" + IdlTypeBytes IdlTypeAsString = "bytes" + IdlTypeString IdlTypeAsString = "string" + IdlTypePublicKey IdlTypeAsString = "publicKey" + + // Custom additions: + IdlTypeUnixTimestamp IdlTypeAsString = "unixTimestamp" + IdlTypeHash IdlTypeAsString = "hash" + IdlTypeDuration IdlTypeAsString = "duration" +) + +type IdlTypeVec struct { + Vec IdlType `json:"vec"` +} + +type IdlTypeOption struct { + Option IdlType `json:"option"` +} + +// User defined type. +type IdlTypeDefined struct { + Defined string `json:"defined"` +} + +// Wrapper type: +type IdlTypeArray struct { + Thing IdlType + Num int +} + +func (env IdlType) MarshalJSON() ([]byte, error) { + var result interface{} + switch { + case env.IsString(): + result = env.GetString() + case env.IsIdlTypeVec(): + result = env.GetIdlTypeVec() + case env.IsIdlTypeOption(): + result = env.GetIdlTypeOption() + case env.IsIdlTypeDefined(): + result = env.GetIdlTypeDefined() + case env.IsArray(): + array := env.GetArray() + result = map[string]interface{}{ + "array": []interface{}{array.Thing, array.Num}, + } + default: + return nil, fmt.Errorf("nil envelope is not supported in IdlType") + } + + return json.Marshal(result) +} + +type newIdlTypeDefinedNamed struct { + Name string `json:"name"` +} + +type newIdlTypeDefined struct { + Defined newIdlTypeDefinedNamed `json:"defined"` +} + +func (env *IdlType) UnmarshalJSON(data []byte) error { + var temp interface{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if temp == nil { + return fmt.Errorf("envelope is nil: %v", env) + } + + switch v := temp.(type) { + case string: + env.AsString = IdlTypeAsString(v) + case map[string]interface{}: + if len(v) == 0 { + return nil + } + + var typeFound bool + if _, ok := v["vec"]; ok { + var target IdlTypeVec + if err := utilz.TranscodeJSON(temp, &target); err != nil { + return err + } + typeFound = true + env.AsIdlTypeVec = &target + } + if _, ok := v["option"]; ok { + if typeFound { + return fmt.Errorf("multiple types found for IdlType: %s", spew.Sdump(temp)) + } + var target IdlTypeOption + if err := utilz.TranscodeJSON(temp, &target); err != nil { + return err + } + typeFound = true + env.asIdlTypeOption = &target + } + if _, ok := v["defined"]; ok { + if typeFound { + return fmt.Errorf("multiple types found for IdlType: %s", spew.Sdump(temp)) + } + var target IdlTypeDefined + if err := utilz.TranscodeJSON(temp, &target); err != nil { + var target2 newIdlTypeDefined + if err2 := utilz.TranscodeJSON(temp, &target2); err2 != nil { + return err + } + target = IdlTypeDefined{Defined: target2.Defined.Name} + } + typeFound = true + env.AsIdlTypeDefined = &target + } + if got, ok := v["array"]; ok { + if typeFound { + return fmt.Errorf("multiple types found for IdlType: %s", spew.Sdump(temp)) + } + arrVal, ok := got.([]interface{}) + if !ok { + return fmt.Errorf("array is not in expected format: %s", spew.Sdump(got)) + } + if len(arrVal) != 2 { + return fmt.Errorf("array is not of expected length: %s", spew.Sdump(got)) + } + var target IdlTypeArray + if err := utilz.TranscodeJSON(arrVal[0], &target.Thing); err != nil { + return err + } + num, ok := arrVal[1].(float64) + if !ok { + return fmt.Errorf("value is unexpected type: %T, expected float64", arrVal[1]) + } + target.Num = int(num) + env.AsIdlTypeArray = &target + } + default: + return fmt.Errorf("Unknown kind: %s", spew.Sdump(temp)) + } + + return nil +} + +// Wrapper type: +type IdlType struct { + AsString IdlTypeAsString + AsIdlTypeVec *IdlTypeVec + asIdlTypeOption *IdlTypeOption + AsIdlTypeDefined *IdlTypeDefined + AsIdlTypeArray *IdlTypeArray +} + +func NewIdlStringType(asString IdlTypeAsString) IdlType { + return IdlType{ + AsString: asString, + } +} + +func (env *IdlType) IsString() bool { + return env.AsString != "" +} +func (env *IdlType) IsIdlTypeVec() bool { + return env.AsIdlTypeVec != nil +} +func (env *IdlType) IsIdlTypeOption() bool { + return env.asIdlTypeOption != nil +} +func (env *IdlType) IsIdlTypeDefined() bool { + return env.AsIdlTypeDefined != nil +} +func (env *IdlType) IsArray() bool { + return env.AsIdlTypeArray != nil +} + +// Getters: +func (env *IdlType) GetString() IdlTypeAsString { + return env.AsString +} +func (env *IdlType) GetIdlTypeVec() *IdlTypeVec { + return env.AsIdlTypeVec +} +func (env *IdlType) GetIdlTypeOption() *IdlTypeOption { + return env.asIdlTypeOption +} +func (env *IdlType) GetIdlTypeDefined() *IdlTypeDefined { + return env.AsIdlTypeDefined +} +func (env *IdlType) GetArray() *IdlTypeArray { + return env.AsIdlTypeArray +} + +type IdlTypeDef struct { + Name string `json:"name"` + Type IdlTypeDefTy `json:"type"` +} + +type IdlTypeDefTyKind string + +const ( + IdlTypeDefTyKindStruct IdlTypeDefTyKind = "struct" + IdlTypeDefTyKindEnum IdlTypeDefTyKind = "enum" + IdlTypeDefTyKindCustom IdlTypeDefTyKind = "custom" +) + +type IdlTypeDefTyStruct struct { + Kind IdlTypeDefTyKind `json:"kind"` // == "struct" + + Fields *IdlTypeDefStruct `json:"fields,omitempty"` +} + +type IdlTypeDefTyEnum struct { + Kind IdlTypeDefTyKind `json:"kind"` // == "enum" + + Variants IdlEnumVariantSlice `json:"variants,omitempty"` +} + +var NilIdlTypeDefTy = IdlTypeDef{Type: IdlTypeDefTy{ + Kind: "struct", + Fields: &IdlTypeDefStruct{}, +}} + +type IdlTypeDefTy struct { + Kind IdlTypeDefTyKind `json:"kind"` + + Fields *IdlTypeDefStruct `json:"fields,omitempty"` + Variants IdlEnumVariantSlice `json:"variants,omitempty"` + Codec string `json:"codec,omitempty"` +} + +type IdlEnumVariantSlice []IdlEnumVariant + +func (slice IdlEnumVariantSlice) IsAllUint8() bool { + for _, elem := range slice { + if !elem.IsUint8() { + return false + } + } + return true +} + +func (slice IdlEnumVariantSlice) IsSimpleEnum() bool { + return slice.IsAllUint8() +} + +type IdlTypeDefStruct = []IdlField + +type IdlEnumVariant struct { + Name string `json:"name"` + Docs []string `json:"docs"` // @custom + Fields *IdlEnumFields `json:"fields,omitempty"` +} + +func (variant *IdlEnumVariant) IsUint8() bool { + // it's a simple uint8 if there is no fields data + return variant.Fields == nil +} + +// type IdlEnumFields = IdlEnumFieldsNamed | IdlEnumFieldsTuple; +type IdlEnumFields struct { + IdlEnumFieldsNamed *IdlEnumFieldsNamed + IdlEnumFieldsTuple *IdlEnumFieldsTuple +} + +type IdlEnumFieldsNamed []IdlField + +type IdlEnumFieldsTuple []IdlType + +// TODO: verify with examples +func (env *IdlEnumFields) UnmarshalJSON(data []byte) error { + var temp any + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if temp == nil { + return fmt.Errorf("envelope is nil: %v", env) + } + + switch v := temp.(type) { + case []any: + if len(v) == 0 { + return nil + } + + firstItem := v[0] + + if _, ok := firstItem.(map[string]any)["name"]; ok { + // TODO: + // If has `name` field, then it's most likely a IdlEnumFieldsNamed. + return utilz.TranscodeJSON(temp, &env.IdlEnumFieldsNamed) + } + return utilz.TranscodeJSON(temp, &env.IdlEnumFieldsTuple) + case map[string]any: + // Only one or the other field is set. Returning early is safe + if named, ok := v["IdlEnumFieldsNamed"]; ok { + return utilz.TranscodeJSON(named, &env.IdlEnumFieldsNamed) + } + if tuple, ok := v["IdlEnumFieldsTuple"]; ok { + return utilz.TranscodeJSON(tuple, &env.IdlEnumFieldsTuple) + } + return fmt.Errorf("Unknown type: %s", spew.Sdump(v)) + default: + return fmt.Errorf("Unknown kind: %s", spew.Sdump(temp)) + } +} + +type IdlErrorCode struct { + Code int `json:"code"` + Name string `json:"name"` + Msg string `json:"msg,omitempty"` +} + +func GetIdlEvent(idlTypes *IdlTypeDefSlice, eventName string) (EventIDLTypes, error) { + myevent := IdlEvent{} + for _, typDefs := range *idlTypes { + if typDefs.Name != eventName { + continue + } + fields := *typDefs.Type.Fields + for _, field := range fields { + myevent.Fields = append(myevent.Fields, IdlEventField{ + Name: field.Name, + Type: field.Type, + }) + } + } + if len(myevent.Fields) == 0 { + return EventIDLTypes{}, fmt.Errorf("event %s has no fields", eventName) + } + return EventIDLTypes{ + Event: myevent, + Types: *idlTypes, + }, nil +} diff --git a/cmd/generate-bindings/solana/cre-sdk-go/capabilities/blockchain/solana/bindings/bindings.go b/cmd/generate-bindings/solana/cre-sdk-go/capabilities/blockchain/solana/bindings/bindings.go new file mode 100644 index 00000000..5ce723a1 --- /dev/null +++ b/cmd/generate-bindings/solana/cre-sdk-go/capabilities/blockchain/solana/bindings/bindings.go @@ -0,0 +1,93 @@ +package bindings + +import ( + "fmt" + "reflect" + "strings" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + solana "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/cre-sdk-go/capabilities/blockchain/solana" + realSolana "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" +) + +// No-pointers, strict type check. +func ValidateSubKeyPathAndValue[T any](inputs []solana.SubKeyPathAndFilter) ([][]string, []solana.SubkeyFilterCriteria, error) { + var zero T + root := reflect.TypeOf(zero) + if root.Kind() != reflect.Struct { + return nil, nil, fmt.Errorf("T must be a struct, got %v", root.Kind()) + } + + paths := make([][]string, 0, len(inputs)) + filters := make([]solana.SubkeyFilterCriteria, 0, len(inputs)) + + for i, in := range inputs { + parts := strings.Split(in.SubkeyPath, ".") + if len(parts) == 0 { + return nil, nil, fmt.Errorf("empty subkey path at index %d", i) + } + + leafT, err := resolveLeafTypeNoPtr(root, parts) + if err != nil { + return nil, nil, fmt.Errorf("path %q: %w", in.SubkeyPath, err) + } + if leafT.Kind() == reflect.Struct { + return nil, nil, fmt.Errorf("path %q resolves to a struct (%s); expected a scalar/leaf", in.SubkeyPath, leafT) + } + + // Strict: require exact/assignable dynamic type (no conversions). + if in.Value == nil { + return nil, nil, fmt.Errorf("path %q: got for non-pointer type %s", in.SubkeyPath, leafT) + } + valT := reflect.TypeOf(in.Value) + if !valT.AssignableTo(leafT) { + return nil, nil, fmt.Errorf("path %q: value type %s not assignable to field type %s", in.SubkeyPath, valT, leafT) + } + + paths = append(paths, parts) + filters = append(filters, solana.SubkeyFilterCriteria{ + SubkeyIndex: uint64(i), + Comparers: []primitives.ValueComparator{{Value: in.Value, Operator: primitives.Eq}}, + }) + } + + return paths, filters, nil +} + +func resolveLeafTypeNoPtr(root reflect.Type, parts []string) (reflect.Type, error) { + t := root + for idx, name := range parts { + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("segment %q at #%d: %s is not a struct", name, idx, t) + } + sf, ok := t.FieldByName(name) + if !ok { + return nil, fmt.Errorf("field %q not found on %s", name, t) + } + if sf.PkgPath != "" { // unexported + return nil, fmt.Errorf("field %q on %s is unexported", name, t) + } + t = sf.Type + } + return t, nil +} + +func ExtractEventIDL(eventName string, contractIdl *idl.Idl) (idl.IdlTypeDef, error) { + for _, typeDef := range contractIdl.Types { + if typeDef.Name == eventName { + return typeDef, nil + } + } + return idl.IdlTypeDef{}, fmt.Errorf("type %s not found", eventName) +} + +type DecodedLog[T any] struct { + Log *solana.Log + Data T +} + +// this should be the same encoding expected by the solana forwarder report +func EncodeAccountList(remainingAccounts []*realSolana.AccountMeta) ([32]byte, error) { + return [32]byte{}, nil +} diff --git a/cmd/generate-bindings/solana/cre-sdk-go/capabilities/blockchain/solana/client.pb.go b/cmd/generate-bindings/solana/cre-sdk-go/capabilities/blockchain/solana/client.pb.go new file mode 100644 index 00000000..ecb3ef64 --- /dev/null +++ b/cmd/generate-bindings/solana/cre-sdk-go/capabilities/blockchain/solana/client.pb.go @@ -0,0 +1,473 @@ +package solana + +import ( + "bytes" + "fmt" + "time" + + "github.com/gagliardetto/anchor-go/errors" + binary "github.com/gagliardetto/binary" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + solanatypes "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/types" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/cre-sdk-go/anchorcodec" + + // solanatypes "github.com/chainlink-solana/pkg/solana/logpoller/types" + + "github.com/smartcontractkit/cre-sdk-go/cre" + "google.golang.org/protobuf/reflect/protoreflect" +) + +const ( + PublicKeyLength = 32 + SignatureLength = 64 +) + +type FilterLogTriggerRequest struct { + Address []byte + EventName string + EventSig solanatypes.EventSignature + EventIdl anchorcodec.EventIDLTypes // this is the only change + SubkeyPaths solanatypes.SubKeyPaths + SubkeyFilters []SubkeyFilterCriteria +} +type SubkeyFilterCriteria struct { + SubkeyIndex uint64 + Comparers []primitives.ValueComparator +} + +type SubKeyPathAndFilter struct { + SubkeyPath string + Value any +} + +type Log struct { + ID int64 + FilterID int64 + ChainID string + LogIndex int64 + BlockHash solanatypes.Hash + BlockNumber int64 + BlockTimestamp time.Time + Address solanatypes.PublicKey + EventSig solanatypes.EventSignature + SubkeyValues solanatypes.IndexedValues + TxHash solanatypes.Signature + Data []byte + CreatedAt time.Time + ExpiresAt *time.Time + SequenceNum int64 + Error *string +} + +func LogTrigger(chainSelector uint64, config *FilterLogTriggerRequest) cre.Trigger[*Log, *Log] { + return nil +} + +func (*Log) ProtoMessage() {} + +func (x *Log) ProtoReflect() protoreflect.Message { + return nil // not implemented +} + +type ForwarderReport struct { + AccountHash [32]byte `json:"account_hash"` + Payload []byte `json:"payload"` +} + +func (obj ForwarderReport) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `AccountHash`: + err = encoder.Encode(obj.AccountHash) + if err != nil { + return errors.NewField("AccountHash", err) + } + // Serialize `Payload`: + err = encoder.Encode(obj.Payload) + if err != nil { + return errors.NewField("Payload", err) + } + return nil +} + +func (obj ForwarderReport) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding ForwarderReport: %w", err) + } + return buf.Bytes(), nil +} + +// type Client struct { +// ChainSelector uint64 +// // TODO: https://smartcontract-it.atlassian.net/browse/CAPPL-799 allow defaults for capabilities +// } + +// func (c *Client) GetAccountInfoWithOpts(runtime cre.Runtime, req *GetAccountInfoRequest) cre.Promise[*GetAccountInfoReply] { +// wrapped := &anypb.Any{} + +// capCallResponse := cre.Then(runtime.CallCapability(&sdkpb.CapabilityRequest{ +// Id: "solana" + ":ChainSelector:" + strconv.FormatUint(c.ChainSelector, 10) + "@1.0.0", +// Payload: wrapped, +// Method: "GetAccountInfoWithOpts", +// }), func(i *sdkpb.CapabilityResponse) (*GetAccountInfoReply, error) { +// switch payload := i.Response.(type) { +// case *sdkpb.CapabilityResponse_Error: +// return nil, errors.New(payload.Error) +// case *sdkpb.CapabilityResponse_Payload: +// output := &GetAccountInfoReply{} +// err := payload.Payload.UnmarshalTo(output) +// return output, err +// default: +// return nil, errors.New("unexpected response type") +// } +// }) + +// return capCallResponse +// } + +// func (c *Client) GetMultipleAccountsWithOpts(runtime cre.Runtime, req GetMultipleAccountsRequest) cre.Promise[*GetMultipleAccountsReply] { +// return cre.PromiseFromResult[*GetMultipleAccountsReply](nil, nil) +// } + +// func (c *Client) SimulateTX(runtime cre.Runtime, input *SimulateTXRequest) cre.Promise[*SimulateTXReply] { +// return cre.PromiseFromResult[*SimulateTXReply](nil, nil) +// } + +// func (c *Client) WriteReport(runtime cre.Runtime, input *WriteCreReportRequest) cre.Promise[*WriteReportReply] { +// wrapped := &anypb.Any{} + +// capCallResponse := cre.Then(runtime.CallCapability(&sdkpb.CapabilityRequest{ +// Id: "solana" + ":ChainSelector:" + strconv.FormatUint(c.ChainSelector, 10) + "@1.0.0", +// Payload: wrapped, +// Method: "WriteReport", +// }), func(i *sdkpb.CapabilityResponse) (*WriteReportReply, error) { +// switch payload := i.Response.(type) { +// case *sdkpb.CapabilityResponse_Error: +// return nil, errors.New(payload.Error) +// case *sdkpb.CapabilityResponse_Payload: +// output := &WriteReportReply{} +// err := payload.Payload.UnmarshalTo(output) +// return output, err +// default: +// return nil, errors.New("unexpected response type") +// } +// }) + +// return capCallResponse +// } + +// type SimulateTXRequest struct { +// Receiver solanatypes.PublicKey +// EncodedTransaction []byte +// Opts *SimulateTXOpts +// } + +// type SimulateTXOpts struct { +// // If true the transaction signatures will be verified +// // (default: false, conflicts with ReplaceRecentBlockhash) +// SigVerify bool + +// // Commitment level to simulate the transaction at. +// // (default: "finalized"). +// Commitment CommitmentType + +// // If true the transaction recent blockhash will be replaced with the most recent blockhash. +// // (default: false, conflicts with SigVerify) +// ReplaceRecentBlockhash bool + +// Accounts *SimulateTransactionAccountsOpts +// } + +// type SimulateTransactionAccountsOpts struct { +// // (optional) Encoding for returned Account data, +// // either "base64" (default), "base64+zstd" or "jsonParsed". +// // - "jsonParsed" encoding attempts to use program-specific state parsers +// // to return more human-readable and explicit account state data. +// // If "jsonParsed" is requested but a parser cannot be found, +// // the field falls back to binary encoding, detectable when +// // the data field is type . +// Encoding EncodingType + +// // An array of accounts to return. +// Addresses []solanatypes.PublicKey +// } + +// type SimulateTXReply struct { +// // Error if transaction failed, null if transaction succeeded. +// Err *string + +// // Array of log messages the transaction instructions output during execution, +// // null if simulation failed before the transaction was able to execute +// // (for example due to an invalid blockhash or signature verification failure) +// Logs []string +// // Array of accounts with the same length as the accounts.addresses array in the request. +// Accounts []*Account + +// // The number of compute budget units consumed during the processing of this transaction. +// UnitsConsumed *uint64 +// } + +// represents solana-go EncodingType +// type EncodingType string + +// const ( +// EncodingBase58 EncodingType = "base58" // limited to Account data of less than 129 bytes +// EncodingBase64 EncodingType = "base64" // will return base64 encoded data for Account data of any size +// EncodingBase64Zstd EncodingType = "base64+zstd" // compresses the Account data using Zstandard and base64-encodes the result + +// // attempts to use program-specific state parsers to +// // return more human-readable and explicit account state data. +// // If "jsonParsed" is requested but a parser cannot be found, +// // the field falls back to "base64" encoding, detectable when the data field is type . +// // Cannot be used if specifying dataSlice parameters (offset, length). +// EncodingJSONParsed EncodingType = "jsonParsed" + +// EncodingJSON EncodingType = "json" // NOTE: you're probably looking for EncodingJSONParsed +// ) + +// represents solana-go CommitmentType +// type CommitmentType string + +// const ( +// // The node will query the most recent block confirmed by supermajority +// // of the cluster as having reached maximum lockout, +// // meaning the cluster has recognized this block as finalized. +// CommitmentFinalized CommitmentType = "finalized" + +// // The node will query the most recent block that has been voted on by supermajority of the cluster. +// // - It incorporates votes from gossip and replay. +// // - It does not count votes on descendants of a block, only direct votes on that block. +// // - This confirmation level also upholds "optimistic confirmation" guarantees in release 1.3 and onwards. +// CommitmentConfirmed CommitmentType = "confirmed" + +// // The node will query its most recent block. Note that the block may still be skipped by the cluster. +// CommitmentProcessed CommitmentType = "processed" +// ) + +// type GetAccountInfoRequest struct { +// Account solanatypes.PublicKey +// Opts *GetAccountInfoOpts +// } + +// func (*GetAccountInfoRequest) ProtoMessage() {} + +// func (x *GetAccountInfoRequest) ProtoReflect() protoreflect.Message { +// var file_capabilities_blockchain_evm_v1alpha_client_proto_msgTypes = make([]protoimpl.MessageInfo, 26) +// mi := &file_capabilities_blockchain_evm_v1alpha_client_proto_msgTypes[2] +// if x != nil { +// ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) +// if ms.LoadMessageInfo() == nil { +// ms.StoreMessageInfo(mi) +// } +// return ms +// } +// return mi.MessageOf(x) +// } + +// type GetAccountInfoReply struct { +// RPCContext +// Value *Account +// } + +// func (x *GetAccountInfoReply) ProtoReflect() protoreflect.Message { +// var file_capabilities_blockchain_evm_v1alpha_client_proto_msgTypes = make([]protoimpl.MessageInfo, 26) +// mi := &file_capabilities_blockchain_evm_v1alpha_client_proto_msgTypes[2] +// if x != nil { +// ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) +// if ms.LoadMessageInfo() == nil { +// ms.StoreMessageInfo(mi) +// } +// return ms +// } +// return mi.MessageOf(x) +// } + +// type GetMultipleAccountsRequest struct { +// Accounts []solanatypes.PublicKey +// Opts *GetMultipleAccountsOpts +// } + +// type GetMultipleAccountsReply struct { +// RPCContext +// Value []*Account +// } + +// represents solana-go PublicKey +// type PublicKey [PublicKeyLength]byte + +// // represents solana-go Signature +// type Signature [SignatureLength]byte + +// // represents solana-go Hash +// type Hash PublicKey + +// represents solana-go AccountsMeta +// type AccountMeta struct { +// PublicKey PublicKey +// IsWritable bool +// IsSigner bool +// } + +// represents solana-go AccountMetaSlice +// type AccountMetaSlice []*AccountMeta + +// represents solana-go DataSlice +// type DataSlice struct { +// Offset *uint64 +// Length *uint64 +// } + +// represents solana-go GetAccountInfoOpts +// type GetAccountInfoOpts struct { +// // Encoding for Account data. +// // Either "base58" (slow), "base64", "base64+zstd", or "jsonParsed". +// // - "base58" is limited to Account data of less than 129 bytes. +// // - "base64" will return base64 encoded data for Account data of any size. +// // - "base64+zstd" compresses the Account data using Zstandard and base64-encodes the result. +// // - "jsonParsed" encoding attempts to use program-specific state parsers to return more +// // human-readable and explicit account state data. If "jsonParsed" is requested but a parser +// // cannot be found, the field falls back to "base64" encoding, +// // detectable when the data field is type . +// // +// // This parameter is optional. +// Encoding EncodingType + +// // Commitment requirement. +// // +// // This parameter is optional. Default value is Finalized +// Commitment CommitmentType + +// // dataSlice parameters for limiting returned account data: +// // Limits the returned account data using the provided offset and length fields; +// // only available for "base58", "base64" or "base64+zstd" encodings. +// // +// // This parameter is optional. +// DataSlice *DataSlice + +// // The minimum slot that the request can be evaluated at. +// // This parameter is optional. +// MinContextSlot *uint64 +// } + +// type Context struct { +// Slot uint64 +// } + +// // represents solana-go RPCContext +// type RPCContext struct { +// Context Context +// } + +// type DataBytesOrJSON struct { +// RawDataEncoding EncodingType +// AsDecodedBinary []byte +// AsJSON []byte +// } + +// // represents solana-go Account +// type Account struct { +// // Number of lamports assigned to this account +// Lamports uint64 + +// // Pubkey of the program this account has been assigned to +// Owner PublicKey + +// // Data associated with the account, either as encoded binary data or JSON format {: }, depending on encoding parameter +// Data *DataBytesOrJSON + +// // Boolean indicating if the account contains a program (and is strictly read-only) +// Executable bool + +// // The epoch at which this account will next owe rent +// RentEpoch *big.Int + +// // The amount of storage space required to store the token account +// Space uint64 +// } + +// represents solana-go TransactionDetailsType +// type TransactionDetailsType string + +// const ( +// TransactionDetailsFull TransactionDetailsType = "full" +// TransactionDetailsSignatures TransactionDetailsType = "signatures" +// TransactionDetailsNone TransactionDetailsType = "none" +// TransactionDetailsAccounts TransactionDetailsType = "accounts" +// ) + +// type TransactionVersion int + +// const ( +// LegacyTransactionVersion TransactionVersion = -1 +// legacyVersion = `"legacy"` +// ) + +// type ConfirmationStatusType string + +// const ( +// ConfirmationStatusProcessed ConfirmationStatusType = "processed" +// ConfirmationStatusConfirmed ConfirmationStatusType = "confirmed" +// ConfirmationStatusFinalized ConfirmationStatusType = "finalized" +// ) + +// type TransactionWithMeta struct { +// // The slot this transaction was processed in. +// Slot uint64 + +// // Estimated production time, as Unix timestamp (seconds since the Unix epoch) +// // of when the transaction was processed. +// // Nil if not available. +// BlockTime *UnixTimeSeconds + +// Transaction *DataBytesOrJSON +// // JSON encoded solana-go TransactionMeta +// MetaJSON []byte + +// Version TransactionVersion +// } + +// represents solana-go GetBlockOpts +// type GetBlockOpts struct { +// // Encoding for each returned Transaction, either "json", "jsonParsed", "base58" (slow), "base64". +// // If parameter not provided, the default encoding is "json". +// // - "jsonParsed" encoding attempts to use program-specific instruction parsers to return +// // more human-readable and explicit data in the transaction.message.instructions list. +// // - If "jsonParsed" is requested but a parser cannot be found, the instruction falls back +// // to regular JSON encoding (accounts, data, and programIdIndex fields). +// // +// // This parameter is optional. +// Encoding EncodingType + +// // Level of transaction detail to return. +// // If parameter not provided, the default detail level is "full". +// // +// // This parameter is optional. +// TransactionDetails TransactionDetailsType + +// // Whether to populate the rewards array. +// // If parameter not provided, the default includes rewards. +// // +// // This parameter is optional. +// Rewards *bool + +// // "processed" is not supported. +// // If parameter not provided, the default is "finalized". +// // +// // This parameter is optional. +// Commitment CommitmentType + +// // Max transaction version to return in responses. +// // If the requested block contains a transaction with a higher version, an error will be returned. +// MaxSupportedTransactionVersion *uint64 +// } + +// var ( +// MaxSupportedTransactionVersion0 uint64 = 0 +// MaxSupportedTransactionVersion1 uint64 = 1 +// ) + +// // UnixTimeSeconds represents a UNIX second-resolution timestamp. +// type UnixTimeSeconds int64 + +// type GetMultipleAccountsOpts GetAccountInfoOpts diff --git a/cmd/generate-bindings/solana/gen.go b/cmd/generate-bindings/solana/gen.go new file mode 100644 index 00000000..b74413f3 --- /dev/null +++ b/cmd/generate-bindings/solana/gen.go @@ -0,0 +1,2 @@ +//go:generate go run ./testdata/gen +package solana diff --git a/cmd/generate-bindings/solana/gen_test.go b/cmd/generate-bindings/solana/gen_test.go new file mode 100644 index 00000000..1a97f327 --- /dev/null +++ b/cmd/generate-bindings/solana/gen_test.go @@ -0,0 +1,17 @@ +package solana_test + +import ( + "testing" + + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana" +) + +func TestGenerateBindings(t *testing.T) { + if err := solana.GenerateBindings( + "./testdata/contracts/idl/data_storage.json", + "data_storage", + "./testdata/data_storage", + ); err != nil { + t.Fatal(err) + } +} diff --git a/cmd/generate-bindings/solana/testdata/contracts/idl/data_storage.json b/cmd/generate-bindings/solana/testdata/contracts/idl/data_storage.json new file mode 100644 index 00000000..2ff93a7a --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/contracts/idl/data_storage.json @@ -0,0 +1,511 @@ +{ + "address": "ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL", + "metadata": { + "name": "data_storage", + "version": "0.1.0", + "spec": "0.1.0", + "description": "Created with Anchor" + }, + "instructions": [ + { + "name": "get_multiple_reserves", + "discriminator": [ + 104, + 122, + 140, + 104, + 175, + 151, + 70, + 42 + ], + "accounts": [], + "args": [], + "returns": { + "vec": { + "defined": { + "name": "UpdateReserves" + } + } + } + }, + { + "name": "get_reserves", + "discriminator": [ + 121, + 140, + 237, + 84, + 218, + 105, + 48, + 17 + ], + "accounts": [], + "args": [], + "returns": { + "defined": { + "name": "UpdateReserves" + } + } + }, + { + "name": "get_tuple_reserves", + "discriminator": [ + 189, + 83, + 186, + 20, + 127, + 80, + 109, + 49 + ], + "accounts": [], + "args": [] + }, + { + "name": "initialize_data_account", + "discriminator": [ + 9, + 64, + 78, + 49, + 71, + 193, + 15, + 250 + ], + "accounts": [ + { + "name": "data_account", + "writable": true, + "pda": { + "seeds": [ + { + "kind": "const", + "value": [ + 100, + 97, + 116, + 97, + 95, + 97, + 99, + 99, + 111, + 117, + 110, + 116 + ] + }, + { + "kind": "account", + "path": "user" + } + ] + } + }, + { + "name": "user", + "writable": true, + "signer": true + }, + { + "name": "system_program", + "address": "11111111111111111111111111111111" + } + ], + "args": [ + { + "name": "input", + "type": { + "defined": { + "name": "UserData" + } + } + } + ] + }, + { + "name": "log_access", + "discriminator": [ + 196, + 55, + 194, + 24, + 5, + 224, + 161, + 204 + ], + "accounts": [ + { + "name": "user", + "signer": true + } + ], + "args": [ + { + "name": "message", + "type": "string" + } + ] + }, + { + "name": "on_report", + "discriminator": [ + 214, + 173, + 18, + 221, + 173, + 148, + 151, + 208 + ], + "accounts": [ + { + "name": "user", + "writable": true, + "signer": true + }, + { + "name": "data_account", + "writable": true, + "pda": { + "seeds": [ + { + "kind": "const", + "value": [ + 100, + 97, + 116, + 97, + 95, + 97, + 99, + 99, + 111, + 117, + 110, + 116 + ] + }, + { + "kind": "account", + "path": "user" + } + ] + } + }, + { + "name": "system_program", + "address": "11111111111111111111111111111111" + } + ], + "args": [ + { + "name": "_metadata", + "type": "bytes" + }, + { + "name": "payload", + "type": "bytes" + } + ] + }, + { + "name": "update_key_value_data", + "discriminator": [ + 67, + 137, + 144, + 35, + 210, + 126, + 254, + 79 + ], + "accounts": [ + { + "name": "user", + "writable": true, + "signer": true + }, + { + "name": "data_account", + "writable": true, + "pda": { + "seeds": [ + { + "kind": "const", + "value": [ + 100, + 97, + 116, + 97, + 95, + 97, + 99, + 99, + 111, + 117, + 110, + 116 + ] + }, + { + "kind": "account", + "path": "user" + } + ] + } + } + ], + "args": [ + { + "name": "key", + "type": "string" + }, + { + "name": "value", + "type": "string" + } + ] + }, + { + "name": "update_user_data", + "discriminator": [ + 11, + 13, + 114, + 150, + 194, + 224, + 192, + 78 + ], + "accounts": [ + { + "name": "user", + "writable": true, + "signer": true + }, + { + "name": "data_account", + "writable": true, + "pda": { + "seeds": [ + { + "kind": "const", + "value": [ + 100, + 97, + 116, + 97, + 95, + 97, + 99, + 99, + 111, + 117, + 110, + 116 + ] + }, + { + "kind": "account", + "path": "user" + } + ] + } + } + ], + "args": [ + { + "name": "input", + "type": { + "defined": { + "name": "UserData" + } + } + } + ] + } + ], + "accounts": [ + { + "name": "DataAccount", + "discriminator": [ + 85, + 240, + 182, + 158, + 76, + 7, + 18, + 233 + ] + } + ], + "events": [ + { + "name": "AccessLogged", + "discriminator": [ + 243, + 53, + 225, + 71, + 64, + 120, + 109, + 25 + ] + }, + { + "name": "DynamicEvent", + "discriminator": [ + 236, + 145, + 224, + 161, + 9, + 222, + 218, + 237 + ] + }, + { + "name": "NoFields", + "discriminator": [ + 160, + 156, + 94, + 85, + 77, + 122, + 98, + 240 + ] + } + ], + "errors": [ + { + "code": 6000, + "name": "DataNotFound", + "msg": "data not found" + } + ], + "types": [ + { + "name": "AccessLogged", + "type": { + "kind": "struct", + "fields": [ + { + "name": "caller", + "type": "pubkey" + }, + { + "name": "message", + "type": "string" + } + ] + } + }, + { + "name": "DataAccount", + "type": { + "kind": "struct", + "fields": [ + { + "name": "sender", + "type": "string" + }, + { + "name": "key", + "type": "string" + }, + { + "name": "value", + "type": "string" + } + ] + } + }, + { + "name": "DynamicEvent", + "type": { + "kind": "struct", + "fields": [ + { + "name": "key", + "type": "string" + }, + { + "name": "user_data", + "type": { + "defined": { + "name": "UserData" + } + } + }, + { + "name": "sender", + "type": "string" + }, + { + "name": "metadata", + "type": "bytes" + }, + { + "name": "metadata_array", + "type": { + "vec": "bytes" + } + } + ] + } + }, + { + "name": "NoFields", + "type": { + "kind": "struct", + "fields": [] + } + }, + { + "name": "UpdateReserves", + "type": { + "kind": "struct", + "fields": [ + { + "name": "total_minted", + "type": "u64" + }, + { + "name": "total_reserve", + "type": "u64" + } + ] + } + }, + { + "name": "UserData", + "type": { + "kind": "struct", + "fields": [ + { + "name": "key", + "type": "string" + }, + { + "name": "value", + "type": "string" + } + ] + } + } + ] +} \ No newline at end of file diff --git a/cmd/generate-bindings/solana/testdata/contracts/source/data_storage.rs b/cmd/generate-bindings/solana/testdata/contracts/source/data_storage.rs new file mode 100644 index 00000000..dc5b02b3 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/contracts/source/data_storage.rs @@ -0,0 +1,251 @@ +use anchor_lang::prelude::*; + +declare_id!("ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL"); + +#[program] +pub mod data_storage { + use super::*; + + // simulate + pub fn get_reserves(_ctx: Context) -> Result { + Ok(UpdateReserves { + total_minted: 100, + total_reserve: 200, + }) + } + + // simulate + pub fn get_multiple_reserves( + _ctx: Context, + ) -> Result> { + let reserves = vec![ + UpdateReserves { + total_minted: 100, + total_reserve: 200, + }, + UpdateReserves { + total_minted: 300, + total_reserve: 400, + }, + ]; + + Ok(reserves) + } + + // simulate + pub fn get_tuple_reserves(_ctx: Context) -> Result<(u64, u64)> { + Ok((100, 200)) + } + + pub fn initialize_data_account(ctx: Context, input: UserData) -> Result<()> { + ctx.accounts.data_account.sender = ctx.accounts.user.key().to_string(); + ctx.accounts.data_account.key = input.key; + ctx.accounts.data_account.value = input.value; + Ok(()) + } + + // no event + pub fn update_key_value_data( + ctx: Context, + key: String, + value: String, + ) -> Result<()> { + let acc = &mut ctx.accounts.data_account; + + acc.sender = ctx.accounts.user.key().to_string(); + acc.key = key.clone(); + acc.value = value.clone(); + + let user_data = UserData { + key: key.clone(), + value: value.clone(), + }; + + emit!(DynamicEvent { + key: key, + user_data: user_data, + sender: ctx.accounts.user.key().to_string(), + metadata: vec![1, 2, 3], + metadata_array: vec![], + }); + + Ok(()) + } + + pub fn update_user_data(ctx: Context, input: UserData) -> Result<()> { + let acc = &mut ctx.accounts.data_account; + + acc.sender = ctx.accounts.user.key().to_string(); + acc.key = input.key.clone(); + acc.value = input.value.clone(); + + let user_data_cloned = input.clone(); + + emit!(DynamicEvent { + key: input.key, + user_data: user_data_cloned, + sender: ctx.accounts.user.key().to_string(), + metadata: vec![1, 2, 3], + metadata_array: vec![], + }); + + Ok(()) + } + + pub fn log_access(ctx: Context, message: String) -> Result<()> { + emit!(AccessLogged { + caller: ctx.accounts.user.key(), + message, + }); + Ok(()) + } + + pub fn on_report(ctx: Context, _metadata: Vec, payload: Vec) -> Result<()> { + // decode payload into UserData + let mut bytes: &[u8] = &payload; + let user = UserData::deserialize(&mut bytes)?; // requires AnchorDeserialize on UserData + + // update mapping-equivalent: this user's PDA + let acc = &mut ctx.accounts.data_account; + acc.sender = ctx.accounts.user.key().to_string(); + acc.key = user.key.clone(); + acc.value = user.value.clone(); + + let user_cloned = user.clone(); + + // emit event + emit!(DynamicEvent { + sender: ctx.accounts.user.key().to_string(), + key: user.key, + user_data: user_cloned, + metadata: vec![1, 2, 3], + metadata_array: vec![], + }); + + Ok(()) + } + + pub fn handle_forwarder_report( + _ctx: Context, + _report: ForwarderReport, + ) -> Result<()> { + // TODO: implement forwarding logic here + Ok(()) + } +} + +// read data from here +#[account] +pub struct DataAccount { + pub sender: String, + pub key: String, + pub value: String, +} + +#[derive(Accounts)] +pub struct Initialize<'info> { + #[account( + init, + payer = user, + space = 8 + + (4 + 64) // sender max 64 + + (4 + 64) // key max 64 + + (4 + 256) // value max 256 + + 1, // bump + seeds = [b"data_account", user.key().as_ref()], // seed for deterministic PDA + bump + )] + pub data_account: Account<'info, DataAccount>, + + #[account(mut)] + pub user: Signer<'info>, + pub system_program: Program<'info, System>, +} + +#[derive(Accounts)] +pub struct UpdateData<'info> { + #[account(mut)] + pub user: Signer<'info>, + + // PDA: one account per user, same seeds as Initialize + #[account( + mut, + seeds = [b"data_account", user.key().as_ref()], + bump, + )] + pub data_account: Account<'info, DataAccount>, +} + +// just use to have a complex event type ? +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug, PartialEq)] +pub struct UserData { + pub key: String, + pub value: String, +} + +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug, PartialEq)] +pub struct ForwarderReport { + pub account_hash: Vec, + pub payload: Vec, +} + +#[event] +pub struct DynamicEvent { + pub key: String, + pub user_data: UserData, + pub sender: String, + pub metadata: Vec, + pub metadata_array: Vec>, +} + +#[event] +pub struct AccessLogged { + pub caller: Pubkey, + pub message: String, +} + +#[event] +pub struct NoFields {} + +#[error_code] +pub enum DataError { + #[msg("data not found")] + DataNotFound = 0, +} + +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug)] +pub struct UpdateReserves { + pub total_minted: u64, + pub total_reserve: u64, +} + +// empty contexts +#[derive(Accounts)] +pub struct GetReserves {} +#[derive(Accounts)] +pub struct GetMultipleReserves {} +#[derive(Accounts)] +pub struct GetTupleReserves {} + +#[derive(Accounts)] +pub struct HandleForwarderReport {} + +#[derive(Accounts)] +pub struct LogAccess<'info> { + pub user: Signer<'info>, +} + +#[derive(Accounts)] +pub struct OnReport<'info> { + #[account(mut)] + pub user: Signer<'info>, + + #[account( + mut, + seeds = [b"data_account", user.key().as_ref()], + bump, + )] + pub data_account: Account<'info, DataAccount>, + + pub system_program: Program<'info, System>, +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/accounts.go b/cmd/generate-bindings/solana/testdata/data_storage/accounts.go new file mode 100644 index 00000000..f5711951 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/accounts.go @@ -0,0 +1,50 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains parsers for the accounts defined in the IDL. +// Code generated by https://github.com/smartcontractkit/cre-cli. DO NOT EDIT. + +package data_storage + +import ( + "fmt" + binary "github.com/gagliardetto/binary" +) + +func ParseAnyAccount(accountData []byte) (any, error) { + decoder := binary.NewBorshDecoder(accountData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek account discriminator: %w", err) + } + switch discriminator { + case Account_DataAccount: + value := new(DataAccount) + err := value.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account as DataAccount: %w", err) + } + return value, nil + default: + return nil, fmt.Errorf("unknown discriminator: %s", binary.FormatDiscriminator(discriminator)) + } +} + +func ParseAccount_DataAccount(accountData []byte) (*DataAccount, error) { + decoder := binary.NewBorshDecoder(accountData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek discriminator: %w", err) + } + if discriminator != Account_DataAccount { + return nil, fmt.Errorf("expected discriminator %v, got %s", Account_DataAccount, binary.FormatDiscriminator(discriminator)) + } + acc := new(DataAccount) + err = acc.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account of type DataAccount: %w", err) + } + return acc, nil +} + +func (c *Codec) DecodeDataAccount(data []byte) (*DataAccount, error) { + return ParseAccount_DataAccount(data) +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/constants.go b/cmd/generate-bindings/solana/testdata/data_storage/constants.go new file mode 100644 index 00000000..0c192cb2 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/constants.go @@ -0,0 +1,4 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains constants. + +package data_storage diff --git a/cmd/generate-bindings/solana/testdata/data_storage/constructor.go b/cmd/generate-bindings/solana/testdata/data_storage/constructor.go new file mode 100644 index 00000000..14846baf --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/constructor.go @@ -0,0 +1,46 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains the constructor for the program. + +package data_storage + +import ( + "encoding/json" + anchorcodec "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/cre-sdk-go/anchorcodec" + solana "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" +) + +var IDL = "{\"address\":\"ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL\",\"metadata\":{\"name\":\"data_storage\",\"version\":\"0.1.0\",\"spec\":\"0.1.0\",\"description\":\"Created with Anchor\"},\"instructions\":[{\"name\":\"get_multiple_reserves\",\"discriminator\":[104,122,140,104,175,151,70,42],\"accounts\":[],\"args\":[],\"returns\":{\"vec\":{\"defined\":{\"name\":\"UpdateReserves\"}}}},{\"name\":\"get_reserves\",\"discriminator\":[121,140,237,84,218,105,48,17],\"accounts\":[],\"args\":[],\"returns\":{\"defined\":{\"name\":\"UpdateReserves\"}}},{\"name\":\"get_tuple_reserves\",\"discriminator\":[189,83,186,20,127,80,109,49],\"accounts\":[],\"args\":[]},{\"name\":\"initialize_data_account\",\"discriminator\":[9,64,78,49,71,193,15,250],\"accounts\":[{\"name\":\"data_account\",\"writable\":true,\"pda\":{\"seeds\":[{\"kind\":\"const\",\"value\":[100,97,116,97,95,97,99,99,111,117,110,116]},{\"kind\":\"account\",\"path\":\"user\"}]}},{\"name\":\"user\",\"writable\":true,\"signer\":true},{\"name\":\"system_program\",\"address\":\"11111111111111111111111111111111\"}],\"args\":[{\"name\":\"input\",\"type\":{\"defined\":{\"name\":\"UserData\"}}}]},{\"name\":\"log_access\",\"discriminator\":[196,55,194,24,5,224,161,204],\"accounts\":[{\"name\":\"user\",\"signer\":true}],\"args\":[{\"name\":\"message\",\"type\":\"string\"}]},{\"name\":\"on_report\",\"discriminator\":[214,173,18,221,173,148,151,208],\"accounts\":[{\"name\":\"user\",\"writable\":true,\"signer\":true},{\"name\":\"data_account\",\"writable\":true,\"pda\":{\"seeds\":[{\"kind\":\"const\",\"value\":[100,97,116,97,95,97,99,99,111,117,110,116]},{\"kind\":\"account\",\"path\":\"user\"}]}},{\"name\":\"system_program\",\"address\":\"11111111111111111111111111111111\"}],\"args\":[{\"name\":\"_metadata\",\"type\":\"bytes\"},{\"name\":\"payload\",\"type\":\"bytes\"}]},{\"name\":\"update_key_value_data\",\"discriminator\":[67,137,144,35,210,126,254,79],\"accounts\":[{\"name\":\"user\",\"writable\":true,\"signer\":true},{\"name\":\"data_account\",\"writable\":true,\"pda\":{\"seeds\":[{\"kind\":\"const\",\"value\":[100,97,116,97,95,97,99,99,111,117,110,116]},{\"kind\":\"account\",\"path\":\"user\"}]}}],\"args\":[{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"value\",\"type\":\"string\"}]},{\"name\":\"update_user_data\",\"discriminator\":[11,13,114,150,194,224,192,78],\"accounts\":[{\"name\":\"user\",\"writable\":true,\"signer\":true},{\"name\":\"data_account\",\"writable\":true,\"pda\":{\"seeds\":[{\"kind\":\"const\",\"value\":[100,97,116,97,95,97,99,99,111,117,110,116]},{\"kind\":\"account\",\"path\":\"user\"}]}}],\"args\":[{\"name\":\"input\",\"type\":{\"defined\":{\"name\":\"UserData\"}}}]}],\"accounts\":[{\"name\":\"DataAccount\",\"discriminator\":[85,240,182,158,76,7,18,233]}],\"events\":[{\"name\":\"AccessLogged\",\"discriminator\":[243,53,225,71,64,120,109,25]},{\"name\":\"DynamicEvent\",\"discriminator\":[236,145,224,161,9,222,218,237]},{\"name\":\"NoFields\",\"discriminator\":[160,156,94,85,77,122,98,240]}],\"errors\":[{\"code\":6000,\"name\":\"DataNotFound\",\"msg\":\"data not found\"}],\"types\":[{\"name\":\"AccessLogged\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"caller\",\"type\":\"pubkey\"},{\"name\":\"message\",\"type\":\"string\"}]}},{\"name\":\"DataAccount\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"sender\",\"type\":\"string\"},{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"value\",\"type\":\"string\"}]}},{\"name\":\"DynamicEvent\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"user_data\",\"type\":{\"defined\":{\"name\":\"UserData\"}}},{\"name\":\"sender\",\"type\":\"string\"},{\"name\":\"metadata\",\"type\":\"bytes\"},{\"name\":\"metadata_array\",\"type\":{\"vec\":\"bytes\"}}]}},{\"name\":\"NoFields\",\"type\":{\"kind\":\"struct\",\"fields\":[]}},{\"name\":\"UpdateReserves\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"total_minted\",\"type\":\"u64\"},{\"name\":\"total_reserve\",\"type\":\"u64\"}]}},{\"name\":\"UserData\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"value\",\"type\":\"string\"}]}}]}" + +type DataStorage struct { + IdlTypes *anchorcodec.IdlTypeDefSlice + client *solana.Client + Codec DataStorageCodec +} + +type Codec struct{} + +func NewDataStorage(client *solana.Client) (*DataStorage, error) { + type idlTypesStruct struct { + anchorcodec.IdlTypeDefSlice `json:"types"` + } + var idlTypes idlTypesStruct + err := json.Unmarshal([]byte(IDL), &idlTypes) + if err != nil { + return nil, err + } + return &DataStorage{ + Codec: &Codec{}, + IdlTypes: &idlTypes.IdlTypeDefSlice, + client: client, + }, nil +} + +type DataStorageCodec interface { + DecodeDataAccount(data []byte) (*DataAccount, error) + EncodeAccessLoggedStruct(in AccessLogged) ([]byte, error) + EncodeDataAccountStruct(in DataAccount) ([]byte, error) + EncodeDynamicEventStruct(in DynamicEvent) ([]byte, error) + EncodeNoFieldsStruct(in NoFields) ([]byte, error) + EncodeUpdateReservesStruct(in UpdateReserves) ([]byte, error) + EncodeUserDataStruct(in UserData) ([]byte, error) +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/discriminators.go b/cmd/generate-bindings/solana/testdata/data_storage/discriminators.go new file mode 100644 index 00000000..da7b9135 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/discriminators.go @@ -0,0 +1,28 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains the discriminators for accounts and events defined in the IDL. + +package data_storage + +// Account discriminators +var ( + Account_DataAccount = [8]byte{85, 240, 182, 158, 76, 7, 18, 233} +) + +// Event discriminators +var ( + Event_AccessLogged = [8]byte{243, 53, 225, 71, 64, 120, 109, 25} + Event_DynamicEvent = [8]byte{236, 145, 224, 161, 9, 222, 218, 237} + Event_NoFields = [8]byte{160, 156, 94, 85, 77, 122, 98, 240} +) + +// Instruction discriminators +var ( + Instruction_GetMultipleReserves = [8]byte{104, 122, 140, 104, 175, 151, 70, 42} + Instruction_GetReserves = [8]byte{121, 140, 237, 84, 218, 105, 48, 17} + Instruction_GetTupleReserves = [8]byte{189, 83, 186, 20, 127, 80, 109, 49} + Instruction_InitializeDataAccount = [8]byte{9, 64, 78, 49, 71, 193, 15, 250} + Instruction_LogAccess = [8]byte{196, 55, 194, 24, 5, 224, 161, 204} + Instruction_OnReport = [8]byte{214, 173, 18, 221, 173, 148, 151, 208} + Instruction_UpdateKeyValueData = [8]byte{67, 137, 144, 35, 210, 126, 254, 79} + Instruction_UpdateUserData = [8]byte{11, 13, 114, 150, 194, 224, 192, 78} +) diff --git a/cmd/generate-bindings/solana/testdata/data_storage/errors.go b/cmd/generate-bindings/solana/testdata/data_storage/errors.go new file mode 100644 index 00000000..576b057b --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/errors.go @@ -0,0 +1,4 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains errors. + +package data_storage diff --git a/cmd/generate-bindings/solana/testdata/data_storage/events.go b/cmd/generate-bindings/solana/testdata/data_storage/events.go new file mode 100644 index 00000000..804e0344 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/events.go @@ -0,0 +1,93 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains parsers for the events defined in the IDL. + +package data_storage + +import ( + "fmt" + binary "github.com/gagliardetto/binary" +) + +func ParseAnyEvent(eventData []byte) (any, error) { + decoder := binary.NewBorshDecoder(eventData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek event discriminator: %w", err) + } + switch discriminator { + case Event_AccessLogged: + value := new(AccessLogged) + err := value.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event as AccessLogged: %w", err) + } + return value, nil + case Event_DynamicEvent: + value := new(DynamicEvent) + err := value.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event as DynamicEvent: %w", err) + } + return value, nil + case Event_NoFields: + value := new(NoFields) + err := value.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event as NoFields: %w", err) + } + return value, nil + default: + return nil, fmt.Errorf("unknown discriminator: %s", binary.FormatDiscriminator(discriminator)) + } +} + +func ParseEvent_AccessLogged(eventData []byte) (*AccessLogged, error) { + decoder := binary.NewBorshDecoder(eventData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek discriminator: %w", err) + } + if discriminator != Event_AccessLogged { + return nil, fmt.Errorf("expected discriminator %v, got %s", Event_AccessLogged, binary.FormatDiscriminator(discriminator)) + } + event := new(AccessLogged) + err = event.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event of type AccessLogged: %w", err) + } + return event, nil +} + +func ParseEvent_DynamicEvent(eventData []byte) (*DynamicEvent, error) { + decoder := binary.NewBorshDecoder(eventData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek discriminator: %w", err) + } + if discriminator != Event_DynamicEvent { + return nil, fmt.Errorf("expected discriminator %v, got %s", Event_DynamicEvent, binary.FormatDiscriminator(discriminator)) + } + event := new(DynamicEvent) + err = event.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event of type DynamicEvent: %w", err) + } + return event, nil +} + +func ParseEvent_NoFields(eventData []byte) (*NoFields, error) { + decoder := binary.NewBorshDecoder(eventData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek discriminator: %w", err) + } + if discriminator != Event_NoFields { + return nil, fmt.Errorf("expected discriminator %v, got %s", Event_NoFields, binary.FormatDiscriminator(discriminator)) + } + event := new(NoFields) + err = event.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event of type NoFields: %w", err) + } + return event, nil +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/fetchers.go b/cmd/generate-bindings/solana/testdata/data_storage/fetchers.go new file mode 100644 index 00000000..606a2030 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/fetchers.go @@ -0,0 +1,4 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains fetcher functions. + +package data_storage diff --git a/cmd/generate-bindings/solana/testdata/data_storage/instructions.go b/cmd/generate-bindings/solana/testdata/data_storage/instructions.go new file mode 100644 index 00000000..2bed748e --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/instructions.go @@ -0,0 +1,1181 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains instructions and instruction parsers. + +package data_storage + +import ( + "bytes" + "fmt" + errors "github.com/gagliardetto/anchor-go/errors" + binary "github.com/gagliardetto/binary" + solanago "github.com/gagliardetto/solana-go" +) + +// Builds a "get_multiple_reserves" instruction. +func NewGetMultipleReservesInstruction() (solanago.Instruction, error) { + accounts__ := solanago.AccountMetaSlice{} + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + nil, + ), nil +} + +// Builds a "get_reserves" instruction. +func NewGetReservesInstruction() (solanago.Instruction, error) { + accounts__ := solanago.AccountMetaSlice{} + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + nil, + ), nil +} + +// Builds a "get_tuple_reserves" instruction. +func NewGetTupleReservesInstruction() (solanago.Instruction, error) { + accounts__ := solanago.AccountMetaSlice{} + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + nil, + ), nil +} + +// Builds a "initialize_data_account" instruction. +func NewInitializeDataAccountInstruction( + // Params: + inputParam UserData, + + // Accounts: + dataAccountAccount solanago.PublicKey, + userAccount solanago.PublicKey, + systemProgramAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_InitializeDataAccount[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `inputParam`: + err = enc__.Encode(inputParam) + if err != nil { + return nil, errors.NewField("inputParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "data_account": Writable, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(dataAccountAccount, true, false)) + // Account 1 "user": Writable, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, true, true)) + // Account 2 "system_program": Read-only, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(systemProgramAccount, false, false)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "log_access" instruction. +func NewLogAccessInstruction( + // Params: + messageParam string, + + // Accounts: + userAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_LogAccess[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `messageParam`: + err = enc__.Encode(messageParam) + if err != nil { + return nil, errors.NewField("messageParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "user": Read-only, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, false, true)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "on_report" instruction. +func NewOnReportInstruction( + // Params: + metadataParam []byte, + payloadParam []byte, + + // Accounts: + userAccount solanago.PublicKey, + dataAccountAccount solanago.PublicKey, + systemProgramAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_OnReport[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `metadataParam`: + err = enc__.Encode(metadataParam) + if err != nil { + return nil, errors.NewField("metadataParam", err) + } + // Serialize `payloadParam`: + err = enc__.Encode(payloadParam) + if err != nil { + return nil, errors.NewField("payloadParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "user": Writable, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, true, true)) + // Account 1 "data_account": Writable, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(dataAccountAccount, true, false)) + // Account 2 "system_program": Read-only, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(systemProgramAccount, false, false)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "update_key_value_data" instruction. +func NewUpdateKeyValueDataInstruction( + // Params: + keyParam string, + valueParam string, + + // Accounts: + userAccount solanago.PublicKey, + dataAccountAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_UpdateKeyValueData[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `keyParam`: + err = enc__.Encode(keyParam) + if err != nil { + return nil, errors.NewField("keyParam", err) + } + // Serialize `valueParam`: + err = enc__.Encode(valueParam) + if err != nil { + return nil, errors.NewField("valueParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "user": Writable, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, true, true)) + // Account 1 "data_account": Writable, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(dataAccountAccount, true, false)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "update_user_data" instruction. +func NewUpdateUserDataInstruction( + // Params: + inputParam UserData, + + // Accounts: + userAccount solanago.PublicKey, + dataAccountAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_UpdateUserData[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `inputParam`: + err = enc__.Encode(inputParam) + if err != nil { + return nil, errors.NewField("inputParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "user": Writable, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, true, true)) + // Account 1 "data_account": Writable, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(dataAccountAccount, true, false)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +type GetMultipleReservesInstruction struct{} + +func (obj *GetMultipleReservesInstruction) GetDiscriminator() []byte { + return Instruction_GetMultipleReserves[:] +} + +// UnmarshalWithDecoder unmarshals the GetMultipleReservesInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *GetMultipleReservesInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "GetMultipleReservesInstruction", err) + } + if discriminator != Instruction_GetMultipleReserves { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "GetMultipleReservesInstruction", Instruction_GetMultipleReserves, discriminator) + } + return nil +} + +func (obj *GetMultipleReservesInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + return []uint8{}, nil +} + +func (obj *GetMultipleReservesInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + return nil +} + +func (obj *GetMultipleReservesInstruction) GetAccountKeys() []solanago.PublicKey { + return []solanago.PublicKey{} +} + +// Unmarshal unmarshals the GetMultipleReservesInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *GetMultipleReservesInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling GetMultipleReservesInstruction: %w", err) + } + return nil +} + +// UnmarshalGetMultipleReservesInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalGetMultipleReservesInstruction(buf []byte) (*GetMultipleReservesInstruction, error) { + obj := new(GetMultipleReservesInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type GetReservesInstruction struct{} + +func (obj *GetReservesInstruction) GetDiscriminator() []byte { + return Instruction_GetReserves[:] +} + +// UnmarshalWithDecoder unmarshals the GetReservesInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *GetReservesInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "GetReservesInstruction", err) + } + if discriminator != Instruction_GetReserves { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "GetReservesInstruction", Instruction_GetReserves, discriminator) + } + return nil +} + +func (obj *GetReservesInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + return []uint8{}, nil +} + +func (obj *GetReservesInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + return nil +} + +func (obj *GetReservesInstruction) GetAccountKeys() []solanago.PublicKey { + return []solanago.PublicKey{} +} + +// Unmarshal unmarshals the GetReservesInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *GetReservesInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling GetReservesInstruction: %w", err) + } + return nil +} + +// UnmarshalGetReservesInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalGetReservesInstruction(buf []byte) (*GetReservesInstruction, error) { + obj := new(GetReservesInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type GetTupleReservesInstruction struct{} + +func (obj *GetTupleReservesInstruction) GetDiscriminator() []byte { + return Instruction_GetTupleReserves[:] +} + +// UnmarshalWithDecoder unmarshals the GetTupleReservesInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *GetTupleReservesInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "GetTupleReservesInstruction", err) + } + if discriminator != Instruction_GetTupleReserves { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "GetTupleReservesInstruction", Instruction_GetTupleReserves, discriminator) + } + return nil +} + +func (obj *GetTupleReservesInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + return []uint8{}, nil +} + +func (obj *GetTupleReservesInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + return nil +} + +func (obj *GetTupleReservesInstruction) GetAccountKeys() []solanago.PublicKey { + return []solanago.PublicKey{} +} + +// Unmarshal unmarshals the GetTupleReservesInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *GetTupleReservesInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling GetTupleReservesInstruction: %w", err) + } + return nil +} + +// UnmarshalGetTupleReservesInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalGetTupleReservesInstruction(buf []byte) (*GetTupleReservesInstruction, error) { + obj := new(GetTupleReservesInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type InitializeDataAccountInstruction struct { + Input UserData `json:"input"` + + // Accounts: + DataAccount solanago.PublicKey `json:"data_account"` + DataAccountWritable bool `json:"data_account_writable"` + User solanago.PublicKey `json:"user"` + UserWritable bool `json:"user_writable"` + UserSigner bool `json:"user_signer"` + SystemProgram solanago.PublicKey `json:"system_program"` +} + +func (obj *InitializeDataAccountInstruction) GetDiscriminator() []byte { + return Instruction_InitializeDataAccount[:] +} + +// UnmarshalWithDecoder unmarshals the InitializeDataAccountInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *InitializeDataAccountInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "InitializeDataAccountInstruction", err) + } + if discriminator != Instruction_InitializeDataAccount { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "InitializeDataAccountInstruction", Instruction_InitializeDataAccount, discriminator) + } + // Deserialize `Input`: + err = decoder.Decode(&obj.Input) + if err != nil { + return err + } + return nil +} + +func (obj *InitializeDataAccountInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from data_account account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "data_account", err) + } + indices = append(indices, index) + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + // Decode from system_program account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "system_program", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *InitializeDataAccountInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 3 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 3, len(indices)) + } + indexOffset := 0 + // Set data_account account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "data_account", len(accountKeys)-1) + } + obj.DataAccount = accountKeys[indices[indexOffset]] + indexOffset++ + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + // Set system_program account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "system_program", len(accountKeys)-1) + } + obj.SystemProgram = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *InitializeDataAccountInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.DataAccount) + keys = append(keys, obj.User) + keys = append(keys, obj.SystemProgram) + return keys +} + +// Unmarshal unmarshals the InitializeDataAccountInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *InitializeDataAccountInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling InitializeDataAccountInstruction: %w", err) + } + return nil +} + +// UnmarshalInitializeDataAccountInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalInitializeDataAccountInstruction(buf []byte) (*InitializeDataAccountInstruction, error) { + obj := new(InitializeDataAccountInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type LogAccessInstruction struct { + Message string `json:"message"` + + // Accounts: + User solanago.PublicKey `json:"user"` + UserSigner bool `json:"user_signer"` +} + +func (obj *LogAccessInstruction) GetDiscriminator() []byte { + return Instruction_LogAccess[:] +} + +// UnmarshalWithDecoder unmarshals the LogAccessInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *LogAccessInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "LogAccessInstruction", err) + } + if discriminator != Instruction_LogAccess { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "LogAccessInstruction", Instruction_LogAccess, discriminator) + } + // Deserialize `Message`: + err = decoder.Decode(&obj.Message) + if err != nil { + return err + } + return nil +} + +func (obj *LogAccessInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *LogAccessInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 1 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 1, len(indices)) + } + indexOffset := 0 + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *LogAccessInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.User) + return keys +} + +// Unmarshal unmarshals the LogAccessInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *LogAccessInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling LogAccessInstruction: %w", err) + } + return nil +} + +// UnmarshalLogAccessInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalLogAccessInstruction(buf []byte) (*LogAccessInstruction, error) { + obj := new(LogAccessInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type OnReportInstruction struct { + Metadata []byte `json:"_metadata"` + Payload []byte `json:"payload"` + + // Accounts: + User solanago.PublicKey `json:"user"` + UserWritable bool `json:"user_writable"` + UserSigner bool `json:"user_signer"` + DataAccount solanago.PublicKey `json:"data_account"` + DataAccountWritable bool `json:"data_account_writable"` + SystemProgram solanago.PublicKey `json:"system_program"` +} + +func (obj *OnReportInstruction) GetDiscriminator() []byte { + return Instruction_OnReport[:] +} + +// UnmarshalWithDecoder unmarshals the OnReportInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *OnReportInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "OnReportInstruction", err) + } + if discriminator != Instruction_OnReport { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "OnReportInstruction", Instruction_OnReport, discriminator) + } + // Deserialize `Metadata`: + err = decoder.Decode(&obj.Metadata) + if err != nil { + return err + } + // Deserialize `Payload`: + err = decoder.Decode(&obj.Payload) + if err != nil { + return err + } + return nil +} + +func (obj *OnReportInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + // Decode from data_account account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "data_account", err) + } + indices = append(indices, index) + // Decode from system_program account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "system_program", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *OnReportInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 3 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 3, len(indices)) + } + indexOffset := 0 + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + // Set data_account account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "data_account", len(accountKeys)-1) + } + obj.DataAccount = accountKeys[indices[indexOffset]] + indexOffset++ + // Set system_program account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "system_program", len(accountKeys)-1) + } + obj.SystemProgram = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *OnReportInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.User) + keys = append(keys, obj.DataAccount) + keys = append(keys, obj.SystemProgram) + return keys +} + +// Unmarshal unmarshals the OnReportInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *OnReportInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling OnReportInstruction: %w", err) + } + return nil +} + +// UnmarshalOnReportInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalOnReportInstruction(buf []byte) (*OnReportInstruction, error) { + obj := new(OnReportInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type UpdateKeyValueDataInstruction struct { + Key string `json:"key"` + Value string `json:"value"` + + // Accounts: + User solanago.PublicKey `json:"user"` + UserWritable bool `json:"user_writable"` + UserSigner bool `json:"user_signer"` + DataAccount solanago.PublicKey `json:"data_account"` + DataAccountWritable bool `json:"data_account_writable"` +} + +func (obj *UpdateKeyValueDataInstruction) GetDiscriminator() []byte { + return Instruction_UpdateKeyValueData[:] +} + +// UnmarshalWithDecoder unmarshals the UpdateKeyValueDataInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *UpdateKeyValueDataInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "UpdateKeyValueDataInstruction", err) + } + if discriminator != Instruction_UpdateKeyValueData { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "UpdateKeyValueDataInstruction", Instruction_UpdateKeyValueData, discriminator) + } + // Deserialize `Key`: + err = decoder.Decode(&obj.Key) + if err != nil { + return err + } + // Deserialize `Value`: + err = decoder.Decode(&obj.Value) + if err != nil { + return err + } + return nil +} + +func (obj *UpdateKeyValueDataInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + // Decode from data_account account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "data_account", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *UpdateKeyValueDataInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 2 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 2, len(indices)) + } + indexOffset := 0 + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + // Set data_account account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "data_account", len(accountKeys)-1) + } + obj.DataAccount = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *UpdateKeyValueDataInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.User) + keys = append(keys, obj.DataAccount) + return keys +} + +// Unmarshal unmarshals the UpdateKeyValueDataInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *UpdateKeyValueDataInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling UpdateKeyValueDataInstruction: %w", err) + } + return nil +} + +// UnmarshalUpdateKeyValueDataInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalUpdateKeyValueDataInstruction(buf []byte) (*UpdateKeyValueDataInstruction, error) { + obj := new(UpdateKeyValueDataInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type UpdateUserDataInstruction struct { + Input UserData `json:"input"` + + // Accounts: + User solanago.PublicKey `json:"user"` + UserWritable bool `json:"user_writable"` + UserSigner bool `json:"user_signer"` + DataAccount solanago.PublicKey `json:"data_account"` + DataAccountWritable bool `json:"data_account_writable"` +} + +func (obj *UpdateUserDataInstruction) GetDiscriminator() []byte { + return Instruction_UpdateUserData[:] +} + +// UnmarshalWithDecoder unmarshals the UpdateUserDataInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *UpdateUserDataInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "UpdateUserDataInstruction", err) + } + if discriminator != Instruction_UpdateUserData { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "UpdateUserDataInstruction", Instruction_UpdateUserData, discriminator) + } + // Deserialize `Input`: + err = decoder.Decode(&obj.Input) + if err != nil { + return err + } + return nil +} + +func (obj *UpdateUserDataInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + // Decode from data_account account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "data_account", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *UpdateUserDataInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 2 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 2, len(indices)) + } + indexOffset := 0 + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + // Set data_account account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "data_account", len(accountKeys)-1) + } + obj.DataAccount = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *UpdateUserDataInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.User) + keys = append(keys, obj.DataAccount) + return keys +} + +// Unmarshal unmarshals the UpdateUserDataInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *UpdateUserDataInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling UpdateUserDataInstruction: %w", err) + } + return nil +} + +// UnmarshalUpdateUserDataInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalUpdateUserDataInstruction(buf []byte) (*UpdateUserDataInstruction, error) { + obj := new(UpdateUserDataInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +// Instruction interface defines common methods for all instruction types +type Instruction interface { + GetDiscriminator() []byte + + UnmarshalWithDecoder(decoder *binary.Decoder) error + + UnmarshalAccountIndices(buf []byte) ([]uint8, error) + + PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error + + GetAccountKeys() []solanago.PublicKey +} + +// ParseInstruction parses instruction data and optionally populates accounts +// If accountIndicesData is nil or empty, accounts will not be populated +func ParseInstruction(instructionData []byte, accountIndicesData []byte, accountKeys []solanago.PublicKey) (Instruction, error) { + // Validate inputs + if len(instructionData) < 8 { + return nil, fmt.Errorf("instruction data too short: expected at least 8 bytes, got %d", len(instructionData)) + } + // Extract discriminator + discriminator := [8]byte{} + copy(discriminator[:], instructionData[0:8]) + // Parse based on discriminator + switch discriminator { + case Instruction_GetMultipleReserves: + instruction := new(GetMultipleReservesInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as GetMultipleReservesInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_GetReserves: + instruction := new(GetReservesInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as GetReservesInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_GetTupleReserves: + instruction := new(GetTupleReservesInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as GetTupleReservesInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_InitializeDataAccount: + instruction := new(InitializeDataAccountInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as InitializeDataAccountInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_LogAccess: + instruction := new(LogAccessInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as LogAccessInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_OnReport: + instruction := new(OnReportInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as OnReportInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_UpdateKeyValueData: + instruction := new(UpdateKeyValueDataInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as UpdateKeyValueDataInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_UpdateUserData: + instruction := new(UpdateUserDataInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as UpdateUserDataInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + default: + return nil, fmt.Errorf("unknown instruction discriminator: %s", binary.FormatDiscriminator(discriminator)) + } +} + +// ParseInstructionTyped parses instruction data and returns a specific instruction type // T must implement the Instruction interface +func ParseInstructionTyped[T Instruction](instructionData []byte, accountIndicesData []byte, accountKeys []solanago.PublicKey) (T, error) { + instruction, err := ParseInstruction(instructionData, accountIndicesData, accountKeys) + if err != nil { + return *new(T), err + } + typed, ok := instruction.(T) + if !ok { + return *new(T), fmt.Errorf("instruction is not of expected type") + } + return typed, nil +} + +// ParseInstructionWithoutAccounts parses instruction data without account information +func ParseInstructionWithoutAccounts(instructionData []byte) (Instruction, error) { + return ParseInstruction(instructionData, nil, []solanago.PublicKey{}) +} + +// ParseInstructionWithAccounts parses instruction data with account information +func ParseInstructionWithAccounts(instructionData []byte, accountIndicesData []byte, accountKeys []solanago.PublicKey) (Instruction, error) { + return ParseInstruction(instructionData, accountIndicesData, accountKeys) +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/program_id.go b/cmd/generate-bindings/solana/testdata/data_storage/program_id.go new file mode 100644 index 00000000..1e0b7950 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/program_id.go @@ -0,0 +1,8 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains the program ID. + +package data_storage + +import solanago "github.com/gagliardetto/solana-go" + +var ProgramID = solanago.MustPublicKeyFromBase58("ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL") diff --git a/cmd/generate-bindings/solana/testdata/data_storage/tests_test.go b/cmd/generate-bindings/solana/testdata/data_storage/tests_test.go new file mode 100644 index 00000000..704cda06 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/tests_test.go @@ -0,0 +1,4 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains tests. + +package data_storage diff --git a/cmd/generate-bindings/solana/testdata/data_storage/types.go b/cmd/generate-bindings/solana/testdata/data_storage/types.go new file mode 100644 index 00000000..29470c23 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/types.go @@ -0,0 +1,643 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains parsers for the types defined in the IDL. + +package data_storage + +import ( + "bytes" + "fmt" + errors "github.com/gagliardetto/anchor-go/errors" + binary "github.com/gagliardetto/binary" + solanago "github.com/gagliardetto/solana-go" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + solana "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" + bindings "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana/bindings" + cre "github.com/smartcontractkit/cre-sdk-go/cre" +) + +type AccessLogged struct { + Caller solanago.PublicKey `json:"caller"` + Message string `json:"message"` +} + +func (obj AccessLogged) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `Caller`: + err = encoder.Encode(obj.Caller) + if err != nil { + return errors.NewField("Caller", err) + } + // Serialize `Message`: + err = encoder.Encode(obj.Message) + if err != nil { + return errors.NewField("Message", err) + } + return nil +} + +func (obj AccessLogged) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding AccessLogged: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *AccessLogged) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `Caller`: + err = decoder.Decode(&obj.Caller) + if err != nil { + return errors.NewField("Caller", err) + } + // Deserialize `Message`: + err = decoder.Decode(&obj.Message) + if err != nil { + return errors.NewField("Message", err) + } + return nil +} + +func (obj *AccessLogged) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling AccessLogged: %w", err) + } + return nil +} + +func UnmarshalAccessLogged(buf []byte) (*AccessLogged, error) { + obj := new(AccessLogged) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeAccessLoggedStruct(in AccessLogged) ([]byte, error) { + return in.Marshal() +} + +func (c *DataStorage) WriteReportFromAccessLogged( + runtime cre.Runtime, + input AccessLogged, + remainingAccounts []*solana.AccountMeta, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeAccessLoggedStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "sha256", + SigningAlgo: "ed25519", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +type DataAccount struct { + Sender string `json:"sender"` + Key string `json:"key"` + Value string `json:"value"` +} + +func (obj DataAccount) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `Sender`: + err = encoder.Encode(obj.Sender) + if err != nil { + return errors.NewField("Sender", err) + } + // Serialize `Key`: + err = encoder.Encode(obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Serialize `Value`: + err = encoder.Encode(obj.Value) + if err != nil { + return errors.NewField("Value", err) + } + return nil +} + +func (obj DataAccount) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding DataAccount: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *DataAccount) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `Sender`: + err = decoder.Decode(&obj.Sender) + if err != nil { + return errors.NewField("Sender", err) + } + // Deserialize `Key`: + err = decoder.Decode(&obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Deserialize `Value`: + err = decoder.Decode(&obj.Value) + if err != nil { + return errors.NewField("Value", err) + } + return nil +} + +func (obj *DataAccount) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling DataAccount: %w", err) + } + return nil +} + +func UnmarshalDataAccount(buf []byte) (*DataAccount, error) { + obj := new(DataAccount) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeDataAccountStruct(in DataAccount) ([]byte, error) { + return in.Marshal() +} + +func (c *DataStorage) WriteReportFromDataAccount( + runtime cre.Runtime, + input DataAccount, + remainingAccounts []*solana.AccountMeta, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeDataAccountStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "sha256", + SigningAlgo: "ed25519", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +type DynamicEvent struct { + Key string `json:"key"` + UserData UserData `json:"user_data"` + Sender string `json:"sender"` + Metadata []byte `json:"metadata"` + MetadataArray [][]byte `json:"metadata_array"` +} + +func (obj DynamicEvent) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `Key`: + err = encoder.Encode(obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Serialize `UserData`: + err = encoder.Encode(obj.UserData) + if err != nil { + return errors.NewField("UserData", err) + } + // Serialize `Sender`: + err = encoder.Encode(obj.Sender) + if err != nil { + return errors.NewField("Sender", err) + } + // Serialize `Metadata`: + err = encoder.Encode(obj.Metadata) + if err != nil { + return errors.NewField("Metadata", err) + } + // Serialize `MetadataArray`: + err = encoder.Encode(obj.MetadataArray) + if err != nil { + return errors.NewField("MetadataArray", err) + } + return nil +} + +func (obj DynamicEvent) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding DynamicEvent: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *DynamicEvent) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `Key`: + err = decoder.Decode(&obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Deserialize `UserData`: + err = decoder.Decode(&obj.UserData) + if err != nil { + return errors.NewField("UserData", err) + } + // Deserialize `Sender`: + err = decoder.Decode(&obj.Sender) + if err != nil { + return errors.NewField("Sender", err) + } + // Deserialize `Metadata`: + err = decoder.Decode(&obj.Metadata) + if err != nil { + return errors.NewField("Metadata", err) + } + // Deserialize `MetadataArray`: + err = decoder.Decode(&obj.MetadataArray) + if err != nil { + return errors.NewField("MetadataArray", err) + } + return nil +} + +func (obj *DynamicEvent) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling DynamicEvent: %w", err) + } + return nil +} + +func UnmarshalDynamicEvent(buf []byte) (*DynamicEvent, error) { + obj := new(DynamicEvent) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeDynamicEventStruct(in DynamicEvent) ([]byte, error) { + return in.Marshal() +} + +func (c *DataStorage) WriteReportFromDynamicEvent( + runtime cre.Runtime, + input DynamicEvent, + remainingAccounts []*solana.AccountMeta, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeDynamicEventStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "sha256", + SigningAlgo: "ed25519", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +type NoFields struct{} + +func (obj NoFields) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + return nil +} + +func (obj NoFields) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding NoFields: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *NoFields) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + return nil +} + +func (obj *NoFields) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling NoFields: %w", err) + } + return nil +} + +func UnmarshalNoFields(buf []byte) (*NoFields, error) { + obj := new(NoFields) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeNoFieldsStruct(in NoFields) ([]byte, error) { + return in.Marshal() +} + +func (c *DataStorage) WriteReportFromNoFields( + runtime cre.Runtime, + input NoFields, + remainingAccounts []*solana.AccountMeta, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeNoFieldsStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "sha256", + SigningAlgo: "ed25519", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +type UpdateReserves struct { + TotalMinted uint64 `json:"total_minted"` + TotalReserve uint64 `json:"total_reserve"` +} + +func (obj UpdateReserves) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `TotalMinted`: + err = encoder.Encode(obj.TotalMinted) + if err != nil { + return errors.NewField("TotalMinted", err) + } + // Serialize `TotalReserve`: + err = encoder.Encode(obj.TotalReserve) + if err != nil { + return errors.NewField("TotalReserve", err) + } + return nil +} + +func (obj UpdateReserves) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding UpdateReserves: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *UpdateReserves) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `TotalMinted`: + err = decoder.Decode(&obj.TotalMinted) + if err != nil { + return errors.NewField("TotalMinted", err) + } + // Deserialize `TotalReserve`: + err = decoder.Decode(&obj.TotalReserve) + if err != nil { + return errors.NewField("TotalReserve", err) + } + return nil +} + +func (obj *UpdateReserves) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling UpdateReserves: %w", err) + } + return nil +} + +func UnmarshalUpdateReserves(buf []byte) (*UpdateReserves, error) { + obj := new(UpdateReserves) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeUpdateReservesStruct(in UpdateReserves) ([]byte, error) { + return in.Marshal() +} + +func (c *DataStorage) WriteReportFromUpdateReserves( + runtime cre.Runtime, + input UpdateReserves, + remainingAccounts []*solana.AccountMeta, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeUpdateReservesStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "sha256", + SigningAlgo: "ed25519", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +type UserData struct { + Key string `json:"key"` + Value string `json:"value"` +} + +func (obj UserData) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `Key`: + err = encoder.Encode(obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Serialize `Value`: + err = encoder.Encode(obj.Value) + if err != nil { + return errors.NewField("Value", err) + } + return nil +} + +func (obj UserData) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding UserData: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *UserData) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `Key`: + err = decoder.Decode(&obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Deserialize `Value`: + err = decoder.Decode(&obj.Value) + if err != nil { + return errors.NewField("Value", err) + } + return nil +} + +func (obj *UserData) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling UserData: %w", err) + } + return nil +} + +func UnmarshalUserData(buf []byte) (*UserData, error) { + obj := new(UserData) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeUserDataStruct(in UserData) ([]byte, error) { + return in.Marshal() +} + +func (c *DataStorage) WriteReportFromUserData( + runtime cre.Runtime, + input UserData, + remainingAccounts []*solana.AccountMeta, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeUserDataStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "sha256", + SigningAlgo: "ed25519", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} diff --git a/cmd/generate-bindings/solana/testdata/gen/main.go b/cmd/generate-bindings/solana/testdata/gen/main.go new file mode 100644 index 00000000..b8630d7c --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/gen/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana" +) + +func main() { + if err := solana.GenerateBindings( + "./testdata/data_storage", + "data_storage", + "./testdata/data_storage.json", + "ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL", + ); err != nil { + panic(err) + } + +} diff --git a/cmd/root.go b/cmd/root.go index 51af5fb8..fd420731 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -213,7 +213,8 @@ func newRootCommand() *cobra.Command { loginCmd := login.New(runtimeContext) logoutCmd := logout.New(runtimeContext) initCmd := creinit.New(runtimeContext) - genBindingsCmd := generatebindings.New(runtimeContext) + genBindingsEvmCmd := generatebindings.NewEvmBindings(runtimeContext) + genBindingsSolanaCmd := generatebindings.NewSolanaBindings(runtimeContext) accountCmd := account.New(runtimeContext) whoamiCmd := whoami.New(runtimeContext) updateCmd := update.New(runtimeContext) @@ -247,7 +248,8 @@ func newRootCommand() *cobra.Command { whoamiCmd, secretsCmd, workflowCmd, - genBindingsCmd, + genBindingsEvmCmd, + genBindingsSolanaCmd, updateCmd, ) diff --git a/flake..nix b/flake..nix new file mode 100644 index 00000000..7055a1df --- /dev/null +++ b/flake..nix @@ -0,0 +1,17 @@ +{ + description = "Go 1.25 dev shell"; + + inputs.nixpkgs.url = "github:NixOS/nixpkgs"; + + outputs = { self, nixpkgs }: + let + system = "x86_64-linux"; # or aarch64-darwin, x86_64-darwin + pkgs = import nixpkgs { inherit system; }; + in { + devShells.${system}.default = pkgs.mkShell { + packages = [ + pkgs.go_1_25 + ]; + }; + }; +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000..cfd92af9 --- /dev/null +++ b/flake.lock @@ -0,0 +1,26 @@ +{ + "nodes": { + "nixpkgs": { + "locked": { + "lastModified": 1767628002, + "narHash": "sha256-XMjHybP9zKNqPxFyVJmWQY1PG94c8fvSv806Vf25GrE=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c4d3bbe305980cf82115ea6b21e7ec6d524f43f1", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 00000000..d7dd44ca --- /dev/null +++ b/flake.nix @@ -0,0 +1,18 @@ +{ + description = "cre-cli dev shell"; + + inputs.nixpkgs.url = "github:NixOS/nixpkgs"; # or nixpkgs-unstable + + outputs = { self, nixpkgs }: + let + system = "aarch64-darwin"; + pkgs = import nixpkgs { inherit system; }; + in + { + devShells.${system}.default = pkgs.mkShell { + packages = [ + pkgs.go_1_25 + ]; + }; + }; +} diff --git a/go.mod b/go.mod index 13413179..60b8642b 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,15 @@ require ( github.com/avast/retry-go/v4 v4.6.1 github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.6 + github.com/dave/jennifer v1.7.1 + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc github.com/denisbrodbeck/machineid v1.0.1 github.com/ethereum/go-ethereum v1.16.8 github.com/fatih/color v1.18.0 + github.com/gagliardetto/anchor-go v1.0.1-0.20250824230401-85e63d2061fb + github.com/gagliardetto/binary v0.8.0 + github.com/gagliardetto/solana-go v1.13.0 + github.com/gagliardetto/utilz v0.1.3 github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.28.0 @@ -28,11 +34,13 @@ require ( github.com/smartcontractkit/chainlink-evm/gethwrappers v0.0.0-20251222115927-36a18321243c github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260206000552-087e235a7963 github.com/smartcontractkit/chainlink-protos/workflows/go v0.0.0-20260106052706-6dd937cb5ec6 + github.com/smartcontractkit/chainlink-solana v1.1.2-0.20260121103211-89fe83165431 github.com/smartcontractkit/chainlink-testing-framework/seth v1.51.3 github.com/smartcontractkit/chainlink/deployment v0.0.0-20260109210342-7c60a208545f github.com/smartcontractkit/chainlink/v2 v2.29.1-cre-beta.0.0.20260209203649-eeb0170a4b93 github.com/smartcontractkit/cre-sdk-go v1.2.0 github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm v1.0.0-beta.5 + github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.1-0.20260210120110-1f2d5201a23f github.com/smartcontractkit/mcms v0.31.1 github.com/smartcontractkit/tdh2/go/tdh2 v0.0.0-20251120172354-e8ec0386b06c github.com/spf13/cobra v1.10.1 @@ -41,6 +49,7 @@ require ( github.com/stretchr/testify v1.11.1 github.com/test-go/testify v1.1.4 go.uber.org/zap v1.27.1 + golang.org/x/mod v0.32.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 @@ -130,7 +139,6 @@ require ( github.com/crate-crypto/go-eth-kzg v1.4.0 // indirect github.com/crate-crypto/go-ipa v0.0.0-20240724233137-53bbb0ceb27a // indirect github.com/danieljoos/wincred v1.2.1 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dchest/siphash v1.2.3 // indirect github.com/deckarep/golang-set/v2 v2.7.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect @@ -152,11 +160,7 @@ require ( github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect - github.com/gagliardetto/anchor-go v1.0.0 // indirect - github.com/gagliardetto/binary v0.8.0 // indirect - github.com/gagliardetto/solana-go v1.13.0 // indirect github.com/gagliardetto/treeout v0.1.4 // indirect - github.com/gagliardetto/utilz v0.1.3 // indirect github.com/getsentry/sentry-go v0.27.0 // indirect github.com/gin-contrib/sessions v0.0.5 // indirect github.com/gin-contrib/sse v0.1.0 // indirect @@ -315,7 +319,6 @@ require ( github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b // indirect github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 // indirect github.com/smartcontractkit/chainlink-protos/svr v1.1.1-0.20260203131522-bb8bc5c423b3 // indirect - github.com/smartcontractkit/chainlink-solana v1.1.2-0.20260121103211-89fe83165431 // indirect github.com/smartcontractkit/chainlink-sui v0.0.0-20260124000807-bff5e296dfb7 // indirect github.com/smartcontractkit/chainlink-tron/relayer v0.0.11-0.20251014143056-a0c6328c91e9 // indirect github.com/smartcontractkit/freeport v0.1.3-0.20250716200817-cb5dfd0e369e // indirect @@ -385,7 +388,6 @@ require ( golang.org/x/arch v0.11.0 // indirect golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect - golang.org/x/mod v0.32.0 // indirect golang.org/x/net v0.49.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect diff --git a/go.sum b/go.sum index d281cbc4..1b5409f9 100644 --- a/go.sum +++ b/go.sum @@ -322,6 +322,8 @@ github.com/danieljoos/wincred v1.2.1 h1:dl9cBrupW8+r5250DYkYxocLeZ1Y4vB1kxgtjxw8 github.com/danieljoos/wincred v1.2.1/go.mod h1:uGaFL9fDn3OLTvzCGulzE+SzjEe5NGlh5FdCcyfPwps= github.com/danielkov/gin-helmet v0.0.0-20171108135313-1387e224435e h1:5jVSh2l/ho6ajWhSPNN84eHEdq3dp0T7+f6r3Tc6hsk= github.com/danielkov/gin-helmet v0.0.0-20171108135313-1387e224435e/go.mod h1:IJgIiGUARc4aOr4bOQ85klmjsShkEEfiRc6q/yBSfo8= +github.com/dave/jennifer v1.7.1 h1:B4jJJDHelWcDhlRQxWeo0Npa/pYKBLrirAQoTN45txo= +github.com/dave/jennifer v1.7.1/go.mod h1:nXbxhEmQfOZhWml3D1cDK5M1FLnMSozpbFN/m3RmGZc= github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -412,8 +414,8 @@ github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gabriel-vasile/mimetype v1.4.10 h1:zyueNbySn/z8mJZHLt6IPw0KoZsiQNszIpU+bX4+ZK0= github.com/gabriel-vasile/mimetype v1.4.10/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= -github.com/gagliardetto/anchor-go v1.0.0 h1:YNt9I/9NOrNzz5uuzfzByAcbp39Ft07w63iPqC/wi34= -github.com/gagliardetto/anchor-go v1.0.0/go.mod h1:X6c9bx9JnmwNiyy8hmV5pAsq1c/zzPvkdzeq9/qmlCg= +github.com/gagliardetto/anchor-go v1.0.1-0.20250824230401-85e63d2061fb h1:vEroPXDXe9mWZJSnGD3zbGfg/VxDwdcuzXXP07b3nVY= +github.com/gagliardetto/anchor-go v1.0.1-0.20250824230401-85e63d2061fb/go.mod h1:HBp4PS/YTGjRGbI2ENChy55PoSs0ZExnMH0EC7CCGMg= github.com/gagliardetto/binary v0.8.0 h1:U9ahc45v9HW0d15LoN++vIXSJyqR/pWw8DDlhd7zvxg= github.com/gagliardetto/binary v0.8.0/go.mod h1:2tfj51g5o9dnvsc+fL3Jxr22MuWzYXwx9wEoN0XQ7/c= github.com/gagliardetto/gofuzz v1.2.2 h1:XL/8qDMzcgvR4+CyRQW9UGdwPRPMHVJfqQ/uMvSUuQw= @@ -1177,6 +1179,14 @@ github.com/smartcontractkit/cre-sdk-go v1.2.0 h1:CAZkJuku0faMlhK5biRL962DNnCSyMu github.com/smartcontractkit/cre-sdk-go v1.2.0/go.mod h1:sgiRyHUiPcxp1e/EMnaJ+ddMFL4MbE3UMZ2MORAAS9U= github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm v1.0.0-beta.5 h1:XMlLU3UVAHjEGDJ2E6cYp8zlyxnctEZ6p2gz+tvMqxI= github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm v1.0.0-beta.5/go.mod h1:v/xKxzUsxkIpT1ZM77vExyNU+dkCQ/y7oXvBbn7v6yY= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.0 h1:8AzRC735Z3vCvcEyElBq8DXv884mQE67bCp6YNGI3jY= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.0/go.mod h1:zlMsfDAXcrfEOGvHhrxaq4oxQN1egrzPwOSK2d5RnwQ= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.1-0.20260210113557-d4d7a18ac8b0 h1:PhZTlaBkQbRyMG5usUX6zTOZ1W4zAyi0uA/A8sbPbuo= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.1-0.20260210113557-d4d7a18ac8b0/go.mod h1:zlMsfDAXcrfEOGvHhrxaq4oxQN1egrzPwOSK2d5RnwQ= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.1-0.20260210114941-eb63f517a727 h1:F2m6TiXwGe65ePjURHktI7XBExe0hquz/EcXfQrLIMA= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.1-0.20260210114941-eb63f517a727/go.mod h1:NlLiprUaGE/tYfh3M8m6QIwB38zIh5ntZpMqTjlSSRg= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.1-0.20260210120110-1f2d5201a23f h1:9tY7ZZTmcI3HA/XU9cc1UtUYGVfHyDsUupkAA7qsOzM= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.1-0.20260210120110-1f2d5201a23f/go.mod h1:NlLiprUaGE/tYfh3M8m6QIwB38zIh5ntZpMqTjlSSRg= github.com/smartcontractkit/freeport v0.1.3-0.20250716200817-cb5dfd0e369e h1:Hv9Mww35LrufCdM9wtS9yVi/rEWGI1UnjHbcKKU0nVY= github.com/smartcontractkit/freeport v0.1.3-0.20250716200817-cb5dfd0e369e/go.mod h1:T4zH9R8R8lVWKfU7tUvYz2o2jMv1OpGCdpY2j2QZXzU= github.com/smartcontractkit/grpc-proxy v0.0.0-20240830132753-a7e17fec5ab7 h1:12ijqMM9tvYVEm+nR826WsrNi6zCKpwBhuApq127wHs=