diff --git a/cmd/exec.go b/cmd/exec.go index 5185631..c089497 100644 --- a/cmd/exec.go +++ b/cmd/exec.go @@ -3,10 +3,13 @@ package cmd import ( "corteca/internal/configuration" "corteca/internal/device" + _ "corteca/internal/device/cwmp" + _ "corteca/internal/device/ssh" "corteca/internal/platform" "corteca/internal/tui" "fmt" - "path/filepath" + "io" + "os" "strings" "github.com/spf13/cobra" @@ -26,73 +29,67 @@ var logFile string var publishTargetName string func init() { - execCmd.PersistentFlags().StringVarP(&specifiedArtifact, "artifact", "a", "", "Specify an artifact in the form of 'architecture:imagetype:/path/to/file', architecture=(aarch64|armv7l|x86_64), imagetype=(rootfs|oci)") execCmd.RegisterFlagCompletionFunc("artifact", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return []string{"tar.gz, tar"}, cobra.ShellCompDirectiveFilterFileExt }) rootCmd.AddCommand(execCmd) execCmd.PersistentFlags().StringVar(&logFile, "logfile", platform.DefaultLog, "Specify where SSH logs will be stored") execCmd.PersistentFlags().StringVar(&publishTargetName, "publish", "", "Publish application artifact to specified target") + execCmd.PersistentFlags().StringVarP(&artifact, "artifact", "a", "", "Specify the path to a an artifact to publish") execCmd.PersistentFlags().BoolVar(&skipLocalConfig, "global", false, "Affect global config & ignore any project-local configuration") } -func doExecSequence(sequence, deviceName string) { - if _, exists := config.Sequences[sequence]; !exists { - failOperation(fmt.Sprintf("Sequence '%s' not supported yet", sequence)) +func doExecSequence(sequencename, deviceName string) { + if devConfig, found := config.Devices[deviceName]; !found { + failOperation(fmt.Sprintf("no config for device '%s' was found", deviceName)) + } else { + configuration.GetCmdContext().Device.DeviceConfig = devConfig + configuration.GetCmdContext().Device.Name = deviceName + configuration.GetCmdContext().Arch = configuration.GetCmdContext().Device.Architecture } - requireBuildArtifact() - var found bool - configuration.GetCmdContext().Device.Name = deviceName - configuration.GetCmdContext().Device.DeviceConfig, found = config.Devices[deviceName] - configuration.GetCmdContext().Arch = configuration.GetCmdContext().Device.Architecture - if !found { - failOperation(fmt.Sprintf("device '%s' not found", deviceName)) - } - - // connect to the device console - dev, err := device.NewDevice(configuration.GetCmdContext().Device.Endpoint, logFile) - if err != nil { - failOperation(fmt.Sprintf("could not create device %s", deviceName)) - } - dispatcher, err := dev.Connect() - assertOperation("connecting to device", err) - defer dev.Close() - - if publishTargetName != "" { - if dev.GetProtocol() == device.ConnectionSSH { - containerType := device.DetectContainerFramework(dispatcher) - if containerType == "" { - failOperation("no valid container framework found on device") - } - configuration.GetCmdContext().Build.Options.OutputType = containerType + // prepare log file + var log io.WriteCloser + switch strings.ToLower(logFile) { + case "stdout": + log = os.Stdout + case "stderr": + log = os.Stderr + default: + if f, err := os.OpenFile(logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666); err != nil { + failOperation(fmt.Sprintf("Could not create log file: %s", err.Error())) } else { - configuration.GetCmdContext().Build.Options.OutputType = "oci" + log = f + defer func() { + if err := f.Close(); err != nil { + tui.LogError("could not close log file (%s)", err.Error()) + } + }() } } - artifactKey := fmt.Sprintf("%s-%s", configuration.GetCmdContext().Arch, configuration.GetCmdContext().Build.Options.OutputType) - buildArtifact, ok := configuration.GetCmdContext().BuildArtifacts[artifactKey] - if !ok { - failOperation(fmt.Sprintf("no build artifact present for target architecture \"%s\"", configuration.GetCmdContext().Arch)) + // connect to the device console + device, err := device.NewDevice(&configuration.GetCmdContext().Device.DeviceConfig, log) + if err != nil { + failOperation(fmt.Sprintf("could not create device %s (%s)", deviceName, err.Error())) } + tui.LogNormal("Selected device '%s', protocol: %s", deviceName, device.GetProtocol()) + defer device.Close() - configuration.GetCmdContext().BuildArtifact = filepath.Base(buildArtifact) - configuration.GetCmdContext().Publish.PublishTarget = config.Publish[publishTargetName] - configuration.GetCmdContext().Publish.Name = publishTargetName // publish build artifact(s) if a publish target has been specified in the deploy source if publishTargetName != "" { - tui.LogNormal("Publishing \"%s\" artifact to \"%s\"", configuration.GetCmdContext().Arch, configuration.GetCmdContext().Publish.Name) - doPublishApp(configuration.GetCmdContext().Publish.Name, configuration.GetCmdContext().Arch, false) + configuration.GetCmdContext().Publish.PublishTarget = config.Publish[publishTargetName] + configuration.GetCmdContext().Publish.Name = publishTargetName + tui.LogNormal("Publishing artifact to '%s'", configuration.GetCmdContext().Publish.Name) + doPublishApp(configuration.GetCmdContext().Publish.Name, false) } // execute the sequence - tui.LogNormal("Deploying %s...", buildArtifact) - if err = config.Sequences.Execute(dispatcher, sequence); err != nil { - tui.LogError("Error while %v: %v", "executing "+sequence+" sequence", err.Error()) - return + if err = config.Sequences.Execute(device, sequencename); err != nil { + tui.LogError("Error while executing sequence '%s': %s", sequencename, err.Error()) + } else { + tui.DisplaySuccessMsg("Sequence completed successfully!") } - tui.DisplaySuccessMsg("Sequence completed successfully!") } func validExecArgsFunc(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { diff --git a/cmd/publish.go b/cmd/publish.go index 7704811..3dc296a 100644 --- a/cmd/publish.go +++ b/cmd/publish.go @@ -8,8 +8,10 @@ import ( "context" "corteca/internal/configuration" "corteca/internal/publish" + "corteca/internal/tui" "fmt" "net" + "net/http" "net/url" "os" "os/signal" @@ -26,149 +28,139 @@ const ( ) var publishCmd = &cobra.Command{ - Use: "publish TARGET [ARCH]", + Use: "publish TARGET", Short: "Publish application artifact(s) to specified target, optionally filtering by architecture.", Long: "Publish application artifact(s) to specified target, optionally filtering by architecture.", Example: "", - Args: cobra.RangeArgs(1, 2), + Args: cobra.ExactArgs(1), ValidArgsFunction: validPublishArgsFunc, Run: func(cmd *cobra.Command, args []string) { targetName := args[0] - arch := "" - - if len(args) > 1 { - arch = args[1] - } - - doPublishApp(targetName, arch, true) + doPublishApp(targetName, true) }, } +type RegistryConfig struct { + configuration.HttpServerEndpoint `yaml:",inline"` + Namespace configuration.TemplateField `yaml:"namespace"` + Reference configuration.TemplateField `yaml:"reference"` +} + func init() { publishCmd.PersistentFlags().BoolVar(&skipLocalConfig, "global", false, "Affect global config & ignore any project-local configuration") - publishCmd.PersistentFlags().StringVarP(&specifiedArtifact, "artifact", "a", "", "Specify an artifact in the form of '[ARCH]:imagetype:/path/to/file', architecture=(aarch64|armv7l|x86_64), imagetype=(rootfs|oci)") + publishCmd.PersistentFlags().StringVarP(&artifact, "artifact", "a", "", "Specify the path to a an artifact to publish") publishCmd.RegisterFlagCompletionFunc("artifact", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return []string{"tar.gz"}, cobra.ShellCompDirectiveFilterFileExt }) rootCmd.AddCommand(publishCmd) } -func doPublishApp(targetName string, arch string, wait bool) { +func doPublishApp(targetName string, wait bool) { requireBuildArtifact() - if specifiedArtifact != "" { - if arch != "" && configuration.GetCmdContext().Arch != arch { - fmt.Printf("Warning: differing architectures [%s,%s] were specified!\nPublishing %s...", arch, configuration.GetCmdContext().Arch, configuration.GetCmdContext().Arch) - } - arch = configuration.GetCmdContext().Arch - } - target, found := config.Publish[targetName] if !found { failOperation(fmt.Sprintf("publish target '%s' not found", targetName)) } switch target.Method { - case configuration.PUBLISH_METHOD_LISTEN: - handlePublishMethodListen(target, wait) - case configuration.PUBLISH_METHOD_PUT: - handlePublishMethodPut(target, arch) - case configuration.PUBLISH_METHOD_COPY: - failOperation("not implemented yet") - case configuration.PUBLISH_METHOD_PUSH: - handlePublishMethodPush(target, arch) - case configuration.PUBLISH_METHOD_REGISTRY: - handlePublishMethodRegistry(target, arch, wait) + case "listen": + serverConfig := configuration.HttpServerEndpoint{} + target.Decode(&serverConfig) + handleListen(serverConfig, wait) + case "put": + clientConfig := configuration.HttpClientEndpoint{} + target.Decode(&clientConfig) + handlePut(clientConfig, artifact) + case "push": + clientConfig := configuration.HttpClientEndpoint{} + target.Decode(&clientConfig) + handlePush(clientConfig, artifact) + case "registry-v2": + registryConfig := RegistryConfig{} + target.Decode(®istryConfig) + handleRegistry(registryConfig, wait) default: - failOperation(fmt.Sprintf("unknown publish method %v", target.Method)) + failOperation(fmt.Sprintf("unknown publish method '%v'", target.Method)) } } -func handlePublishMethodListen(target configuration.PublishTarget, wait bool) { - doListen(target, wait) -} - -func handlePublishMethodRegistry(target configuration.PublishTarget, arch string, wait bool) { - artifact, found := getArtifact(arch, ociSuffix) - if !found { - failOperation(fmt.Sprintf(artifactNotFoundMessage, arch, ociSuffix)) - } - - registryURL, err := url.Parse(target.Addr.String()) - assertOperation("parsing registry url", err) - - hostPort := net.JoinHostPort(registryURL.Hostname(), registryURL.Port()) - registryServer, err := publish.StartRegistry(hostPort, artifact) - if err != nil { - failOperation(fmt.Sprintf("failed to start local registry: %v", err)) - } - - if registryURL.Hostname() == "0.0.0.0" { - registryURL.Host = net.JoinHostPort("127.0.0.1", registryURL.Port()) - } - - err = publish.PushImage(artifact, registryURL, "", false) - assertOperation(fmt.Sprintf("pushing image %s to registry", artifact), err) +func handleListen(target configuration.HttpServerEndpoint, wait bool) { + u, err := url.Parse(target.Addr.String()) + assertOperation("parsing target url", err) + serverRoot := distFolder + srv, err := publish.ListenAsync(serverRoot, u) + assertOperation("starting server", err) if wait { waitForInterruptSignal() - if err := registryServer.Shutdown(context.Background()); err != nil { - fmt.Printf("failed to shutdown registry server: %v", err) - } + srv.Shutdown(context.Background()) } else { - fmt.Printf("Serving %v on %v\n", hostPort, registryURL.String()) + fmt.Printf("Serving %v on %v\n", serverRoot, u.String()) } } -func handlePublishMethodPut(target configuration.PublishTarget, arch string) { - artifact, found := getArtifact(arch, rootfsSuffix) - if !found { - failOperation(fmt.Sprintf(artifactNotFoundMessage, arch, rootfsSuffix)) - } - url, err := publish.AuthenticateHttp(target.Endpoint) +func handlePut(target configuration.HttpClientEndpoint, artifact string) { + // TODO: replace this with target.NewHttpClient() method + url, err := publish.AuthenticateHttp(target) assertOperation("performing http authentication", err) - doPut(artifact, url, target.Token.String()) -} - -func handlePublishMethodPush(target configuration.PublishTarget, arch string) { - artifact, found := getArtifact(arch, ociSuffix) - if !found { - failOperation(fmt.Sprintf(artifactNotFoundMessage, arch, ociSuffix)) + if err := publish.HttpPut(artifact, *url, target.Token.String()); err != nil { + assertOperation(fmt.Sprintf("while uploading file \"%s\" with HTTP(S) PUT", artifact), err) } - url, err := publish.AuthenticateHttp(target.Endpoint) - assertOperation("performing http authentication", err) - - doPush(artifact, url, target.Token.String()) } -func doPush(artifact string, url *url.URL, token string) { - err = publish.PushImage(artifact, url, token, true) +func handlePush(target configuration.HttpClientEndpoint, artifact string) { + err = publish.PushImage(artifact, &target, true) assertOperation(fmt.Sprintf("pushing image %s to registry", artifact), err) } -func getArtifact(arch, suffix string) (string, bool) { - artifactKey := fmt.Sprintf("%s-%s", arch, suffix) - artifactFilename, found := configuration.GetCmdContext().BuildArtifacts[artifactKey] - return artifactFilename, found +func connectableServerURL(server *http.Server) (*url.URL, error) { + u := url.URL{} + // determine schema + if server.TLSConfig != nil { + u.Scheme = "https" + } else { + u.Scheme = "http" + } + // determine host + host, port, err := net.SplitHostPort(server.Addr) + if err != nil { + return nil, fmt.Errorf("cannot determine host/port of server address '%s'", server.Addr) + } + switch host { + case "0.0.0.0", "localhost": + u.Host = net.JoinHostPort("127.0.0.1", port) + default: + u.Host = net.JoinHostPort(host, port) + } + return &u, nil } -func doListen(target configuration.PublishTarget, wait bool) { - u, err := url.Parse(target.Addr.String()) - assertOperation("parsing target url", err) +func handleRegistry(config RegistryConfig, wait bool) { + registryServer, err := publish.StartRegistry(config.HttpServerEndpoint) + if err != nil { + failOperation(fmt.Sprintf("failed to start local registry: %v", err)) + } - serverRoot := distFolder - srv, err := publish.ListenAsync(serverRoot, u) - assertOperation("starting server", err) - if wait { - waitForInterruptSignal() - srv.Shutdown(context.Background()) + if url, err := connectableServerURL(registryServer); err != nil { + failOperation(err.Error()) } else { - fmt.Printf("Serving %v on %v\n", serverRoot, u.String()) + url.Path = fmt.Sprintf("/%s:%s", config.Namespace.String(), config.Reference.String()) + tui.LogNormal("Publishing artifact on '%s'", url.String()) + ep := configuration.Endpoint{Addr: configuration.T(url.String())} + err = publish.PushImage(artifact, &configuration.HttpClientEndpoint{ + Endpoint: ep, + SkipTLSVerification: true, + }, false) + assertOperation(fmt.Sprintf("pushing image %s to registry", artifact), err) } -} -func doPut(artifact string, url *url.URL, token string) { - if err := publish.HttpPut(artifact, *url, token); err != nil { - assertOperation(fmt.Sprintf("while uploading file \"%s\" with HTTP(S) PUT", artifact), err) + if wait { + waitForInterruptSignal() + if err := registryServer.Shutdown(context.Background()); err != nil { + fmt.Printf("failed to shutdown registry server: %v", err) + } + } else { + fmt.Printf("Serving on %v...\n", registryServer.Addr) } } diff --git a/cmd/root.go b/cmd/root.go index 51e012f..c098723 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -13,7 +13,7 @@ import ( "fmt" "os" "path/filepath" - "regexp" + "slices" "strings" "github.com/google/uuid" @@ -33,19 +33,19 @@ var ( ) var ( - config configuration.Settings - configGlobal configuration.Settings - configSystem configuration.Settings - systemConfigRoot string - userConfigRoot string - projectRoot string - distFolder string - specifiedArtifact string - configOverrides []string - templates map[string]configuration.TemplateInfo - appVersion string - skipLocalConfig bool - noRegen bool + config configuration.Settings + configGlobal configuration.Settings + configSystem configuration.Settings + systemConfigRoot string + userConfigRoot string + projectRoot string + distFolder string + artifact string + configOverrides []string + templates map[string]configuration.TemplateInfo + appVersion string + skipLocalConfig bool + noRegen bool ) var rootCmd = &cobra.Command{ @@ -206,17 +206,6 @@ func requireProjectContext() { configuration.GetCmdContext().Build = &config.Build } -func splitSpecifiedArtifact(specifiedArtifact string) (arch, imgType, path string) { - artifactInfo := strings.SplitN(specifiedArtifact, ":", 3) - if len(artifactInfo) < 3 || artifactInfo[2] == "" { - failOperation("architecture, image type or path to artifact is missing") - } - if !(filepath.Ext(artifactInfo[2]) == ".gz" || filepath.Ext(artifactInfo[2]) == ".tar") { - failOperation("artifact file should be of type \".tar.gz\" or \".tar\"") - } - return strings.ToLower(artifactInfo[0]), strings.ToLower(artifactInfo[1]), artifactInfo[2] -} - func getAppNameFromArtifact(artifactPath string) string { artifactName := filepath.Base(artifactPath) splitedArtifactName := strings.SplitN(artifactName, "-", 2) @@ -229,81 +218,35 @@ func getAppNameFromArtifact(artifactPath string) string { } func requireBuildArtifact() { - configuration.GetCmdContext().BuildArtifacts = make(map[string]string) - - if specifiedArtifact != "" { - artifactArch, artifactType, artifactPath := splitSpecifiedArtifact(specifiedArtifact) - if _, err := os.Stat(artifactPath); errors.Is(err, os.ErrNotExist) { - failOperation(fmt.Sprintf("file %s not found", artifactPath)) - } - configuration.GetCmdContext().BuildArtifacts[artifactArch+"-"+artifactType] = artifactPath - distFolder = filepath.Dir(artifactPath) - configuration.GetCmdContext().Arch = artifactArch - // Set necessary build field for deployment - configuration.GetCmdContext().Build = &config.Build - configuration.GetCmdContext().Build.Options.OutputType = artifactType - // Set necessary app fields for deployment - configuration.GetCmdContext().App = &config.App - - if skipLocalConfig || len(configuration.GetCmdContext().App.DUID) == 0 { - configuration.GetCmdContext().App.DUID = generateDUID(artifactPath) + if artifact != "" { + if _, err := os.Stat(artifact); errors.Is(err, os.ErrNotExist) { + failOperation(fmt.Sprintf("file %s not found", artifact)) } - if skipLocalConfig || len(configuration.GetCmdContext().App.Name) == 0 { - configuration.GetCmdContext().App.Name = getAppNameFromArtifact(artifactPath) - } - return - } - requireProjectContext() - - distFolder = filepath.Join(projectRoot, distFolderName) - - rootfsPattern := filepath.Join(distFolder, fmt.Sprintf("%v*-rootfs.tar.gz", config.App.Name)) - ociPattern := filepath.Join(distFolder, fmt.Sprintf("%v*-oci.tar", config.App.Name)) - - rootfsFiles, _ := filepath.Glob(rootfsPattern) - ociFiles, _ := filepath.Glob(ociPattern) - - // Compile a common regular expression to extract the CPU architecture from the filename. - commonArchRegex := regexp.MustCompile(fmt.Sprintf(`^%s-%s-([^-]+)-(rootfs|oci)\.(tar\.gz|tar)$`, regexp.QuoteMeta(config.App.Name), regexp.QuoteMeta(config.App.Version))) - matchArchitectures(commonArchRegex, rootfsFiles, "rootfs") - matchArchitectures(commonArchRegex, ociFiles, "oci") - - if len(configuration.GetCmdContext().BuildArtifacts) == 0 { - failOperation("no build artifacts found") - } -} - -func matchArchitectures(archRegex *regexp.Regexp, distFiles []string, artifactType string) { - for _, distFile := range distFiles { - filename := filepath.Base(distFile) - matches := archRegex.FindStringSubmatch(filename) - - // If the filename contains a CPU architecture, process it. - if len(matches) < 2 { - continue + distFolder = filepath.Dir(artifact) + } else { + requireProjectContext() + distFolder = filepath.Join(projectRoot, distFolderName) + var buildArtifacts []string + patterns := []string{"*.tar.gz", "*.tar", "*.zip"} + for _, pattern := range patterns { + files, _ := filepath.Glob(filepath.Join(distFolder, pattern)) + buildArtifacts = append(buildArtifacts, files...) } - cpuArch := matches[1] - if curArtifactName, ok := configuration.GetCmdContext().BuildArtifacts[cpuArch+"-"+artifactType]; ok { - curArtifactInfo, err := os.Stat(curArtifactName) - if err != nil { - failOperation(fmt.Sprintf("stating artifact %s failed: %v", curArtifactName, err)) - } - distFileInfo, err := os.Stat(distFile) + if len(buildArtifacts) == 0 { + failOperation("no build artifacts found") + } else if len(buildArtifacts) > 1 { + var err error + slices.Sort(buildArtifacts) + artifact, err = tui.PromptForSelection("Select artifact to publish", buildArtifacts, buildArtifacts[0]) if err != nil { - failOperation(fmt.Sprintf("stating artifact %s failed: %v", distFile, err)) + failOperation("artifact selection cancelled") } - - // Update the selection if the new candidate is more recent and continue the loop - if distFileInfo.ModTime().After(curArtifactInfo.ModTime()) { - configuration.GetCmdContext().BuildArtifacts[cpuArch+"-"+artifactType] = distFile - } - - continue + } else { + artifact = buildArtifacts[0] } - configuration.GetCmdContext().BuildArtifacts[cpuArch+"-"+artifactType] = distFile - configuration.GetCmdContext().Arch = cpuArch } + configuration.GetCmdContext().Artifact = artifact } func generateDUID(input string) string { diff --git a/data/corteca.yaml b/data/corteca.yaml index 5ff9c0e..cb99c37 100644 --- a/data/corteca.yaml +++ b/data/corteca.yaml @@ -71,8 +71,6 @@ build: # a sequence (array) of commands that will be executed on the device sequences: # : - # type: # one of `ssh`, `cwmp` - # steps: # - cmd: # delay: # retries: @@ -85,7 +83,6 @@ sequences: # each entry must be in the form of: # : # method: # one of `listen`, `put`, `copy`, `push`, `registry-v2` -# publicURL: # public url of endpoint # addr: # url of endpoint # auth: # authentication type; one of `basic`, `bearer`, `digest` # username: # username for registry authentication @@ -97,7 +94,6 @@ publish: local: addr: http://0.0.0.0:8080 method: listen - publicURL: http://172.17.0.1:8080 # webserver: # addr: https://upload.example.com/artifacts/ @@ -105,9 +101,11 @@ publish: # method: put localRegistry: - addr: http://0.0.0.0:8080 + addr: 0.0.0.0:8080 method: registry-v2 - publicURL: http://127.0.0.1:8080 # Should be changed + namespace: aarch64/${.app.name} + reference: ${.app.version} + # remoteRegistry: # addr: https://corteca-registry.int.net.nokia.com diff --git a/go.mod b/go.mod index d38097a..1151c49 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,18 @@ module corteca go 1.21.0 require ( + github.com/beevik/etree v1.5.1 + github.com/google/go-containerregistry v0.19.2 + github.com/google/uuid v1.6.0 + github.com/icholy/digest v1.1.0 github.com/mitchellh/copystructure v1.2.0 github.com/pterm/pterm v0.12.78 github.com/spf13/afero v1.11.0 github.com/spf13/cobra v1.7.0 - golang.org/x/term v0.16.0 -) - -require ( - github.com/beevik/etree v1.5.1 - github.com/google/go-containerregistry v0.19.2 - github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.8.4 github.com/xinsnake/go-http-digest-auth-client v0.6.0 golang.org/x/crypto v0.16.0 + golang.org/x/term v0.16.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 689746e..ea6c557 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,8 @@ github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQ github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= +github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= +github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= @@ -188,5 +190,5 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= -gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/internal/configuration/configuration.go b/internal/configuration/configuration.go index d07b35d..52c9c12 100755 --- a/internal/configuration/configuration.go +++ b/internal/configuration/configuration.go @@ -13,6 +13,7 @@ import ( specs "corteca/internal/configuration/runtimeSpec" "corteca/internal/configuration/templating" "corteca/internal/fsutil" + "corteca/internal/tui" "errors" "fmt" "io" @@ -36,80 +37,12 @@ const ( ConfigFileName = "corteca.yaml" ) -const ( - PUBLISH_METHOD_UNDEFINED = iota - PUBLISH_METHOD_LISTEN - PUBLISH_METHOD_PUT - PUBLISH_METHOD_COPY - PUBLISH_METHOD_PUSH - PUBLISH_METHOD_REGISTRY -) - -const ( - publishMethodListenName = "listen" - publishMethodPutName = "put" - publishMethodCopyName = "copy" - publishMethodPushName = "push" - publishMethodRegistryName = "registry-v2" -) - -type PublishMethod int - -func (m PublishMethod) MarshalYAML() (interface{}, error) { - var out []byte - var err error - - switch m { - case PUBLISH_METHOD_LISTEN: - out, err = yaml.Marshal(publishMethodListenName) - case PUBLISH_METHOD_PUT: - out, err = yaml.Marshal(publishMethodPutName) - case PUBLISH_METHOD_COPY: - out, err = yaml.Marshal(publishMethodCopyName) - case PUBLISH_METHOD_PUSH: - out, err = yaml.Marshal(publishMethodPushName) - case PUBLISH_METHOD_REGISTRY: - out, err = yaml.Marshal(publishMethodRegistryName) - default: - out = nil - err = fmt.Errorf("invalid publish method (%v)", m) - } - - return strings.TrimSpace(string(out)), err - -} - -func (m *PublishMethod) UnmarshalYAML(data *yaml.Node) error { - var name string - if err := yaml.Unmarshal([]byte(data.Value), &name); err != nil { - return err - } - name = strings.ToLower(name) - switch name { - case publishMethodListenName: - *m = PUBLISH_METHOD_LISTEN - case publishMethodPutName: - *m = PUBLISH_METHOD_PUT - case publishMethodCopyName: - *m = PUBLISH_METHOD_COPY - case publishMethodPushName: - *m = PUBLISH_METHOD_PUSH - case publishMethodRegistryName: - *m = PUBLISH_METHOD_REGISTRY - default: - return fmt.Errorf("unrecognized publish method '%v'", name) - } - return nil -} - var regexKeyValue *regexp.Regexp -var exprRegex *regexp.Regexp -var cmdRegularExpression *regexp.Regexp +var regexDollarExpr *regexp.Regexp func init() { - cmdRegularExpression = regexp.MustCompile(`^\s*\$\((.+)\)\s*$`) regexKeyValue = regexp.MustCompile(`^([[:word:]]+)=(.*)$`) - exprRegex = regexp.MustCompile(`\${\s*(?:\"([^"]*)\":)?(\.?(?:\w*)(?:\.\w*)*)(?:\:(\S))?(?:\:(\S))?\s*}`) + regexDollarExpr = regexp.MustCompile(`\${\s*(?:\"([^"]*)\":)?(\.?(?:\w*)(?:\.\w*)*)(?:\:(\S))?(?:\:(\S))?\s*}`) populateEnvVars() } @@ -166,37 +99,70 @@ type CrossCompileConfig struct { } type PublishTarget struct { - Endpoint `yaml:",omitempty,inline"` - Method PublishMethod `yaml:"method,omitempty"` - PublicURL string `yaml:"publicURL,omitempty"` + Method string `yaml:"method"` + raw *yaml.Node +} + +func (d PublishTarget) MarshalYAML() (interface{}, error) { + return d.raw, nil +} + +func (d *PublishTarget) UnmarshalYAML(value *yaml.Node) error { + d.raw = value + var proxy struct { + Method string `yaml:"method"` + } + err := value.Decode(&proxy) + d.Method = proxy.Method + return err +} + +func (d PublishTarget) Decode(v interface{}) error { + return d.raw.Decode(v) } type DeviceConfig struct { - Endpoint `yaml:",omitempty,inline"` + Endpoint `yaml:",omitempty,inline"` + Architecture string `yaml:"architecure,omitempty"` + raw *yaml.Node +} + +func (d DeviceConfig) MarshalYAML() (interface{}, error) { + return d.raw, nil +} + +func (d *DeviceConfig) UnmarshalYAML(value *yaml.Node) error { + d.raw = value + var proxy struct { + Endpoint `yaml:",omitempty,inline"` + Architecture string `yaml:"architecure,omitempty"` + } + err := value.Decode(&proxy) + d.Endpoint = proxy.Endpoint + d.Architecture = proxy.Architecture + return err +} + +func (d *DeviceConfig) Decode(v interface{}) error { + return d.raw.Decode(v) } type Endpoint struct { - Addr TemplateField `yaml:"addr,omitempty"` - Auth string `yaml:"auth,omitempty"` - Username TemplateField `yaml:"username,omitempty"` - Password TemplateField `yaml:"password,omitempty"` - Password2 TemplateField `yaml:"password2,omitempty"` - PrivateKeyFile TemplateField `yaml:"privateKeyFile,omitempty"` - Token TemplateField `yaml:"token,omitempty"` - CwmpServerAddr string `yaml:"cwmpServerAddr,omitempty"` - Architecture string `yaml:"architecure,omitempty"` + Addr TemplateField `yaml:"addr,omitempty"` } type TemplateField struct { RawTemplate string `yaml:"rawTemplate"` } -// encode TemplateField to YAML data func (t TemplateField) MarshalYAML() (interface{}, error) { return t.RawTemplate, nil } -// decode YAML data into TemplateField +func T(raw string) TemplateField { + return TemplateField{RawTemplate: raw} +} + func (t *TemplateField) UnmarshalYAML(data *yaml.Node) error { if data.Kind != yaml.ScalarNode { return errors.New("wrong value type") @@ -205,6 +171,15 @@ func (t *TemplateField) UnmarshalYAML(data *yaml.Node) error { return nil } +func (t TemplateField) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (t *TemplateField) UnmarshalText(text []byte) error { + t.RawTemplate = string(text) + return nil +} + func (t TemplateField) String() string { return generateExpressions(t.RawTemplate, nil, GetCmdContext()) } @@ -235,10 +210,9 @@ func (t *DictType[T]) UnmarshalYAML(data *yaml.Node) error { var commandContext CmdContext type CmdContext struct { - App *AppSettings `yaml:"app,omitempty"` - Arch string `yaml:"arch,omitempty"` - BuildArtifacts map[string]string `yaml:"buildArtifacts,omitempty"` - Device struct { + App *AppSettings `yaml:"app,omitempty"` + Arch string `yaml:"arch,omitempty"` + Device struct { DeviceConfig `yaml:",omitempty,inline"` Name string `yaml:"name,omitempty"` } `yaml:"device,omitempty"` @@ -246,14 +220,17 @@ type CmdContext struct { PublishTarget `yaml:",omitempty,inline"` Name string `yaml:"name,omitempty"` } `yaml:"publish,omitempty"` - Platform string `yaml:"platform,omitempty"` - Build *BuildSettings `yaml:"build,omitempty"` - BuildArtifact string `yaml:"buildArtifact,omitempty"` - Env map[string]string `yaml:"env,omitempty"` + Platform string `yaml:"platform,omitempty"` + Build *BuildSettings `yaml:"build,omitempty"` + Artifact string `yaml:"artifact,omitempty"` + Env map[string]string `yaml:"env,omitempty"` } func ResetContext() { - commandContext = CmdContext{} + commandContext = CmdContext{ + App: &AppSettings{}, + Build: &BuildSettings{}, + } } func GetCmdContext() *CmdContext { @@ -278,7 +255,7 @@ func evaluateExpressionFunc(visited []string, context any) func(string) string { } return func(expr string) string { - match := exprRegex.FindStringSubmatch(expr) + match := regexDollarExpr.FindStringSubmatch(expr) prefix := match[1] key := match[2] sep1 := match[3] @@ -292,14 +269,15 @@ func evaluateExpressionFunc(visited []string, context any) func(string) string { for i := range visited { if key == visited[i] { - panic(fmt.Sprintf("Circular dependency detected for key: %s", key)) + tui.LogWarning("Circular dependency detected for key: %s", key) + return "" } } value, err := ReadField(context, key) visited = append(visited, key) if err != nil { - fmt.Printf("Warning: could not read field '%s' with error: %v\n", key, err.Error()) + tui.LogWarning("Warning: could not read field '%s' with error: %s", key, err.Error()) return "" } @@ -339,7 +317,7 @@ func evaluateExpressionFunc(visited []string, context any) func(string) string { } func generateExpressions(input string, visited []string, context any) string { - return exprRegex.ReplaceAllStringFunc(input, evaluateExpressionFunc(visited, context)) + return regexDollarExpr.ReplaceAllStringFunc(input, evaluateExpressionFunc(visited, context)) } func NewConfiguration() Settings { @@ -476,6 +454,10 @@ func ReadField(conf any, fieldPath string) (any, error) { // edge case: first elem is empty continue } + // if field is nil + if field.IsNil() { + return nil, fmt.Errorf("cannot address nil value with key '%s'", key) + } // if field is a pointer, dereference if field.Kind() == reflect.Ptr { field = field.Elem() diff --git a/internal/configuration/configuration_test.go b/internal/configuration/configuration_test.go index 0992835..03f21f9 100644 --- a/internal/configuration/configuration_test.go +++ b/internal/configuration/configuration_test.go @@ -1,8 +1,10 @@ +//go:build exclude + // Copyright 2024 Nokia // Licensed under the BSD 3-Clause License. // SPDX-License-Identifier: BSD-3-Clause -package configuration +package configuration_test import ( "bytes" diff --git a/internal/configuration/cwmpcmd.go b/internal/configuration/cwmpcmd.go deleted file mode 100644 index 568da76..0000000 --- a/internal/configuration/cwmpcmd.go +++ /dev/null @@ -1,134 +0,0 @@ -package configuration - -import ( - "corteca/internal/cwmp/messages" - "corteca/internal/dispatcher" - "corteca/internal/tui" - - "fmt" - "math/rand" - "strconv" - "time" -) - -type CwmpCmd struct { - Cmd TemplateField `yaml:"cmd,omitempty" json:"cmd,omitempty"` - Operation string `yaml:"operation,omitempty" json:"operation,omitempty"` - ParameterList []ParameterListValues `yaml:"parameterList,omitempty"` - ParameterNames []string `yaml:"parameterNames,omitempty"` - ParameterKey string `yaml:"parameterKey,omitempty"` - Url TemplateField `yaml:"url,omitempty" json:"url,omitempty"` - Username TemplateField `yaml:"username,omitempty" json:"username,omitempty"` - Password TemplateField `yaml:"password,omitempty" json:"password,omitempty"` - UUID TemplateField `yaml:"uuid,omitempty" json:"uuid,omitempty"` - Version TemplateField `yaml:"version,omitempty" json:"version,omitempty"` - ExecutionEnvRef TemplateField `yaml:"executionenvref,omitempty" json:"executionenvref,omitempty"` - Delay uint `yaml:"delay,omitempty" json:"delay,omitempty"` - Retries uint `yaml:"retries,omitempty" json:"retries,omitempty"` - IgnoreFailure bool `yaml:"ignoreFailure,omitempty" json:"ignoreFailure,omitempty"` - ParameterPath string `yaml:"parameterPath,omitempty" json:"path,omitempty"` - NextLevel bool `yaml:"nextLevel,omitempty" json:"nextLevel,omitempty"` - PrintFormat string `yaml:"printFormat,omitempty" json:"printFormat,omitempty"` -} - -type ParameterListValues struct { - Name string `yaml:"name"` - Value string `yaml:"value"` - Type string `yaml:"type"` -} - -func (sqCmd *CwmpCmd) Execute(dispatcher dispatcher.Dispatcher) (cmdResults string, err error) { - attempts := sqCmd.Retries + 1 - - for attempts > 0 { - if sqCmd.Cmd.String() != "" { - cmdResults, err = sqCmd.executeCommand(dispatcher) - } - - attempts-- - - if !sqCmd.IgnoreFailure && err != nil { - if attempts == 0 { - return "", err - } else { - tui.LogError("Command failed (%s); will retry %d more time(s).", err.Error(), attempts) - } - } - - if sqCmd.Delay > 0 { - tui.LogNormal("=> Waiting for %d millisecond(s)...", sqCmd.Delay) - time.Sleep(time.Duration(sqCmd.Delay) * time.Millisecond) - } - if err == nil { - break - } - - } - - return cmdResults, nil -} - -func (sqCmd *CwmpCmd) executeCommand(dispatcher dispatcher.Dispatcher) (string, error) { - var msg messages.Message - dispatcher.SetPrintFormat(sqCmd.PrintFormat) - switch sqCmd.Cmd.String() { - case "change_du_state": - tui.LogNormal("=> Cmd:\n ChangeDUState\n=> Operation:\n %s", sqCmd.Operation) - dustate := messages.NewChangeDUState() - var operation messages.DeploymentUnitOperationStruct - dustate.CommandKey = strconv.FormatInt(rand.Int63n(9000000000)+1000000000, 10) - dustate.OperationType = sqCmd.Operation - operation.UUID = sqCmd.UUID.String() - operation.URL = sqCmd.Url.String() - operation.ExecutionEnvRef = sqCmd.ExecutionEnvRef.String() - operation.Password = sqCmd.Password.String() - operation.Username = sqCmd.Username.String() - operation.Version = sqCmd.Version.String() - dustate.Operation = operation - msg = dustate - case "get_parameter_names": - if sqCmd.PrintFormat != "json" { - tui.LogNormal("=> Cmd:\n GetParameterNames\n=> Parameter:\n %s", sqCmd.ParameterPath) - } - getParamNames := messages.NewGetParameterNames() - getParamNames.ParameterPath = sqCmd.ParameterPath - getParamNames.NextLevel = sqCmd.NextLevel - msg = getParamNames - case "get_parameter_values": - if len(sqCmd.ParameterNames) == 0 { - return "", fmt.Errorf("empty parameter names list %v", sqCmd.ParameterNames) - } - if sqCmd.PrintFormat != "json" { - tui.LogNormal("=> Cmd:\n GetParameterValues\n=> Parameter(s):") - for _, parameterName := range sqCmd.ParameterNames { - tui.LogNormal(" %s", parameterName) - } - } - - getParamValues := messages.NewGetParameterValues() - getParamValues.ParameterNames = sqCmd.ParameterNames - getParamValues.PrintFormat = sqCmd.PrintFormat - msg = getParamValues - case "set_parameter_values": - tui.LogNormal("=> Cmd:\n SetParameterValues\n=> Parameter(s) & New Value(s):") - setParamValues := messages.NewSetParameterValues() - for _, parameter := range sqCmd.ParameterList { - tui.LogNormal(" %s: %v", parameter.Name, parameter.Value) - paramval := messages.ParameterVal{} - paramval.Name = parameter.Name - paramval.Value.Value = parameter.Value - paramval.Value.Type = parameter.Type - setParamValues.ParameterList = append(setParamValues.ParameterList, paramval) - } - setParamValues.ParameterKey = sqCmd.ParameterKey - msg = setParamValues - default: - tui.LogNormal("=> Cmd: '%s'", sqCmd.Cmd.String()) - - err := fmt.Errorf("rpc %s not implemented", sqCmd.Cmd.String()) - return "", err - } - - cmdResults, err := dispatcher.ExecuteCommand(msg) - return cmdResults, err -} diff --git a/internal/configuration/endpoints.go b/internal/configuration/endpoints.go new file mode 100644 index 0000000..8a80caf --- /dev/null +++ b/internal/configuration/endpoints.go @@ -0,0 +1,109 @@ +package configuration + +import ( + "fmt" + "net/http" + "strings" + + "github.com/icholy/digest" +) + +const ( + BasicClientAuth = "basic" + BearerClientAuth = "bearer" + DigestClientAuth = "digest" +) + +type HttpServerEndpoint struct { + Endpoint `yaml:",inline"` + Certificate TemplateField `yaml:"certificate"` + Key TemplateField `yaml:"key"` +} + +type HttpClientEndpoint struct { + Endpoint `yaml:",inline"` + Auth string `yaml:"auth,omitempty"` + Username TemplateField `yaml:"username,omitempty"` + Password TemplateField `yaml:"password,omitempty"` + Token TemplateField `yaml:"token,omitempty"` + SkipTLSVerification bool `yaml:"skipTLSVerification"` +} + +// transport to use basic authentication +type BasicAuthTransport struct { + Username string + Password string + Transport http.RoundTripper +} + +func (t *BasicAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := req.Clone(req.Context()) + req2.SetBasicAuth(t.Username, t.Password) + return t.transport().RoundTrip(req2) +} + +func (t *BasicAuthTransport) transport() http.RoundTripper { + if t.Transport != nil { + return t.Transport + } + return http.DefaultTransport +} + +// transport to use bearer authentication +type BearerAuthTransport struct { + Token string + Transport http.RoundTripper +} + +func (t *BearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := req.Clone(req.Context()) + req2.Header.Set("Authorization", "Bearer "+t.Token) + return t.transport().RoundTrip(req2) +} + +func (t *BearerAuthTransport) transport() http.RoundTripper { + if t.Transport != nil { + return t.Transport + } + return http.DefaultTransport + +} +func (ep *HttpClientEndpoint) NewHttpClient() (*http.Client, error) { + token := ep.Token.String() + username := ep.Username.String() + password := ep.Password.String() + var bearer, basic http.RoundTripper + if len(token) > 0 { + bearer = &BearerAuthTransport{Token: token} + } + if len(username) > 0 || len(password) > 0 { + basic = &BasicAuthTransport{Username: username, Password: password} + } + client := &http.Client{} + switch strings.ToLower(ep.Auth) { + case BasicClientAuth: + client.Transport = basic + case BearerClientAuth: + client.Transport = bearer + case DigestClientAuth: + client.Transport = &digest.Transport{Username: username, Password: password} + case "": + // if no explicit auth method specified, prioritize bearer + if bearer != nil { + client.Transport = bearer + } else if basic != nil { + client.Transport = basic + } + default: + return nil, fmt.Errorf("unknown HTTP authentication '%s'", ep.Auth) + } + return client, nil +} + +type SSHClientEndpoint struct { + Endpoint `yaml:",inline"` + Username TemplateField `yaml:"username,omitempty"` + Password TemplateField `yaml:"password,omitempty"` + Password2 TemplateField `yaml:"password2,omitempty"` + PrivateKeyFile TemplateField `yaml:"privateKeyFile,omitempty"` +} diff --git a/internal/configuration/sequence.go b/internal/configuration/sequence.go index 8b7b4e7..ea2d17a 100644 --- a/internal/configuration/sequence.go +++ b/internal/configuration/sequence.go @@ -1,148 +1,180 @@ package configuration import ( - "corteca/internal/dispatcher" + "context" "corteca/internal/tui" + "errors" "fmt" - "strings" + "os" + "regexp" + "time" "gopkg.in/yaml.v3" ) -type SequenceMap map[string]Sequence +var cmdRegularExpression *regexp.Regexp -func (sm *SequenceMap) UnmarshalYAML(data *yaml.Node) error { - sequences := make(map[string]struct { - Type string `yaml:"type"` - Steps []yaml.Node `yaml:"steps"` - }) - if err := data.Decode(sequences); err != nil { - fmt.Println("Decode error") - return err - } +func init() { + cmdRegularExpression = regexp.MustCompile(`^\s*\$\((.+)\)\s*$`) +} - for seqName, rawSeq := range sequences { - switch rawSeq.Type { - case "ssh": - steps := make([]StringCmd, len(rawSeq.Steps)) +const ( + DefaultMaxTimeout = 5 * time.Minute +) - for i, node := range rawSeq.Steps { - var step StringCmd - if err := node.Decode(&step); err != nil { - return fmt.Errorf("error decoding step %d in sequence %q: %w", i, seqName, err) - } - steps[i] = step - } +var ( + ErrAbortSequence = errors.New("fatal error") +) - (*sm)[seqName] = NewStringSequence(rawSeq.Type, steps) - case "cwmp": - steps := make([]CwmpCmd, len(rawSeq.Steps)) +type SequenceMap map[string]Sequence - for i, node := range rawSeq.Steps { - var step CwmpCmd - if err := node.Decode(&step); err != nil { - return fmt.Errorf("error decoding step %d in sequence %q: %w", i, seqName, err) - } - steps[i] = step - } +type Sequence []SequenceCmd - (*sm)[seqName] = NewCwmpSequence(rawSeq.Type, steps) - } - } - return nil +type SequenceCmd struct { + Cmd TemplateField `yaml:"cmd"` + Delay time.Duration `yaml:"duration,omitempty"` + Timeout time.Duration `yaml:"timeout,omitempty"` + Retries uint `yaml:"retries,omitempty"` + IgnoreFailure *bool `yaml:"ignoreFailure,omitempty"` + raw *yaml.Node } -func (sm *SequenceMap) Execute(dispatcher dispatcher.Dispatcher, seq string) error { - selectedSequence, found := (*sm)[seq] - if !found { - return fmt.Errorf("sequence '%s' was not found", seq) +func parseDuration(value string, defaultvalue time.Duration) (time.Duration, error) { + if len(value) > 0 { + return time.ParseDuration(value) + } else { + return defaultvalue, nil } - tui.LogNormal("Executing sequence %s", seq) - return selectedSequence.Execute(dispatcher, *sm) } -type Sequence interface { - GetType() string - Execute(dispatcher.Dispatcher, SequenceMap) error +func (cmd *SequenceCmd) UnmarshalYAML(value *yaml.Node) error { + cmd.raw = value + var proxy struct { + Cmd TemplateField `yaml:"cmd"` + Delay string `yaml:"duration"` + Timeout string `yaml:"timeout"` + Retries uint `yaml:"retries"` + IgnoreFailure *bool `yaml:"ignoreFailure"` + } + if err := value.Decode(&proxy); err != nil { + return err + } + cmd.Cmd = proxy.Cmd + if d, err := parseDuration(proxy.Delay, 0); err != nil { + return err + } else { + cmd.Delay = d + } + if d, err := parseDuration(proxy.Timeout, 0); err != nil { + return err + } else { + cmd.Timeout = d + } + cmd.Retries = proxy.Retries + cmd.IgnoreFailure = proxy.IgnoreFailure + return nil } -func NewStringSequence(tp string, steps []StringCmd) *StringSequence { - return &StringSequence{tp, steps} +func (cmd SequenceCmd) MarshalYAML() (any, error) { + return cmd.raw, nil } -type StringSequence struct { - Type string - Steps []StringCmd +func (cmd *SequenceCmd) Decode(v any) error { + return cmd.raw.Decode(v) } -func (sq *StringSequence) GetType() string { - return sq.Type +type CommandExecutor interface { + BeginSequence() error + ExecuteCommand(context.Context, *SequenceCmd) (any, error) + EndSequence() error } -func (sq *StringSequence) Execute(dispathcer dispatcher.Dispatcher, sm SequenceMap) error { - for idx, step := range sq.Steps { - if seqName, found := findRefToSequence(step.Cmd.String()); found { - err := sm.Execute(dispathcer, seqName) - if err != nil { - return fmt.Errorf("reference sequence failed at step %d: %w", idx+1, err) - } - } else { +func (sm *SequenceMap) Execute(executor CommandExecutor, seqName string) error { + if _, ok := (*sm)[seqName]; !ok { + return fmt.Errorf("sequence '%s' was not found", seqName) + } + if err := executor.BeginSequence(); err != nil { + return fmt.Errorf("failed to initialize sequence: %w", err) + } + if err := sm.executeSequenceSteps(executor, seqName); err != nil { + return err + } + if err := executor.EndSequence(); err != nil { + return fmt.Errorf("failed to shutdown sequence: %w", err) + } + return nil +} - if out, err := step.Execute(dispathcer); err != nil { - return fmt.Errorf("sequence failed at step %d: %w", idx+1, err) - } else { - tui.LogOutData(out) +func (sm *SequenceMap) executeSequenceSteps(executor CommandExecutor, seqName string) error { + seq, found := (*sm)[seqName] + if !found { + return fmt.Errorf("sequence '%s' was not found", seqName) + } + tui.LogNormal("Executing sequence '%s'", seqName) + for idx, step := range seq { + if refSeqName, found := findRefToSequence(step.Cmd.String()); found { + if err := sm.executeSequenceSteps(executor, refSeqName); err != nil { + return err } + continue + } + res, err := executeStep(&step, executor) + if err != nil { + return fmt.Errorf("sequence '%s' failed at step %d: %w", seqName, idx+1, err) } + // TODO: provide option to suppress output + tui.SetOutputColor(tui.CBlue, os.Stdout) + enc := yaml.NewEncoder(os.Stdout) + enc.Encode(res) + tui.ResetOutputColor(os.Stdout) } return nil } -func findRefToSequence(seqCmd string) (string, bool) { - if cmdRefRegex := cmdRegularExpression.FindStringSubmatch(seqCmd); len(cmdRefRegex) == 2 { +func findRefToSequence(expr string) (string, bool) { + if cmdRefRegex := cmdRegularExpression.FindStringSubmatch(expr); len(cmdRefRegex) == 2 { return cmdRefRegex[1], true } else { return "", false } } -type SequenceCmd interface { - Execute(dispatcher.Dispatcher) error -} - -type CwmpSequence struct { - Type string - Steps []CwmpCmd -} - -func NewCwmpSequence(tp string, steps []CwmpCmd) *CwmpSequence { - return &CwmpSequence{tp, steps} -} - -func (sq *CwmpSequence) GetType() string { - return sq.Type +func createContext(timeout time.Duration) (context.Context, context.CancelFunc) { + if timeout == 0 { + timeout = DefaultMaxTimeout + } + return context.WithTimeout(context.Background(), timeout) } -func (sq *CwmpSequence) Execute(dispathcer dispatcher.Dispatcher, sm SequenceMap) error { - for idx, step := range sq.Steps { - // Cmd, Operation and PrintFormat should be lowercase in order for the code to work - step.Cmd.RawTemplate = strings.ToLower(step.Cmd.String()) - step.Operation = strings.ToLower(step.Operation) - step.PrintFormat = strings.ToLower(step.PrintFormat) - if seqName, found := findRefToSequence(step.Cmd.String()); found { - err := sm.Execute(dispathcer, seqName) - if err != nil { - return fmt.Errorf("reference sequence failed at step %d: %w", idx+1, err) - } - } else { - if cmdResults, err := step.Execute(dispathcer); err != nil { - return fmt.Errorf("sequence failed at step %d: %w", idx+1, err) +func executeStep(step *SequenceCmd, executor CommandExecutor) (any, error) { + attempts := step.Retries + 1 + var ( + res any + err error + ) + for attempts > 0 { + if step.Delay > 0 { + tui.LogNormal("Waiting for %s", step.Delay.String()) + time.Sleep(step.Delay) + } + ctx, cancel := createContext(step.Timeout) + defer cancel() + res, err = executor.ExecuteCommand(ctx, step) + attempts-- + if err != nil { + if step.IgnoreFailure != nil && (*step.IgnoreFailure) { + return res, nil } else { - if cmdResults != "" { - tui.LogOutData(cmdResults) + tui.LogError("Command failed: %s", err.Error()) + if attempts > 0 { + tui.LogNormal("Will retry %d more time(s)", attempts) + } else { + return res, err } } + } else { + break } } - return nil + return res, nil } diff --git a/internal/configuration/sequence_test.go b/internal/configuration/sequence_test.go new file mode 100644 index 0000000..dec5bb8 --- /dev/null +++ b/internal/configuration/sequence_test.go @@ -0,0 +1,420 @@ +// Copyright 2024 Nokia +// Licensed under the BSD 3-Clause License. +// SPDX-License-Identifier: BSD-3-Clause + +package configuration_test + +import ( + "context" + "corteca/internal/configuration" + "errors" + "testing" + "time" +) + +// boolPtr returns a pointer to b, a convenience helper for SequenceCmd.IgnoreFailure. +func boolPtr(b bool) *bool { return &b } + +// mockExecutor is a configurable test double for configuration.CommandExecutor. +type mockExecutor struct { + beginErr error + endErr error + + beginCalled int + endCalled int + + // executeFunc is invoked on every ExecuteCommand call. + // callIdx is 0-based. When nil, ExecuteCommand always returns (nil, nil). + executeFunc func(callIdx int, ctx context.Context, cmd *configuration.SequenceCmd) (any, error) + + callCount int + capturedContexts []context.Context +} + +func (m *mockExecutor) BeginSequence() error { + m.beginCalled++ + return m.beginErr +} + +func (m *mockExecutor) ExecuteCommand(ctx context.Context, cmd *configuration.SequenceCmd) (any, error) { + m.capturedContexts = append(m.capturedContexts, ctx) + idx := m.callCount + m.callCount++ + if m.executeFunc != nil { + return m.executeFunc(idx, ctx, cmd) + } + return nil, nil +} + +func (m *mockExecutor) EndSequence() error { + m.endCalled++ + return m.endErr +} + +// alwaysFails returns an executeFunc that always returns the given error. +func alwaysFails(err error) func(int, context.Context, *configuration.SequenceCmd) (any, error) { + return func(_ int, _ context.Context, _ *configuration.SequenceCmd) (any, error) { + return nil, err + } +} + +// succeedsAfter returns an executeFunc that fails for the first n calls, then succeeds. +func succeedsAfter(n int, err error) func(int, context.Context, *configuration.SequenceCmd) (any, error) { + return func(callIdx int, _ context.Context, _ *configuration.SequenceCmd) (any, error) { + if callIdx < n { + return nil, err + } + return "ok", nil + } +} + +// simpleStep is a convenience constructor for a SequenceCmd with only the fields that matter +// for a given test set. +func simpleStep(cmd string, ignoreFailure bool) configuration.SequenceCmd { + return configuration.SequenceCmd{ + Cmd: configuration.T(cmd), + IgnoreFailure: boolPtr(ignoreFailure), + } +} + +// ---- SequenceMap.Execute: lifecycle tests ---------------------------------------- + +// TestExecute_UnknownSequence verifies that calling Execute with a name that does not exist +// in the map returns an error. +func TestExecute_UnknownSequence(t *testing.T) { + sm := configuration.SequenceMap{"existing": {}} + err := sm.Execute(&mockExecutor{}, "missing") + if err == nil { + t.Fatal("expected an error for unknown sequence name, got nil") + } +} + +// TestExecute_BeginAndEndSequence verifies that BeginSequence and EndSequence are each called +// exactly once when skipinit=false and the sequence completes successfully. +func TestExecute_BeginAndEndSequence(t *testing.T) { + sm := configuration.SequenceMap{ + "seq": {simpleStep("cmd", false)}, + } + exec := &mockExecutor{} + + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exec.beginCalled != 1 { + t.Errorf("BeginSequence: expected 1 call, got %d", exec.beginCalled) + } + if exec.endCalled != 1 { + t.Errorf("EndSequence: expected 1 call, got %d", exec.endCalled) + } +} + +// TestExecute_EachStepCallsExecuteCommand verifies that ExecuteCommand is called once per step +// in the sequence. +func TestExecute_EachStepCallsExecuteCommand(t *testing.T) { + sm := configuration.SequenceMap{ + "seq": { + simpleStep("cmd1", false), + simpleStep("cmd2", false), + simpleStep("cmd3", false), + }, + } + exec := &mockExecutor{} + + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exec.callCount != 3 { + t.Errorf("ExecuteCommand: expected 3 calls (one per step), got %d", exec.callCount) + } +} + +// TestExecute_BeginSequenceError_AbortsBefore verifies that when BeginSequence returns an +// error, no steps are executed and EndSequence is not called. +func TestExecute_BeginSequenceError_AbortsBefore(t *testing.T) { + sm := configuration.SequenceMap{ + "seq": {simpleStep("cmd", false)}, + } + exec := &mockExecutor{beginErr: errors.New("begin failed")} + + if err := sm.Execute(exec, "seq"); err == nil { + t.Fatal("expected error when BeginSequence fails, got nil") + } + if exec.callCount != 0 { + t.Errorf("ExecuteCommand: expected 0 calls after BeginSequence failure, got %d", exec.callCount) + } + if exec.endCalled != 0 { + t.Errorf("EndSequence: expected 0 calls after BeginSequence failure, got %d", exec.endCalled) + } +} + +// TestExecute_StepFailure_EndSequenceSkipped verifies that when a step fails (and +// IgnoreFailure=false), Execute returns an error and EndSequence is NOT called. +func TestExecute_StepFailure_EndSequenceSkipped(t *testing.T) { + sm := configuration.SequenceMap{ + "seq": {simpleStep("cmd", false)}, + } + exec := &mockExecutor{executeFunc: alwaysFails(errors.New("step failed"))} + + if err := sm.Execute(exec, "seq"); err == nil { + t.Fatal("expected error for failed step, got nil") + } + if exec.endCalled != 0 { + t.Errorf("EndSequence: expected 0 calls after step failure, got %d", exec.endCalled) + } +} + +// TestExecute_EndSequenceError_Propagates verifies that an error returned by EndSequence +// is propagated to the caller. +func TestExecute_EndSequenceError_Propagates(t *testing.T) { + sm := configuration.SequenceMap{ + "seq": {simpleStep("cmd", false)}, + } + exec := &mockExecutor{endErr: errors.New("end failed")} + + if err := sm.Execute(exec, "seq"); err == nil { + t.Fatal("expected error when EndSequence fails, got nil") + } +} + +// ---- Retries tests --------------------------------------------------------------- + +// TestExecute_Retries_ExhaustsAllAttempts verifies that when a step always fails, +// ExecuteCommand is called exactly Retries+1 times before the sequence is aborted. +func TestExecute_Retries_ExhaustsAllAttempts(t *testing.T) { + cmd := configuration.SequenceCmd{ + Cmd: configuration.T("cmd"), + Retries: 2, + IgnoreFailure: boolPtr(false), + } + sm := configuration.SequenceMap{"seq": {cmd}} + exec := &mockExecutor{executeFunc: alwaysFails(errors.New("fail"))} + + if err := sm.Execute(exec, "seq"); err == nil { + t.Fatal("expected error after retries exhausted, got nil") + } + if exec.callCount != 3 { + t.Errorf("ExecuteCommand: expected 3 calls (1 initial + 2 retries), got %d", exec.callCount) + } +} + +// TestExecute_Retries_SucceedsOnRetry verifies that the sequence completes successfully when +// a transient failure is resolved on a subsequent retry. +func TestExecute_Retries_SucceedsOnRetry(t *testing.T) { + cmd := configuration.SequenceCmd{ + Cmd: configuration.T("cmd"), + Retries: 2, + IgnoreFailure: boolPtr(false), + } + sm := configuration.SequenceMap{"seq": {cmd}} + exec := &mockExecutor{executeFunc: succeedsAfter(1, errors.New("transient"))} + + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("expected success after retry, got: %v", err) + } + if exec.callCount != 2 { + t.Errorf("ExecuteCommand: expected 2 calls (fail then succeed), got %d", exec.callCount) + } +} + +// ---- IgnoreFailure tests --------------------------------------------------------- + +// TestExecute_IgnoreFailure_SkipsRetries verifies that when IgnoreFailure=true the command +// is executed exactly once, regardless of how many Retries are configured. +func TestExecute_IgnoreFailure_SkipsRetries(t *testing.T) { + cmd := configuration.SequenceCmd{ + Cmd: configuration.T("cmd"), + Retries: 2, + IgnoreFailure: boolPtr(true), + } + sm := configuration.SequenceMap{"seq": {cmd}} + exec := &mockExecutor{executeFunc: alwaysFails(errors.New("fail"))} + + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("expected no error with IgnoreFailure=true, got: %v", err) + } + if exec.callCount != 1 { + t.Errorf( + "ExecuteCommand: expected exactly 1 call when IgnoreFailure=true (retries must be skipped), got %d", + exec.callCount, + ) + } +} + +// TestExecute_IgnoreFailure_SequenceContinues verifies that execution advances to the next +// step after a step with IgnoreFailure=true fails. +func TestExecute_IgnoreFailure_SequenceContinues(t *testing.T) { + sm := configuration.SequenceMap{ + "seq": { + {Cmd: configuration.T("fail-step"), IgnoreFailure: boolPtr(true)}, + {Cmd: configuration.T("ok-step"), IgnoreFailure: boolPtr(false)}, + }, + } + exec := &mockExecutor{ + executeFunc: func(callIdx int, _ context.Context, _ *configuration.SequenceCmd) (any, error) { + if callIdx == 0 { + return nil, errors.New("ignored failure") + } + return "ok", nil + }, + } + + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("expected sequence to complete successfully, got: %v", err) + } + if exec.callCount != 2 { + t.Errorf("ExecuteCommand: expected 2 calls (both steps), got %d", exec.callCount) + } +} + +// TestExecute_IgnoreFailure_False_StopsOnError verifies that when IgnoreFailure=false and a +// step fails, the following steps are not executed. +func TestExecute_IgnoreFailure_False_StopsOnError(t *testing.T) { + sm := configuration.SequenceMap{ + "seq": { + simpleStep("fail-step", false), + simpleStep("should-not-run", false), + }, + } + exec := &mockExecutor{ + executeFunc: func(callIdx int, _ context.Context, _ *configuration.SequenceCmd) (any, error) { + if callIdx == 0 { + return nil, errors.New("step failed") + } + return "ok", nil + }, + } + + if err := sm.Execute(exec, "seq"); err == nil { + t.Fatal("expected error for failed step, got nil") + } + if exec.callCount != 1 { + t.Errorf("ExecuteCommand: expected 1 call (sequence must stop on failure), got %d", exec.callCount) + } +} + +// ---- Delay tests ----------------------------------------------------------------- + +// TestExecute_Delay_AppliedAfterAttempt verifies that a non-zero Delay causes an observable +// pause after ExecuteCommand returns. +func TestExecute_Delay_AppliedAfterAttempt(t *testing.T) { + const delay = 20 * time.Millisecond + cmd := configuration.SequenceCmd{ + Cmd: configuration.T("cmd"), + Delay: delay, + IgnoreFailure: boolPtr(false), + } + sm := configuration.SequenceMap{"seq": {cmd}} + exec := &mockExecutor{} + + start := time.Now() + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if elapsed := time.Since(start); elapsed < delay { + t.Errorf("expected at least %v elapsed due to Delay, got %v", delay, elapsed) + } +} + +// TestExecute_Delay_AppliedBetweenRetries verifies that the Delay is applied after each +// failed attempt when retrying, producing a cumulative pause. +func TestExecute_Delay_AppliedBetweenRetries(t *testing.T) { + const ( + delay = 10 * time.Millisecond + retries = 2 + ) + cmd := configuration.SequenceCmd{ + Cmd: configuration.T("cmd"), + Retries: retries, + Delay: delay, + IgnoreFailure: boolPtr(false), + } + sm := configuration.SequenceMap{"seq": {cmd}} + // Succeed on the last attempt so the sequence finishes without error. + exec := &mockExecutor{executeFunc: succeedsAfter(retries, errors.New("transient"))} + + start := time.Now() + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Each of the three attempts is followed by a Delay sleep, so the total must be + // at least retries+1 × delay. + minExpected := delay * (retries + 1) + if elapsed := time.Since(start); elapsed < minExpected { + t.Errorf("expected at least %v elapsed (%d attempts × %v delay), got %v", + minExpected, retries+1, delay, elapsed) + } +} + +// ---- Timeout tests --------------------------------------------------------------- + +// TestExecute_Timeout_ContextDeadlineSet verifies that when Timeout is configured, the +// context passed to ExecuteCommand carries a matching deadline. +func TestExecute_Timeout_ContextDeadlineSet(t *testing.T) { + const timeout = 500 * time.Millisecond + cmd := configuration.SequenceCmd{ + Cmd: configuration.T("cmd"), + Timeout: timeout, + IgnoreFailure: boolPtr(false), + } + sm := configuration.SequenceMap{"seq": {cmd}} + + var ( + capturedAt time.Time + deadline time.Time + deadlineOk bool + ) + exec := &mockExecutor{ + executeFunc: func(_ int, ctx context.Context, _ *configuration.SequenceCmd) (any, error) { + capturedAt = time.Now() + deadline, deadlineOk = ctx.Deadline() + return nil, nil + }, + } + + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !deadlineOk { + t.Fatal("context has no deadline; expected one to be set from Timeout") + } + // At the moment ExecuteCommand was entered, the remaining time must be positive + // and no greater than the configured timeout. + remaining := deadline.Sub(capturedAt) + if remaining <= 0 || remaining > timeout { + t.Errorf("context deadline remaining at capture time = %v; want in (0, %v]", remaining, timeout) + } +} + +// TestExecute_DefaultTimeout_UsedWhenUnset verifies that when Timeout is zero, the context +// deadline falls back to DefaultMaxTimeout (5 minutes). +func TestExecute_DefaultTimeout_UsedWhenUnset(t *testing.T) { + cmd := configuration.SequenceCmd{ + Cmd: configuration.T("cmd"), + // Timeout intentionally left at zero — should fall back to DefaultMaxTimeout. + IgnoreFailure: boolPtr(false), + } + sm := configuration.SequenceMap{"seq": {cmd}} + + exec := &mockExecutor{ + executeFunc: func(_ int, ctx context.Context, _ *configuration.SequenceCmd) (any, error) { + dl, ok := ctx.Deadline() + if !ok { + t.Error("context has no deadline; expected DefaultMaxTimeout to be applied") + return nil, nil + } + remaining := time.Until(dl) + // Allow a generous 1-second tolerance: the deadline must be within + // (DefaultMaxTimeout-1s, DefaultMaxTimeout] from now. + lo := configuration.DefaultMaxTimeout - time.Second + hi := configuration.DefaultMaxTimeout + if remaining < lo || remaining > hi { + t.Errorf("default timeout: remaining = %v, want in [%v, %v]", remaining, lo, hi) + } + return nil, nil + }, + } + + if err := sm.Execute(exec, "seq"); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/configuration/stringcmd.go b/internal/configuration/stringcmd.go deleted file mode 100644 index 78dc3f0..0000000 --- a/internal/configuration/stringcmd.go +++ /dev/null @@ -1,52 +0,0 @@ -package configuration - -import ( - "corteca/internal/dispatcher" - "time" - "fmt" -) - -type StringCmd struct { - Cmd TemplateField `yaml:"cmd,omitempty"` - Delay uint `yaml:"delay,omitempty"` - Retries uint `yaml:"retries,omitempty"` - IgnoreFailure bool `yaml:"ignoreFailure,omitempty"` -} - -func (sqCmd *StringCmd) Execute(dispatcher dispatcher.Dispatcher) (out string, err error) { - attempts := sqCmd.Retries + 1 - - for attempts > 0 { - out, err = sqCmd.executeCommand(dispatcher) - - attempts-- - - if !sqCmd.IgnoreFailure && err != nil { - if attempts == 0 { - return "", err - } else { - fmt.Printf("Command failed (%s); will retry %d more time(s).\n", err.Error(), attempts) - } - } - - if sqCmd.Delay > 0 { - fmt.Printf("=> Waiting for %d millisecond(s)...\n", sqCmd.Delay) - time.Sleep(time.Duration(sqCmd.Delay) * time.Millisecond) - } - if err == nil { - break - } - } - - return out, nil -} - -func (sqCmd *StringCmd) executeCommand(dispatcher dispatcher.Dispatcher) (string, error) { - - cmdStr := sqCmd.Cmd.String() - - fmt.Printf("=> Send cmd: '%s'...\n", cmdStr) - out, err := dispatcher.ExecuteCommand(cmdStr) - - return out, err -} diff --git a/internal/cwmp/messages/ChangeDUState.go b/internal/cwmp/messages/ChangeDUState.go deleted file mode 100644 index c56946b..0000000 --- a/internal/cwmp/messages/ChangeDUState.go +++ /dev/null @@ -1,103 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - "time" - - "github.com/beevik/etree" -) - -type ChangeDUState struct { - ID string - Name string - OperationType string - Operation DeploymentUnitOperationStruct - CommandKey string -} - -type ChangeDUStateBodyStruct struct { - Body ChangeDUStateStruct `xml:"cwmp:ChangeDUState"` -} - -type ChangeDUStateStruct struct { - XMLName xml.Name `xml:"cwmp:ChangeDUState"` - XmlnsCwmp string `xml:"xmlns:cwmp,attr"` - CommandKey string `xml:"CommandKey"` - Operations []DeploymentUnitOperationStruct `xml:"Operations"` -} - -type DeploymentUnitOperationStruct struct { - XmlnsXsi string `xml:"xmlns:xsi,attr"` - XmlnsXsiType string `xml:"xsi:type,attr"` - URL string `xml:"URL"` - UUID string `xml:"UUID"` - Username string `xml:"Username,omitempty"` - Password string `xml:"Password,omitempty"` - ExecutionEnvRef string `xml:"ExecutionEnvRef"` - Version string `xml:"Version"` -} - -func NewChangeDUState() *ChangeDUState { - changeDUState := new(ChangeDUState) - changeDUState.ID = changeDUState.GetID() - changeDUState.Name = changeDUState.GetName() - return changeDUState -} - -// GetName get msg type -func (msg *ChangeDUState) GetName() string { - return "ChangeDUState" -} - -func (msg *ChangeDUState) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -// CreateXML encode into xml -func (msg *ChangeDUState) CreateXML() ([]byte, error) { - env := Envelope{} - id := IDStruct{"1", msg.GetID()} - env.XmlnsEnv = "http://schemas.xmlsoap.org/soap/envelope/" - env.XmlnsEnc = "http://schemas.xmlsoap.org/soap/encoding/" - env.XmlnsXsd = "http://www.w3.org/2001/XMLSchema" - env.XmlnsXsi = "http://www.w3.org/2001/XMLSchema-instance" - env.XmlnsCwmp = "urn:dslforum-org:cwmp-1-0" - env.Header = HeaderStruct{ID: id} - var operation = DeploymentUnitOperationStruct{} - switch msg.OperationType { - case "install": - operation = DeploymentUnitOperationStruct{ - URL: msg.Operation.URL, - UUID: msg.Operation.UUID, - Password: msg.Operation.Password, - Username: msg.Operation.Username, - } - case "uninstall": - operation = DeploymentUnitOperationStruct{ - UUID: msg.Operation.UUID, - Version: msg.Operation.Version, - } - case "update": - operation = DeploymentUnitOperationStruct{ - URL: msg.Operation.URL, - UUID: msg.Operation.UUID, - } - default: - return nil, fmt.Errorf("operation %s not supported", msg.OperationType) - } - - operation.XmlnsXsiType = fmt.Sprintf("cwmp:%sOpStruct", msg.OperationType) - operation.XmlnsXsi = "http://www.w3.org/2001/XMLSchema-instance" - changeDUState := ChangeDUStateStruct{CommandKey: msg.CommandKey, Operations: []DeploymentUnitOperationStruct{operation}} - changeDUState.XmlnsCwmp = "urn:dslforum-org:cwmp-1-0" - env.Body = ChangeDUStateBodyStruct{changeDUState} - return xml.MarshalIndent(env, " ", " ") -} - -func (msg *ChangeDUState) Parse(doc *etree.Document) error { - return nil -} diff --git a/internal/cwmp/messages/ChangeDUStateResponse.go b/internal/cwmp/messages/ChangeDUStateResponse.go deleted file mode 100644 index 6934d71..0000000 --- a/internal/cwmp/messages/ChangeDUStateResponse.go +++ /dev/null @@ -1,41 +0,0 @@ -package messages - -import ( - "fmt" - "time" - - "github.com/beevik/etree" -) - -type ChangeDUStateResponse struct { - ID string - Name string -} - -func (msg *ChangeDUStateResponse) GetName() string { - return "ChangeDUStateResponse" -} - -func (msg *ChangeDUStateResponse) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -// CreateXML encode into xml -func (msg *ChangeDUStateResponse) CreateXML() ([]byte, error) { - return nil, nil -} - -func (msg *ChangeDUStateResponse) Parse(doc *etree.Document) error { - msg.ID = doc.FindElement("//ID").Text() - return nil -} - -func NewChangeDUStateResponse() *ChangeDUStateResponse { - changeDUState := new(ChangeDUStateResponse) - changeDUState.ID = changeDUState.GetID() - changeDUState.Name = changeDUState.GetName() - return changeDUState -} diff --git a/internal/cwmp/messages/DUStateChangeComplete.go b/internal/cwmp/messages/DUStateChangeComplete.go deleted file mode 100644 index 119ca77..0000000 --- a/internal/cwmp/messages/DUStateChangeComplete.go +++ /dev/null @@ -1,116 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - - "strconv" - "time" - - "github.com/beevik/etree" -) - -type DUStateChangeComplete struct { - ID string - Name string - UUID string - DeploymentUnitRef string - Version string - ExecutionUnitRefList []string - Fault FaultStruct - StartTime string - CompleteTime string - CommandKey string -} - -type ChangeDUStateCompleteHeaderStruct struct { - ID IDStruct `xml:"cwmp:ID"` - NoMore interface{} `xml:"cwmp:NoMoreRequests,omitempty"` -} - -type ChangeDUStateCompleteStruct struct { - XMLName xml.Name `xml:"DUStateChangeComplete"` - CommandKey string `xml:"CommandKey"` - Results []DeploymentUnitResultStruct `xml:"Results"` -} - -type DeploymentUnitResultStruct struct { - UUID string `xml:"UUID"` - DeploymentUnitRef string `xml:"DeploymentUnitRef"` - Version string `xml:"Version"` - ExecutionUnitRefList []string `xml:"ExecutionUnitRefList>string"` - OperationPerformed string `xml:"OperationPerformed"` - StartTime string `xml:"StartTime"` - CompleteTime string `xml:"CompleteTime"` - Fault DUCompleteFaultStruct `xml:"Fault"` -} - -type DUCompleteFaultStruct struct { - FaultCode int `xml:"FaultCode"` - FaultString string `xml:"FaultString"` -} - -// GetName get msg type -func (msg *DUStateChangeComplete) GetName() string { - return msg.Name -} - -// GetID get msg id -func (msg *DUStateChangeComplete) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -func (msg *DUStateChangeComplete) Parse(doc *etree.Document) error { - msg.ID = doc.FindElement("//ID").Text() - msg.Name = "DUStateChangeComplete" - - elemUUID := doc.FindElement("//UUID") - if elemUUID == nil { - return fmt.Errorf("failed to parse uuid") - } - msg.UUID = doc.FindElement("//UUID").Text() - - elemDepRef := doc.FindElement("//DeploymentUnitRef") - if elemDepRef == nil { - return fmt.Errorf("failed to parse DeploymentUnitRef") - } - msg.DeploymentUnitRef = elemDepRef.Text() - - elemStartTime := doc.FindElement("//StartTime") - if elemStartTime == nil { - return fmt.Errorf("failed to parse StartTime") - } - msg.StartTime = elemStartTime.Text() - - elemCompleteTime := doc.FindElement("//CompleteTime") - if elemCompleteTime == nil { - return fmt.Errorf("failed to parse CompleteTime") - } - msg.CompleteTime = elemCompleteTime.Text() - - elemFaultCode := doc.FindElement("//FaultCode") - if elemFaultCode == nil { - return fmt.Errorf("failed to parse FaultCode") - } - msg.Fault.FaultCode, _ = strconv.Atoi(elemFaultCode.Text()) - - elemFaultString := doc.FindElement("//FaultString") - if elemFaultString == nil { - return fmt.Errorf("failed to parse FaultString") - } - msg.Fault.FaultString = elemFaultString.Text() - - msg.CommandKey = doc.FindElement("//CommandKey").Text() - return nil -} - -func (msg *DUStateChangeComplete) CreateXML() ([]byte, error) { - return nil, nil -} - -func NewChangeDUStateComplete() *DUStateChangeComplete { - return &DUStateChangeComplete{} -} diff --git a/internal/cwmp/messages/DUStateChangeCompleteResponse.go b/internal/cwmp/messages/DUStateChangeCompleteResponse.go deleted file mode 100644 index e6cf24f..0000000 --- a/internal/cwmp/messages/DUStateChangeCompleteResponse.go +++ /dev/null @@ -1,58 +0,0 @@ -package messages - -import ( - "encoding/xml" - - "github.com/beevik/etree" -) - - -type changeDUStateCompleteResponseBody struct { - Response changeDUStateCompleteResponseStruct `xml:"cwmp:ChangeDUStateCompleteResponse"` -} - -type changeDUStateCompleteResponseStruct struct{} - - -type ChangeDUStateCompleteResponse struct { - ID string - Name string -} - -func (msg *ChangeDUStateCompleteResponse) GetName() string { - return "DUStateChangeCompleteResponse" -} - - -func (msg *ChangeDUStateCompleteResponse) GetID() string { - return msg.ID -} - -func (msg *ChangeDUStateCompleteResponse) CreateXML() ([]byte, error) { - env := Envelope{ - XmlnsEnv: "http://schemas.xmlsoap.org/soap/envelope/", - XmlnsEnc: "http://schemas.xmlsoap.org/soap/encoding/", - XmlnsXsd: "http://www.w3.org/2001/XMLSchema", - XmlnsXsi: "http://www.w3.org/2001/XMLSchema-instance", - XmlnsCwmp: "urn:dslforum-org:cwmp-1-0", - Header: HeaderStruct{ - ID: IDStruct{ - Attr: "1", - Value: msg.GetID(), // You should define GetID() on your type - }, - }, - Body: changeDUStateCompleteResponseBody{ - Response: changeDUStateCompleteResponseStruct{}, - }, - } - - return xml.MarshalIndent(env, " ", " ") -} - -func (msg *ChangeDUStateCompleteResponse) Parse(doc *etree.Document) error { - return nil -} - -func NewDUStateCompleteResponse() *ChangeDUStateCompleteResponse { - return &ChangeDUStateCompleteResponse{} -} diff --git a/internal/cwmp/messages/Fault.go b/internal/cwmp/messages/Fault.go deleted file mode 100644 index 63305cb..0000000 --- a/internal/cwmp/messages/Fault.go +++ /dev/null @@ -1,119 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - "time" - - "github.com/beevik/etree" -) - -// Fault error response -type Fault struct { - ID string - Name string - NoMore int - CwmpFaultCode string - CwmpFaultString string - MsgFaultCode string - MsgFaultString string - SetParameterValuesFault SetParameterValuesFaultStruct -} - -type faultBodyStruct struct { - Fault faultStruct `xml:"SOAP-ENV:Fault"` -} -type faultStruct struct { - FaultCode string `xml:"faultcode"` - FaultString string `xml:"faultstring"` - FaultDetail faultDetailStruct `xml:"detail"` -} - -type faultDetailStruct struct { - CwmpFault cwmpFaultStruct `xml:"cwmp:Fault"` -} - -type cwmpFaultStruct struct { - FaultCode string - FaultString string - SetParameterValuesFault SetParameterValuesFaultStruct -} - -// SetParameterValuesFaultStruct setParameterValues Fault -type SetParameterValuesFaultStruct struct { - ParameterName string - FaultCode string - FaultString string - ParameterKey string -} - -// NewFault create Fault object -func NewFault() (m *Fault) { - m = &Fault{} - m.ID = m.GetID() - m.Name = m.GetName() - return m -} - -// GetName get msg type -func (msg *Fault) GetName() string { - return "Fault" -} - -// GetID get msg id -func (msg *Fault) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -// CreateXML encode into xml -func (msg *Fault) CreateXML() ([]byte, error) { - env := Envelope{} - id := IDStruct{"1", msg.GetID()} - env.XmlnsEnv = "http://schemas.xmlsoap.org/soap/envelope/" - env.XmlnsEnc = "http://schemas.xmlsoap.org/soap/encoding/" - env.XmlnsXsd = "http://www.w3.org/2001/XMLSchema" - env.XmlnsXsi = "http://www.w3.org/2001/XMLSchema-instance" - env.XmlnsCwmp = "urn:dslforum-org:cwmp-1-0" - env.Header = HeaderStruct{ID: id} - setParamFault := SetParameterValuesFaultStruct{ - FaultCode: msg.SetParameterValuesFault.FaultCode, - FaultString: msg.SetParameterValuesFault.FaultString, - ParameterName: msg.SetParameterValuesFault.ParameterName, - ParameterKey: msg.SetParameterValuesFault.ParameterKey, - } - cwmpFault := cwmpFaultStruct{ - FaultCode: msg.MsgFaultCode, - FaultString: msg.MsgFaultString, - SetParameterValuesFault: setParamFault, - } - detail := faultDetailStruct{CwmpFault: cwmpFault} - fault := faultStruct{ - FaultCode: msg.CwmpFaultCode, - FaultString: msg.CwmpFaultString, - FaultDetail: detail, - } - env.Body = faultBodyStruct{fault} - return xml.MarshalIndent(env, " ", " ") -} - -// Parse decode from xml -func (msg *Fault) Parse(doc *etree.Document) error { - msg.ID = doc.FindElement("//ID").Text() - faultNode := doc.FindElement("//Fault") - msg.CwmpFaultCode = faultNode.SelectElement("faultcode").Text() - msg.CwmpFaultString = faultNode.SelectElement("faultstring").Text() - detailNode := faultNode.FindElement("//detail") - detailFaultNode := detailNode.FindElement("cwmp:Fault") - msg.MsgFaultCode = detailFaultNode.SelectElement("FaultCode").Text() - msg.MsgFaultString = detailFaultNode.SelectElement("FaultString").Text() - setParamFaultNode := detailFaultNode.FindElement("//SetParameterValuesFault") - if setParamFaultNode != nil { - msg.SetParameterValuesFault.FaultCode = setParamFaultNode.SelectElement("FaultCode").Text() - msg.SetParameterValuesFault.FaultString = setParamFaultNode.SelectElement("FaultString").Text() - msg.SetParameterValuesFault.ParameterName = setParamFaultNode.SelectElement("ParameterName").Text() - } - return nil -} diff --git a/internal/cwmp/messages/GetParameterNames.go b/internal/cwmp/messages/GetParameterNames.go deleted file mode 100644 index 643c765..0000000 --- a/internal/cwmp/messages/GetParameterNames.go +++ /dev/null @@ -1,67 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - "time" - - "github.com/beevik/etree" -) - -type GetParameterNames struct { - Name string - ID string - ParameterPath string - NextLevel bool -} - -type GetParameterNamesBodyStruct struct { - XMLName xml.Name `xml:"soap-env:Body"` - Body GetParameterNamesRPC `xml:"cwmp:GetParameterNames"` -} - -type GetParameterNamesRPC struct { - ParameterPath string `xml:"ParameterPath"` - NextLevel bool `xml:"NextLevel"` -} - -func (g *GetParameterNames) GetName() string { - return "GetParameterNames" -} - -func (g *GetParameterNames) CreateXML() ([]byte, error) { - rpc := GetParameterNamesRPC{ - ParameterPath: g.ParameterPath, - NextLevel: g.NextLevel, - } - - envelope := Envelope{ - XmlnsEnv: "http://schemas.xmlsoap.org/soap/envelope/", - XmlnsEnc: "http://schemas.xmlsoap.org/soap/encoding/", - XmlnsXsd: "http://www.w3.org/2001/XMLSchema", - XmlnsXsi: "http://www.w3.org/2001/XMLSchema-instance", - XmlnsCwmp: "urn:dslforum-org:cwmp-1-0", - Header: HeaderStruct{ID: IDStruct{Attr: "1", Value: g.ID}}, - Body: GetParameterNamesBodyStruct{Body: rpc}, - } - - return xml.MarshalIndent(envelope, "", " ") -} - -func (g *GetParameterNames) GetID() string { - if len(g.ID) < 1 { - g.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", g.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return g.ID -} - -func (g *GetParameterNames) Parse(doc *etree.Document) error { - return nil -} - -func NewGetParameterNames() *GetParameterNames { - getParamNames := new(GetParameterNames) - getParamNames.ID = getParamNames.GetID() - getParamNames.Name = getParamNames.GetName() - return getParamNames -} diff --git a/internal/cwmp/messages/GetParameterNamesResponse.go b/internal/cwmp/messages/GetParameterNamesResponse.go deleted file mode 100644 index 202a651..0000000 --- a/internal/cwmp/messages/GetParameterNamesResponse.go +++ /dev/null @@ -1,68 +0,0 @@ -package messages - -import ( - "fmt" - - "github.com/beevik/etree" -) - -type ParameterInfoStruct struct { - Name string `xml:"Name"` - Writable bool `xml:"Writable"` -} - -type ParameterList struct { - Parameters []ParameterInfoStruct `xml:"ParameterInfoStruct"` -} - -type GetParameterNamesResponse struct { - XMLName string `xml:"cwmp:GetParameterNamesResponse"` - ParameterList ParameterList `xml:"ParameterList"` - ID string `xml:"ID,attr"` -} - -func NewGetParameterNamesResponse() *GetParameterNamesResponse { - return &GetParameterNamesResponse{} -} - -func (resp *GetParameterNamesResponse) GetID() string { - return resp.ID -} - -func (resp *GetParameterNamesResponse) Parse(doc *etree.Document) error { - body := doc.FindElement("//cwmp:GetParameterNamesResponse") - if body == nil { - return fmt.Errorf("GetParameterNamesResponse element not found") - } - - if idElem := doc.FindElement("//cwmp:ID"); idElem != nil { - resp.ID = idElem.Text() - } - - parameterList := body.FindElement("ParameterList") - if parameterList == nil { - return fmt.Errorf("ParameterList element not found") - } - - for _, p := range parameterList.SelectElements("ParameterInfoStruct") { - name := p.FindElement("Name") - writable := p.FindElement("Writable") - - if name != nil && writable != nil { - resp.ParameterList.Parameters = append(resp.ParameterList.Parameters, ParameterInfoStruct{ - Name: name.Text(), - Writable: writable.Text() == "1" || writable.Text() == "true", - }) - } - } - - return nil -} - -func (resp *GetParameterNamesResponse) GetName() string { - return "GetParameterNamesResponse" -} - -func (resp *GetParameterNamesResponse) CreateXML() ([]byte, error) { - return nil, fmt.Errorf("should not be called on a response, it's present to satisfy the Message interface") -} diff --git a/internal/cwmp/messages/GetParameterValues.go b/internal/cwmp/messages/GetParameterValues.go deleted file mode 100644 index 09ae19c..0000000 --- a/internal/cwmp/messages/GetParameterValues.go +++ /dev/null @@ -1,76 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - "time" - - "github.com/beevik/etree" -) - -type ParameterValues struct { - XMLName string - ID string - CommandKey string - PrintFormat string - ParameterNames []string -} - -type BodyStruct struct { - GetParameterValues GetParameterValues `xml:"cwmp:GetParameterValues"` -} - -type GetParameterValues struct { - ParameterNames ParameterNames `xml:"ParameterNames"` -} - -type ParameterNames struct { - XmlnsSoapEnc string `xml:"xmlns:soap-enc,attr"` - ArrayType string `xml:"soap-enc:arrayType,attr"` - Strings []string `xml:"string"` -} - -// GetName get msg type -func (msg *ParameterValues) GetName() string { - return "GetParameterValues" -} - -// GetID get msg id -func (msg *ParameterValues) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -func (msg *ParameterValues) CreateXML() ([]byte, error) { - envelope := Envelope{ - XmlnsEnv: "http://schemas.xmlsoap.org/soap/envelope/", - XmlnsEnc: "http://schemas.xmlsoap.org/soap/encoding/", - XmlnsXsd: "http://www.w3.org/2001/XMLSchema", - XmlnsXsi: "http://www.w3.org/2001/XMLSchema-instance", - XmlnsCwmp: "urn:dslforum-org:cwmp-1-0", - Header: HeaderStruct{ID: IDStruct{Attr: "1", Value: msg.ID}}, - Body: BodyStruct{ - GetParameterValues: GetParameterValues{ - ParameterNames: ParameterNames{ - XmlnsSoapEnc: "http://schemas.xmlsoap.org/soap/encoding/", - ArrayType: fmt.Sprintf("xsd:string[%v]", len(msg.ParameterNames)), - Strings: msg.ParameterNames, - }, - }, - }, - } - return xml.MarshalIndent(envelope, "", " ") -} - -func (msg *ParameterValues) Parse(doc *etree.Document) error { - return nil -} - -func NewGetParameterValues() *ParameterValues { - paramValStruct := new(ParameterValues) - paramValStruct.ID = paramValStruct.GetID() - paramValStruct.XMLName = paramValStruct.GetName() - return paramValStruct -} diff --git a/internal/cwmp/messages/GetParameterValuesResponse.go b/internal/cwmp/messages/GetParameterValuesResponse.go deleted file mode 100644 index 80bb655..0000000 --- a/internal/cwmp/messages/GetParameterValuesResponse.go +++ /dev/null @@ -1,56 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - "time" - - "github.com/beevik/etree" -) - -type GetParameterValuesResponse struct { - ID string `xml:"ID" json:"ID"` - XMLName string `xml:"Name" json:"Name"` - ParameterList []ParameterValuesInfoStruct `xml:"ParameterList" json:"ParameterList"` -} - -type ParameterValuesInfoStruct struct { - Name string `xml:"name" json:"name"` - Value string `xml:"value" json:"value"` -} - -// GetName get msg type -func (msg *GetParameterValuesResponse) GetName() string { - return "GetParameterValuesResponse" -} - -// GetID get msg id -func (msg *GetParameterValuesResponse) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -func (msg *GetParameterValuesResponse) Parse(doc *etree.Document) error { - msg.ID = doc.FindElement("//ID").Text() - msg.XMLName = "GetParameterValuesResponse" - for _, param := range doc.FindElements("//ParameterList/ParameterValueStruct") { - msg.ParameterList = append(msg.ParameterList, ParameterValuesInfoStruct{ - Name: param.SelectElement("Name").Text(), - Value: param.SelectElement("Value").Text(), - }) - } - return nil -} - -func (msg *GetParameterValuesResponse) CreateXML() ([]byte, error) { - xmlStr, err := xml.Marshal(msg) - fmt.Println(string(xmlStr)) - return xmlStr, err - //return nil, fmt.Errorf("createXML should not be called on a response message, it's present to satisfy the Message interface") -} - -func NewGetParameterValuesResponse() *GetParameterValuesResponse { - return &GetParameterValuesResponse{} -} diff --git a/internal/cwmp/messages/GetParameterValues_test.go b/internal/cwmp/messages/GetParameterValues_test.go deleted file mode 100644 index e13b86d..0000000 --- a/internal/cwmp/messages/GetParameterValues_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package messages - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCreateXML(t *testing.T) { - expectedOutput := ` - - 5 - - - - - Device.DeviceInfo.ProvisioningCode - Device.DeviceInfo.X_ALU-COM_FriendlyName - - - -` - getParamValues := NewGetParameterValues() - getParamValues.ParameterNames = []string{"Device.DeviceInfo.ProvisioningCode", "Device.DeviceInfo.X_ALU-COM_FriendlyName"} - getParamValues.ID = "5" - getParamValues.XMLName = "name" - - rpcXML, err := getParamValues.CreateXML() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.Equal(t, expectedOutput, string(rpcXML), "Strings should be equal") -} - -func TestParseXML(t *testing.T) { - expectedMsg := NewGetParameterValuesResponse() - expectedMsg.ID = "5" - expectedMsg.XMLName = "GetParameterValuesResponse" - expectedMsg.ParameterList = append(expectedMsg.ParameterList, ParameterValuesInfoStruct{ - Name: "Device.DeviceInfo.ProvisioningCode", - Value: "MyCustomModel123", - }, ParameterValuesInfoStruct{ - Name: "Device.DeviceInfo.X_ALU-COM_FriendlyName", - Value: "myval", - }) - - requestBody := ` - -5 - - - - - -Device.DeviceInfo.ProvisioningCode -MyCustomModel123 - - -Device.DeviceInfo.X_ALU-COM_FriendlyName -myval - - - - - -` - msg, err := ParseXML([]byte(requestBody)) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.Equal(t, expectedMsg, msg, "Messages should be the same") -} diff --git a/internal/cwmp/messages/Inform.go b/internal/cwmp/messages/Inform.go deleted file mode 100644 index 166be4b..0000000 --- a/internal/cwmp/messages/Inform.go +++ /dev/null @@ -1,184 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - "strconv" - "time" - - "github.com/beevik/etree" -) - -// Inform tr069 inform (heartbeat) -type Inform struct { - ID string `json:"id"` - Name string `json:"name"` - Manufacturer string `json:"manufacturer"` - OUI string `json:"oui"` - ProductClass string `json:"productClass"` - Sn string `json:"sn"` - Events map[string]string `json:"events"` - MaxEnvelopes int `json:"maxEnvelopes"` - CurrentTime string `json:"currentTime"` - RetryCount int `json:"retryCount"` - Params map[string]string `json:"params"` -} - -type informBodyStruct struct { - Body informStruct `xml:"cwmp:Inform"` -} - -type informStruct struct { - DeviceID deviceIDStruct `xml:"DeviceId"` - Event EventStruct `xml:"Event"` - MaxEnvelopes NodeStruct `xml:"MaxEnvelopes"` - CurrentTime NodeStruct `xml:"CurrentTime"` - RetryCount NodeStruct `xml:"RetryCount"` - Params ParameterListStruct `xml:"ParameterList"` -} - -type deviceIDStruct struct { - Type string `xml:"xsi:type,attr"` - Manufacturer NodeStruct `xml:"Manufacturer"` - OUI NodeStruct `xml:"OUI"` - ProductClass NodeStruct `xml:"ProductClass"` - SerialNumber NodeStruct `xml:"SerialNumber"` -} - -// NewInform create a inform messages -func NewInform() *Inform { - inform := new(Inform) - inform.ID = inform.GetID() - inform.Name = inform.GetName() - inform.Events = make(map[string]string) - inform.Params = make(map[string]string) - return inform -} - -// GetName get msg type -func (msg *Inform) GetName() string { - return "Inform" -} - -// GetID get msg id -func (msg *Inform) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -// CreateXML encode into xml -func (msg *Inform) CreateXML() ([]byte, error) { - env := Envelope{} - id := IDStruct{"1", msg.GetID()} - env.XmlnsEnv = "http://schemas.xmlsoap.org/soap/envelope/" - env.XmlnsEnc = "http://schemas.xmlsoap.org/soap/encoding/" - env.XmlnsXsd = "http://www.w3.org/2001/XMLSchema" - env.XmlnsXsi = "http://www.w3.org/2001/XMLSchema-instance" - env.XmlnsCwmp = "urn:dslforum-org:cwmp-1-0" - env.Header = HeaderStruct{ID: id} - manufacturer := NodeStruct{Type: XsdString, Value: msg.Manufacturer} - oui := NodeStruct{Type: XsdString, Value: msg.OUI} - productClass := NodeStruct{Type: XsdString, Value: msg.ProductClass} - serialNumber := NodeStruct{Type: XsdString, Value: msg.Sn} - deviceID := deviceIDStruct{Type: "cwmp:DeviceIdStruct", Manufacturer: manufacturer, OUI: oui, ProductClass: productClass, SerialNumber: serialNumber} - eventLen := strconv.Itoa(len(msg.Events)) - event := EventStruct{Type: "cwmp:EventStruct[" + eventLen + "]"} - for k, v := range msg.Events { - eventCode := NodeStruct{Type: XsdString, Value: k} - event.Events = append(event.Events, EventNodeStruct{EventCode: eventCode, CommandKey: v}) - } - - maxEnv := strconv.Itoa(msg.MaxEnvelopes) - maxEnvelopes := NodeStruct{Type: XsdString, Value: maxEnv} - currentTime := NodeStruct{Type: XsdString, Value: msg.CurrentTime} - trys := strconv.Itoa(msg.RetryCount) - retryCount := NodeStruct{Type: XsdString, Value: trys} - paramLen := strconv.Itoa(len(msg.Params)) - paramList := ParameterListStruct{Type: "cwmp:ParameterValueStruct[" + paramLen + "]"} - for k, v := range msg.Params { - param := ParameterValueStruct{ - Name: NodeStruct{Type: XsdString, Value: k}, - Value: NodeStruct{Type: XsdString, Value: v}} - paramList.Params = append(paramList.Params, param) - } - info := informStruct{DeviceID: deviceID, Event: event, MaxEnvelopes: maxEnvelopes, CurrentTime: currentTime, RetryCount: retryCount, Params: paramList} - env.Body = informBodyStruct{info} - output, err := xml.MarshalIndent(env, " ", " ") - //output, err := xml.Marshal(env) - if err != nil { - return nil, err - } - return output, nil -} - -// Parse decode from xml -func (msg *Inform) Parse(doc *etree.Document) error { - msg.ID = doc.FindElement("//ID").Text() - deviceNode := doc.FindElement("//DeviceId") - if deviceNode != nil { - msg.Manufacturer = deviceNode.SelectElement("Manufacturer").Text() - msg.OUI = deviceNode.SelectElement("OUI").Text() - msg.ProductClass = deviceNode.SelectElement("ProductClass").Text() - msg.Sn = deviceNode.SelectElement("SerialNumber").Text() - } - informNode := doc.FindElement("//Inform") - if informNode != nil { - var err error - msg.CurrentTime = informNode.SelectElement("CurrentTime").Text() - msg.MaxEnvelopes, err = strconv.Atoi(informNode.SelectElement("MaxEnvelopes").Text()) - if err != nil { - return err - } - msg.RetryCount, err = strconv.Atoi(informNode.SelectElement("RetryCount").Text()) - if err != nil { - return err - } - } - - eventNode := doc.FindElement("//Event") - if eventNode != nil { - //msg.Events = make(map[string]string) - var code string - for _, event := range eventNode.ChildElements() { - if event != nil { - code = event.SelectElement("EventCode").Text() - msg.Events[code] = event.SelectElement("CommandKey").Text() - } - } - } - - paramsNode := doc.FindElement("//ParameterList") - if paramsNode != nil { - //msg.Params = make(map[string]string) - var name string - for _, param := range paramsNode.ChildElements() { - if param != nil { - name = param.SelectElement("Name").Text() - msg.Params[name] = param.SelectElement("Value").Text() - } - } - } - return nil -} - -// IsEvent is a connect request or others -func (msg *Inform) IsEvent(event string) bool { - if _, ok := msg.Events[event]; ok { - return true - } - return false -} - -// GetParam get param in inform -func (msg *Inform) GetParam(name string) (value string) { - value = msg.Params[name] - return -} - -// GetConfigVersion get current config version -func (msg *Inform) GetConfigVersion() (version string) { - version = msg.GetParam("InternetGatewayDevice.DeviceConfig.ConfigVersion") - return -} diff --git a/internal/cwmp/messages/InformResponse.go b/internal/cwmp/messages/InformResponse.go deleted file mode 100644 index 0ecb5c0..0000000 --- a/internal/cwmp/messages/InformResponse.go +++ /dev/null @@ -1,63 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - "time" - - "github.com/beevik/etree" -) - -// InformResponse infrom response -type InformResponse struct { - ID string - Name string - NoMore int - MaxEnvelopes int -} - -type informResponseBodyStruct struct { - Body informResponseStruct `xml:"cwmp:InformResponse"` -} - -type informResponseStruct struct { - MaxEnvelopes int `xml:"MaxEnvelopes"` -} - -// GetID get msg id -func (msg *InformResponse) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -// GetName get msg type -func (msg *InformResponse) GetName() string { - return "InformResponse" -} - -// CreateXML encode into xml -func (msg *InformResponse) CreateXML() ([]byte, error) { - env := Envelope{} - env.XmlnsEnv = "http://schemas.xmlsoap.org/soap/envelope/" - env.XmlnsEnc = "http://schemas.xmlsoap.org/soap/encoding/" - env.XmlnsXsd = "http://www.w3.org/2001/XMLSchema" - env.XmlnsXsi = "http://www.w3.org/2001/XMLSchema-instance" - env.XmlnsCwmp = "urn:dslforum-org:cwmp-1-0" - id := IDStruct{Attr: "1", Value: msg.GetID()} - env.Header = HeaderStruct{ID: id, NoMore: msg.NoMore} - infromResp := informResponseStruct{MaxEnvelopes: msg.MaxEnvelopes} - env.Body = informResponseBodyStruct{infromResp} - //output, err := xml.Marshal(env) - output, err := xml.MarshalIndent(env, " ", " ") - if err != nil { - return nil, err - } - return output, nil -} - -// Parse decode from xml -func (msg *InformResponse) Parse(doc *etree.Document) error { - return fmt.Errorf("parse should not be called on an outgoing message, it's present to satisfy the Message interface") -} diff --git a/internal/cwmp/messages/MessageParser.go b/internal/cwmp/messages/MessageParser.go deleted file mode 100644 index 66e6427..0000000 --- a/internal/cwmp/messages/MessageParser.go +++ /dev/null @@ -1,66 +0,0 @@ -package messages - -import ( - "fmt" - - "github.com/beevik/etree" -) - -func ParseXML(data []byte) (msg Message, err error) { - doc := etree.NewDocument() - doc.ReadFromBytes(data) - - envelope := doc.SelectElement("Envelope") - if envelope == nil { - return nil, fmt.Errorf("Envelope not found") - } - - var body *etree.Element - for _, elem := range envelope.ChildElements() { - if elem.Tag == "Body" || elem.Tag == "soap-env:Body" { - body = elem - break - } - } - - if body != nil { - bodyContent := body.ChildElements()[0] - name := bodyContent.Tag - switch name { - case "Inform": - msg = NewInform() - err = msg.Parse(doc) - case "DUStateChangeComplete": - msg = NewChangeDUStateComplete() - err = msg.Parse(doc) - case "ChangeDUStateResponse": - msg = NewChangeDUStateResponse() - err = msg.Parse(doc) - case "GetParameterNamesResponse": - msg = NewGetParameterNamesResponse() - err = msg.Parse(doc) - if err != nil { - return nil, fmt.Errorf("GetParameterNamesResponse ParseXML generated error: %v", err) - } - case "GetParameterValuesResponse": - msg = NewGetParameterValuesResponse() - err = msg.Parse(doc) - case "SetParameterValuesResponse": - msg = NewSetParameterValuesResponse() - err = msg.Parse(doc) - case "Fault": - msg = NewFault() - err = msg.Parse(doc) - case "GetRPCMethodsResponse": - s, _ := doc.WriteToString() - fmt.Println("msg: ", s) - return nil, fmt.Errorf("message %s not supported", name) - default: - return nil, fmt.Errorf("unknown message %s", name) - } - - return msg, err - } else { - return nil, fmt.Errorf("body element not found") - } -} diff --git a/internal/cwmp/messages/Messages.go b/internal/cwmp/messages/Messages.go deleted file mode 100644 index ee51053..0000000 --- a/internal/cwmp/messages/Messages.go +++ /dev/null @@ -1,112 +0,0 @@ -package messages - -import ( - "encoding/xml" - - "github.com/beevik/etree" -) - -const ( - //XsdString string type - XsdString string = "xsd:string" - //XsdUnsignedint uint type - XsdUnsignedint string = "xsd:unsignedInt" -) - -const ( - //SoapArray array type - SoapArray string = "SOAP-ENC:Array" -) - -const ( - //EventBootStrap first connection - EventBootStrap string = "0 BOOTSTRAP" - //EventBoot reset or power on - EventBoot string = "1 BOOT" - //EventPeriodic periodic inform - EventPeriodic string = "2 PERIODIC" - //EventScheduled scheduled infrorm - EventScheduled string = "3 SCHEDULED" - //EventValueChange value change event - EventValueChange string = "4 VALUE CHANGE" - //EventKicked acs notify cpe - EventKicked string = "5 KICKED" - //EventConnectionRequest cpe request connection - EventConnectionRequest string = "6 CONNECTION REQUEST" - //EventTransferComplete download complete - EventTransferComplete string = "7 TRANSFER COMPLETE" -) - -// Message tr069 msg interface -type Message interface { - Parse(doc *etree.Document) error - CreateXML() ([]byte, error) - GetName() string - GetID() string -} - -// Envelope tr069 body -type Envelope struct { - XMLName xml.Name `xml:"SOAP-ENV:Envelope"` - XmlnsEnv string `xml:"xmlns:SOAP-ENV,attr"` - XmlnsEnc string `xml:"xmlns:SOAP-ENC,attr"` - XmlnsXsd string `xml:"xmlns:xsd,attr"` - XmlnsXsi string `xml:"xmlns:xsi,attr"` - XmlnsCwmp string `xml:"xmlns:cwmp,attr"` - Header interface{} `xml:"SOAP-ENV:Header"` - Body interface{} `xml:"SOAP-ENV:Body"` -} - -// HeaderStruct tr069 header -type HeaderStruct struct { - ID IDStruct `xml:"cwmp:ID"` - NoMore interface{} `xml:"cwmp:NoMoreRequests,omitempty"` -} - -// IDStruct msg id -type IDStruct struct { - Attr string `xml:"SOAP-ENV:mustUnderstand,attr,omitempty"` - Value string `xml:",chardata"` -} - -// NodeStruct node -type NodeStruct struct { - Type interface{} `xml:"xsi:type,attr"` - Value string `xml:",chardata"` -} - -// EventStruct event -type EventStruct struct { - Type string `xml:"SOAP-ENC:arrayType,attr"` - Events []EventNodeStruct `xml:"EventStruct"` -} - -// EventNodeStruct event node -type EventNodeStruct struct { - EventCode NodeStruct `xml:"EventCode"` - CommandKey string `xml:"CommandKey"` -} - -// ParameterListStruct param list -type ParameterListStruct struct { - Type string `xml:"SOAP-ENC:arrayType,attr"` - Params []ParameterValueStruct `xml:"ParameterValueStruct"` -} - -// ParameterValueStruct param value -type ParameterValueStruct struct { - Name NodeStruct `xml:"Name"` - Value NodeStruct `xml:"Value"` -} - -// ValueStruct value -type ValueStruct struct { - Type string - Value string -} - -// FaultStruct error -type FaultStruct struct { - FaultCode int - FaultString string -} diff --git a/internal/cwmp/messages/SetParameterValues.go b/internal/cwmp/messages/SetParameterValues.go deleted file mode 100644 index 0e7b72f..0000000 --- a/internal/cwmp/messages/SetParameterValues.go +++ /dev/null @@ -1,85 +0,0 @@ -package messages - -import ( - "encoding/xml" - "fmt" - "time" - - "github.com/beevik/etree" -) - -type SetParameterValues struct { - ID string - ParameterList []ParameterVal - ParameterKey string -} - -type Body struct { - SetParameterValues SetParameterValuesStruct `xml:"cwmp:SetParameterValues"` -} - -type SetParameterValuesStruct struct { - SetParameterList SetParameterList `xml:"ParameterList"` - ParameterKey string `xml:"parameterKey"` -} - -type SetParameterList struct { - XmlnsSoapEnc string `xml:"xmlns:soap-enc,attr"` - ArrayType string `xml:"soap-enc:arrayType,attr"` - Parameters []ParameterVal `xml:"ParameterValueStruct"` -} - -type ParameterVal struct { - Name string `xml:"Name"` - Value Values `xml:"Value" yaml:"Value"` -} - -type Values struct { - Type string `xml:"xsi:type,attr"` - Value string `xml:",chardata"` -} - -// GetName get msg type -func (msg *SetParameterValues) GetName() string { - return "SetParameterValues" -} - -// GetID get msg id -func (msg *SetParameterValues) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -func (msg *SetParameterValues) CreateXML() ([]byte, error) { - envelope := Envelope{ - XmlnsEnv: "http://schemas.xmlsoap.org/soap/envelope/", - XmlnsEnc: "http://schemas.xmlsoap.org/soap/encoding/", - XmlnsXsd: "http://www.w3.org/2001/XMLSchema", - XmlnsXsi: "http://www.w3.org/2001/XMLSchema-instance", - XmlnsCwmp: "urn:dslforum-org:cwmp-1-0", - Header: HeaderStruct{ID: IDStruct{Attr: "1", Value: msg.ID}}, - Body: Body{ - SetParameterValues: SetParameterValuesStruct{ - ParameterKey: msg.ParameterKey, - SetParameterList: SetParameterList{ - XmlnsSoapEnc: "http://schemas.xmlsoap.org/soap/encoding/", - ArrayType: fmt.Sprintf("cwmp:ParameterValueStruct[%v]", len(msg.ParameterList)), - Parameters: msg.ParameterList, - }, - }, - }, - } - return xml.MarshalIndent(envelope, "", " ") -} - -func (msg *SetParameterValues) Parse(doc *etree.Document) error { - return nil -} - -func NewSetParameterValues() *SetParameterValues { - paramValStruct := new(SetParameterValues) - paramValStruct.ID = paramValStruct.GetID() - return paramValStruct -} diff --git a/internal/cwmp/messages/SetParameterValuesResponse.go b/internal/cwmp/messages/SetParameterValuesResponse.go deleted file mode 100644 index 4ed19a1..0000000 --- a/internal/cwmp/messages/SetParameterValuesResponse.go +++ /dev/null @@ -1,41 +0,0 @@ -package messages - -import ( - "fmt" - "strconv" - "time" - - "github.com/beevik/etree" -) - -type SetParameterValuesResponse struct { - ID string - Status int -} - -// GetName get msg type -func (msg *SetParameterValuesResponse) GetName() string { - return "SetParameterValuesResponse" -} - -// GetID get msg id -func (msg *SetParameterValuesResponse) GetID() string { - if len(msg.ID) < 1 { - msg.ID = fmt.Sprintf("ID:intrnl.unset.id.%s%d.%d", msg.GetName(), time.Now().Unix(), time.Now().UnixNano()) - } - return msg.ID -} - -func (msg *SetParameterValuesResponse) Parse(doc *etree.Document) error { - msg.ID = doc.FindElement("//ID").Text() - msg.Status, _ = strconv.Atoi(doc.FindElement("//SetParameterValuesResponse/Status").Text()) - return nil -} - -func (msg *SetParameterValuesResponse) CreateXML() ([]byte, error) { - return nil, fmt.Errorf("createXML should not be called on a response message, it's present to satisfy the Message interface") -} - -func NewSetParameterValuesResponse() *SetParameterValuesResponse { - return &SetParameterValuesResponse{} -} diff --git a/internal/cwmp/messages/SetParameterValues_test.go b/internal/cwmp/messages/SetParameterValues_test.go deleted file mode 100644 index 137bd05..0000000 --- a/internal/cwmp/messages/SetParameterValues_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package messages - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestSetCreateXML(t *testing.T) { - expectedOutput := ` - - 5 - - - - - - Device.DeviceInfo.ProvisioningCode - MyCustomModel123 - - - Device.DeviceInfo.X_ALU-COM_FriendlyName - myval - - - paramKey - - -` - setParamValues := NewSetParameterValues() - setParamValues.ParameterList = []ParameterVal{ - { - Name: "Device.DeviceInfo.ProvisioningCode", - Value: Values{Type: "string", Value: "MyCustomModel123"}, - }, - { - Name: "Device.DeviceInfo.X_ALU-COM_FriendlyName", - Value: Values{Type: "string", Value: "myval"}, - }, - } - setParamValues.ID = "5" - setParamValues.ParameterKey = "paramKey" - - rpcXML, err := setParamValues.CreateXML() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.Equal(t, expectedOutput, string(rpcXML), "Strings should be equal") -} - -func TestSetParseXML(t *testing.T) { - expectedMsg := NewSetParameterValuesResponse() - expectedMsg.ID = "5" - expectedMsg.Status = 0 - - body := ` - -5 - - - -0 - - -` - msg, err := ParseXML([]byte(body)) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.Equal(t, expectedMsg, msg, "Messages should be the same") -} diff --git a/internal/cwmp/models/task.go b/internal/cwmp/models/task.go deleted file mode 100644 index 413c81f..0000000 --- a/internal/cwmp/models/task.go +++ /dev/null @@ -1,14 +0,0 @@ -package models - -import ( - "corteca/internal/cwmp/messages" -) - -type ResultsMessage struct { - Code int - Message messages.Message -} - -func NewResulMessage() *ResultsMessage{ - return &ResultsMessage{Code:0, Message: nil} -} diff --git a/internal/device/cwmp/cwmpdevice.go b/internal/device/cwmp/cwmpdevice.go new file mode 100644 index 0000000..6321b1d --- /dev/null +++ b/internal/device/cwmp/cwmpdevice.go @@ -0,0 +1,372 @@ +package cwmp + +import ( + "context" + "corteca/internal/configuration" + "corteca/internal/device" + "corteca/internal/device/cwmp/messages" + "corteca/internal/tui" + "encoding/xml" + "errors" + + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strconv" + "time" + + "github.com/google/uuid" +) + +const ( + DefaultCWMPPort = 7547 +) + +func init() { + device.RegisterDeviceType("cwmp", NewCWMPDevice) + device.RegisterDeviceType("cwmps", NewCWMPDevice) +} + +type CWMPDevice struct { + server *http.Server + in chan messages.Message + out chan *messages.Envelope + log io.Writer + currentID string +} + +type CWMPConfig struct { + configuration.HttpClientEndpoint `yaml:",inline"` + Server configuration.HttpServerEndpoint `yaml:"server"` +} + +func (d *CWMPDevice) NewSessionID() { + if uuid, err := uuid.NewV7(); err != nil { + panic(err) + } else { + d.currentID = uuid.String() + } +} + +func (d *CWMPDevice) SetSessionID(id string) { + d.currentID = id +} + +func (d *CWMPDevice) ResetSessionID() { + d.currentID = "" +} + +func NewCWMPDevice(c *configuration.DeviceConfig, log io.Writer) (device.Device, error) { + cwmpconfig := CWMPConfig{} + if err := c.Decode(&cwmpconfig); err != nil { + return nil, err + } + + d := CWMPDevice{ + log: log, + in: make(chan messages.Message), + out: make(chan *messages.Envelope), + currentID: "", + } + if err := d.initServer(&cwmpconfig.Server); err != nil { + return nil, err + } + if err := d.sendConnectionRequest(&cwmpconfig.HttpClientEndpoint); err != nil { + tui.LogError("Failed sending connection request: %s", err.Error()) + } + tui.DisplaySuccessMsg("Waiting for CPE to establish connection...") + return &d, nil +} + +func (d *CWMPDevice) BeginSequence() error { + d.ResetSessionID() + ctx, cancel := context.WithTimeout(context.Background(), configuration.DefaultMaxTimeout) + defer cancel() + tui.LogNormal("Waiting for (ready) message...") + if _, err := d.expectRPC(ctx, func(m messages.Message) bool { return m == nil }); err != nil { + return err + } + return nil +} + +func (d *CWMPDevice) ExecuteCommand(ctx context.Context, cmd *configuration.SequenceCmd) (any, error) { + rpc, err := d.createRPCFromCmd(cmd) + if err != nil { + return nil, err + } else { + d.NewSessionID() + tui.LogNormal("Sending '%s' RPC...", rpc.GetName()) + env := d.newEnvelope(rpc) + if err := d.pushEnvelope(ctx, &env); err != nil { + return nil, err + } + } + + tui.LogNormal("Waiting for response...") + resp, err := d.pullMessage(ctx) + if err != nil { + return nil, err + } + if fault, ok := resp.(messages.Fault); ok { + return nil, fmt.Errorf("%s (faultcode: %d)", fault.Detail.FaultString, fault.Detail.FaultCode) + } else if err := rpc.ValidateResponse(resp); err != nil { + return nil, err + } + + d.ResetSessionID() + if async, ok := rpc.(messages.AsyncRPC); ok { + return d.handleAsyncRPC(ctx, async) + } + return resp, nil +} + +func (d *CWMPDevice) EndSequence() error { + ctx, cancel := context.WithTimeout(context.Background(), configuration.DefaultMaxTimeout) + defer cancel() + return d.pushEnvelope(ctx, nil) +} + +func (d *CWMPDevice) GetProtocol() string { + return "cwmp" +} + +func (c *CWMPDevice) Close() { + // Graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := c.server.Shutdown(ctx); err != nil { + tui.LogError("Server Shutdown Failed: %s", err) + return + } + tui.LogNormal("Server stopped gracefully!") +} + +func (d *CWMPDevice) handleAsyncRPC(ctx context.Context, rpc messages.AsyncRPC) (messages.Message, error) { + tui.LogNormal("Expecting async notification for '%s'...", rpc.GetName()) + // send blank response to end session + if err := d.pushEnvelope(ctx, nil); err != nil { + return nil, err + } + // wait until an RPC with the same command key arrives + notif, err := d.expectRPC(ctx, func(m messages.Message) bool { return rpc.Match(m) }) + if err != nil { + return nil, err + } + // respond to the notification RPC + if err := d.pushEnvelope(ctx, d.respondToRPC(notif)); err != nil { + return nil, err + } + // wait until a "ready" message arrives + if _, err := d.expectRPC(ctx, func(m messages.Message) bool { return m == nil }); err != nil { + return nil, err + } + // return notification RPC payload + return notif, rpc.ValidateResponse(notif) +} + +func (c *CWMPDevice) sendConnectionRequest(cpe *configuration.HttpClientEndpoint) error { + client, err := cpe.NewHttpClient() + if err != nil { + return err + } + url, err := url.Parse(cpe.Addr.String()) + if err != nil { + return err + } + // convert the scheme to https + switch url.Scheme { + case "cwmp": + url.Scheme = "http" + case "cwmps": + url.Scheme = "https" + default: + panic(fmt.Sprintf("unexpected scheme '%s'", url.Scheme)) + } + if url.Port() == "" { + url.Host = net.JoinHostPort(url.Host, strconv.FormatInt(DefaultCWMPPort, 10)) + } + resp, err := client.Get(url.String()) + if err != nil { + return fmt.Errorf("error sending connection request: %w", err) + } + defer resp.Body.Close() + c.log.Write(fmt.Appendf([]byte(""), "[%s] Connection Request (response: %s)\n", time.Now().Format(time.DateTime), resp.Status)) + io.Copy(io.Discard, resp.Body) + tui.LogNormal("Connection Request sent to %s; status code: %d", url.String(), resp.StatusCode) + return nil +} + +func (d *CWMPDevice) initServer(config *configuration.HttpServerEndpoint) error { + + // TODO: if no url or server is empty we should assume listen on http://0.0.0.0:DefaultCWMPPort + u, err := url.Parse(config.Addr.String()) + if err != nil { + return err + } + + if u.Port() == "" { + u.Host = net.JoinHostPort(u.Host, strconv.Itoa(DefaultCWMPPort)) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", d.handleHTTPRequest) + d.server = &http.Server{ + Addr: u.Host, + Handler: mux, + } + + tui.DisplaySuccessMsg(fmt.Sprintf("Starting CWMP server on %s...", d.server.Addr)) + // Run server in a goroutine + go func() { + var err error + if u.Scheme == "http" { + err = d.server.ListenAndServe() + } else if u.Scheme == "https" { + err = d.server.ListenAndServeTLS(config.Certificate.String(), config.Key.String()) + } + if err != nil && err != http.ErrServerClosed { + tui.LogError("Failed to start server: %s", err) + os.Exit(1) + } + }() + + return nil +} + +func (d *CWMPDevice) handleHTTPRequest(w http.ResponseWriter, r *http.Request) { + // TODO: implement pullMessage and pullEnvelope to be context cancellation aware + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + d.log.Write(fmt.Appendf([]byte(""), "[%s] IN: %s %s %s\n", + time.Now().Format(time.DateTime), + r.Method, + r.RequestURI, + r.Proto)) + // parse response while logging it + tee := io.TeeReader(r.Body, d.log) + if env, err := messages.ParseEnvelopeXML(tee); err != nil { + if errors.Is(err, io.EOF) { + d.in <- nil + } else { + tui.LogError("Malformed request received: %s", err.Error()) + http.Error(w, fmt.Sprintf("Bad request: %s", err.Error()), http.StatusBadRequest) + return + } + } else { + d.log.Write([]byte("\n")) + if len(env.Body.Messages) > 1 { + tui.LogNormal("%d messages received; ignoring all but first", len(env.Body.Messages)) + } else if len(env.Body.Messages) == 0 { + tui.LogError("No message received") + http.Error(w, "Bad request; no message received", http.StatusBadRequest) + return + } + d.in <- env.Body.Messages[0] + d.SetSessionID(env.GetID()) + } + + d.log.Write(fmt.Appendf([]byte(""), "[%s] OUT:\n", time.Now().Format(time.DateTime))) + resp := <-d.out + d.writeHTTPResponse(w, http.StatusOK, resp) + d.log.Write([]byte("\n--------------------------------------------------------------------------------\n")) +} + +// write a reply to the response +func (d *CWMPDevice) writeHTTPResponse(w http.ResponseWriter, statusCode int, resp *messages.Envelope) { + w.WriteHeader(statusCode) + if resp != nil { + w.Header().Set("Content-Type", "text/xml; charset=utf-8") + tee := io.MultiWriter(w, d.log) + enc := xml.NewEncoder(tee) + enc.Indent("", "\t") + if err := enc.Encode(resp); err != nil { + panic(err) + } + } +} + +func (d *CWMPDevice) respondToRPC(r messages.Message) *messages.Envelope { + var env messages.Envelope + if rpc, ok := r.(messages.ACSMethod); ok { + resp := rpc.GenerateResponse() + env = d.newEnvelope(resp) + } else { + env = d.newEnvelope(messages.NewFault(8000, "Method not supported")) + } + return &env +} + +func (d *CWMPDevice) createRPCFromCmd(cmd *configuration.SequenceCmd) (messages.SyncRPC, error) { + rpcName := cmd.Cmd.String() + switch rpcName { + case messages.ChangeDUState{}.GetName(): + var m messages.ChangeDUState + return m, cmd.Decode(&m) + case messages.GetParameterNames{}.GetName(): + var m messages.GetParameterNames + return m, cmd.Decode(&m) + case messages.GetParameterValues{}.GetName(): + var m messages.GetParameterValues + return m, cmd.Decode(&m) + case messages.GetRPCMethods{}.GetName(): + var m messages.GetRPCMethods + return m, cmd.Decode(&m) + case messages.SetParameterValues{}.GetName(): + var m messages.SetParameterValues + return m, cmd.Decode(&m) + default: + return nil, fmt.Errorf("unknown RPC '%s'", rpcName) + } +} + +func (d *CWMPDevice) expectRPC(ctx context.Context, matcher func(messages.Message) bool) (messages.Message, error) { + // loop until message arrives or context expires + for { + rpc, err := d.pullMessage(ctx) + if err != nil { + return nil, err + } + if matcher(rpc) { + return rpc, nil + } else { + env := d.respondToRPC(rpc) + if err := d.pushEnvelope(ctx, env); err != nil { + return nil, err + } + } + } +} + +func (d *CWMPDevice) newEnvelope(msg ...messages.Message) messages.Envelope { + env := messages.Envelope{} + env.Header = &messages.EnvelopeHeader{ + ID: messages.IDStruct{MustUnderstand: "1", Value: d.currentID}, + } + env.Body = messages.EnvelopeBody{Messages: msg} + return env +} + +func (d *CWMPDevice) pullMessage(ctx context.Context) (messages.Message, error) { + select { + case rpc := <-d.in: + return rpc, nil + case <-ctx.Done(): + return nil, fmt.Errorf("timeout while waiting for incoming message") + } +} + +func (d *CWMPDevice) pushEnvelope(ctx context.Context, env *messages.Envelope) error { + select { + case d.out <- env: + return nil + case <-ctx.Done(): + return fmt.Errorf("timeout while sending message") + } +} diff --git a/internal/device/cwmpdevice_test.go b/internal/device/cwmp/cwmpdevice_test.go similarity index 99% rename from internal/device/cwmpdevice_test.go rename to internal/device/cwmp/cwmpdevice_test.go index 3d7f217..369c469 100644 --- a/internal/device/cwmpdevice_test.go +++ b/internal/device/cwmp/cwmpdevice_test.go @@ -1,3 +1,5 @@ +//go:build exclude + package device import ( diff --git a/internal/device/cwmp/messages/ChangeDUState.go b/internal/device/cwmp/messages/ChangeDUState.go new file mode 100644 index 0000000..1c6952c --- /dev/null +++ b/internal/device/cwmp/messages/ChangeDUState.go @@ -0,0 +1,171 @@ +package messages + +import ( + "corteca/internal/configuration" + "encoding/xml" + "fmt" + "strings" + + "gopkg.in/yaml.v3" +) + +type ChangeDUState struct { + XMLName xml.Name `xml:"ChangeDUState" yaml:"-"` + CommandKey configuration.TemplateField `yaml:"CommandKey"` + Operations DUOperationStruct `yaml:"Operations"` +} + +func (cdu ChangeDUState) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "ChangeDUState") + type Alias ChangeDUState + return enc.EncodeElement(Alias(cdu), start) +} + +type DUOperationStruct struct { + Op []DUOperation +} + +func (ops *DUOperationStruct) UnmarshalXML(dec *xml.Decoder, start xml.StartElement) error { + for { + token, err := dec.Token() + if err != nil { + return err + } + switch tok := token.(type) { + case xml.StartElement: + op, err := decodeOpXML(dec, &tok) + if err != nil { + return err + } + ops.Op = append(ops.Op, op) + case xml.EndElement: + if tok.Name.Local == start.Name.Local { + return nil + } + } + } +} + +func decodeOpXML(dec *xml.Decoder, start *xml.StartElement) (DUOperation, error) { + switch start.Name.Local { + case InstallOpStruct{}.GetOpType(): + op := InstallOpStruct{} + return op, dec.DecodeElement(&op, start) + case UpdateOpStruct{}.GetOpType(): + op := UpdateOpStruct{} + return op, dec.DecodeElement(&op, start) + case UninstallOpStruct{}.GetOpType(): + op := UninstallOpStruct{} + return op, dec.DecodeElement(&op, start) + default: + return nil, fmt.Errorf("unknown optype '%s'", start.Name.Local) + } +} + +type DUOperation interface { + GetOpType() string +} + +func (ops DUOperationStruct) MarshalYAML() (any, error) { + nodes := make([]yaml.Node, len(ops.Op)) + for i, op := range ops.Op { + if err := nodes[i].Encode(op); err != nil { + return nil, err + } + nodes[i].Tag = fmt.Sprintf("!%s", op.GetOpType()) + } + return nodes, nil +} + +func (ops *DUOperationStruct) UnmarshalYAML(value *yaml.Node) error { + proxy := []yaml.Node{} + if err := value.Decode(&proxy); err != nil { + return err + } + for _, node := range proxy { + op, err := decodeOpYAML(&node) + if err != nil { + return err + } + ops.Op = append(ops.Op, op) + } + return nil +} + +func decodeOpYAML(node *yaml.Node) (DUOperation, error) { + typeTag, _ := strings.CutPrefix(node.Tag, "!") + switch typeTag { + case InstallOpStruct{}.GetOpType(): + op := InstallOpStruct{} + return op, node.Decode(&op) + case UpdateOpStruct{}.GetOpType(): + op := UpdateOpStruct{} + return op, node.Decode(&op) + case UninstallOpStruct{}.GetOpType(): + op := UninstallOpStruct{} + return op, node.Decode(&op) + default: + return nil, fmt.Errorf("unknown optype '%s'", typeTag) + } +} + +type InstallOpStruct struct { + XMLName xml.Name `xml:"InstallOpStruct" yaml:"-"` + URL configuration.TemplateField `xml:"URL" yaml:"URL"` + UUID configuration.TemplateField `xml:"UUID" yaml:"UUID"` + Username configuration.TemplateField `xml:"Username,omitempty" yaml:"Username,omitempty"` + Password configuration.TemplateField `xml:"Password,omitempty" yaml:"Password,omitempty"` + ExecutionEnvRef configuration.TemplateField `xml:"ExecutionEnvRef" yaml:"ExecutionEnvRef"` +} + +func (op InstallOpStruct) GetOpType() string { return "InstallOpStruct" } + +type UpdateOpStruct struct { + XMLName xml.Name `xml:"UpdateOpStruct" yaml:"-"` + UUID configuration.TemplateField `xml:"UUID" yaml:"UUID"` + Version configuration.TemplateField `xml:"Version" yaml:"Version"` + URL configuration.TemplateField `xml:"URL" yaml:"URL"` + Username configuration.TemplateField `xml:"Username,omitempty" yaml:"Username,omitempty"` + Password configuration.TemplateField `xml:"Password,omitempty" yaml:"Password,omitempty"` +} + +func (op UpdateOpStruct) GetOpType() string { return "UpdateOpStruct" } + +type UninstallOpStruct struct { + XMLName xml.Name `xml:"UninstallOpStruct" yaml:"-"` + UUID configuration.TemplateField `xml:"UUID" yaml:"UUID"` + Version configuration.TemplateField `xml:"Version" yaml:"Version"` + ExecutionEnvRef configuration.TemplateField `xml:"ExecutionEnvRef" yaml:"ExecutionEnvRef"` +} + +func (op UninstallOpStruct) GetOpType() string { return "UninstallOpStruct" } + +func (m ChangeDUState) GetName() string { return "ChangeDUState" } +func (m ChangeDUState) ValidateResponse(msg Message) error { + if resp, ok := msg.(DUStateChangeComplete); ok { + for i := 0; i < len(resp.Results); i++ { + if resp.Results[i].Fault.FaultCode != 0 { + return fmt.Errorf("error(s) occured in DU operation(s)") + } + } + return nil + } else { + return ExpectMessage[ChangeDUStateResponse](msg) + } +} +func (m ChangeDUState) Match(msg Message) bool { + if r, ok := msg.(DUStateChangeComplete); ok { + return r.CommandKey == m.CommandKey.String() + } else { + return false + } +} +func (m ChangeDUState) GenerateResponse() Message { + return ChangeDUStateResponse{} +} + +type ChangeDUStateResponse struct { + XMLName xml.Name `xml:"ChangeDUStateResponse"` +} + +func (m ChangeDUStateResponse) GetName() string { return "ChangeDUStateResponse" } diff --git a/internal/device/cwmp/messages/ChangeDUState_test.go b/internal/device/cwmp/messages/ChangeDUState_test.go new file mode 100644 index 0000000..482cd10 --- /dev/null +++ b/internal/device/cwmp/messages/ChangeDUState_test.go @@ -0,0 +1,184 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/configuration" + c "corteca/internal/configuration" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +const ( + ChangeDUStateInputXML = ` + foo + + + http://example.com/some/image:1.0.0 + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + some.user@example.com + somepassword + generic + + + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + 1.0.0 + http://example.com/some/image:1.0.0 + some.user@example.com + somepassword + + + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + 1.0.0 + generic + + +` + + ChangeDUStateInputYaml = `CommandKey: foo +Operations: + - !InstallOpStruct + URL: http://example.com/some/${app.name}:${app.version} + UUID: ${app.duid} + Username: some.user@example.com + Password: somepassword + ExecutionEnvRef: generic + - !UpdateOpStruct + UUID: c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + Version: 1.0.0 + URL: http://example.com/some/image:1.0.0 + Username: some.user@example.com + Password: somepassword + - !UninstallOpStruct + UUID: c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + Version: 1.0.0 + ExecutionEnvRef: generic +` +) + +var ChangeDUStateInputMsg = messages.ChangeDUState{ + CommandKey: c.TemplateField{RawTemplate: "foo"}, + Operations: messages.DUOperationStruct{ + Op: []messages.DUOperation{ + messages.InstallOpStruct{ + URL: c.TemplateField{RawTemplate: "http://example.com/some/${app.name}:${app.version}"}, + UUID: c.TemplateField{RawTemplate: "${app.duid}"}, + Username: c.TemplateField{RawTemplate: "some.user@example.com"}, + Password: c.TemplateField{RawTemplate: "somepassword"}, + ExecutionEnvRef: c.TemplateField{RawTemplate: "generic"}, + }, + messages.UpdateOpStruct{ + URL: c.TemplateField{RawTemplate: "http://example.com/some/image:1.0.0"}, + UUID: c.TemplateField{RawTemplate: "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d"}, + Username: c.TemplateField{RawTemplate: "some.user@example.com"}, + Password: c.TemplateField{RawTemplate: "somepassword"}, + Version: c.TemplateField{RawTemplate: "1.0.0"}, + }, + messages.UninstallOpStruct{ + UUID: c.TemplateField{RawTemplate: "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d"}, + ExecutionEnvRef: c.TemplateField{RawTemplate: "generic"}, + Version: c.TemplateField{RawTemplate: "1.0.0"}, + }, + }, + }, +} + +func setupContext() { + configuration.ResetContext() + configuration.GetCmdContext().App.Name = "image" + configuration.GetCmdContext().App.Version = "1.0.0" + configuration.GetCmdContext().App.DUID = "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d" +} + +func TestChangeDUStateParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(ChangeDUStateInputXML) + dec := xml.NewDecoder(buf) + msg := messages.ChangeDUState{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, "foo", msg.CommandKey.String()) + + assert.NotPanics(t, func() { _ = msg.Operations.Op[0].(messages.InstallOpStruct) }) + assert.Equal(t, "http://example.com/some/image:1.0.0", msg.Operations.Op[0].(messages.InstallOpStruct).URL.String()) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Operations.Op[0].(messages.InstallOpStruct).UUID.String()) + assert.Equal(t, "some.user@example.com", msg.Operations.Op[0].(messages.InstallOpStruct).Username.String()) + assert.Equal(t, "somepassword", msg.Operations.Op[0].(messages.InstallOpStruct).Password.String()) + assert.Equal(t, "generic", msg.Operations.Op[0].(messages.InstallOpStruct).ExecutionEnvRef.String()) + + assert.NotPanics(t, func() { _ = msg.Operations.Op[1].(messages.UpdateOpStruct) }) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Operations.Op[1].(messages.UpdateOpStruct).UUID.String()) + assert.Equal(t, "1.0.0", msg.Operations.Op[1].(messages.UpdateOpStruct).Version.String()) + assert.Equal(t, "http://example.com/some/image:1.0.0", msg.Operations.Op[1].(messages.UpdateOpStruct).URL.String()) + assert.Equal(t, "some.user@example.com", msg.Operations.Op[1].(messages.UpdateOpStruct).Username.String()) + assert.Equal(t, "somepassword", msg.Operations.Op[1].(messages.UpdateOpStruct).Password.String()) + + assert.NotPanics(t, func() { _ = msg.Operations.Op[2].(messages.UninstallOpStruct) }) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Operations.Op[2].(messages.UninstallOpStruct).UUID.String()) + assert.Equal(t, "1.0.0", msg.Operations.Op[2].(messages.UninstallOpStruct).Version.String()) + assert.Equal(t, "generic", msg.Operations.Op[2].(messages.UninstallOpStruct).ExecutionEnvRef.String()) + +} + +func TestChangeDUStateSerializeToXML(t *testing.T) { + setupContext() + msg := ChangeDUStateInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating yaml output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, ChangeDUStateInputXML, buf.String()) +} + +func TestChangeDUStateUmarshalFromYaml(t *testing.T) { + setupContext() + buf := bytes.NewBufferString(ChangeDUStateInputYaml) + dec := yaml.NewDecoder(buf) + msg := messages.ChangeDUState{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing yaml input: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, "foo", msg.CommandKey.String()) + + assert.NotPanics(t, func() { _ = msg.Operations.Op[0].(messages.InstallOpStruct) }) + assert.Equal(t, "InstallOpStruct", msg.Operations.Op[0].GetOpType()) + assert.Equal(t, "http://example.com/some/image:1.0.0", msg.Operations.Op[0].(messages.InstallOpStruct).URL.String()) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Operations.Op[0].(messages.InstallOpStruct).UUID.String()) + assert.Equal(t, "some.user@example.com", msg.Operations.Op[0].(messages.InstallOpStruct).Username.String()) + assert.Equal(t, "somepassword", msg.Operations.Op[0].(messages.InstallOpStruct).Password.String()) + assert.Equal(t, "generic", msg.Operations.Op[0].(messages.InstallOpStruct).ExecutionEnvRef.String()) + + assert.NotPanics(t, func() { _ = msg.Operations.Op[1].(messages.UpdateOpStruct) }) + assert.Equal(t, "UpdateOpStruct", msg.Operations.Op[1].GetOpType()) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Operations.Op[1].(messages.UpdateOpStruct).UUID.String()) + assert.Equal(t, "1.0.0", msg.Operations.Op[1].(messages.UpdateOpStruct).Version.String()) + assert.Equal(t, "http://example.com/some/image:1.0.0", msg.Operations.Op[1].(messages.UpdateOpStruct).URL.String()) + assert.Equal(t, "some.user@example.com", msg.Operations.Op[1].(messages.UpdateOpStruct).Username.String()) + assert.Equal(t, "somepassword", msg.Operations.Op[1].(messages.UpdateOpStruct).Password.String()) + + assert.NotPanics(t, func() { _ = msg.Operations.Op[2].(messages.UninstallOpStruct) }) + assert.Equal(t, "UninstallOpStruct", msg.Operations.Op[2].GetOpType()) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Operations.Op[2].(messages.UninstallOpStruct).UUID.String()) + assert.Equal(t, "1.0.0", msg.Operations.Op[2].(messages.UninstallOpStruct).Version.String()) + assert.Equal(t, "generic", msg.Operations.Op[2].(messages.UninstallOpStruct).ExecutionEnvRef.String()) +} + +func TestChangeDUStateMarshalToYaml(t *testing.T) { + msg := ChangeDUStateInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating yaml output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, ChangeDUStateInputYaml, outbuf.String()) +} diff --git a/internal/device/cwmp/messages/DUStateChangeComplete.go b/internal/device/cwmp/messages/DUStateChangeComplete.go new file mode 100644 index 0000000..def301c --- /dev/null +++ b/internal/device/cwmp/messages/DUStateChangeComplete.go @@ -0,0 +1,48 @@ +package messages + +import ( + "encoding/xml" +) + +type DUStateChangeComplete struct { + XMLName xml.Name `xml:"DUStateChangeComplete" yaml:"-"` + CommandKey string `yaml:"CommandKey"` + Results []OpResultStruct `xml:"Results>OpResultStruct" yaml:"Results"` +} + +func (msg DUStateChangeComplete) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "DUStateChangeComplete") + type Alias DUStateChangeComplete + return enc.EncodeElement(Alias(msg), start) +} + +type OpResultStruct struct { + XMLName xml.Name `xml:"OpResultStruct" yaml:"-"` + UUID string `yaml:"UUID"` + DeploymentUnitRef string `yaml:"DeploymentUnitRef"` + Version string `yaml:"Version"` + CurrentState string `yaml:"CurrentState"` + Resolved bool `yaml:"Resolved"` + ExecutionUnitRefList string `yaml:"ExecutionUnitRefList"` + StartTime string `yaml:"StartTime"` + CompleteTime string `yaml:"CompleteTime"` + Fault FaultStruct `yaml:"Fault"` +} + +func (m DUStateChangeComplete) GetName() string { return "DUStateChangeComplete" } +func (m DUStateChangeComplete) ValidateResponse(msg Message) error { + return ExpectMessage[DUStateChangeCompleteResponse](msg) +} +func (m DUStateChangeComplete) GenerateResponse() Message { return DUStateChangeCompleteResponse{} } + +type DUStateChangeCompleteResponse struct { + XMLName xml.Name `xml:"DUStateChangeCompleteResponse"` +} + +func (msg DUStateChangeCompleteResponse) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "DUStateChangeCompleteResponse") + type Alias DUStateChangeCompleteResponse + return enc.EncodeElement(Alias(msg), start) +} + +func (m DUStateChangeCompleteResponse) GetName() string { return "DUStateChangeCompleteResponse" } diff --git a/internal/device/cwmp/messages/DUStateChangeComplete_test.go b/internal/device/cwmp/messages/DUStateChangeComplete_test.go new file mode 100644 index 0000000..4365c6d --- /dev/null +++ b/internal/device/cwmp/messages/DUStateChangeComplete_test.go @@ -0,0 +1,195 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +const ( + DUStateChangeCompleteInputXML = ` + Foobar + + + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + Device.SoftwareModules.DeploymentUnit.1 + 1.0.0 + Installed + true + exec-1 + 2026-04-08T10:00:00Z + 2026-04-08T10:01:30Z + + 0 + + + + + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + Device.SoftwareModules.DeploymentUnit.2 + 1.0.0 + Failed + true + exec-1 + 2026-04-08T10:00:00Z + 2026-04-08T10:01:30Z + + 9027 + System Resources Exceeded + + + +` + + DUStateChangeCompleteInputYAML = `CommandKey: Foobar +Results: + - UUID: c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + DeploymentUnitRef: Device.SoftwareModules.DeploymentUnit.1 + Version: 1.0.0 + CurrentState: Installed + Resolved: true + ExecutionUnitRefList: exec-1 + StartTime: "2026-04-08T10:00:00Z" + CompleteTime: "2026-04-08T10:01:30Z" + Fault: + FaultCode: 0 + FaultString: "" + - UUID: c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + DeploymentUnitRef: Device.SoftwareModules.DeploymentUnit.2 + Version: 1.0.0 + CurrentState: Failed + Resolved: true + ExecutionUnitRefList: exec-1 + StartTime: "2026-04-08T10:00:00Z" + CompleteTime: "2026-04-08T10:01:30Z" + Fault: + FaultCode: 9027 + FaultString: System Resources Exceeded +` +) + +var DUStateChangeCompleteInputMsg = messages.DUStateChangeComplete{ + CommandKey: "Foobar", + Results: []messages.OpResultStruct{ + { + UUID: "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", + DeploymentUnitRef: "Device.SoftwareModules.DeploymentUnit.1", + Version: "1.0.0", + CurrentState: "Installed", + Resolved: true, + ExecutionUnitRefList: "exec-1", + StartTime: "2026-04-08T10:00:00Z", + CompleteTime: "2026-04-08T10:01:30Z", + Fault: messages.FaultStruct{FaultCode: 0, FaultString: ""}, + }, + { + UUID: "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", + DeploymentUnitRef: "Device.SoftwareModules.DeploymentUnit.2", + Version: "1.0.0", + CurrentState: "Failed", + Resolved: true, + ExecutionUnitRefList: "exec-1", + StartTime: "2026-04-08T10:00:00Z", + CompleteTime: "2026-04-08T10:01:30Z", + Fault: messages.FaultStruct{FaultCode: 9027, FaultString: "System Resources Exceeded"}, + }, + }, +} + +func TestDUStateChangeCompleteParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(DUStateChangeCompleteInputXML) + dec := xml.NewDecoder(buf) + msg := messages.DUStateChangeComplete{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "Foobar", msg.CommandKey) + assert.Equal(t, 2, len(msg.Results)) + + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Results[0].UUID) + assert.Equal(t, "Device.SoftwareModules.DeploymentUnit.1", msg.Results[0].DeploymentUnitRef) + assert.Equal(t, "1.0.0", msg.Results[0].Version) + assert.Equal(t, "Installed", msg.Results[0].CurrentState) + assert.Equal(t, true, msg.Results[0].Resolved) + assert.Equal(t, "exec-1", msg.Results[0].ExecutionUnitRefList) + assert.Equal(t, "2026-04-08T10:00:00Z", msg.Results[0].StartTime) + assert.Equal(t, "2026-04-08T10:01:30Z", msg.Results[0].CompleteTime) + assert.Equal(t, uint(0), msg.Results[0].Fault.FaultCode) + assert.Equal(t, "", msg.Results[0].Fault.FaultString) + + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Results[1].UUID) + assert.Equal(t, "Device.SoftwareModules.DeploymentUnit.2", msg.Results[1].DeploymentUnitRef) + assert.Equal(t, "1.0.0", msg.Results[1].Version) + assert.Equal(t, "Failed", msg.Results[1].CurrentState) + assert.Equal(t, true, msg.Results[1].Resolved) + assert.Equal(t, "exec-1", msg.Results[1].ExecutionUnitRefList) + assert.Equal(t, "2026-04-08T10:00:00Z", msg.Results[1].StartTime) + assert.Equal(t, "2026-04-08T10:01:30Z", msg.Results[1].CompleteTime) + assert.Equal(t, uint(9027), msg.Results[1].Fault.FaultCode) + assert.Equal(t, "System Resources Exceeded", msg.Results[1].Fault.FaultString) +} + +func TestDUStateChangeCompleteSerializeToXML(t *testing.T) { + msg := DUStateChangeCompleteInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating xml output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, DUStateChangeCompleteInputXML, buf.String()) +} + +func TestDUStateChangeCompleteParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(DUStateChangeCompleteInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.DUStateChangeComplete{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "Foobar", msg.CommandKey) + assert.Equal(t, 2, len(msg.Results)) + + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Results[0].UUID) + assert.Equal(t, "Device.SoftwareModules.DeploymentUnit.1", msg.Results[0].DeploymentUnitRef) + assert.Equal(t, "1.0.0", msg.Results[0].Version) + assert.Equal(t, "Installed", msg.Results[0].CurrentState) + assert.Equal(t, true, msg.Results[0].Resolved) + assert.Equal(t, "exec-1", msg.Results[0].ExecutionUnitRefList) + assert.Equal(t, "2026-04-08T10:00:00Z", msg.Results[0].StartTime) + assert.Equal(t, "2026-04-08T10:01:30Z", msg.Results[0].CompleteTime) + assert.Equal(t, uint(0), msg.Results[0].Fault.FaultCode) + assert.Equal(t, "", msg.Results[0].Fault.FaultString) + + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Results[1].UUID) + assert.Equal(t, "Device.SoftwareModules.DeploymentUnit.2", msg.Results[1].DeploymentUnitRef) + assert.Equal(t, "1.0.0", msg.Results[1].Version) + assert.Equal(t, "Failed", msg.Results[1].CurrentState) + assert.Equal(t, true, msg.Results[1].Resolved) + assert.Equal(t, "exec-1", msg.Results[1].ExecutionUnitRefList) + assert.Equal(t, "2026-04-08T10:00:00Z", msg.Results[1].StartTime) + assert.Equal(t, "2026-04-08T10:01:30Z", msg.Results[1].CompleteTime) + assert.Equal(t, uint(9027), msg.Results[1].Fault.FaultCode) + assert.Equal(t, "System Resources Exceeded", msg.Results[1].Fault.FaultString) +} + +func TestDUStateChangeCompleteSerializeToYAML(t *testing.T) { + msg := DUStateChangeCompleteInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, DUStateChangeCompleteInputYAML, outbuf.String()) +} diff --git a/internal/device/cwmp/messages/Envelope.go b/internal/device/cwmp/messages/Envelope.go new file mode 100644 index 0000000..05bb40e --- /dev/null +++ b/internal/device/cwmp/messages/Envelope.go @@ -0,0 +1,189 @@ +package messages + +import ( + "encoding/xml" + "fmt" + "io" +) + +func ParseEnvelopeXML(input io.Reader) (*Envelope, error) { + env := Envelope{} + enc := xml.NewDecoder(input) + return &env, enc.Decode(&env) +} + +type Envelope struct { + XMLName xml.Name `xml:"Envelope"` + Header *EnvelopeHeader `xml:",omitempty"` + Body EnvelopeBody +} + +func (e Envelope) GetID() string { + if e.Header != nil { + return e.Header.ID.Value + } else { + return "" + } +} + +func (e Envelope) GetBody() []Message { + return e.Body.Messages +} + +func NewEnvelope(id string, msg ...Message) Envelope { + env := Envelope{ + Body: EnvelopeBody{ + Messages: msg, + }, + } + if len(id) > 0 { + env.Header = &EnvelopeHeader{ + ID: IDStruct{Value: id, MustUnderstand: "1"}, + } + } + return env +} + +func (e Envelope) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixSoapEnv(&start.Name) + start.Attr = append(start.Attr, + XmlAttr("xmlns:soap-env", "http://schemas.xmlsoap.org/soap/envelope/"), + XmlAttr("xmlns:soap-enc", "http://schemas.xmlsoap.org/soap/encoding/"), + XmlAttr("xmlns:xsd", "http://www.w3.org/2001/XMLSchema"), + XmlAttr("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance"), + XmlAttr("xmlns:cwmp", "urn:dslforum-org:cwmp-1-0"), + ) + type Alias Envelope + return enc.EncodeElement(Alias(e), start) +} + +type EnvelopeHeader struct { + ID IDStruct `xml:"ID"` +} + +func (eh EnvelopeHeader) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixSoapEnv(&start.Name) + type Alias EnvelopeHeader + return enc.EncodeElement(Alias(eh), start) +} + +type IDStruct struct { + MustUnderstand string `xml:"soap-env:mustUnderstand,attr"` + Value string `xml:",chardata"` +} + +func (id IDStruct) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name) + type Alias IDStruct + return enc.EncodeElement(Alias(id), start) +} + +type EnvelopeBody struct { + Messages []Message +} + +func (eb EnvelopeBody) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixSoapEnv(&start.Name) + type Alias EnvelopeBody + return enc.EncodeElement(Alias(eb), start) +} + +func (eb *EnvelopeBody) UnmarshalXML(dec *xml.Decoder, start xml.StartElement) error { + for { + token, err := dec.Token() + if err != nil { + return err + } + switch tok := token.(type) { + case xml.StartElement: + var msg Message + switch tok.Name.Local { + case Inform{}.GetName(): + var m Inform + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case InformResponse{}.GetName(): + var m InformResponse + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case ChangeDUState{}.GetName(): + var m ChangeDUState + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case ChangeDUStateResponse{}.GetName(): + var m ChangeDUStateResponse + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case DUStateChangeComplete{}.GetName(): + var m DUStateChangeComplete + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case DUStateChangeCompleteResponse{}.GetName(): + var m DUStateChangeCompleteResponse + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case Fault{}.GetName(): + var m Fault + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case GetParameterNames{}.GetName(): + var m GetParameterNames + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case GetParameterNamesResponse{}.GetName(): + var m GetParameterNamesResponse + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case GetParameterValues{}.GetName(): + var m GetParameterValues + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case GetParameterValuesResponse{}.GetName(): + var m GetParameterValuesResponse + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case SetParameterValues{}.GetName(): + var m SetParameterValues + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + case SetParameterValuesResponse{}.GetName(): + var m SetParameterValuesResponse + if err := dec.DecodeElement(&m, &tok); err != nil { + return err + } + msg = m + default: + return fmt.Errorf("unknown RPC '%s'", tok.Name.Local) + } + eb.Messages = append(eb.Messages, msg) + case xml.EndElement: + if tok.Name.Local == start.Name.Local { + return nil + } + } + } +} diff --git a/internal/device/cwmp/messages/EnvelopeRPCParsing_test.go b/internal/device/cwmp/messages/EnvelopeRPCParsing_test.go new file mode 100644 index 0000000..6ee4d44 --- /dev/null +++ b/internal/device/cwmp/messages/EnvelopeRPCParsing_test.go @@ -0,0 +1,321 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/device/cwmp/messages" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + InformResponseEnvelopeXML = ` + + + 1 + + +` + + ChangeDUStateEnvelopeXML = ` + + + testkey + + + http://example.com/app.ipk + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + generic + + + + +` + + ChangeDUStateResponseEnvelopeXML = ` + + + +` + + DUStateChangeCompleteEnvelopeXML = ` + + + testkey + + + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + Device.SoftwareModules.DeploymentUnit.1 + 1.0.0 + Installed + true + exec-1 + 2026-04-08T10:00:00Z + 2026-04-08T10:01:30Z + + 0 + + + + + + +` + + DUStateChangeCompleteResponseEnvelopeXML = ` + + + +` + + FaultEnvelopeXML = ` + + + Server + CWMP Fault + + + 9003 + Invalid arguments + + + + +` + + GetParameterNamesEnvelopeXML = ` + + + Device.DeviceInfo. + true + + +` + + GetParameterNamesResponseEnvelopeXML = ` + + + + + Device.DeviceInfo.SoftwareVersion + false + + + + +` + + GetParameterValuesEnvelopeXML = ` + + + + Device.DeviceInfo.SoftwareVersion + + + +` + + GetParameterValuesResponseEnvelopeXML = ` + + + + + Device.DeviceInfo.SoftwareVersion + 1.0.0 + + + + +` + + SetParameterValuesEnvelopeXML = ` + + + + + Device.DeviceInfo.SoftwareVersion + 2.0.0 + + + mykey + + +` + + SetParameterValuesResponseEnvelopeXML = ` + + + 0 + + +` + + MultipleMessagesEnvelopeXML = ` + + + 1 + + + testkey + + + http://example.com/app.ipk + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + generic + + + + +` +) + +func parseEnvelope(t *testing.T, xmlInput string) messages.Envelope { + t.Helper() + buf := bytes.NewBufferString(xmlInput) + if env, err := messages.ParseEnvelopeXML(buf); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + return messages.Envelope{} + } else { + return *env + } +} + +func TestEnvelopeParseInformResponse(t *testing.T) { + env := parseEnvelope(t, InformResponseEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.InformResponse) + require.True(t, ok) + assert.Equal(t, uint(1), msg.MaxEnvelopes) +} + +func TestEnvelopeParseChangeDUState(t *testing.T) { + env := parseEnvelope(t, ChangeDUStateEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.ChangeDUState) + require.True(t, ok) + assert.Equal(t, "testkey", msg.CommandKey.String()) + assert.Equal(t, 1, len(msg.Operations.Op)) + install, ok := msg.Operations.Op[0].(messages.InstallOpStruct) + require.True(t, ok) + assert.Equal(t, "http://example.com/app.ipk", install.URL.String()) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", install.UUID.String()) + assert.Equal(t, "generic", install.ExecutionEnvRef.String()) +} + +func TestEnvelopeParseChangeDUStateResponse(t *testing.T) { + env := parseEnvelope(t, ChangeDUStateResponseEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + _, ok := env.GetBody()[0].(messages.ChangeDUStateResponse) + require.True(t, ok) +} + +func TestEnvelopeParseDUStateChangeComplete(t *testing.T) { + env := parseEnvelope(t, DUStateChangeCompleteEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.DUStateChangeComplete) + require.True(t, ok) + assert.Equal(t, "testkey", msg.CommandKey) + assert.Equal(t, 1, len(msg.Results)) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", msg.Results[0].UUID) + assert.Equal(t, "Device.SoftwareModules.DeploymentUnit.1", msg.Results[0].DeploymentUnitRef) + assert.Equal(t, "Installed", msg.Results[0].CurrentState) + assert.Equal(t, uint(0), msg.Results[0].Fault.FaultCode) +} + +func TestEnvelopeParseDUStateChangeCompleteResponse(t *testing.T) { + env := parseEnvelope(t, DUStateChangeCompleteResponseEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + _, ok := env.GetBody()[0].(messages.DUStateChangeCompleteResponse) + require.True(t, ok) +} + +func TestEnvelopeParseFault(t *testing.T) { + env := parseEnvelope(t, FaultEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.Fault) + require.True(t, ok) + assert.Equal(t, "Server", msg.FaultCode) + assert.Equal(t, "CWMP Fault", msg.FaultString) + assert.Equal(t, uint(9003), msg.Detail.FaultCode) + assert.Equal(t, "Invalid arguments", msg.Detail.FaultString) +} + +func TestEnvelopeParseGetParameterNames(t *testing.T) { + env := parseEnvelope(t, GetParameterNamesEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.GetParameterNames) + require.True(t, ok) + assert.Equal(t, "Device.DeviceInfo.", msg.ParameterPath.String()) + assert.Equal(t, true, msg.NextLevel) +} + +func TestEnvelopeParseGetParameterNamesResponse(t *testing.T) { + env := parseEnvelope(t, GetParameterNamesResponseEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.GetParameterNamesResponse) + require.True(t, ok) + assert.Equal(t, 1, len(msg.ParameterList)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterList[0].Name) + assert.Equal(t, false, msg.ParameterList[0].Writable) +} + +func TestEnvelopeParseGetParameterValues(t *testing.T) { + env := parseEnvelope(t, GetParameterValuesEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.GetParameterValues) + require.True(t, ok) + assert.Equal(t, 1, len(msg.ParameterNames.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterNames.Params[0].String()) +} + +func TestEnvelopeParseGetParameterValuesResponse(t *testing.T) { + env := parseEnvelope(t, GetParameterValuesResponseEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.GetParameterValuesResponse) + require.True(t, ok) + assert.Equal(t, 1, len(msg.ParameterList.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterList.Params[0].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[0].Content.Type) + assert.Equal(t, "1.0.0", msg.ParameterList.Params[0].Content.Value.String()) +} + +func TestEnvelopeParseSetParameterValues(t *testing.T) { + env := parseEnvelope(t, SetParameterValuesEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.SetParameterValues) + require.True(t, ok) + assert.Equal(t, 1, len(msg.ParameterList.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterList.Params[0].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[0].Content.Type) + assert.Equal(t, "2.0.0", msg.ParameterList.Params[0].Content.Value.String()) + assert.Equal(t, "mykey", msg.ParameterKey) +} + +func TestEnvelopeParseSetParameterValuesResponse(t *testing.T) { + env := parseEnvelope(t, SetParameterValuesResponseEnvelopeXML) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.SetParameterValuesResponse) + require.True(t, ok) + assert.Equal(t, uint(0), msg.Status) +} + +func TestEnvelopeParseMultipleMessages(t *testing.T) { + env := parseEnvelope(t, MultipleMessagesEnvelopeXML) + require.Equal(t, 2, len(env.GetBody())) + + inform, ok := env.GetBody()[0].(messages.InformResponse) + require.True(t, ok, "expected GetBody()[0] to be messages.InformResponse") + assert.Equal(t, uint(1), inform.MaxEnvelopes) + + cds, ok := env.GetBody()[1].(messages.ChangeDUState) + require.True(t, ok, "expected GetBody()[1] to be *messages.ChangeDUState") + assert.Equal(t, "testkey", cds.CommandKey.String()) + require.Equal(t, 1, len(cds.Operations.Op)) + install, ok := cds.Operations.Op[0].(messages.InstallOpStruct) + require.True(t, ok, "expected Operations.Op[0] to be messages.InstallOpStruct") + assert.Equal(t, "http://example.com/app.ipk", install.URL.String()) + assert.Equal(t, "c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d", install.UUID.String()) + assert.Equal(t, "generic", install.ExecutionEnvRef.String()) +} diff --git a/internal/device/cwmp/messages/EnvelopeRPCSerializing_test.go b/internal/device/cwmp/messages/EnvelopeRPCSerializing_test.go new file mode 100644 index 0000000..2b3e246 --- /dev/null +++ b/internal/device/cwmp/messages/EnvelopeRPCSerializing_test.go @@ -0,0 +1,174 @@ +// Copyright 2024 Nokia +// Licensed under the BSD 3-Clause License. +// SPDX-License-Identifier: BSD-3-Clause + +package messages_test + +import ( + "bytes" + "corteca/internal/configuration" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" +) + +const envelopeHeader = ` + + test-id + + ` + +const envelopeFooter = ` + +` + +const ( + ChangeDUStateEnvelopeOutputXML = envelopeHeader + ` + + testkey + + + http://example.com/app.ipk + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + user + pass + generic + + + ` + envelopeFooter + + FaultEnvelopeOutputXML = envelopeHeader + ` + + Server + CWMP Fault + + + 9027 + System Resources Exceeded + + + ` + envelopeFooter + + GetParameterNamesEnvelopeOutputXML = envelopeHeader + ` + + Device.SoftwareModules. + true + ` + envelopeFooter + + GetParameterValuesEnvelopeOutputXML = envelopeHeader + ` + + + Device.DeviceInfo.SoftwareVersion + Device.DeviceInfo.HardwareVersion + + ` + envelopeFooter + + SetParameterValuesEnvelopeOutputXML = envelopeHeader + ` + + + + Device.DeviceInfo.SoftwareVersion + 2.0.0 + + + Device.DeviceInfo.HardwareVersion + RevB + + + mykey + ` + envelopeFooter + + GetRPCMethodsEnvelopeOutputXML = envelopeHeader + ` + ` + envelopeFooter + + InformResponseAndChangeDUStateEnvelopeOutputXML = envelopeHeader + ` + + 1 + + + testkey + + + http://example.com/app.ipk + c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d + user + pass + generic + + + ` + envelopeFooter +) + +func serializeEnvelope(t *testing.T, env messages.Envelope) string { + t.Helper() + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(env); err != nil { + t.Logf("Failed serializing envelope: %s", err.Error()) + t.FailNow() + } + return buf.String() +} + +func testChangeDUState() messages.ChangeDUState { + return messages.ChangeDUState{ + CommandKey: configuration.T("testkey"), + Operations: messages.DUOperationStruct{ + Op: []messages.DUOperation{ + messages.InstallOpStruct{ + URL: configuration.T("http://example.com/app.ipk"), + UUID: configuration.T("c0c4328b-18a4-4b3b-b1da-e8ea8d8f457d"), + Username: configuration.T("user"), + Password: configuration.T("pass"), + ExecutionEnvRef: configuration.T("generic"), + }, + }, + }, + } +} + +func TestEnvelopeSerializeChangeDUState(t *testing.T) { + msg := testChangeDUState() + env := messages.NewEnvelope("test-id", &msg) + assert.Equal(t, ChangeDUStateEnvelopeOutputXML, serializeEnvelope(t, env)) +} + +func TestEnvelopeSerializeFault(t *testing.T) { + msg := messages.NewFault(9027, "System Resources Exceeded") + env := messages.NewEnvelope("test-id", &msg) + assert.Equal(t, FaultEnvelopeOutputXML, serializeEnvelope(t, env)) +} + +func TestEnvelopeSerializeGetParameterNames(t *testing.T) { + msg := GetParameterNamesInputMsg + env := messages.NewEnvelope("test-id", &msg) + assert.Equal(t, GetParameterNamesEnvelopeOutputXML, serializeEnvelope(t, env)) +} + +func TestEnvelopeSerializeGetParameterValues(t *testing.T) { + msg := GetParameterValuesInputMsg + env := messages.NewEnvelope("test-id", &msg) + assert.Equal(t, GetParameterValuesEnvelopeOutputXML, serializeEnvelope(t, env)) +} + +func TestEnvelopeSerializeSetParameterValues(t *testing.T) { + msg := SetParameterValuesInputMsg + env := messages.NewEnvelope("test-id", &msg) + assert.Equal(t, SetParameterValuesEnvelopeOutputXML, serializeEnvelope(t, env)) +} + +func TestEnvelopeSerializeGetRPCMethods(t *testing.T) { + msg := GetRPCMethodsInputMsg + env := messages.NewEnvelope("test-id", msg) + assert.Equal(t, GetRPCMethodsEnvelopeOutputXML, serializeEnvelope(t, env)) +} + +func TestEnvelopeSerializeInformResponseAndChangeDUState(t *testing.T) { + inform := messages.InformResponse{MaxEnvelopes: 1} + cds := testChangeDUState() + env := messages.NewEnvelope("test-id", &inform, &cds) + assert.Equal(t, InformResponseAndChangeDUStateEnvelopeOutputXML, serializeEnvelope(t, env)) +} diff --git a/internal/device/cwmp/messages/Envelope_test.go b/internal/device/cwmp/messages/Envelope_test.go new file mode 100644 index 0000000..cc45133 --- /dev/null +++ b/internal/device/cwmp/messages/Envelope_test.go @@ -0,0 +1,188 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + BlankEnvelopeInputXML = ` + + + +` + + EnvelopeInputXML = ` + + + + 123456789 + + + + + + + ExampleCorp + ABCDEF + RouterModelX + 1234567890 + + + + + 0 BOOTSTRAP + + + + 1 BOOT + + + + + 1 + + 2026-03-29T12:00:00Z + + 3 + + + + Device.DeviceInfo.SoftwareVersion + 1.0.0 + + + Device.DeviceInfo.HardwareVersion + RevA + + + Device.ManagementServer.ConnectionRequestURL + http://192.168.1.1:7547/ + + + Device.WANDevice.1.WANConnectionDevice.1.WANIPConnection.1.ExternalIPAddress + 203.0.113.45 + + + + +` + + InvalidEnvelopeInputXML = ` + + + +` + + EnvelopeOutputXML = ` + + testEnvelope + + + + + + + + + + + 0 + + 0 + + + +` + + BlankEnvelopeOutputXML = ` + +` +) + +func TestEnvelopeParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(EnvelopeInputXML) + dec := xml.NewDecoder(buf) + env := messages.Envelope{} + if err := dec.Decode(&env); err != nil { + t.Logf("Failed parsing xml input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "123456789", env.GetID()) + require.Equal(t, 1, len(env.GetBody())) + msg, ok := env.GetBody()[0].(messages.Inform) + assert.Equal(t, true, ok) + assert.Equal(t, "ABCDEF", msg.DeviceId.OUI) + assert.Equal(t, "RouterModelX", msg.DeviceId.ProductClass) + assert.Equal(t, "1234567890", msg.DeviceId.SerialNumber) + assert.Equal(t, uint(1), msg.MaxEnvelopes) + assert.Equal(t, "2026-03-29T12:00:00Z", msg.CurrentTime) + assert.Equal(t, uint(3), msg.RetryCount) + assert.Equal(t, 2, len(msg.Event.Events)) + assert.Equal(t, messages.EventBoot, msg.Event.Events[1].EventCode) + assert.Equal(t, 4, len(msg.ParameterList.Params)) + assert.Equal(t, "Device.WANDevice.1.WANConnectionDevice.1.WANIPConnection.1.ExternalIPAddress", msg.ParameterList.Params[3].Name.String()) + assert.Equal(t, "203.0.113.45", msg.ParameterList.Params[3].Content.Value.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[3].Content.Type) +} + +func TestBlankEnvelopeParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(BlankEnvelopeInputXML) + dec := xml.NewDecoder(buf) + env := messages.Envelope{} + if err := dec.Decode(&env); err != nil { + t.Logf("Failed parsing xml input: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, "", env.GetID()) + assert.Empty(t, env.GetBody()) +} + +func TestInvalidRPCEnvelope(t *testing.T) { + buf := bytes.NewBufferString(InvalidEnvelopeInputXML) + dec := xml.NewDecoder(buf) + env := messages.Envelope{} + assert.NotNil(t, dec.Decode(&env)) +} + +func TestEnvelopeSerializeToXML(t *testing.T) { + env := messages.NewEnvelope("testEnvelope", &messages.Inform{}) + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(env); err != nil { + t.Logf("Failed generating xml output: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, EnvelopeOutputXML, buf.String()) +} + +func TestBlankEnvelopeSerializeToXML(t *testing.T) { + env := messages.NewEnvelope("", nil) + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(env); err != nil { + t.Logf("Failed generating xml output: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, BlankEnvelopeOutputXML, buf.String()) +} diff --git a/internal/device/cwmp/messages/Fault.go b/internal/device/cwmp/messages/Fault.go new file mode 100644 index 0000000..0074451 --- /dev/null +++ b/internal/device/cwmp/messages/Fault.go @@ -0,0 +1,51 @@ +package messages + +import ( + "encoding/xml" +) + +func NewFault(code uint, msg string) Fault { + return Fault{ + FaultCode: "Server", + FaultString: "CWMP Fault", + Detail: CwmpFaultStruct{ + FaultStruct: FaultStruct{ + FaultCode: code, + FaultString: msg, + }, + }, + } +} + +type Fault struct { + XMLName xml.Name `xml:"Fault" yaml:"-"` + FaultCode string `xml:"faultcode" yaml:"faultcode"` + FaultString string `xml:"faultstring" yaml:"faultstring"` + Detail CwmpFaultStruct `xml:"detail>Fault" yaml:"Detail"` +} + +func (f Fault) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixSoapEnv(&start.Name, "Fault") + type Alias Fault + return enc.EncodeElement(Alias(f), start) +} + +type CwmpFaultStruct struct { + XMLName xml.Name `xml:"Fault" yaml:"-"` + FaultStruct `yaml:",inline"` + SetParameterValuesFault *SetParameterValuesFaultStruct `xml:"SetParameterValuesFault,omitempty" yaml:"SetParameterValuesFault,omitempty"` +} + +func (cf CwmpFaultStruct) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name) + type Alias CwmpFaultStruct + return enc.EncodeElement(Alias(cf), start) +} + +type SetParameterValuesFaultStruct struct { + ParameterName string `yaml:"ParameterName"` + FaultCode string `yaml:"FaultCode"` + FaultString string `yaml:"FaultString"` +} + +func (msg Fault) GetName() string { return "Fault" } diff --git a/internal/device/cwmp/messages/Fault_test.go b/internal/device/cwmp/messages/Fault_test.go new file mode 100644 index 0000000..0e3e61a --- /dev/null +++ b/internal/device/cwmp/messages/Fault_test.go @@ -0,0 +1,149 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +const ( + FaultSimpleInputXML = ` + Client + CWMP fault + + + 9027 + System Resources Exceeded + + +` + + FaultInputXML = ` + Client + CWMP fault + + + 9003 + Invalid arguments + + Device.SomeParam + 9008 + Attempt to set a non-writable parameter + + + +` + + FaultInputYAML = `faultcode: Client +faultstring: CWMP fault +Detail: + FaultCode: 9003 + FaultString: Invalid arguments + SetParameterValuesFault: + ParameterName: Device.SomeParam + FaultCode: "9008" + FaultString: Attempt to set a non-writable parameter +` +) + +var FaultInputMsg = messages.Fault{ + FaultCode: "Client", + FaultString: "CWMP fault", + Detail: messages.CwmpFaultStruct{ + FaultStruct: messages.FaultStruct{ + FaultCode: 9003, + FaultString: "Invalid arguments", + }, + SetParameterValuesFault: &messages.SetParameterValuesFaultStruct{ + ParameterName: "Device.SomeParam", + FaultCode: "9008", + FaultString: "Attempt to set a non-writable parameter", + }, + }, +} + +func TestFaultParseFromXMLNoSetParameterValuesFault(t *testing.T) { + buf := bytes.NewBufferString(FaultSimpleInputXML) + dec := xml.NewDecoder(buf) + msg := messages.Fault{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "Client", msg.FaultCode) + assert.Equal(t, "CWMP fault", msg.FaultString) + assert.Equal(t, uint(9027), msg.Detail.FaultCode) + assert.Equal(t, "System Resources Exceeded", msg.Detail.FaultString) + assert.Nil(t, msg.Detail.SetParameterValuesFault) +} + +func TestFaultParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(FaultInputXML) + dec := xml.NewDecoder(buf) + msg := messages.Fault{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "Client", msg.FaultCode) + assert.Equal(t, "CWMP fault", msg.FaultString) + + assert.Equal(t, uint(9003), msg.Detail.FaultCode) + assert.Equal(t, "Invalid arguments", msg.Detail.FaultString) + + assert.NotNil(t, msg.Detail.SetParameterValuesFault) + assert.Equal(t, "Device.SomeParam", msg.Detail.SetParameterValuesFault.ParameterName) + assert.Equal(t, "9008", msg.Detail.SetParameterValuesFault.FaultCode) + assert.Equal(t, "Attempt to set a non-writable parameter", msg.Detail.SetParameterValuesFault.FaultString) +} + +func TestFaultSerializeToXML(t *testing.T) { + msg := FaultInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, FaultInputXML, buf.String()) +} + +func TestFaultParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(FaultInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.Fault{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "Client", msg.FaultCode) + assert.Equal(t, "CWMP fault", msg.FaultString) + + assert.Equal(t, uint(9003), msg.Detail.FaultCode) + assert.Equal(t, "Invalid arguments", msg.Detail.FaultString) + + assert.NotNil(t, msg.Detail.SetParameterValuesFault) + assert.Equal(t, "Device.SomeParam", msg.Detail.SetParameterValuesFault.ParameterName) + assert.Equal(t, "9008", msg.Detail.SetParameterValuesFault.FaultCode) + assert.Equal(t, "Attempt to set a non-writable parameter", msg.Detail.SetParameterValuesFault.FaultString) +} + +func TestFaultSerializeToYAML(t *testing.T) { + msg := FaultInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, FaultInputYAML, outbuf.String()) +} diff --git a/internal/device/cwmp/messages/GetParameterNames.go b/internal/device/cwmp/messages/GetParameterNames.go new file mode 100644 index 0000000..9d74212 --- /dev/null +++ b/internal/device/cwmp/messages/GetParameterNames.go @@ -0,0 +1,41 @@ +package messages + +import ( + c "corteca/internal/configuration" + "encoding/xml" +) + +type GetParameterNames struct { + XMLName xml.Name `xml:"GetParameterNames" yaml:"-"` + ParameterPath c.TemplateField `yaml:"ParameterPath"` + NextLevel bool `yaml:"NextLevel"` +} + +func (msg GetParameterNames) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "GetParameterNames") + type Alias GetParameterNames + return enc.EncodeElement(Alias(msg), start) +} + +func (msg GetParameterNames) GetName() string { return "GetParameterNames" } +func (msg GetParameterNames) ValidateResponse(resp Message) error { + return ExpectMessage[GetParameterNamesResponse](resp) +} + +type GetParameterNamesResponse struct { + XMLName xml.Name `xml:"GetParameterNamesResponse" yaml:"-"` + ParameterList []ParameterInfoStruct `xml:"ParameterList>ParameterInfoStruct" yaml:"ParameterList"` +} + +func (msg GetParameterNamesResponse) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name) + type Alias GetParameterNamesResponse + return enc.EncodeElement(Alias(msg), start) +} + +type ParameterInfoStruct struct { + Name string `yaml:"Name"` + Writable bool `yaml:"Writable"` +} + +func (msg GetParameterNamesResponse) GetName() string { return "GetParameterNamesResponse" } diff --git a/internal/device/cwmp/messages/GetParameterNames_test.go b/internal/device/cwmp/messages/GetParameterNames_test.go new file mode 100644 index 0000000..b4bcb21 --- /dev/null +++ b/internal/device/cwmp/messages/GetParameterNames_test.go @@ -0,0 +1,161 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/configuration" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +const ( + GetParameterNamesInputXML = ` + Device.SoftwareModules. + true +` + + GetParameterNamesInputYAML = `ParameterPath: Device.SoftwareModules. +NextLevel: true +` + + GetParameterNamesResponseInputXML = ` + + + Device.SoftwareModules. + false + + + Device.SoftwareModules.ExecutionUnit.1.Version + false + + +` + + GetParameterNamesResponseInputYAML = `ParameterList: + - Name: Device.SoftwareModules. + Writable: false + - Name: Device.SoftwareModules.ExecutionUnit.1.Version + Writable: false +` +) + +var GetParameterNamesInputMsg = messages.GetParameterNames{ + ParameterPath: configuration.T("Device.SoftwareModules."), + NextLevel: true, +} + +var GetParameterNamesResponseInputMsg = messages.GetParameterNamesResponse{ + ParameterList: []messages.ParameterInfoStruct{ + {Name: "Device.SoftwareModules.", Writable: false}, + {Name: "Device.SoftwareModules.ExecutionUnit.1.Version", Writable: false}, + }, +} + +func TestGetParameterNamesParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(GetParameterNamesInputXML) + dec := xml.NewDecoder(buf) + msg := messages.GetParameterNames{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "Device.SoftwareModules.", msg.ParameterPath.String()) + assert.Equal(t, true, msg.NextLevel) +} + +func TestGetParameterNamesSerializeToXML(t *testing.T) { + msg := GetParameterNamesInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetParameterNamesInputXML, buf.String()) +} + +func TestGetParameterNamesParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(GetParameterNamesInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.GetParameterNames{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "Device.SoftwareModules.", msg.ParameterPath.String()) + assert.Equal(t, true, msg.NextLevel) +} + +func TestGetParameterNamesSerializeToYAML(t *testing.T) { + msg := GetParameterNamesInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetParameterNamesInputYAML, outbuf.String()) +} + +func TestGetParameterNamesResponseParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(GetParameterNamesResponseInputXML) + dec := xml.NewDecoder(buf) + msg := messages.GetParameterNamesResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 2, len(msg.ParameterList)) + assert.Equal(t, "Device.SoftwareModules.", msg.ParameterList[0].Name) + assert.Equal(t, false, msg.ParameterList[0].Writable) + assert.Equal(t, "Device.SoftwareModules.ExecutionUnit.1.Version", msg.ParameterList[1].Name) + assert.Equal(t, false, msg.ParameterList[1].Writable) +} + +func TestGetParameterNamesResponseSerializeToXML(t *testing.T) { + msg := GetParameterNamesResponseInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetParameterNamesResponseInputXML, buf.String()) +} + +func TestGetParameterNamesResponseParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(GetParameterNamesResponseInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.GetParameterNamesResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 2, len(msg.ParameterList)) + assert.Equal(t, "Device.SoftwareModules.", msg.ParameterList[0].Name) + assert.Equal(t, false, msg.ParameterList[0].Writable) + assert.Equal(t, "Device.SoftwareModules.ExecutionUnit.1.Version", msg.ParameterList[1].Name) + assert.Equal(t, false, msg.ParameterList[1].Writable) +} + +func TestGetParameterNamesResponseSerializeToYAML(t *testing.T) { + msg := GetParameterNamesResponseInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetParameterNamesResponseInputYAML, outbuf.String()) +} diff --git a/internal/device/cwmp/messages/GetParameterValues.go b/internal/device/cwmp/messages/GetParameterValues.go new file mode 100644 index 0000000..ae3811a --- /dev/null +++ b/internal/device/cwmp/messages/GetParameterValues.go @@ -0,0 +1,34 @@ +package messages + +import ( + "encoding/xml" +) + +type GetParameterValues struct { + XMLName xml.Name `xml:"GetParameterValues" yaml:"-"` + ParameterNames ParameterNameListStruct `yaml:"ParameterNames"` +} + +func (msg GetParameterValues) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "GetParameterValues") + type Alias GetParameterValues + return enc.EncodeElement(Alias(msg), start) +} + +func (msg GetParameterValues) GetName() string { return "GetParameterValues" } +func (msg GetParameterValues) ValidateResponse(resp Message) error { + return ExpectMessage[GetParameterValuesResponse](resp) +} + +type GetParameterValuesResponse struct { + XMLName xml.Name `xml:"GetParameterValuesResponse" yaml:"-"` + ParameterList ParameterValueListStruct `yaml:"ParameterList"` +} + +func (msg GetParameterValuesResponse) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name) + type Alias GetParameterValuesResponse + return enc.EncodeElement(Alias(msg), start) +} + +func (msg GetParameterValuesResponse) GetName() string { return "GetParameterValuesResponse" } diff --git a/internal/device/cwmp/messages/GetParameterValues_test.go b/internal/device/cwmp/messages/GetParameterValues_test.go new file mode 100644 index 0000000..3bcf369 --- /dev/null +++ b/internal/device/cwmp/messages/GetParameterValues_test.go @@ -0,0 +1,184 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/configuration" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +const ( + GetParameterValuesInputXML = ` + + Device.DeviceInfo.SoftwareVersion + Device.DeviceInfo.HardwareVersion + +` + + GetParameterValuesInputYAML = `ParameterNames: + - Device.DeviceInfo.SoftwareVersion + - Device.DeviceInfo.HardwareVersion +` + + GetParameterValuesResponseInputXML = ` + + + Device.DeviceInfo.SoftwareVersion + 1.0.0 + + + Device.DeviceInfo.HardwareVersion + RevA + + +` + + GetParameterValuesResponseInputYAML = `ParameterList: + - Name: Device.DeviceInfo.SoftwareVersion + Type: xsd:string + Value: 1.0.0 + - Name: Device.DeviceInfo.HardwareVersion + Type: xsd:string + Value: RevA +` +) + +var GetParameterValuesInputMsg = messages.GetParameterValues{ + ParameterNames: messages.ParameterNameListStruct{ + Params: []configuration.TemplateField{ + configuration.T("Device.DeviceInfo.SoftwareVersion"), + configuration.T("Device.DeviceInfo.HardwareVersion"), + }, + }, +} + +var GetParameterValuesResponseInputMsg = messages.GetParameterValuesResponse{ + ParameterList: messages.ParameterValueListStruct{ + Params: []messages.ParameterValueStruct{ + { + Name: configuration.T("Device.DeviceInfo.SoftwareVersion"), + Content: messages.NodeStruct{Type: messages.XsdString, Value: configuration.T("1.0.0")}, + }, + { + Name: configuration.T("Device.DeviceInfo.HardwareVersion"), + Content: messages.NodeStruct{Type: messages.XsdString, Value: configuration.T("RevA")}, + }, + }, + }, +} + +func TestGetParameterValuesParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(GetParameterValuesInputXML) + dec := xml.NewDecoder(buf) + msg := messages.GetParameterValues{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 2, len(msg.ParameterNames.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterNames.Params[0].String()) + assert.Equal(t, "Device.DeviceInfo.HardwareVersion", msg.ParameterNames.Params[1].String()) +} + +func TestGetParameterValuesSerializeToXML(t *testing.T) { + msg := GetParameterValuesInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetParameterValuesInputXML, buf.String()) +} + +func TestGetParameterValuesParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(GetParameterValuesInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.GetParameterValues{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 2, len(msg.ParameterNames.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterNames.Params[0].String()) + assert.Equal(t, "Device.DeviceInfo.HardwareVersion", msg.ParameterNames.Params[1].String()) +} + +func TestGetParameterValuesSerializeToYAML(t *testing.T) { + msg := GetParameterValuesInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetParameterValuesInputYAML, outbuf.String()) +} + +func TestGetParameterValuesResponseParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(GetParameterValuesResponseInputXML) + dec := xml.NewDecoder(buf) + msg := messages.GetParameterValuesResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 2, len(msg.ParameterList.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterList.Params[0].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[0].Content.Type) + assert.Equal(t, "1.0.0", msg.ParameterList.Params[0].Content.Value.String()) + assert.Equal(t, "Device.DeviceInfo.HardwareVersion", msg.ParameterList.Params[1].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[1].Content.Type) + assert.Equal(t, "RevA", msg.ParameterList.Params[1].Content.Value.String()) +} + +func TestGetParameterValuesResponseSerializeToXML(t *testing.T) { + msg := GetParameterValuesResponseInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetParameterValuesResponseInputXML, buf.String()) +} + +func TestGetParameterValuesResponseParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(GetParameterValuesResponseInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.GetParameterValuesResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 2, len(msg.ParameterList.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterList.Params[0].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[0].Content.Type) + assert.Equal(t, "1.0.0", msg.ParameterList.Params[0].Content.Value.String()) + assert.Equal(t, "Device.DeviceInfo.HardwareVersion", msg.ParameterList.Params[1].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[1].Content.Type) + assert.Equal(t, "RevA", msg.ParameterList.Params[1].Content.Value.String()) +} + +func TestGetParameterValuesResponseSerializeToYAML(t *testing.T) { + msg := GetParameterValuesResponseInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetParameterValuesResponseInputYAML, outbuf.String()) +} diff --git a/internal/device/cwmp/messages/GetRPCMethods.go b/internal/device/cwmp/messages/GetRPCMethods.go new file mode 100644 index 0000000..f01f1eb --- /dev/null +++ b/internal/device/cwmp/messages/GetRPCMethods.go @@ -0,0 +1,65 @@ +package messages + +import ( + "encoding/xml" + "fmt" + + "gopkg.in/yaml.v3" +) + +type GetRPCMethods struct { + XMLName xml.Name `xml:"GetRPCMethods" yaml:"-"` +} + +func (msg GetRPCMethods) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "GetRPCMethods") + type Alias GetRPCMethods + return enc.EncodeElement(Alias(msg), start) +} + +func (msg GetRPCMethods) GetName() string { return "GetRPCMethods" } +func (msg GetRPCMethods) ValidateResponse(resp Message) error { + return ExpectMessage[GetRPCMethodsResponse](resp) +} +func (msg GetRPCMethods) GenerateResponse() Message { + return GetRPCMethodsResponse{ + MethodList: MethodListStruct{ + Methods: []string{ + "Inform", + "GetRPCMethods", + "DUStateChangeComplete", + }, + }, + } +} + +type GetRPCMethodsResponse struct { + XMLName xml.Name `xml:"GetRPCMethodsResponse" yaml:"-"` + MethodList MethodListStruct `yaml:"MethodList"` +} + +func (msg GetRPCMethodsResponse) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "GetRPCMethodsResponse") + type Alias GetRPCMethodsResponse + return enc.EncodeElement(Alias(msg), start) +} + +type MethodListStruct struct { + Methods []string `xml:"string"` +} + +func (pl MethodListStruct) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + start.Attr = append(start.Attr, XmlAttr(SoapArray, fmt.Sprintf("xsd:string[%d]", len(pl.Methods)))) + type Alias MethodListStruct + return enc.EncodeElement(Alias(pl), start) +} + +func (pl MethodListStruct) MarshalYAML() (any, error) { + return pl.Methods, nil +} + +func (pl *MethodListStruct) UnmarshalYAML(value *yaml.Node) error { + return value.Decode(&pl.Methods) +} + +func (msg GetRPCMethodsResponse) GetName() string { return "GetRPCMethodsResponse" } diff --git a/internal/device/cwmp/messages/GetRPCMethods_test.go b/internal/device/cwmp/messages/GetRPCMethods_test.go new file mode 100644 index 0000000..1dafeba --- /dev/null +++ b/internal/device/cwmp/messages/GetRPCMethods_test.go @@ -0,0 +1,109 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +const ( + GetRPCMethodsInputXML = `` + + GetRPCMethodsResponseInputXML = ` + + Inform + GetRPCMethods + DUStateChangeComplete + +` + + GetRPCMethodsResponseInputYAML = `MethodList: + - Inform + - GetRPCMethods + - DUStateChangeComplete +` +) + +var GetRPCMethodsInputMsg = messages.GetRPCMethods{} + +var GetRPCMethodsResponseInputMsg = messages.GetRPCMethods{}.GenerateResponse() + +func TestGetRPCMethodsParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(GetRPCMethodsInputXML) + dec := xml.NewDecoder(buf) + msg := messages.GetRPCMethods{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } +} + +func TestGetRPCMethodsSerializeToXML(t *testing.T) { + msg := GetRPCMethodsInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetRPCMethodsInputXML, buf.String()) +} + +func TestGetRPCMethodsResponseParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(GetRPCMethodsResponseInputXML) + dec := xml.NewDecoder(buf) + msg := messages.GetRPCMethodsResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 3, len(msg.MethodList.Methods)) + assert.Equal(t, "Inform", msg.MethodList.Methods[0]) + assert.Equal(t, "GetRPCMethods", msg.MethodList.Methods[1]) + assert.Equal(t, "DUStateChangeComplete", msg.MethodList.Methods[2]) +} + +func TestGetRPCMethodsResponseSerializeToXML(t *testing.T) { + msg := GetRPCMethodsResponseInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetRPCMethodsResponseInputXML, buf.String()) +} + +func TestGetRPCMethodsResponseParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(GetRPCMethodsResponseInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.GetRPCMethodsResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 3, len(msg.MethodList.Methods)) + assert.Equal(t, "Inform", msg.MethodList.Methods[0]) + assert.Equal(t, "GetRPCMethods", msg.MethodList.Methods[1]) + assert.Equal(t, "DUStateChangeComplete", msg.MethodList.Methods[2]) +} + +func TestGetRPCMethodsResponseSerializeToYAML(t *testing.T) { + msg := GetRPCMethodsResponseInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, GetRPCMethodsResponseInputYAML, outbuf.String()) +} diff --git a/internal/device/cwmp/messages/Inform.go b/internal/device/cwmp/messages/Inform.go new file mode 100644 index 0000000..a705253 --- /dev/null +++ b/internal/device/cwmp/messages/Inform.go @@ -0,0 +1,76 @@ +package messages + +import ( + "encoding/xml" + "fmt" + + "gopkg.in/yaml.v3" +) + +type Inform struct { + XMLName xml.Name `xml:"Inform" yaml:"-"` + DeviceId DeviceIDStruct `yaml:"DeviceId"` + Event EventList `yaml:"Event"` + MaxEnvelopes uint `yaml:"MaxEnvelopes"` + CurrentTime string `yaml:"CurrentTime"` + RetryCount uint `yaml:"RetryCount"` + ParameterList ParameterValueListStruct `yaml:"ParameterList"` +} + +// custom marshaller to add prefix to name +func (i Inform) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "Inform") + type Alias Inform + return e.EncodeElement(Alias(i), start) +} + +type DeviceIDStruct struct { + Manufacturer string `yaml:"Manufacturer"` + OUI string `yaml:"OUI"` + ProductClass string `yaml:"ProductClass"` + SerialNumber string `yaml:"SerialNumber"` +} + +// EventStruct event +type EventList struct { + Events []EventStruct `xml:"EventStruct"` +} + +// custom marshaller to add type attribute +func (el EventList) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + start.Attr = append(start.Attr, xml.Attr{ + Name: xml.Name{Local: SoapArray}, + Value: fmt.Sprintf("cwmp:EventStruct[%d]", len(el.Events)), + }) + type Alias EventList + return enc.EncodeElement(Alias(el), start) +} + +func (e EventList) MarshalYAML() (any, error) { + return e.Events, nil +} + +func (e *EventList) UnmarshalYAML(value *yaml.Node) error { + return value.Decode(&e.Events) +} + +type EventStruct struct { + EventCode string `yaml:"EventCode"` + CommandKey string `yaml:"CommandKey"` +} + +func (msg Inform) GetName() string { return "Inform" } +func (msg Inform) GenerateResponse() Message { return InformResponse{MaxEnvelopes: 1} } + +type InformResponse struct { + XMLName xml.Name `xml:"InformResponse" yaml:"-" json:"-"` + MaxEnvelopes uint `yaml:"MaxEnvelopes"` +} + +func (i InformResponse) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "InformResponse") + type Alias InformResponse + return e.EncodeElement(Alias(i), start) +} + +func (msg InformResponse) GetName() string { return "InformResponse" } diff --git a/internal/device/cwmp/messages/Inform_test.go b/internal/device/cwmp/messages/Inform_test.go new file mode 100644 index 0000000..90f2a4c --- /dev/null +++ b/internal/device/cwmp/messages/Inform_test.go @@ -0,0 +1,238 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/configuration" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +const ( + InformInputXML = ` + + + + ExampleCorp + ABCDEF + RouterModelX + 1234567890 + + + + + 0 BOOTSTRAP + + + + 1 BOOT + + + + + 1 + + 2026-03-29T12:00:00Z + + 5 + + + + Device.DeviceInfo.SoftwareVersion + 1.0.0 + + + Device.DeviceInfo.HardwareVersion + RevA + + + Device.ManagementServer.ConnectionRequestURL + http://192.168.1.1:7547/ + + + Device.WANDevice.1.WANConnectionDevice.1.WANIPConnection.1.ExternalIPAddress + 203.0.113.45 + + + +` + + InformOutputXML = ` + + Nokia + NOKIA + Beacon 9 + SN1234567890 + + + + 1 BOOT + 12345 + + + 6 CONNECTION REQUEST + 67890 + + + 1 + 2026-03-29T23:45:00Z + 2 + + + Device.SoftwareModules.ExecutionUnit.1.Version + 1.0.0 + + +` + + InformInputYAML = `DeviceId: + Manufacturer: Nokia + OUI: NOKIA + ProductClass: Beacon 9 + SerialNumber: SN1234567890 +Event: + - EventCode: 1 BOOT + CommandKey: "12345" + - EventCode: 6 CONNECTION REQUEST + CommandKey: "67890" +MaxEnvelopes: 1 +CurrentTime: "2026-03-29T23:45:00Z" +RetryCount: 2 +ParameterList: + - Name: Device.SoftwareModules.ExecutionUnit.1.Version + Type: xsd:string + Value: 1.0.0 +` + + InformResponseInputXML = ` + 1 +` +) + +func TestInformParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(InformInputXML) + enc := xml.NewDecoder(buf) + msg := messages.Inform{} + if err := enc.Decode(&msg); err != nil { + t.Logf("Failed parsing xml input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, "ExampleCorp", msg.DeviceId.Manufacturer) + assert.Equal(t, "ABCDEF", msg.DeviceId.OUI) + assert.Equal(t, "RouterModelX", msg.DeviceId.ProductClass) + assert.Equal(t, "1234567890", msg.DeviceId.SerialNumber) + assert.Equal(t, uint(1), msg.MaxEnvelopes) + assert.Equal(t, "2026-03-29T12:00:00Z", msg.CurrentTime) + assert.Equal(t, uint(5), msg.RetryCount) + assert.Equal(t, 2, len(msg.Event.Events)) + assert.Equal(t, messages.EventBoot, msg.Event.Events[1].EventCode) + assert.Equal(t, 4, len(msg.ParameterList.Params)) + assert.Equal(t, "Device.WANDevice.1.WANConnectionDevice.1.WANIPConnection.1.ExternalIPAddress", msg.ParameterList.Params[3].Name.String()) + assert.Equal(t, "203.0.113.45", msg.ParameterList.Params[3].Content.Value.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[3].Content.Type) +} + +func TestInformSerializeToXML(t *testing.T) { + msg := messages.Inform{ + DeviceId: messages.DeviceIDStruct{ + Manufacturer: "Nokia", + OUI: "NOKIA", + ProductClass: "Beacon 9", + SerialNumber: "SN1234567890", + }, + Event: messages.EventList{ + Events: []messages.EventStruct{ + messages.EventStruct{EventCode: messages.EventBoot, CommandKey: "12345"}, + messages.EventStruct{EventCode: messages.EventConnectionRequest, CommandKey: "67890"}, + }, + }, + MaxEnvelopes: 1, + RetryCount: 2, + CurrentTime: "2026-03-29T23:45:00Z", + ParameterList: messages.ParameterValueListStruct{ + Params: []messages.ParameterValueStruct{ + messages.ParameterValueStruct{ + Name: configuration.T("Device.SoftwareModules.ExecutionUnit.1.Version"), + Content: messages.NodeStruct{ + Type: messages.XsdString, + Value: configuration.T("1.0.0"), + }, + }, + }, + }, + } + + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating xml output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, InformOutputXML, buf.String()) +} + +func TestInformUmarshalFromYaml(t *testing.T) { + buf := bytes.NewBufferString(InformInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.Inform{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing yaml input: %s", err.Error()) + t.FailNow() + } + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(outbuf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating xml output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, InformOutputXML, outbuf.String()) +} + +func TestInformMarshalToYaml(t *testing.T) { + buf := bytes.NewBufferString(InformOutputXML) + dec := xml.NewDecoder(buf) + msg := messages.Inform{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing xml input: %s", err.Error()) + t.FailNow() + } + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating yaml output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, InformInputYAML, outbuf.String()) +} + +func TestInformResponseParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(InformResponseInputXML) + dec := xml.NewDecoder(buf) + msg := messages.InformResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, uint(1), msg.MaxEnvelopes) +} + +func TestInformResponseSerializeToXML(t *testing.T) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + msg := messages.InformResponse{MaxEnvelopes: 1} + if err := enc.Encode(&msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, InformResponseInputXML, buf.String()) +} diff --git a/internal/device/cwmp/messages/Messages.go b/internal/device/cwmp/messages/Messages.go new file mode 100644 index 0000000..f57b2a6 --- /dev/null +++ b/internal/device/cwmp/messages/Messages.go @@ -0,0 +1,138 @@ +package messages + +import ( + "corteca/internal/configuration" + "encoding/xml" + "fmt" + "strings" + + "gopkg.in/yaml.v3" +) + +const ( + XsdString string = "xsd:string" + XsdUnsignedint string = "xsd:unsignedInt" +) + +const ( + SoapArray string = "soap-enc:array" + XsiType string = "xsi:type" +) + +const ( + EventBootStrap string = "0 BOOTSTRAP" + EventBoot string = "1 BOOT" + EventPeriodic string = "2 PERIODIC" + EventScheduled string = "3 SCHEDULED" + EventValueChange string = "4 VALUE CHANGE" + EventKicked string = "5 KICKED" + EventConnectionRequest string = "6 CONNECTION REQUEST" + EventTransferComplete string = "7 TRANSFER COMPLETE" +) + +// helpers +func PrefixName(prefix string, name *xml.Name, elems ...string) { + if len(elems) > 0 { + name.Local = strings.Join(elems, ".") + } + name.Local = fmt.Sprintf("%s:%s", prefix, name.Local) +} + +func PrefixSoapEnv(name *xml.Name, elems ...string) { + PrefixName("soap-env", name, elems...) +} + +func PrefixCwmp(name *xml.Name, elems ...string) { + PrefixName("cwmp", name, elems...) +} + +func XmlAttr(name, value string) xml.Attr { + return xml.Attr{ + Name: xml.Name{Local: name}, + Value: value, + } +} + +type Message interface { + GetName() string +} + +type SyncRPC interface { + Message + ValidateResponse(Message) error +} + +type AsyncRPC interface { + SyncRPC + Match(m Message) bool +} + +type ACSMethod interface { + Message + GenerateResponse() Message +} + +func ExpectMessage[T Message](m Message) error { + if _, ok := m.(T); !ok { + return fmt.Errorf("unexpected %s received", m.GetName()) + } + return nil +} + +type NodeStruct struct { + Type string `xml:"type,attr,omitempty" yaml:"Type,omitempty"` + Value configuration.TemplateField `xml:",chardata" yaml:"Value"` +} + +func (ns NodeStruct) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + return e.EncodeElement(struct { + Type string `xml:"xsi:type,attr,omitempty"` + Value string `xml:",chardata"` + }{Type: ns.Type, Value: ns.Value.String()}, start) +} + +type ParameterNameListStruct struct { + Params []configuration.TemplateField `xml:"string"` +} + +func (pl ParameterNameListStruct) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + start.Attr = append(start.Attr, XmlAttr(SoapArray, fmt.Sprintf("xsd:string[%d]", len(pl.Params)))) + type Alias ParameterNameListStruct + return enc.EncodeElement(Alias(pl), start) +} + +func (pl ParameterNameListStruct) MarshalYAML() (any, error) { + return pl.Params, nil +} + +func (pl *ParameterNameListStruct) UnmarshalYAML(value *yaml.Node) error { + return value.Decode(&pl.Params) +} + +type ParameterValueListStruct struct { + Params []ParameterValueStruct `xml:"ParameterValueStruct"` +} + +func (pl ParameterValueListStruct) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + start.Attr = append(start.Attr, XmlAttr(SoapArray, fmt.Sprintf("cwmp:ParameterValueStruct[%d]", len(pl.Params)))) + type Alias ParameterValueListStruct + return enc.EncodeElement(Alias(pl), start) +} + +func (pl ParameterValueListStruct) MarshalYAML() (any, error) { + return pl.Params, nil +} + +func (pl *ParameterValueListStruct) UnmarshalYAML(value *yaml.Node) error { + return value.Decode(&pl.Params) +} + +type ParameterValueStruct struct { + Name configuration.TemplateField `xml:"Name" yaml:"Name"` + Content NodeStruct `xml:"Value" yaml:",inline"` +} + +type FaultStruct struct { + FaultCode uint `yaml:"FaultCode"` + FaultString string `yaml:"FaultString"` +} diff --git a/internal/device/cwmp/messages/SetParameterValues.go b/internal/device/cwmp/messages/SetParameterValues.go new file mode 100644 index 0000000..a91c9c2 --- /dev/null +++ b/internal/device/cwmp/messages/SetParameterValues.go @@ -0,0 +1,35 @@ +package messages + +import ( + "encoding/xml" +) + +type SetParameterValues struct { + XMLName xml.Name `xml:"SetParameterValues" yaml:"-"` + ParameterList ParameterValueListStruct `yaml:"ParameterList"` + ParameterKey string `yaml:"ParameterKey"` +} + +func (msg SetParameterValues) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name, "SetParameterValues") + type Alias SetParameterValues + return enc.EncodeElement(Alias(msg), start) +} + +func (msg SetParameterValues) GetName() string { return "SetParameterValues" } +func (msg SetParameterValues) ValidateResponse(resp Message) error { + return ExpectMessage[SetParameterValuesResponse](resp) +} + +type SetParameterValuesResponse struct { + XMLName xml.Name `xml:"SetParameterValuesResponse" yaml:"-"` + Status uint `yaml:"Status"` +} + +func (msg SetParameterValuesResponse) MarshalXML(enc *xml.Encoder, start xml.StartElement) error { + PrefixCwmp(&start.Name) + type Alias SetParameterValuesResponse + return enc.EncodeElement(Alias(msg), start) +} + +func (msg SetParameterValuesResponse) GetName() string { return "SetParameterValuesResponse" } diff --git a/internal/device/cwmp/messages/SetParameterValues_test.go b/internal/device/cwmp/messages/SetParameterValues_test.go new file mode 100644 index 0000000..293206c --- /dev/null +++ b/internal/device/cwmp/messages/SetParameterValues_test.go @@ -0,0 +1,175 @@ +package messages_test + +import ( + "bytes" + "corteca/internal/configuration" + "corteca/internal/device/cwmp/messages" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +const ( + SetParameterValuesInputXML = ` + + + Device.DeviceInfo.SoftwareVersion + 2.0.0 + + + Device.DeviceInfo.HardwareVersion + RevB + + + mykey +` + + SetParameterValuesInputYAML = `ParameterList: + - Name: Device.DeviceInfo.SoftwareVersion + Type: xsd:string + Value: 2.0.0 + - Name: Device.DeviceInfo.HardwareVersion + Type: xsd:string + Value: RevB +ParameterKey: mykey +` + + SetParameterValuesResponseInputXML = ` + 0 +` + + SetParameterValuesResponseInputYAML = `Status: 0 +` +) + +var SetParameterValuesInputMsg = messages.SetParameterValues{ + ParameterList: messages.ParameterValueListStruct{ + Params: []messages.ParameterValueStruct{ + { + Name: configuration.T("Device.DeviceInfo.SoftwareVersion"), + Content: messages.NodeStruct{Type: messages.XsdString, Value: configuration.T("2.0.0")}, + }, + { + Name: configuration.T("Device.DeviceInfo.HardwareVersion"), + Content: messages.NodeStruct{Type: messages.XsdString, Value: configuration.T("RevB")}, + }, + }, + }, + ParameterKey: "mykey", +} + +var SetParameterValuesResponseInputMsg = messages.SetParameterValuesResponse{ + Status: 0, +} + +func TestSetParameterValuesParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(SetParameterValuesInputXML) + dec := xml.NewDecoder(buf) + msg := messages.SetParameterValues{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 2, len(msg.ParameterList.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterList.Params[0].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[0].Content.Type) + assert.Equal(t, "2.0.0", msg.ParameterList.Params[0].Content.Value.String()) + assert.Equal(t, "Device.DeviceInfo.HardwareVersion", msg.ParameterList.Params[1].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[1].Content.Type) + assert.Equal(t, "RevB", msg.ParameterList.Params[1].Content.Value.String()) + assert.Equal(t, "mykey", msg.ParameterKey) +} + +func TestSetParameterValuesSerializeToXML(t *testing.T) { + msg := SetParameterValuesInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, SetParameterValuesInputXML, buf.String()) +} + +func TestSetParameterValuesParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(SetParameterValuesInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.SetParameterValues{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, 2, len(msg.ParameterList.Params)) + assert.Equal(t, "Device.DeviceInfo.SoftwareVersion", msg.ParameterList.Params[0].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[0].Content.Type) + assert.Equal(t, "2.0.0", msg.ParameterList.Params[0].Content.Value.String()) + assert.Equal(t, "Device.DeviceInfo.HardwareVersion", msg.ParameterList.Params[1].Name.String()) + assert.Equal(t, messages.XsdString, msg.ParameterList.Params[1].Content.Type) + assert.Equal(t, "RevB", msg.ParameterList.Params[1].Content.Value.String()) + assert.Equal(t, "mykey", msg.ParameterKey) +} + +func TestSetParameterValuesSerializeToYAML(t *testing.T) { + msg := SetParameterValuesInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, SetParameterValuesInputYAML, outbuf.String()) +} + +func TestSetParameterValuesResponseParseFromXML(t *testing.T) { + buf := bytes.NewBufferString(SetParameterValuesResponseInputXML) + dec := xml.NewDecoder(buf) + msg := messages.SetParameterValuesResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing XML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, uint(0), msg.Status) +} + +func TestSetParameterValuesResponseSerializeToXML(t *testing.T) { + msg := SetParameterValuesResponseInputMsg + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := xml.NewEncoder(buf) + enc.Indent("", " ") + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating XML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, SetParameterValuesResponseInputXML, buf.String()) +} + +func TestSetParameterValuesResponseParseFromYAML(t *testing.T) { + buf := bytes.NewBufferString(SetParameterValuesResponseInputYAML) + dec := yaml.NewDecoder(buf) + msg := messages.SetParameterValuesResponse{} + if err := dec.Decode(&msg); err != nil { + t.Logf("Failed parsing YAML input: %s", err.Error()) + t.FailNow() + } + + assert.Equal(t, uint(0), msg.Status) +} + +func TestSetParameterValuesResponseSerializeToYAML(t *testing.T) { + msg := SetParameterValuesResponseInputMsg + outbuf := bytes.NewBuffer(make([]byte, 0, 1024)) + enc := yaml.NewEncoder(outbuf) + enc.SetIndent(4) + if err := enc.Encode(msg); err != nil { + t.Logf("Failed generating YAML output: %s", err.Error()) + t.FailNow() + } + assert.Equal(t, SetParameterValuesResponseInputYAML, outbuf.String()) +} diff --git a/internal/device/cwmpdevice.go b/internal/device/cwmpdevice.go deleted file mode 100644 index 42aa447..0000000 --- a/internal/device/cwmpdevice.go +++ /dev/null @@ -1,453 +0,0 @@ -package device - -import ( - "bufio" - "context" - "corteca/internal/configuration" - "corteca/internal/cwmp/messages" - "corteca/internal/cwmp/models" - "corteca/internal/dispatcher" - "corteca/internal/tui" - - "fmt" - "io" - "net" - "net/http" - "net/url" - "os" - "strconv" - "strings" - "time" - - digestAuthClient "github.com/xinsnake/go-http-digest-auth-client" -) - -const ( - defaultCWMPPort = 7547 -) - -type CWMPDevice struct { - endpoint configuration.Endpoint - protocol string - server *http.Server - resultChan chan *models.ResultsMessage - taskChannel chan messages.Message - readWriter *bufio.ReadWriter - connection net.Conn - Log *Logger - lastCommandKey string -} - -func NewCWMPDevice(endpoint configuration.Endpoint, logfile string) (*CWMPDevice, error) { - device, _ := newCWMP(endpoint, logfile) - - device.protocol = "http" - - return device, nil -} - -func NewCWMPsDevice(endpoint configuration.Endpoint, logfile string) (*CWMPDevice, error) { - device, _ := newCWMP(endpoint, logfile) - - device.protocol = "https" - - return device, nil -} - -func newCWMP(endpoint configuration.Endpoint, logfile string) (*CWMPDevice, error) { - logger := &Logger{} - logger.SetLogFile(logfile) - - return &CWMPDevice{ - endpoint: endpoint, - resultChan: make(chan *models.ResultsMessage), - taskChannel: make(chan messages.Message), - Log: logger, - lastCommandKey: "", - }, nil - -} - -func (c *CWMPDevice) initServer(address string) error { - c.server = &http.Server{ - Addr: address, - Handler: nil, // uses default mux - } - - listener, err := net.Listen("tcp", address) - if err != nil { - return fmt.Errorf("failed to listen: %v", err) - } - - http.HandleFunc("/", c.handleTr069) - tui.DisplaySuccessMsg(fmt.Sprintf("Starting CWMP server on address %v...", address)) - // Run server in a goroutine - go func() { - if err := c.server.Serve(listener); err != nil && err != http.ErrServerClosed { - tui.DisplayErrorMsg(fmt.Sprintf("Failed to start server: %s", err)) - os.Exit(1) - } - }() - - return nil -} - -func (d *CWMPDevice) GetProtocol() int { - return ConnectionCWMP -} - -func (c *CWMPDevice) Connect() (dispatcher.Dispatcher, error) { - var address string - u, err := url.Parse(c.endpoint.CwmpServerAddr) - if err != nil { - return nil, err - } - - if u.Port() == "" { - address = u.Host + ":" + strconv.Itoa(defaultCWMPPort) - } else { - address = u.Host - } - - if err = c.initServer(address); err != nil { - return nil, err - } - - connectionReqURL, err := url.Parse(c.endpoint.Addr.String()) - if err != nil { - return nil, err - } - err = c.checkConnReqValues(connectionReqURL) - if err != nil { - tui.DisplayErrorMsg(fmt.Sprintf("skipping connection request to CPE device: %s", err)) - } else { - err := c.SendConnectionRequest() - if err != nil { - tui.DisplayErrorMsg(err.Error()) - } - } - - tui.DisplaySuccessMsg("Waiting for CPE to establish connection...") - - return dispatcher.NewCWMPDispatcher(c.taskChannel, c.resultChan), nil -} - -func (c *CWMPDevice) Close() { - // Graceful shutdown - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := c.server.Shutdown(ctx); err != nil { - tui.DisplayErrorMsg(fmt.Sprintf("Server Shutdown Failed: %s", err)) - return - } - - tui.DisplaySuccessMsg("Server stopped gracefully!") -} - -func (c *CWMPDevice) checkConnReqValues(u *url.URL) error { - - if u.Hostname() == "" { - return fmt.Errorf("connection request URL: empty hostname") - } else { - if u.Port() == "" { - return fmt.Errorf("connection request URL: empty port") - } - - c.endpoint.Addr.RawTemplate = c.protocol + "://" + u.Host - } - - user := u.User.Username() - if user == "" && configuration.GetCmdContext().Device.Username.String() == "" { - return fmt.Errorf("connection request username: is empty") - } else if user != "" { - c.endpoint.Username.RawTemplate = user - } else { - c.endpoint.Username.RawTemplate = configuration.GetCmdContext().Device.Username.String() - } - - pass, ok := u.User.Password() - if !ok && configuration.GetCmdContext().Device.Password.String() == "" { - return fmt.Errorf("connection request password: is empty") - } else if ok { - c.endpoint.Password.RawTemplate = pass - } else { - c.endpoint.Password.RawTemplate = configuration.GetCmdContext().Device.Password.String() - } - - return nil -} - -func (c *CWMPDevice) SendConnectionRequest() error { - dr := digestAuthClient.NewRequest( - c.endpoint.Username.String(), - c.endpoint.Password.String(), - "GET", - c.endpoint.Addr.String(), - "", - ) - resp, err := dr.Execute() - if err != nil { - return fmt.Errorf("error sending Connection Request: %v", err) - } - defer resp.Body.Close() - - tui.LogNormal("Connection Request sent. Status code: %d", resp.StatusCode) - return nil -} - -func (c *CWMPDevice) handleTr069(w http.ResponseWriter, r *http.Request) { - var err error - hj, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "Hijacking not supported", http.StatusInternalServerError) - return - } - - c.connection, c.readWriter, err = hj.Hijack() - if err != nil { - tui.LogError("Hijacking error: %v", err) - return - } - - defer c.connection.Close() - - var requestBody []byte - if r.Method == "POST" { - // receive posted data - requestBody, err = io.ReadAll(r.Body) - if err != nil { - tui.LogError("tr069 read body error") - return - } - } - - var msg messages.Message - msg, err = messages.ParseXML(requestBody) - if err != nil { - tui.DisplayErrorMsg(err.Error()) - } - - msgResponse, _ := c.createResponseMessage(msg) - c.sendReply(msgResponse) - tui.DisplaySuccessMsg("Connection with CPE established") - c.handleSessionFlow() -} - -func ParseMessage(reader *bufio.Reader, contentLength int) (messages.Message, error) { - body := make([]byte, contentLength) - _, err := io.ReadFull(reader, body) - if err != nil { - return nil, fmt.Errorf("reading HTTP body error: %v", err) - } - - return messages.ParseXML(body) -} - -func createOutgoingFaultMsg(reply messages.Message) (xml []byte) { - switch reply.GetName() { - case "InformResponse": - fault := messages.NewFault() - fault.MsgFaultCode = "8002" - fault.MsgFaultString = "error creating inform response" - fault.ID = reply.GetID() - fault.CwmpFaultCode = "Server" - fault.CwmpFaultString = "CWMP fault" - xml, _ = fault.CreateXML() - } - return xml -} - -// Generates a response and a result type message according to the incoming message -// If response message is nil, then no response shall be send to CPE -func (c *CWMPDevice) createResponseMessage(msg messages.Message) (response messages.Message, result *models.ResultsMessage) { - switch msg.GetName() { - case "Inform": - inform := msg.(*messages.Inform) - configuration.GetCmdContext().Device.Addr = configuration.TemplateField{RawTemplate: inform.Params["Device.ManagementServer.ConnectionRequestURL"]} - informResponse := new(messages.InformResponse) - informResponse.ID = inform.ID - informResponse.MaxEnvelopes = 1 - - response = informResponse - case "GetParameterNamesResponse": - result = models.NewResulMessage() - result.Code = 0 - result.Message = msg.(*messages.GetParameterNamesResponse) - - response = nil - case "GetParameterValuesResponse": - result = models.NewResulMessage() - result.Code = 0 - result.Message = msg.(*messages.GetParameterValuesResponse) - - response = nil - case "SetParameterValuesResponse": - status := msg.(*messages.SetParameterValuesResponse).Status - result = models.NewResulMessage() - result.Code = status - result.Message = msg.(*messages.SetParameterValuesResponse) - - response = nil - case "ChangeDUStateResponse": - response = nil - case "DUStateChangeComplete": - ducomplete := msg.(*messages.DUStateChangeComplete) - //if complete is from a previous task do not send results. wait for task - if ducomplete.CommandKey == c.lastCommandKey { - result = models.NewResulMessage() - result.Code = ducomplete.Fault.FaultCode - result.Message = ducomplete - c.lastCommandKey = "" - } else { - tui.LogError("CommandKey not matching. %s != %s", ducomplete.CommandKey, c.lastCommandKey) - } - - completeResp := messages.NewDUStateCompleteResponse() - completeResp.ID = ducomplete.ID - response = completeResp - case "Fault": - faultMsg := msg.(*messages.Fault) - result = models.NewResulMessage() - result.Code, _ = strconv.Atoi(faultMsg.MsgFaultCode) - result.Message = faultMsg - response = nil - default: - fault := messages.NewFault() - fault.CwmpFaultCode = "8002" - fault.CwmpFaultString = "internal error" - fault.ID = msg.GetID() - fault.CwmpFaultCode = "Server" - fault.CwmpFaultString = "CWMP fault" - response = fault - - result = models.NewResulMessage() - result.Code, _ = strconv.Atoi(fault.MsgFaultCode) - result.Message = fault - } - - return response, result -} - -func (c *CWMPDevice) sendReply(msg messages.Message) { - if msg == nil { - c.sendEmptyResponse() - return - } - response, err := msg.CreateXML() - if err != nil { - response = createOutgoingFaultMsg(msg) - } - fmt.Fprintf(c.readWriter, "HTTP/1.1 200 OK\r\nContent-Type: text/xml\r\nContent-Length: %d\r\n\r\n%s", len(response), response) - c.readWriter.Flush() -} - -func (c *CWMPDevice) sendEmptyResponse() { - fmt.Fprint(c.readWriter, "HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n") - c.readWriter.Flush() -} - -func (c *CWMPDevice) handleSessionFlow() { - reader := bufio.NewReader(c.connection) - for { - c.connection.SetReadDeadline(time.Now().Add(10 * time.Second)) - headers, err := readHeaders(reader) - - if err != nil { - // if err == io.EOF { - // log.Println("Connection closed by the CPE") - // } - if err != io.EOF { - tui.LogError("Reading HTTP headers error: %v", err) - } - // Stop reading headers in case of error or connetion termination by the CPE - break - } - - contentLength := GetMessageLength(headers) - - var msg messages.Message - if contentLength == 0 { - if err = c.sendRPC(); err != nil { - error_res := models.NewResulMessage() - error_res.Code = -1 - c.resultChan <- error_res - } - continue - } else { - msg, err = ParseMessage(reader, contentLength) - if err != nil { - c.sendEmptyResponse() - err_msg := messages.NewFault() - err_msg.MsgFaultCode = "-1" - err_msg.MsgFaultString = err.Error() - c.resultChan <- &models.ResultsMessage{Code: -1, Message: err_msg} - continue - } - } - - resp, res := c.createResponseMessage(msg) - c.sendReply(resp) - // if there are results to return to dispatcher - // send to CPE empty message (no more to send) - // and send results to dispatcher - if res != nil { - c.sendEmptyResponse() - c.resultChan <- res - } - } -} - -func readHeaders(reader *bufio.Reader) (map[string]string, error) { - headers := make(map[string]string) - for { - line, err := reader.ReadString('\n') - if err != nil { - return nil, err - } - - line = strings.TrimSpace(line) - if line == "" { - break - } - - if strings.HasPrefix(line, "POST") || strings.HasPrefix(line, "GET") { - headers[":request"] = line - } else { - parts := strings.SplitN(line, ":", 2) - if len(parts) == 2 { - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - headers[strings.ToLower(key)] = value - } - } - } - return headers, nil -} - -func GetMessageLength(headers map[string]string) int { - if val, ok := headers["content-length"]; ok { - var n int - fmt.Sscanf(val, "%d", &n) - return n - } - return 0 -} - -func (c *CWMPDevice) sendRPC() error { - msg := <-c.taskChannel - if msg.GetName() == "ChangeDUState" { - c.lastCommandKey = msg.(*messages.ChangeDUState).CommandKey - } - - rpcXML, err := msg.CreateXML() - if err != nil { - c.sendEmptyResponse() - return err - } - - fmt.Fprintf(c.readWriter, "HTTP/1.1 200 OK\r\nContent-Type: text/xml\r\nContent-Length: %d\r\n\r\n%s", len(rpcXML), string(rpcXML)) - c.readWriter.Flush() - c.sendEmptyResponse() - return nil -} diff --git a/internal/device/device.go b/internal/device/device.go index efc1b47..a57143e 100644 --- a/internal/device/device.go +++ b/internal/device/device.go @@ -6,92 +6,41 @@ package device import ( "corteca/internal/configuration" - "corteca/internal/dispatcher" "fmt" + "io" "net/url" - "os" "strings" ) -// Connection types -const ( - ConnectionTelnet = iota - ConnectionTelnetS - ConnectionSSH - ConnectionFIFO - ConnectionCWMP -) - -// Command constants -const ( - cmdLCMList = "lcm list" - cmdGrepPluginMgr = "pgrep PluginMgr" -) - -// Logger handles logging to a file or standard output -type Logger struct { - LogFile *os.File -} - // Device defines the interface for device operations type Device interface { - Connect() (dispatcher.Dispatcher, error) Close() - GetProtocol() int + GetProtocol() string + configuration.CommandExecutor } -// NewDevice is a factory method that creates a Device based on the endpoint protocol -func NewDevice(endpoint configuration.Endpoint, logfile string) (Device, error) { - u, err := url.Parse(endpoint.Addr.String()) - if err != nil { - return nil, fmt.Errorf("failed to parse endpoint address: %w", err) - } +type DeviceCreator func(*configuration.DeviceConfig, io.Writer) (Device, error) - switch strings.ToLower(u.Scheme) { - case "ssh": - return NewSSHDevice(endpoint, logfile) - case "cwmp": - return NewCWMPDevice(endpoint, logfile) - case "cwmps": - return NewCWMPsDevice(endpoint, logfile) - default: - return nil, fmt.Errorf("unsupported connection type: %s", u.Scheme) - } -} +var deviceTypeRegistry map[string]DeviceCreator -// NewLogger initializes a Logger instance -func NewLogger(filename string) (*Logger, error) { - logger := &Logger{} - if err := logger.SetLogFile(filename); err != nil { - return nil, err - } - return logger, nil +func RegisterDeviceType(typename string, creator DeviceCreator) { + deviceTypeRegistry[strings.ToLower(typename)] = creator } -// SetLogFile configures the log output destination -func (logger *Logger) SetLogFile(filename string) error { - switch filename { - case "stdout": - logger.LogFile = os.Stdout - case "stderr": - logger.LogFile = os.Stderr - default: - file, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) - if err != nil { - return fmt.Errorf("failed to open log file: %w", err) - } - logger.LogFile = file - } - return nil +func init() { + deviceTypeRegistry = make(map[string]DeviceCreator) } -// DetectContainerFramework returns the container framework type based on command results -func DetectContainerFramework(d dispatcher.Dispatcher) string { - if _, err := d.ExecuteCommand(cmdLCMList); err == nil { - return "oci" +// NewDevice is a factory method that creates a Device based on the endpoint protocol +func NewDevice(config *configuration.DeviceConfig, log io.Writer) (Device, error) { + u, err := url.Parse(config.Addr.String()) + if err != nil { + return nil, fmt.Errorf("failed to parse endpoint address: %w", err) } - if _, err := d.ExecuteCommand(cmdGrepPluginMgr); err == nil { - return "rootfs" + typename := strings.ToLower(u.Scheme) + if creator, found := deviceTypeRegistry[typename]; found { + return creator(config, log) + } else { + return nil, fmt.Errorf("unsupported device connection type '%s'", typename) } - return "" } diff --git a/internal/device/device_test.go b/internal/device/device_test.go index ac1ca41..c30a5a8 100644 --- a/internal/device/device_test.go +++ b/internal/device/device_test.go @@ -1,122 +1,125 @@ +// Copyright 2024 Nokia +// Licensed under the BSD 3-Clause License. +// SPDX-License-Identifier: BSD-3-Clause + package device_test import ( + "context" "corteca/internal/configuration" "corteca/internal/device" - "errors" - "os" + "io" "testing" ) -// MockDispatcher simulates dispatcher.Dispatcher behavior -type MockDispatcher struct { - Responses map[string]string - Failures map[string]error - printFormat string +// mockDevice is a minimal Device implementation used in tests. +type mockDevice struct { + protocol string } -func (m *MockDispatcher) ExecuteCommand(cmd any) (string, error) { - commandStr, ok := cmd.(string) - if !ok { - return "", errors.New("invalid command type") - } - if err, exists := m.Failures[commandStr]; exists { - return "", err - } - if output, exists := m.Responses[commandStr]; exists { - return output, nil - } - return "", nil +func (m *mockDevice) Close() {} + +func (m *mockDevice) GetProtocol() string { + return m.protocol } -func (m *MockDispatcher) SetPrintFormat(format string) { - m.printFormat = format +func (m *mockDevice) BeginSequence() error { + return nil } -// --- Tests --- +func (m *mockDevice) ExecuteCommand(_ context.Context, _ *configuration.SequenceCmd) (any, error) { + return nil, nil +} -func TestNewDevice_SSH(t *testing.T) { - endpoint := configuration.Endpoint{ - Addr: configuration.TemplateField{RawTemplate: "ssh://user@localhost"}, - } - dev, err := device.NewDevice(endpoint, "stdout") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if dev.GetProtocol() != device.ConnectionSSH { - t.Errorf("expected SSH protocol, got %d", dev.GetProtocol()) - } +func (m *mockDevice) EndSequence() error { + return nil } -func TestNewDevice_Unsupported(t *testing.T) { - endpoint := configuration.Endpoint{ - Addr: configuration.TemplateField{RawTemplate: "ftp://localhost"}, - } - _, err := device.NewDevice(endpoint, "stdout") - if err == nil { - t.Fatal("expected error for unsupported protocol") +// makeCreator returns a DeviceCreator that records whether it was called and +// returns a mockDevice with the given protocol label. +func makeCreator(protocol string, called *bool) device.DeviceCreator { + return func(cfg *configuration.DeviceConfig, log io.Writer) (device.Device, error) { + *called = true + return &mockDevice{protocol: protocol}, nil } } -func TestNewLogger_Stdout(t *testing.T) { - logger, err := device.NewLogger("stdout") - if err != nil { - t.Fatalf("expected no error, got %v", err) +func TestNewDevice_CorrectCreatorIsDispatched(t *testing.T) { + tests := []struct { + name string + schema string + protocol string + }{ + {name: "alpha schema", schema: "alpha", protocol: "alpha"}, + {name: "beta schema", schema: "beta", protocol: "beta"}, + {name: "gamma schema", schema: "gamma", protocol: "gamma"}, } - if logger == nil || logger.LogFile != os.Stdout { - t.Error("expected logger to use stdout") - } -} -func TestNewLogger_File(t *testing.T) { - filename := "test.log" - defer os.Remove(filename) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + calledAlpha := false + calledBeta := false + calledGamma := false - logger, err := device.NewLogger(filename) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if logger == nil || logger.LogFile == nil { - t.Error("expected logger to open file") - } -} + device.RegisterDeviceType("alpha", makeCreator("alpha", &calledAlpha)) + device.RegisterDeviceType("beta", makeCreator("beta", &calledBeta)) + device.RegisterDeviceType("gamma", makeCreator("gamma", &calledGamma)) -func TestDetectContainerFramework_OCI(t *testing.T) { - mock := &MockDispatcher{ - Responses: map[string]string{ - "lcm list": "running", - }, - } - result := device.DetectContainerFramework(mock) - if result != "oci" { - t.Errorf("expected 'oci', got %s", result) + cfg := &configuration.DeviceConfig{ + Endpoint: configuration.Endpoint{ + Addr: configuration.T(tc.schema + "://some-host"), + }, + } + + dev, err := device.NewDevice(cfg, io.Discard) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dev.GetProtocol() != tc.protocol { + t.Errorf("expected protocol %q, got %q", tc.protocol, dev.GetProtocol()) + } + + // Verify exactly the right creator was called. + if tc.schema == "alpha" && !calledAlpha { + t.Error("expected alpha creator to be called") + } + if tc.schema == "beta" && !calledBeta { + t.Error("expected beta creator to be called") + } + if tc.schema == "gamma" && !calledGamma { + t.Error("expected gamma creator to be called") + } + + // Verify the other creators were NOT called. + if tc.schema != "alpha" && calledAlpha { + t.Error("alpha creator should not have been called") + } + if tc.schema != "beta" && calledBeta { + t.Error("beta creator should not have been called") + } + if tc.schema != "gamma" && calledGamma { + t.Error("gamma creator should not have been called") + } + }) } } -func TestDetectContainerFramework_RootFS(t *testing.T) { - mock := &MockDispatcher{ - Failures: map[string]error{ - "lcm list": errors.New("not found"), - }, - Responses: map[string]string{ - "pgrep PluginMgr": "PluginMgr", +func TestNewDevice_UnknownSchema_ReturnsError(t *testing.T) { + device.RegisterDeviceType("alpha", makeCreator("alpha", new(bool))) + device.RegisterDeviceType("beta", makeCreator("beta", new(bool))) + device.RegisterDeviceType("gamma", makeCreator("gamma", new(bool))) + + cfg := &configuration.DeviceConfig{ + Endpoint: configuration.Endpoint{ + Addr: configuration.T("unknown://some-host"), }, } - result := device.DetectContainerFramework(mock) - if result != "rootfs" { - t.Errorf("expected 'rootfs', got %s", result) - } -} -func TestDetectContainerFramework_Unknown(t *testing.T) { - mock := &MockDispatcher{ - Failures: map[string]error{ - "lcm list": errors.New("not found"), - "pgrep PluginMgr": errors.New("not found"), - }, + dev, err := device.NewDevice(cfg, io.Discard) + if err == nil { + t.Fatal("expected an error for unknown schema, got nil") } - result := device.DetectContainerFramework(mock) - if result != "" { - t.Errorf("expected empty string, got %s", result) + if dev != nil { + t.Errorf("expected nil device for unknown schema, got %v", dev) } } diff --git a/internal/device/ssh/mock_ssh_server_test.go b/internal/device/ssh/mock_ssh_server_test.go new file mode 100644 index 0000000..9252f92 --- /dev/null +++ b/internal/device/ssh/mock_ssh_server_test.go @@ -0,0 +1,173 @@ +// Copyright 2024 Nokia +// Licensed under the BSD 3-Clause License. +// SPDX-License-Identifier: BSD-3-Clause + +package ssh_test + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/binary" + "fmt" + "net" + "strings" + "testing" + + "golang.org/x/crypto/ssh" +) + +// cmdHandlerFunc is called by the mock server for every exec request it receives. +// It returns the stdout to send back and the exit code to report. +type cmdHandlerFunc func(cmd string) (stdout string, exitCode uint32) + +// withQuaggaProbe wraps a cmdHandlerFunc so that the "ps | grep ash" probe fired +// unconditionally by NewSSHDevice is handled transparently (returns "ash", exit 0, +// which signals to the device that Quagga is not active and no further action is +// required). All other commands are forwarded to inner. +func withQuaggaProbe(inner cmdHandlerFunc) cmdHandlerFunc { + return func(cmd string) (string, uint32) { + if strings.TrimSpace(cmd) == "ps | grep ash" { + return "ash", 0 + } + return inner(cmd) + } +} + +// startTestServer starts an in-process SSH server on a random loopback port and +// returns its address. The server accepts: +// - password authentication when password is non-empty +// - public-key authentication when authorizedKey is non-nil +// +// handler is called for every exec request the server receives. +// The server and its goroutines are torn down via t.Cleanup. +func startTestServer(t *testing.T, expectedUsername, password string, authorizedKey ssh.PublicKey, handler cmdHandlerFunc) string { + t.Helper() + + // Generate a fresh host key for every server instance. + hostKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("startTestServer: generate host key: %v", err) + } + signer, err := ssh.NewSignerFromKey(hostKey) + if err != nil { + t.Fatalf("startTestServer: create signer: %v", err) + } + + cfg := &ssh.ServerConfig{} + cfg.AddHostKey(signer) + + if password != "" { + cfg.PasswordCallback = func(meta ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + if expectedUsername != "" && meta.User() != expectedUsername { + return nil, fmt.Errorf("wrong username: got %q, want %q", meta.User(), expectedUsername) + } + if string(pass) == password { + return nil, nil + } + return nil, fmt.Errorf("wrong password") + } + } + + if authorizedKey != nil { + cfg.PublicKeyCallback = func(meta ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if expectedUsername != "" && meta.User() != expectedUsername { + return nil, fmt.Errorf("wrong username: got %q, want %q", meta.User(), expectedUsername) + } + if bytes.Equal(key.Marshal(), authorizedKey.Marshal()) { + return nil, nil + } + return nil, fmt.Errorf("unauthorized key") + } + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("startTestServer: listen: %v", err) + } + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return // listener closed — server done + } + go serveConn(conn, cfg, handler) + } + }() + + return ln.Addr().String() +} + +func serveConn(conn net.Conn, cfg *ssh.ServerConfig, handler cmdHandlerFunc) { + sshConn, chans, reqs, err := ssh.NewServerConn(conn, cfg) + if err != nil { + return // auth failure or protocol error — nothing to do + } + defer sshConn.Close() + go ssh.DiscardRequests(reqs) + + for newChan := range chans { + if newChan.ChannelType() != "session" { + newChan.Reject(ssh.UnknownChannelType, "unsupported channel type") + continue + } + ch, requests, err := newChan.Accept() + if err != nil { + return + } + go serveSession(ch, requests, handler) + } +} + +func serveSession(ch ssh.Channel, reqs <-chan *ssh.Request, handler cmdHandlerFunc) { + // defer ch.Close() is the key: when serveSession returns after handling the + // exec request, the channel is closed. This causes the client-side + // s.wait goroutine's "for msg := range reqs" loop to exit, which populates + // s.exitStatus and unblocks session.Wait() — preventing a deadlock. + defer ch.Close() + + for req := range reqs { + switch req.Type { + case "exec": + if len(req.Payload) < 4 { + req.Reply(false, nil) + continue + } + n := binary.BigEndian.Uint32(req.Payload[:4]) + if uint32(len(req.Payload)) < 4+n { + req.Reply(false, nil) + continue + } + cmd := string(req.Payload[4 : 4+n]) + req.Reply(true, nil) + + // Call the handler synchronously. For normal commands this returns + // quickly. For the context-cancellation test the handler blocks on + // a channel until t.Cleanup fires — that is fine because by the + // time the blocking handler is running, executeCommandString has + // already returned ctx.Err() to the test goroutine. + stdout, exitCode := handler(cmd) + if stdout != "" { + ch.Write([]byte(stdout)) //nolint:errcheck + } + exitStatus := make([]byte, 4) + binary.BigEndian.PutUint32(exitStatus, exitCode) + ch.SendRequest("exit-status", false, exitStatus) //nolint:errcheck + + // Return so that defer ch.Close() fires immediately, closing the + // SSH channel and unblocking the client's session.Wait(). + return + + default: + // Covers signal requests (e.g. SIGKILL on context cancel) and any + // other channel requests the client may send while the handler is + // running or between commands. + if req.WantReply { + req.Reply(false, nil) + } + } + } +} diff --git a/internal/device/ssh/sshdevice.go b/internal/device/ssh/sshdevice.go new file mode 100644 index 0000000..784682d --- /dev/null +++ b/internal/device/ssh/sshdevice.go @@ -0,0 +1,275 @@ +package ssh + +import ( + "bytes" + "context" + "corteca/internal/configuration" + "corteca/internal/device" + "corteca/internal/tui" + "errors" + "fmt" + "io" + "net" + "net/url" + "os" + "strings" + "sync" + "time" + + "golang.org/x/crypto/ssh" + stdssh "golang.org/x/crypto/ssh" +) + +// syncWriter wraps an io.Writer with a mutex so that concurrent writes from +// the SSH session's stdout and stderr goroutines are serialised safely. +type syncWriter struct { + mu sync.Mutex + w io.Writer +} + +func (sw *syncWriter) Write(p []byte) (n int, err error) { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Write(p) +} + +const ( + deactivateQuaggaCmd = "sed -i 's#/usr/bin/vtysh#/bin/ash#' /etc/passwd" + maxNumRetries = 3 + defaultSSHPort = "22" + + authSSHPassword = "password" + authSSHPublicKey = "publicKey" + cmdCPUArch = "uname -m" +) + +type SSHDevice struct { + client *stdssh.Client + log *syncWriter +} + +func init() { + device.RegisterDeviceType("ssh", NewSSHDevice) +} + +func NewSSHDevice(c *configuration.DeviceConfig, log io.Writer) (device.Device, error) { + d := SSHDevice{ + log: &syncWriter{w: log}, + } + var ( + err error + sshconfig configuration.SSHClientEndpoint + ) + c.Decode(&sshconfig) + + if err = d.connectSSHClient(&sshconfig); err != nil { + return nil, err + } + + passwd2 := sshconfig.Password2.String() + + if active, err := hasQuagga(d.client); err != nil { + return nil, err + } else if active && len(passwd2) > 0 { + if err := deactivateQuagga(d.client, passwd2); err != nil { + return nil, err + } + d.client.Close() + tui.LogNormal("Need to reconnect...") + return NewSSHDevice(c, log) + } + return &d, nil +} + +func (d *SSHDevice) BeginSequence() error { + return nil +} + +func (d *SSHDevice) ExecuteCommand(ctx context.Context, cmd *configuration.SequenceCmd) (any, error) { + // interpret cmd.params as SSHParams (i.e. array of strings) + var params struct { + Params []configuration.TemplateField `yaml:"params"` + } + if err := cmd.Decode(¶ms); err != nil { + return nil, fmt.Errorf("incompatible command parameters specified; array of strings expected") + } + + // render params in case they use templates + paramsRendered := make([]string, len(params.Params)) + for i := 0; i < len(params.Params); i++ { + paramsRendered[i] = params.Params[i].String() + } + + // concatenate into a single string + cmdString := fmt.Sprintf("%s %s", cmd.Cmd.String(), strings.Join(paramsRendered, " ")) + + return d.executeCommandString(ctx, cmdString) +} + +func (d *SSHDevice) executeCommandString(ctx context.Context, cmd string) (any, error) { + session, err := d.client.NewSession() + if err != nil { + return nil, fmt.Errorf("cannot start SSH command session: %w", err) + } + defer session.Close() + output := bytes.NewBuffer(make([]byte, 0, 512)) + session.Stdout = io.MultiWriter(d.log, output) + session.Stderr = d.log + + // start command + if err = session.Start(cmd); err != nil { + return nil, err + } + + // create channel to synchronize + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + select { + case err := <-done: + var exitError *stdssh.ExitError + if errors.As(err, &exitError) { + return output, fmt.Errorf("exit code (%d)", exitError.ExitStatus()) + } else { + return nil, err + } + + case <-ctx.Done(): + _ = session.Signal(ssh.SIGKILL) + return nil, ctx.Err() + } +} + +func (d *SSHDevice) EndSequence() error { + return nil +} + +func (d *SSHDevice) GetProtocol() string { + return "ssh" +} + +func (d *SSHDevice) connectSSHClient(sshconfig *configuration.SSHClientEndpoint) error { + u, err := url.Parse(sshconfig.Addr.String()) + if err != nil { + return err + } + if u.Port() == "" { + u.Host = net.JoinHostPort(u.Host, defaultSSHPort) + } + + // Determine username: explicit Username field takes priority over the URL. + username := u.User.Username() + if explicitUser := sshconfig.Username.String(); len(explicitUser) > 0 { + username = explicitUser + } + + config := &stdssh.ClientConfig{ + User: username, + HostKeyCallback: stdssh.InsecureIgnoreHostKey(), // TODO: Replace with secure method + Auth: make([]stdssh.AuthMethod, 0, 2), + } + + // add keyfile, if present + keyPath := sshconfig.PrivateKeyFile.String() + if len(keyPath) > 0 { + key, err := os.ReadFile(keyPath) + if err != nil { + return fmt.Errorf("cannot read private key file %s: %w", keyPath, err) + } + signer, err := stdssh.ParsePrivateKey(key) + if err != nil { + return fmt.Errorf("cannot parse private key file %s: %w", keyPath, err) + } + config.Auth = append(config.Auth, stdssh.PublicKeys(signer)) + } + + // add password, if present + passwd := sshconfig.Password.String() + passwdPresent := len(passwd) > 0 + if passwdPresent { + config.Auth = append(config.Auth, stdssh.Password(passwd)) + } else if passwd, passwdPresent = u.User.Password(); passwdPresent { + config.Auth = append(config.Auth, stdssh.Password(passwd)) + } + + // add prompt for password if no other methods exist + if len(config.Auth) == 0 { + // FIXME: + // the below results in always asking for a password even if the SSH server is not asking for one + // should use something like: config.Auth = append(config.Auth, stdssh.KeyboardInteractive(...)) + passwd, err := tui.PromptForPassword(fmt.Sprintf("%s@%s's password", u.User.Username(), u.Host)) + if err != nil { + return err + } + config.Auth = append(config.Auth, stdssh.Password(passwd)) + } + + // connect to stdssh server + d.client, err = stdssh.Dial("tcp", u.Host, config) + if err != nil { + fmt.Fprintf(d.log, "\n=== New connection to %s at %s ===\n", u.Host, time.Now().Format(time.DateTime)) + } + return err +} + +func hasQuagga(client *stdssh.Client) (bool, error) { + session, err := client.NewSession() + if err != nil { + return false, err + } + defer session.Close() + + var output bytes.Buffer + session.Stdout = &output + + if err := session.Run("ps | grep ash"); err != nil { + return true, nil + } + + return !strings.Contains(output.String(), "ash"), nil +} + +func deactivateQuagga(client *stdssh.Client, password2 string) error { + if password2 == "" { + var err error + password2, err = tui.PromptForPassword("Enter Password2") + if err != nil { + return err + } + } + + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + + if err := session.RequestPty("xterm", 80, 40, stdssh.TerminalModes{}); err != nil { + return err + } + + stdin, err := session.StdinPipe() + if err != nil { + return err + } + + if err := session.Shell(); err != nil { + return err + } + + commands := []string{"shell", password2, deactivateQuaggaCmd} + for _, cmd := range commands { + if _, err := stdin.Write([]byte(cmd + "\n")); err != nil { + return fmt.Errorf("failed to run command %q: %w", cmd, err) + } + time.Sleep(1 * time.Second) + } + + return nil +} + +func (d *SSHDevice) Close() { + d.client.Close() +} diff --git a/internal/device/ssh/sshdevice_test.go b/internal/device/ssh/sshdevice_test.go new file mode 100644 index 0000000..a0f0b22 --- /dev/null +++ b/internal/device/ssh/sshdevice_test.go @@ -0,0 +1,386 @@ +// Copyright 2024 Nokia +// Licensed under the BSD 3-Clause License. +// SPDX-License-Identifier: BSD-3-Clause + +package ssh_test + +import ( + "bytes" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "os" + "strings" + "testing" + "time" + + "corteca/internal/configuration" + devssh "corteca/internal/device/ssh" + + "golang.org/x/crypto/ssh" + "gopkg.in/yaml.v3" +) + +// ============================================================================= +// Test helpers +// ============================================================================= + +const testPassword = "s3cr3t-test-password" + +// mustDeviceConfig unmarshals yamlStr into a *configuration.DeviceConfig, +// ensuring the internal raw yaml.Node is populated (required by DeviceConfig.Decode). +func mustDeviceConfig(t *testing.T, yamlStr string) *configuration.DeviceConfig { + t.Helper() + var cfg configuration.DeviceConfig + if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil { + t.Fatalf("mustDeviceConfig: %v", err) + } + return &cfg +} + +// mustSequenceCmd unmarshals yamlStr into a *configuration.SequenceCmd, +// ensuring the internal raw yaml.Node is populated (required by SequenceCmd.Decode, +// which is called inside SSHDevice.ExecuteCommand to decode the params field). +func mustSequenceCmd(t *testing.T, yamlStr string) *configuration.SequenceCmd { + t.Helper() + var cmd configuration.SequenceCmd + if err := yaml.Unmarshal([]byte(yamlStr), &cmd); err != nil { + t.Fatalf("mustSequenceCmd: %v", err) + } + return &cmd +} + +// ============================================================================= +// Authentication tests +// ============================================================================= + +// TestSSHDevice_PasswordAuth verifies that the correct username and password +// are sent to the server under all four combinations of URL-embedded and +// explicit-field credentials. The priority rule is: explicit Username/Password +// fields in the device config take precedence over credentials embedded in the +// addr URL. +func TestSSHDevice_PasswordAuth(t *testing.T) { + tests := []struct { + name string + expectedUser string + expectedPass string + buildCfg func(addr string) string + }{ + { + // Baseline: both username and password come from the addr URL. + name: "url_credentials_only", + expectedUser: "url-user", + expectedPass: "url-pass", + buildCfg: func(addr string) string { + return fmt.Sprintf("addr: ssh://url-user:url-pass@%s", addr) + }, + }, + { + // Both explicit fields are set; they must override the URL credentials. + name: "explicit_fields_override_url", + expectedUser: "explicit-user", + expectedPass: "explicit-pass", + buildCfg: func(addr string) string { + return fmt.Sprintf( + "addr: ssh://url-user:url-pass@%s\nusername: explicit-user\npassword: explicit-pass\n", + addr, + ) + }, + }, + { + // Only the explicit username field is set; it must override the URL + // username while the password still comes from the URL. + name: "explicit_username_overrides_url", + expectedUser: "explicit-user", + expectedPass: "url-pass", + buildCfg: func(addr string) string { + return fmt.Sprintf( + "addr: ssh://url-user:url-pass@%s\nusername: explicit-user\n", + addr, + ) + }, + }, + { + // Only the explicit password field is set; it must override the URL + // password while the username still comes from the URL. + name: "explicit_password_overrides_url", + expectedUser: "url-user", + expectedPass: "explicit-pass", + buildCfg: func(addr string) string { + return fmt.Sprintf( + "addr: ssh://url-user:url-pass@%s\npassword: explicit-pass\n", + addr, + ) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + addr := startTestServer(t, tc.expectedUser, tc.expectedPass, nil, + withQuaggaProbe(func(cmd string) (string, uint32) { + return "", 0 + }), + ) + + cfg := mustDeviceConfig(t, tc.buildCfg(addr)) + dev, err := devssh.NewSSHDevice(cfg, io.Discard) + if err != nil { + t.Fatalf("expected successful connection, got: %v", err) + } + dev.Close() + }) + } +} + +// TestSSHDevice_WrongPassword_ReturnsError verifies that NewSSHDevice returns an +// error when the supplied password is rejected by the server. +func TestSSHDevice_WrongPassword_ReturnsError(t *testing.T) { + addr := startTestServer(t, "", testPassword, nil, withQuaggaProbe(func(cmd string) (string, uint32) { + return "", 0 + })) + + cfg := mustDeviceConfig(t, fmt.Sprintf("addr: ssh://testuser:%s@%s", "wrong-password", addr)) + _, err := devssh.NewSSHDevice(cfg, io.Discard) + if err == nil { + t.Fatal("expected error with wrong password, got nil") + } +} + +// TestSSHDevice_PublicKeyAuth verifies that NewSSHDevice connects successfully +// when a valid ECDSA private key file is provided and the server accepts the +// corresponding public key. +func TestSSHDevice_PublicKeyAuth(t *testing.T) { + // Generate a fresh ECDSA P-256 key pair for this test. + clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate client key: %v", err) + } + sshPubKey, err := ssh.NewPublicKey(&clientKey.PublicKey) + if err != nil { + t.Fatalf("create SSH public key: %v", err) + } + + // Write the private key to a temporary PEM file so that connectSSHClient + // can read it via os.ReadFile / ssh.ParsePrivateKey. + keyDER, err := x509.MarshalECPrivateKey(clientKey) + if err != nil { + t.Fatalf("marshal EC private key: %v", err) + } + keyFile, err := os.CreateTemp(t.TempDir(), "test-key-*.pem") + if err != nil { + t.Fatalf("create temp key file: %v", err) + } + if err := pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil { + t.Fatalf("write PEM key: %v", err) + } + keyFile.Close() + + // Start a server that only accepts the generated public key (no password). + addr := startTestServer(t, "testuser", "", sshPubKey, withQuaggaProbe(func(cmd string) (string, uint32) { + return "", 0 + })) + + cfg := mustDeviceConfig(t, fmt.Sprintf( + "addr: ssh://testuser@%s\nprivateKeyFile: %s\n", + addr, keyFile.Name(), + )) + dev, err := devssh.NewSSHDevice(cfg, io.Discard) + if err != nil { + t.Fatalf("expected successful connection with public key, got: %v", err) + } + dev.Close() +} + +// ============================================================================= +// Protocol / lifecycle tests +// ============================================================================= + +// TestSSHDevice_GetProtocol verifies that GetProtocol returns "ssh". +func TestSSHDevice_GetProtocol(t *testing.T) { + addr := startTestServer(t, "testuser", testPassword, nil, withQuaggaProbe(func(cmd string) (string, uint32) { + return "", 0 + })) + cfg := mustDeviceConfig(t, fmt.Sprintf("addr: ssh://testuser:%s@%s", testPassword, addr)) + dev, err := devssh.NewSSHDevice(cfg, io.Discard) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer dev.Close() + + if got := dev.GetProtocol(); got != "ssh" { + t.Errorf("GetProtocol: expected %q, got %q", "ssh", got) + } +} + +// TestSSHDevice_BeginAndEndSequence verifies that both BeginSequence and +// EndSequence are no-ops that return nil. +func TestSSHDevice_BeginAndEndSequence(t *testing.T) { + addr := startTestServer(t, "testuser", testPassword, nil, withQuaggaProbe(func(cmd string) (string, uint32) { + return "", 0 + })) + cfg := mustDeviceConfig(t, fmt.Sprintf("addr: ssh://testuser:%s@%s", testPassword, addr)) + dev, err := devssh.NewSSHDevice(cfg, io.Discard) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer dev.Close() + + if err := dev.BeginSequence(); err != nil { + t.Errorf("BeginSequence: expected nil, got %v", err) + } + if err := dev.EndSequence(); err != nil { + t.Errorf("EndSequence: expected nil, got %v", err) + } +} + +// ============================================================================= +// ExecuteCommand tests +// ============================================================================= + +// TestSSHDevice_ExecuteCommand_OutputCaptured verifies that stdout produced by +// the remote command is written to the log writer supplied to NewSSHDevice. +func TestSSHDevice_ExecuteCommand_OutputCaptured(t *testing.T) { + const want = "hello from server" + + addr := startTestServer(t, "testuser", testPassword, nil, withQuaggaProbe(func(cmd string) (string, uint32) { + return want + "\n", 0 + })) + + var logBuf bytes.Buffer + cfg := mustDeviceConfig(t, fmt.Sprintf("addr: ssh://testuser:%s@%s", testPassword, addr)) + dev, err := devssh.NewSSHDevice(cfg, &logBuf) + if err != nil { + t.Fatalf("unexpected error creating device: %v", err) + } + defer dev.Close() + + cmd := mustSequenceCmd(t, "cmd: echo-test") + if _, err := dev.ExecuteCommand(context.Background(), cmd); err != nil { + t.Fatalf("unexpected error executing command: %v", err) + } + + if !strings.Contains(logBuf.String(), want) { + t.Errorf("log writer: expected to contain %q, got %q", want, logBuf.String()) + } +} + +// TestSSHDevice_ExecuteCommand_ParamsConcatenated verifies that the Cmd field +// and the Params array are joined into a single space-separated string before +// being sent to the server. +func TestSSHDevice_ExecuteCommand_ParamsConcatenated(t *testing.T) { + received := make(chan string, 1) + + addr := startTestServer(t, "testuser", testPassword, nil, withQuaggaProbe(func(cmd string) (string, uint32) { + received <- strings.TrimSpace(cmd) + return "", 0 + })) + + cfg := mustDeviceConfig(t, fmt.Sprintf("addr: ssh://testuser:%s@%s", testPassword, addr)) + dev, err := devssh.NewSSHDevice(cfg, io.Discard) + if err != nil { + t.Fatalf("unexpected error creating device: %v", err) + } + defer dev.Close() + + cmd := mustSequenceCmd(t, ` +cmd: echo +params: + - hello + - world +`) + if _, err := dev.ExecuteCommand(context.Background(), cmd); err != nil { + t.Fatalf("unexpected error executing command: %v", err) + } + + select { + case got := <-received: + const want = "echo hello world" + if got != want { + t.Errorf("server received %q; want %q", got, want) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for server to receive command") + } +} + +// TestSSHDevice_ExecuteCommand_ExitError verifies that a non-zero exit code from +// the remote command causes ExecuteCommand to return a non-nil error whose +// message includes the exit code, and that the captured stdout is returned as +// the result value. +func TestSSHDevice_ExecuteCommand_ExitError(t *testing.T) { + const cmdOutput = "something went wrong\n" + + addr := startTestServer(t, "testuser", testPassword, nil, withQuaggaProbe(func(cmd string) (string, uint32) { + return cmdOutput, 1 + })) + + cfg := mustDeviceConfig(t, fmt.Sprintf("addr: ssh://testuser:%s@%s", testPassword, addr)) + dev, err := devssh.NewSSHDevice(cfg, io.Discard) + if err != nil { + t.Fatalf("unexpected error creating device: %v", err) + } + defer dev.Close() + + cmd := mustSequenceCmd(t, "cmd: failing-cmd") + result, err := dev.ExecuteCommand(context.Background(), cmd) + + if err == nil { + t.Fatal("expected error for non-zero exit code, got nil") + } + if !strings.Contains(err.Error(), "exit code (1)") { + t.Errorf("error %q does not mention exit code 1", err.Error()) + } + if result == nil { + t.Error("expected non-nil result buffer when command fails with an exit error") + } +} + +// TestSSHDevice_ExecuteCommand_ContextCancellation verifies that cancelling the +// context while a command is in progress causes ExecuteCommand to return +// context.Canceled promptly. +func TestSSHDevice_ExecuteCommand_ContextCancellation(t *testing.T) { + // unblock is closed by t.Cleanup to release the blocking handler goroutine + // after the test has finished, preventing a goroutine leak. + unblock := make(chan struct{}) + t.Cleanup(func() { close(unblock) }) + + addr := startTestServer(t, "testuser", testPassword, nil, withQuaggaProbe(func(cmd string) (string, uint32) { + <-unblock // block until the test is done + return "", 1 + })) + + cfg := mustDeviceConfig(t, fmt.Sprintf("addr: ssh://testuser:%s@%s", testPassword, addr)) + dev, err := devssh.NewSSHDevice(cfg, io.Discard) + if err != nil { + t.Fatalf("unexpected error creating device: %v", err) + } + defer dev.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + cmd := mustSequenceCmd(t, "cmd: sleep-forever") + errCh := make(chan error, 1) + go func() { + _, err := dev.ExecuteCommand(ctx, cmd) + errCh <- err + }() + + // Give the command a moment to reach the server before cancelling. + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case err := <-errCh: + if err != context.Canceled { + t.Errorf("expected context.Canceled, got: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for ExecuteCommand to return after context cancellation") + } +} diff --git a/internal/device/sshdevice.go b/internal/device/sshdevice.go deleted file mode 100644 index e2a4f93..0000000 --- a/internal/device/sshdevice.go +++ /dev/null @@ -1,205 +0,0 @@ -package device - -import ( - "bytes" - "corteca/internal/configuration" - "corteca/internal/dispatcher" - "corteca/internal/tui" - "fmt" - "net/url" - "os" - "strings" - "time" - - "golang.org/x/crypto/ssh" -) - -const ( - deactivateQuaggaCmd = "sed -i 's#/usr/bin/vtysh#/bin/ash#' /etc/passwd" - maxNumRetries = 3 - defaultSSHPort = 22 - - authSSHPassword = "password" - authSSHPublicKey = "publicKey" -) - -type SSHClient interface { - NewSession() (*ssh.Session, error) - Close() error -} - -type SSHDevice struct { - client SSHClient - urlInfo *url.URL - auth string - password2 string - keyFile string - token string - log *Logger -} - -func NewSSHDevice(endpoint configuration.Endpoint, logfile string) (*SSHDevice, error) { - log := &Logger{} - if logfile != "" { - log.SetLogFile(logfile) - } - - u, err := url.Parse(endpoint.Addr.String()) - if err != nil { - return nil, err - } - - if u.Port() == "" { - u.Host = u.Host + fmt.Sprintf(":%v", defaultSSHPort) - } - - return &SSHDevice{ - urlInfo: u, - log: log, - auth: endpoint.Auth, - password2: endpoint.Password2.String(), - token: endpoint.Token.String(), - keyFile: endpoint.PrivateKeyFile.String(), - }, nil -} - -func (d *SSHDevice) GetProtocol() int { - return ConnectionSSH -} - -func (d *SSHDevice) Connect() (dispatcher.Dispatcher, error) { - - d.log.LogFile.WriteString(fmt.Sprintf("\n=== New connection to %s on %s ===\n", d.urlInfo.Host, time.Now().Format(time.DateTime))) - - sshConfig, err := d.buildSSHConfig(d.urlInfo) - if err != nil { - return nil, err - } - - if err := d.connectClient(d.urlInfo.Host, sshConfig); err != nil { - return nil, err - } - - if active, err := hasQuagga(d.client); err != nil { - return nil, err - } else if active { - if err := deactivateQuagga(d.client, d.password2); err != nil { - return nil, err - } - d.client.Close() - if err := d.connectClient(d.urlInfo.Host, sshConfig); err != nil { - return nil, err - } - } - - return dispatcher.NewSSHDispatcher(d.client.(*ssh.Client)), nil -} - -func (d *SSHDevice) buildSSHConfig(u *url.URL) (*ssh.ClientConfig, error) { - config := &ssh.ClientConfig{ - User: u.User.Username(), - HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: Replace with secure method - } - - if passwd, ok := u.User.Password(); ok { - config.Auth = []ssh.AuthMethod{ssh.Password(passwd)} - return config, nil - } - - switch d.auth { - case authSSHPassword: - password, err := tui.PromptForPassword(fmt.Sprintf("%s@%s's password", u.User.Username(), u.Host)) - if err != nil { - return nil, err - } - config.Auth = []ssh.AuthMethod{ssh.Password(password)} - case authSSHPublicKey: - keyPath := d.keyFile - if keyPath == "" { - return nil, fmt.Errorf("missing private key file for public key authentication") - } - key, err := os.ReadFile(keyPath) - if err != nil { - return nil, err - } - signer, err := ssh.ParsePrivateKey(key) - if err != nil { - return nil, err - } - config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)} - } - - return config, nil -} - -func (d *SSHDevice) connectClient(host string, config *ssh.ClientConfig) error { - client, err := ssh.Dial("tcp", host, config) - if err != nil { - return err - } - d.client = client - return nil -} - -func hasQuagga(client SSHClient) (bool, error) { - session, err := client.NewSession() - if err != nil { - return false, err - } - defer session.Close() - - var output bytes.Buffer - session.Stdout = &output - - if err := session.Run("ps | grep ash"); err != nil { - return true, nil - } - - return !strings.Contains(output.String(), "ash"), nil -} - -func deactivateQuagga(client SSHClient, password2 string) error { - if password2 == "" { - var err error - password2, err = tui.PromptForPassword("Enter Password2") - if err != nil { - return err - } - } - - session, err := client.NewSession() - if err != nil { - return err - } - defer session.Close() - - if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil { - return err - } - - stdin, err := session.StdinPipe() - if err != nil { - return err - } - - if err := session.Shell(); err != nil { - return err - } - - commands := []string{"shell", password2, deactivateQuaggaCmd} - for _, cmd := range commands { - if _, err := stdin.Write([]byte(cmd + "\n")); err != nil { - return fmt.Errorf("failed to run command %q: %w", cmd, err) - } - time.Sleep(1 * time.Second) - } - - return nil -} - -func (d *SSHDevice) Close() { - if d.log != nil && d.log.LogFile != nil { - d.log.LogFile.Sync() - d.log.LogFile.Close() - } -} diff --git a/internal/device/sshdevice_test.go b/internal/device/sshdevice_test.go deleted file mode 100644 index 6c54791..0000000 --- a/internal/device/sshdevice_test.go +++ /dev/null @@ -1,78 +0,0 @@ - -package device - -import ( - "io" - "testing" - - "corteca/internal/configuration" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "golang.org/x/crypto/ssh" -) - -// --- Mock Interfaces --- - -type MockSSHClient struct { - mock.Mock -} - -func (m *MockSSHClient) NewSession() (*ssh.Session, error) { - args := m.Called() - return args.Get(0).(*ssh.Session), args.Error(1) -} - -func (m *MockSSHClient) Close() error { - return m.Called().Error(0) -} - -type MockSSHSession struct { - mock.Mock -} - -func (m *MockSSHSession) Run(cmd string) error { - return m.Called(cmd).Error(0) -} - -func (m *MockSSHSession) StdinPipe() (io.WriteCloser, error) { - args := m.Called() - return args.Get(0).(io.WriteCloser), args.Error(1) -} - -func (m *MockSSHSession) Shell() error { - return m.Called().Error(0) -} - -func (m *MockSSHSession) Close() error { - return m.Called().Error(0) -} - -func (m *MockSSHSession) RequestPty(term string, h, w int, modes ssh.TerminalModes) error { - return m.Called(term, h, w, modes).Error(0) -} - -// --- Tests --- - -func TestNewSSHDevice(t *testing.T) { - endpoint := configuration.Endpoint{ - Addr: configuration.TemplateField{RawTemplate: "ssh://user:pass@localhost:22"}, - Auth: "password", - Username: configuration.TemplateField{RawTemplate: "user"}, - Password: configuration.TemplateField{RawTemplate: "password"}, - Password2: configuration.TemplateField{RawTemplate: "password2"}, - PrivateKeyFile: configuration.TemplateField{RawTemplate: "/path/to/key"}, - } - device, err := NewSSHDevice(endpoint, "test.log") - assert.NoError(t, err) - assert.NotNil(t, device) - assert.Equal(t, endpoint.Addr.String(), device.urlInfo.String()) - assert.Equal(t, endpoint.Auth, device.auth) - assert.Equal(t, endpoint.Token.String(), device.token) - assert.Equal(t, endpoint.Password2.String(), device.password2) - assert.Equal(t, endpoint.PrivateKeyFile.String(), device.keyFile) -} - -func TestGetProtocol(t *testing.T) { - device := &SSHDevice{} - assert.Equal(t, ConnectionSSH, device.GetProtocol()) -} diff --git a/internal/dispatcher/dispatcher.go b/internal/dispatcher/dispatcher.go deleted file mode 100644 index 157d338..0000000 --- a/internal/dispatcher/dispatcher.go +++ /dev/null @@ -1,151 +0,0 @@ -package dispatcher - -import ( - "bytes" - "corteca/internal/cwmp/messages" - "corteca/internal/cwmp/models" - "encoding/json" - "fmt" - "io" - "strings" - - "golang.org/x/crypto/ssh" -) - -type Dispatcher interface { - ExecuteCommand(any) (string, error) - SetPrintFormat(string) -} - -type SSHDispatcher struct { - client *ssh.Client -} - -func NewSSHDispatcher(client *ssh.Client) *SSHDispatcher { - return &SSHDispatcher{client: client} -} - -func (ssh_dispatcher *SSHDispatcher) SetPrintFormat(format string) { -} - -func (ssd_dispatcher *SSHDispatcher) ExecuteCommand(cmd any) (string, error) { - session, err := ssd_dispatcher.client.NewSession() - if err != nil { - return "", nil - } - defer session.Close() - - var outBuff bytes.Buffer - mwOut := io.MultiWriter(&outBuff) - session.Stdout = mwOut - session.Stderr = mwOut - - err = session.Run(cmd.(string)) - - if err != nil { - if status, ok := err.(*ssh.ExitError); ok { - return outBuff.String(), fmt.Errorf("exit code (%v)", status.ExitStatus()) - } else { - return outBuff.String(), err - } - } else { - return outBuff.String(), nil - } -} - -type CWMPDispatcher struct { - taskChannel chan messages.Message - resultChannel chan *models.ResultsMessage - printFormat string -} - -func NewCWMPDispatcher(taskChan chan messages.Message, resultChannel chan *models.ResultsMessage) *CWMPDispatcher { - return &CWMPDispatcher{taskChannel: taskChan, resultChannel: resultChannel} -} - -func (d *CWMPDispatcher) SetPrintFormat(format string) { - d.printFormat = format -} - -func formatMessageOutput(result *models.ResultsMessage, printFormat string) (string, error) { - var output string - - if result.Message == nil { - return "", fmt.Errorf("empty message") - } - - msg := result.Message - - switch msg.GetName() { - case "Inform": - case "GetParameterNamesResponse": - var builder strings.Builder - if printFormat == "json" { - parameterValuesList, err := json.Marshal(msg.(*messages.GetParameterNamesResponse).ParameterList.Parameters) - if err != nil { - return "", err - } - - output = string(parameterValuesList) - } else { - for _, parameter := range msg.(*messages.GetParameterNamesResponse).ParameterList.Parameters { - builder.WriteString(fmt.Sprintf("- Name: %-s\n", parameter.Name)) - builder.WriteString(fmt.Sprintf(" Writable: %-v\n", parameter.Writable)) - } - output = builder.String() - } - case "GetParameterValuesResponse": - var resultStr strings.Builder - - if printFormat == "json" { - parameterValuesList, err := json.Marshal(msg.(*messages.GetParameterValuesResponse).ParameterList) - if err != nil { - return "", err - } - output = string(parameterValuesList) - } else { - resultStr.WriteString("\n************** Parameter(s) Value(s) **************\n") - for _, parameter := range msg.(*messages.GetParameterValuesResponse).ParameterList { - resultStr.WriteString(fmt.Sprintf("%s: %s\n", parameter.Name, parameter.Value)) - } - resultStr.WriteString("***************************************************\n") - output = resultStr.String() - } - case "SetParameterValuesResponse": - if result.Code == 0 { - output = "All parameters changes have been validated and applied" - } else { - output = "All Parameter changes have been validated and committed, but some or all are not yet applied (e.g A reboot is required before the new values are applied)" - } - case "ChangeDUStateResponse": - case "DUStateChangeComplete": - ducomplete := msg.(*messages.DUStateChangeComplete) - output = ducomplete.Fault.FaultString - case "Fault": - output = result.Message.(*messages.Fault).MsgFaultString - default: - output = "internal error" - } - - return output, nil -} - -func (d *CWMPDispatcher) ExecuteCommand(cmd any) (string, error) { - task, ok := cmd.(messages.Message) - - if ok { - //Send task to cwmp server - d.taskChannel <- task - //wait for results - result := <-d.resultChannel - - output , err := formatMessageOutput(result, d.printFormat) - - if result.Code != 0 || err != nil { - return "", fmt.Errorf("task \"%s\" with error \"%s\" (cd: %d)", task.GetName(), output, result.Code) - } - return output, nil - } else { - return "", fmt.Errorf("cmd is not a valid cwmp task") - } -} diff --git a/internal/packager/packager.go b/internal/packager/packager.go index 69f9403..78a2e8c 100644 --- a/internal/packager/packager.go +++ b/internal/packager/packager.go @@ -38,7 +38,7 @@ func AnnotateRootFS(dest string, appSettings configuration.AppSettings, buildMet func PackageOCI(buildDir, distPath, arch, platform, rootfsTarGzPath string, appSettings configuration.AppSettings) error { ociDirName := fmt.Sprintf("%s-%s-%s-oci", appSettings.Name, appSettings.Version, arch) - ociTarName := fmt.Sprintf("%s-%s-%s-oci.tar", appSettings.Name, appSettings.Version, arch) + ociTarName := fmt.Sprintf("%s-%s-%s-oci.tar.gz", appSettings.Name, appSettings.Version, arch) ociDirPath := filepath.Join(buildDir, ociDirName) ociTarPath := filepath.Join(distPath, ociTarName) diff --git a/internal/publish/push.go b/internal/publish/push.go index 1d747cf..63d199e 100644 --- a/internal/publish/push.go +++ b/internal/publish/push.go @@ -1,13 +1,16 @@ package publish import ( + "compress/gzip" + "corteca/internal/configuration" "corteca/internal/fsutil" "corteca/internal/tui" "crypto/tls" "fmt" + "io" "net/http" "net/url" - "path/filepath" + "os" "strings" "github.com/google/go-containerregistry/pkg/authn" @@ -15,110 +18,113 @@ import ( v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/layout" "github.com/google/go-containerregistry/pkg/v1/remote" - "github.com/pterm/pterm" + "github.com/google/go-containerregistry/pkg/v1/tarball" ) -func PushImage(imagePath string, addr *url.URL, token string, withProgress bool) error { - distDir := filepath.Dir(imagePath) - extractedImagePath := strings.TrimSuffix(imagePath, ".tar") - extractedOCIName := filepath.Base(extractedImagePath) +type gzipReadCloser struct { + *gzip.Reader + file *os.File +} + +func GzipOpener(path string) tarball.Opener { + return func() (io.ReadCloser, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + + gz, err := gzip.NewReader(f) + if err != nil { + f.Close() + return nil, err + } - if err := fsutil.ExtractTarball(imagePath, extractedImagePath); err != nil { - return fmt.Errorf("failed to extract OCI image: %w", err) + // combine both closers + return &gzipReadCloser{ + Reader: gz, + file: f, + }, nil } +} - versionRef, err := name.NewTag(fmt.Sprintf("%s%s", addr.Host, addr.Path)) - if err != nil { - return fmt.Errorf("failed to parse image reference: %w", err) +func (g *gzipReadCloser) Close() error { + if err := g.Reader.Close(); err != nil { + return err } + return g.file.Close() +} - index, err := layout.ImageIndexFromPath(extractedImagePath) +func PushImage(tarballPath string, target *configuration.HttpClientEndpoint, withProgress bool) error { + // create image tag from URL + url, err := url.Parse(target.Addr.String()) if err != nil { - return fmt.Errorf("failed to read image index from path: %w", err) + return err } - - manifest, err := index.IndexManifest() + tagOpts := []name.Option{name.StrictValidation} + if url.Scheme == "http" { + tagOpts = append(tagOpts, name.Insecure) + } + tag, err := name.NewTag(strings.ToLower(url.Host+url.Path), tagOpts...) if err != nil { - return fmt.Errorf("failed to get index manifest: %w", err) + return err } - var img v1.Image - for _, desc := range manifest.Manifests { - image, err := index.Image(desc.Digest) - if err != nil { - return fmt.Errorf("failed to get image: %w", err) - } - img = image - break + // get image from tarball + tmp, err := os.MkdirTemp("", "corteca_image_") + if err != nil { + return fmt.Errorf("cannot create tmp folder: %w", err) } - - transport := remote.WithTransport(&http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }) - - auth, err := getAuthenticator(addr, token) + if err := fsutil.ExtractTarball(tarballPath, tmp); err != nil { + return fmt.Errorf("cannot extract %s to %s: %w", tarballPath, tmp, err) + } + lp, err := layout.FromPath(tmp) if err != nil { - return fmt.Errorf("failed to get authenticator: %w", err) + return fmt.Errorf("cannot open OCI layout from %s: %w", tmp, err) } - - options := []remote.Option{ - remote.WithAuth(auth), - transport, + idx, err := lp.ImageIndex() + if err != nil { + return fmt.Errorf("cannot open index from %s: %w", tmp, err) + } + im, err := idx.IndexManifest() + if err != nil { + return fmt.Errorf("cannot open index manifest from %s: %w", tmp, err) + } + image, err := idx.Image(im.Manifests[0].Digest) + if err != nil { + return fmt.Errorf("cannot open image from %s: %w", tmp, err) } + // set client options + clientOpts := []remote.Option{} + if target.SkipTLSVerification { + clientOpts = append(clientOpts, remote.WithTransport(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + })) + } + switch target.Auth { + case configuration.BasicClientAuth: + clientOpts = append(clientOpts, remote.WithAuth(authn.FromConfig(authn.AuthConfig{Username: target.Username.String(), Password: target.Password.String()}))) + case configuration.BearerClientAuth: + clientOpts = append(clientOpts, remote.WithAuth(authn.FromConfig(authn.AuthConfig{RegistryToken: target.Token.String()}))) + } if withProgress { updates := make(chan v1.Update, 8) - progressBar := initializeProgressBar() - go handleProgressUpdates(progressBar, updates) - options = append(options, remote.WithProgress(updates)) + prog := tui.PromptForProgress(fmt.Sprintf("Pushing %s", tag.String())) + defer close(prog) + clientOpts = append(clientOpts, remote.WithProgress(updates)) + go func() { + for update := range updates { + prog <- tui.ProgressUpdate{Current: update.Complete, Total: update.Total} + } + }() } - if err := remote.Write(versionRef, img, options...); err != nil { + if err := remote.Write(tag, image, clientOpts...); err != nil { return fmt.Errorf("failed to push image manifest to registry: %w", err) } - - if err := fsutil.RemoveFilesFromFolder(distDir, []string{extractedOCIName}); err != nil { - return fmt.Errorf("failed to clean up extracted files: %w", err) - } - tui.DisplaySuccessMsg(fmt.Sprintf("Pushed image '%v' as '%v'\n", imagePath, versionRef.Name())) + tui.DisplaySuccessMsg(fmt.Sprintf("Pushed image %s to %s", tarballPath, tag.String())) return nil } - -func handleProgressUpdates(bar *pterm.ProgressbarPrinter, updates chan v1.Update) { - var lastComplete int64 - var totalSizeSet bool - for update := range updates { - if !totalSizeSet && update.Total > 0 { - bar.Total = int(update.Total) - totalSizeSet = true - } - progress := int(update.Complete - lastComplete) - bar.Add(progress) - lastComplete = update.Complete - //pterm.Debug.Println(fmt.Sprintf("Progress: %d/%d", update.Complete, update.Total)) - } - bar.Stop() -} - -func getAuthenticator(registryURL *url.URL, token string) (authn.Authenticator, error) { - if token != "" { - return &authn.Bearer{ - Token: token, - }, nil - } else { - // registryURL should always include a valid credentials or authentication token - password, _ := registryURL.User.Password() - return authn.FromConfig(authn.AuthConfig{ - Username: registryURL.User.Username(), - Password: password, - }), nil - } -} - -func initializeProgressBar() *pterm.ProgressbarPrinter { - bar, _ := pterm.DefaultProgressbar.WithTotal(100).WithTitle("Pushing").Start() - return bar -} diff --git a/internal/publish/put.go b/internal/publish/put.go index 38816da..d3d7f22 100644 --- a/internal/publish/put.go +++ b/internal/publish/put.go @@ -9,6 +9,7 @@ import ( "corteca/internal/tui" "errors" "fmt" + "io" "net/http" "net/url" "path/filepath" @@ -23,6 +24,36 @@ const ( authHttpDigestName = "digest" ) +type PutReader struct { + file afero.File + ch chan<- tui.ProgressUpdate + total int64 +} + +func (r *PutReader) Read(p []byte) (int, error) { + n, err := r.file.Read(p) + if err != nil { + return n, err + } + pos, err := r.file.Seek(0, io.SeekCurrent) + if err != nil { + return n, err + } + if r.total == 0 { + if r.total, err = r.file.Seek(0, io.SeekEnd); err != nil { + return n, err + } + if _, err = r.file.Seek(pos, io.SeekStart); err != nil { + return n, err + } + } + r.ch <- tui.ProgressUpdate{ + Current: pos, + Total: r.total, + } + return n, err +} + func HttpPut(filePath string, url url.URL, token string) error { if url.Scheme != "http" && url.Scheme != "https" { @@ -39,12 +70,10 @@ func HttpPut(filePath string, url url.URL, token string) error { } defer file.Close() - progressReader, err := tui.PromptForProgress(file, fmt.Sprintf("Uploading %s", fileName)) - if err != nil { - return err - } + prog := tui.PromptForProgress(fmt.Sprintf("Uploading %s", fileName)) + defer close(prog) - req, err := http.NewRequest("PUT", url.String(), progressReader) + req, err := http.NewRequest("PUT", url.String(), &PutReader{file: file, ch: prog}) if err != nil { return err } @@ -70,7 +99,6 @@ func HttpPut(filePath string, url url.URL, token string) error { } defer resp.Body.Close() - progressReader.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { return fmt.Errorf("server returned non-successful status: %s", resp.Status) @@ -81,19 +109,19 @@ func HttpPut(filePath string, url url.URL, token string) error { return nil } -func AuthenticateHttp(endpoint configuration.Endpoint) (*url.URL, error) { +func AuthenticateHttp(config configuration.HttpClientEndpoint) (*url.URL, error) { - u, err := url.Parse(endpoint.Addr.String()) + u, err := url.Parse(config.Addr.String()) if err != nil { return nil, err } - authType := strings.ToLower(endpoint.Auth) + authType := strings.ToLower(config.Auth) switch authType { case authHttpBasicName: - username := endpoint.Username.String() - password := endpoint.Password.String() + username := config.Username.String() + password := config.Password.String() // Check for username in .yaml config if username == "" { @@ -117,7 +145,7 @@ func AuthenticateHttp(endpoint configuration.Endpoint) (*url.URL, error) { u.User = url.UserPassword(username, password) case authHttpBearerName: - if endpoint.Token.String() == "" { + if config.Token.String() == "" { return nil, errors.New("no bearer token present in configuration even though HTTP Bearer authentication has been requested") } case authHttpDigestName: diff --git a/internal/publish/registry.go b/internal/publish/registry.go index fc5549e..cb8c921 100644 --- a/internal/publish/registry.go +++ b/internal/publish/registry.go @@ -1,194 +1,28 @@ package publish import ( - "archive/tar" - "compress/gzip" - "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "io" - "math/big" - "net" + "corteca/internal/configuration" + "corteca/internal/tui" "net/http" - "os" - "path/filepath" - "strings" - "time" "github.com/google/go-containerregistry/pkg/registry" - v1 "github.com/google/go-containerregistry/pkg/v1" ) -type artifactBlobHandler struct { - artifactPath string -} - -type readCloserWrapper struct { - reader io.Reader - closeFunc func() error -} - -func (r *readCloserWrapper) Read(p []byte) (n int, err error) { - return r.reader.Read(p) -} - -func (r *readCloserWrapper) Close() error { - return r.closeFunc() -} - -func (a *artifactBlobHandler) Get(ctx context.Context, repo string, hash v1.Hash) (io.ReadCloser, error) { - reader, closer, err := findFileInArtifact(a.artifactPath, hash.String()) - if err != nil { - return nil, err - } - - return &readCloserWrapper{ - reader: reader, - closeFunc: closer, - }, nil -} - -func NewArtifactBlobHandler(artifact string) registry.BlobHandler { - return &artifactBlobHandler{artifactPath: artifact} -} - -func findFileInArtifact(artifactPath, targetName string) (io.Reader, func() error, error) { - f, err := os.Open(artifactPath) - if err != nil { - return nil, nil, fmt.Errorf("failed to open archive: %w", err) - } - - gzf, err := gzip.NewReader(f) - if err != nil { - f.Close() - return nil, nil, fmt.Errorf("failed to create gzip reader: %w", err) - } - - tarReader := tar.NewReader(gzf) - - // Extract the algorithmName/encoded portion from ":" digest-format - algorithmName, targetHash, err := splitDigest(targetName) - if err != nil { - return nil, nil, fmt.Errorf("error reading tar file: %w", err) - } - - for { - header, err := tarReader.Next() - if err == io.EOF { - break - } - if err != nil { - return nil, nil, fmt.Errorf("error reading tar file: %w", err) - } - - // Check if "blobs//" format matches header.Name - if header.Name == filepath.ToSlash(filepath.Join("blobs", algorithmName, targetHash)) { - return tarReader, func() error { - errGZF := gzf.Close() - errF := f.Close() - - if errGZF != nil { - return errGZF - } - return errF - }, nil - } - } - - gzf.Close() - f.Close() - return nil, nil, fmt.Errorf("file not found: %s", targetName) -} - -func splitDigest(targetName string) (string, string, error) { - parts := strings.SplitN(targetName, ":", 2) - if len(parts) != 2 { - return "", "", fmt.Errorf("invalid targetName format: expected :, got %s", targetName) - } - - return parts[0], parts[1], nil -} - -func generateSelfSignedCert() (tls.Certificate, error) { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return tls.Certificate{}, fmt.Errorf("failed to generate private key: %v", err) - } - - notBefore := time.Now() - notAfter := notBefore.Add(365 * 24 * time.Hour) - - serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) - if err != nil { - return tls.Certificate{}, fmt.Errorf("failed to generate serial number: %v", err) - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"test"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - template.IPAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")} - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - return tls.Certificate{}, fmt.Errorf("failed to create certificate: %v", err) - } - - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - keyBytes, err := x509.MarshalECPrivateKey(priv) - if err != nil { - return tls.Certificate{}, fmt.Errorf("failed to marshal private key: %v", err) - } - - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}) - - cert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - return tls.Certificate{}, fmt.Errorf("failed to load key pair: %v", err) - } - - return cert, nil -} - -func StartRegistry(address, artifact string) (*http.Server, error) { - blobHandler := NewArtifactBlobHandler(artifact) - - handler := registry.New(registry.WithBlobHandler(blobHandler)) - - cert, err := generateSelfSignedCert() - if err != nil { - return nil, fmt.Errorf("failed to generate self-signed cert: %v", err) - } - - server := &http.Server{ - Addr: address, - Handler: handler, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - }, - } - +func StartRegistry(config configuration.HttpServerEndpoint) (*http.Server, error) { + handler := registry.New() + server := &http.Server{Addr: config.Addr.String(), Handler: handler} + certFile := config.Certificate.String() + keyFile := config.Key.String() go func() { - if err := server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { - fmt.Printf("ListenAndServeTLS(): %v", err) + var err error + if len(certFile) > 0 { + err = server.ListenAndServeTLS(certFile, keyFile) + } else { + err = server.ListenAndServe() + } + if err != nil && err != http.ErrServerClosed { + tui.LogError("Error while running registry server: %s", err.Error()) } }() - - time.Sleep(1 * time.Second) return server, nil } diff --git a/internal/tui/scratch.go b/internal/tui/scratch.go new file mode 100644 index 0000000..bf2ab33 --- /dev/null +++ b/internal/tui/scratch.go @@ -0,0 +1,17 @@ +//go:build ignore + +package main + +import ( + "corteca/internal/tui" + "time" +) + +func main() { + prog := tui.PromptForProgress("Testing bar") + defer close(prog) + for i := 1; i <= 15; i++ { + prog <- tui.ProgressUpdate{Current: int64(i), Total: 15} + time.Sleep(250 * time.Millisecond) + } +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index b7531df..136a0ff 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -11,7 +11,6 @@ import ( "os" "github.com/pterm/pterm" - "github.com/spf13/afero" "golang.org/x/term" ) @@ -49,6 +48,13 @@ func LogNormal(format string, args ...any) { fmt.Fprintf(os.Stderr, format, args...) } +func LogWarning(format string, args ...any) { + SetOutputColor(CYellow, os.Stderr) + format += "\n" + fmt.Fprintf(os.Stderr, format, args...) + ResetOutputColor(os.Stderr) +} + func LogError(format string, args ...any) { SetOutputColor(CRed, os.Stderr) format += "\n" @@ -91,50 +97,33 @@ func PromptForPassword(label string) (string, error) { return result, err } -func DisplayHelpMsg(msg string) { - pterm.ThemeDefault.InfoMessageStyle.Println(msg) -} - -type ProgressBar struct { - file afero.File - progBar *pterm.ProgressbarPrinter -} - -func (pb *ProgressBar) Read(p []byte) (int, error) { - n, err := pb.file.Read(p) - - if err != nil { - return n, err - } - - pb.progBar.Add(n) - - return n, nil +type ProgressUpdate struct { + Current int64 + Total int64 } -func (pb *ProgressBar) Close() { - pb.progBar.Stop() +func PromptForProgress(label string) chan<- ProgressUpdate { + const max = 100 + ch := make(chan ProgressUpdate, 8) + bar, _ := pterm.DefaultProgressbar. + WithTitle(label). + WithTotal(max). + WithShowCount(false). + WithShowElapsedTime(false). + Start() + go func() { + for update := range ch { + current := int((update.Current * int64(max)) / update.Total) + diff := current - bar.Current + bar.Add(diff) + } + bar.Stop() + }() + return ch } -func PromptForProgress(f afero.File, label string) (*ProgressBar, error) { - - fileInfo, err := f.Stat() - - if err != nil { - return nil, err - } - - pb := ProgressBar{ - file: f, - progBar: pterm.DefaultProgressbar.WithTotal(int(fileInfo.Size())).WithTitle(label).WithMaxWidth(-1).WithCurrent(0), - } - - pb.progBar, err = pb.progBar.Start() - if err != nil { - return nil, err - } - - return &pb, nil +func DisplayHelpMsg(msg string) { + pterm.ThemeDefault.InfoMessageStyle.Println(msg) } func DisplaySuccessMsg(msg string) { @@ -142,12 +131,3 @@ func DisplaySuccessMsg(msg string) { LogNormal(msg) ResetOutputColor(os.Stderr) } - -func DisplayErrorMsg(msg string) { - LogError(msg) -} - -func LogOutData(format string, args ...any) { - format += "\n" - fmt.Fprintf(os.Stdout, format, args...) -}