diff --git a/README.md b/README.md index e4a1e35d..3f773351 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,41 @@ switches are most important to you to have implemented next in the new sqlcmd. - `:Connect` now has an optional `-G` parameter to select one of the authentication methods for Azure SQL Database - `SqlAuthentication`, `ActiveDirectoryDefault`, `ActiveDirectoryIntegrated`, `ActiveDirectoryServicePrincipal`, `ActiveDirectoryManagedIdentity`, `ActiveDirectoryPassword`. If `-G` is not provided, either Integrated security or SQL Authentication will be used, dependent on the presence of a `-U` username parameter. - The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces. - Sqlcmd can now print results using a vertical format. Use the new `--vertical` command line option to set it. It's also controlled by the `SQLCMDFORMAT` scripting variable. +- `:help` displays a list of available sqlcmd commands. +- `:serverlist` lists local SQL Server instances discovered via the SQL Server Browser service (UDP port 1434). The command queries the SQL Browser service and displays the server name and instance name for each discovered instance. If no instances are found or the Browser service is not running, no output is produced. Non-timeout errors are printed to stderr. + +``` +1> :serverlist +MYSERVER\SQL2019 +MYSERVER\SQL2022 +``` + +#### Using :serverlist in batch scripts + +When automating server discovery, you can capture the output and check for errors: + +```batch +@echo off +REM Discover local SQL Server instances and connect to the first one +sqlcmd -Q ":serverlist" 2>nul > servers.txt +if %errorlevel% neq 0 ( + echo Error discovering servers + exit /b 1 +) +for /f "tokens=1" %%s in (servers.txt) do ( + echo Connecting to %%s... + sqlcmd -S %%s -Q "SELECT @@SERVERNAME" + goto :done +) +echo No SQL Server instances found +:done +``` + +To capture stderr separately (for error logging): +```batch +sqlcmd -Q ":serverlist" 2>errors.log > servers.txt +if exist errors.log if not "%%~z errors.log"=="0" type errors.log +``` ``` 1> select session_id, client_interface_name, program_name from sys.dm_exec_sessions where session_id=@@spid diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 7d69b24b..cf142d94 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -5,20 +5,16 @@ package sqlcmd import ( - "context" "errors" "fmt" - "net" "os" "regexp" "runtime/trace" "strconv" "strings" - "time" mssql "github.com/microsoft/go-mssqldb" "github.com/microsoft/go-mssqldb/azuread" - "github.com/microsoft/go-mssqldb/msdsn" "github.com/microsoft/go-sqlcmd/internal/localizer" "github.com/microsoft/go-sqlcmd/pkg/console" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" @@ -236,7 +232,7 @@ func Execute(version string) { fmt.Println() fmt.Println(localizer.Sprintf("Servers:")) } - listLocalServers() + sqlcmd.ListLocalServers(os.Stdout) os.Exit(0) } if len(argss) > 0 { @@ -915,76 +911,3 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { s.SetError(nil) return s.Exitcode, err } - -func listLocalServers() { - bmsg := []byte{byte(msdsn.BrowserAllInstances)} - resp := make([]byte, 16*1024-1) - dialer := &net.Dialer{} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - conn, err := dialer.DialContext(ctx, "udp", ":1434") - // silently ignore failures to connect, same as ODBC - if err != nil { - return - } - defer conn.Close() - dl, _ := ctx.Deadline() - _ = conn.SetDeadline(dl) - _, err = conn.Write(bmsg) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - fmt.Println(err) - } - return - } - read, err := conn.Read(resp) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - fmt.Println(err) - } - return - } - - data := parseInstances(resp[:read]) - instances := make([]string, 0, len(data)) - for s := range data { - if s == "MSSQLSERVER" { - - instances = append(instances, "(local)", data[s]["ServerName"]) - } else { - instances = append(instances, fmt.Sprintf(`%s\%s`, data[s]["ServerName"], s)) - } - } - for _, s := range instances { - fmt.Println(" ", s) - } -} - -func parseInstances(msg []byte) msdsn.BrowserData { - results := msdsn.BrowserData{} - if len(msg) > 3 && msg[0] == 5 { - out_s := string(msg[3:]) - tokens := strings.Split(out_s, ";") - instdict := map[string]string{} - got_name := false - var name string - for _, token := range tokens { - if got_name { - instdict[name] = token - got_name = false - } else { - name = token - if len(name) == 0 { - if len(instdict) == 0 { - break - } - results[strings.ToUpper(instdict["InstanceName"])] = instdict - instdict = map[string]string{} - continue - } - got_name = true - } - } - } - return results -} diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..72464494 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -113,6 +113,16 @@ func newCommands() Commands { action: xmlCommand, name: "XML", }, + "HELP": { + regex: regexp.MustCompile(`(?im)^[ \t]*:HELP(?:[ \t]+(.*$)|$)`), + action: helpCommand, + name: "HELP", + }, + "SERVERLIST": { + regex: regexp.MustCompile(`(?im)^[ \t]*:SERVERLIST(?:[ \t]+(.*$)|$)`), + action: serverlistCommand, + name: "SERVERLIST", + }, } } @@ -596,6 +606,65 @@ func xmlCommand(s *Sqlcmd, args []string, line uint) error { return nil } +// helpCommand displays the list of available sqlcmd commands +func helpCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) > 0 && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("HELP", line) + } + helpText := `:!! [] + - Executes a command in the operating system shell. +:connect server[\instance] [-l timeout] [-U user [-P password]] + - Connects to a SQL Server instance. +:ed + - Edits the current or last executed statement cache. +:error + - Redirects error output to a file, stderr, or stdout. +:exit + - Quits sqlcmd immediately. +:exit() + - Execute statement cache; quit with no return value. +:exit() + - Execute the specified query; returns numeric result. +go [] + - Executes the statement cache (n times). +:help + - Shows this list of commands. +:list + - Prints the content of the statement cache. +:listvar + - Lists the set sqlcmd scripting variables. +:on error [exit|ignore] + - Action for batch or sqlcmd command errors. +:out |stderr|stdout + - Redirects query output to a file, stderr, or stdout. +:quit + - Quits sqlcmd immediately. +:r + - Append file contents to the statement cache. +:reset + - Discards the statement cache. +:serverlist + - Lists local SQL Server instances. +:setvar {variable} + - Removes a sqlcmd scripting variable. +:setvar + - Sets a sqlcmd scripting variable. +:xml [on|off] + - Sets XML output mode. +` + _, err := s.GetOutput().Write([]byte(helpText)) + return err +} + +// serverlistCommand lists locally available SQL Server instances +func serverlistCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) > 0 && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("SERVERLIST", line) + } + ListLocalServers(s.GetOutput()) + return nil +} + func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { var b *strings.Builder end := len(arg) diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 56d509da..4abf0706 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -54,6 +54,10 @@ func TestCommandParsing(t *testing.T) { {`:XML ON `, "XML", []string{`ON `}}, {`:RESET`, "RESET", []string{""}}, {`RESET`, "RESET", []string{""}}, + {`:HELP`, "HELP", []string{""}}, + {`:help`, "HELP", []string{""}}, + {`:SERVERLIST`, "SERVERLIST", []string{""}}, + {`:serverlist`, "SERVERLIST", []string{""}}, } for _, test := range commands { @@ -464,3 +468,25 @@ func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) { } } + +func TestHelpCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer func() { _ = buf.Close() }() + s.SetOutput(buf) + + err := helpCommand(s, []string{""}, 1) + assert.NoError(t, err, "helpCommand should not error") + + output := buf.buf.String() + // Verify key commands are listed + assert.Contains(t, output, ":connect", "help should list :connect") + assert.Contains(t, output, ":exit", "help should list :exit") + assert.Contains(t, output, ":help", "help should list :help") + assert.Contains(t, output, ":setvar", "help should list :setvar") + assert.Contains(t, output, ":listvar", "help should list :listvar") + assert.Contains(t, output, ":out", "help should list :out") + assert.Contains(t, output, ":error", "help should list :error") + assert.Contains(t, output, ":r", "help should list :r") + assert.Contains(t, output, ":serverlist", "help should list :serverlist") + assert.Contains(t, output, "go", "help should list go") +} diff --git a/pkg/sqlcmd/serverlist.go b/pkg/sqlcmd/serverlist.go new file mode 100644 index 00000000..021933d2 --- /dev/null +++ b/pkg/sqlcmd/serverlist.go @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "sort" + "strings" + "time" + + "github.com/microsoft/go-mssqldb/msdsn" +) + +// ListLocalServers queries the SQL Browser service for available SQL Server instances +// and writes the results to the provided writer. +func ListLocalServers(w io.Writer) { + instances, err := GetLocalServerInstances() + if err != nil { + fmt.Fprintln(os.Stderr, err) + } + for _, s := range instances { + _, _ = fmt.Fprintf(w, " %s\n", s) + } +} + +// GetLocalServerInstances queries the SQL Browser service and returns a list of +// available SQL Server instances on the local machine. +// Returns an error for non-timeout network errors. +func GetLocalServerInstances() ([]string, error) { + bmsg := []byte{byte(msdsn.BrowserAllInstances)} + resp := make([]byte, 16*1024-1) + dialer := &net.Dialer{} + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + conn, err := dialer.DialContext(ctx, "udp", ":1434") + // silently ignore failures to connect, same as ODBC + if err != nil { + return nil, nil + } + defer func() { _ = conn.Close() }() + dl, _ := ctx.Deadline() + _ = conn.SetDeadline(dl) + _, err = conn.Write(bmsg) + if err != nil { + // Only return error if it's not a timeout + if !errors.Is(err, os.ErrDeadlineExceeded) { + return nil, err + } + return nil, nil + } + read, err := conn.Read(resp) + if err != nil { + // Only return error if it's not a timeout + if !errors.Is(err, os.ErrDeadlineExceeded) { + return nil, err + } + return nil, nil + } + + data := parseInstances(resp[:read]) + instances := make([]string, 0, len(data)) + + // Sort instance names for deterministic output + instanceNames := make([]string, 0, len(data)) + for s := range data { + instanceNames = append(instanceNames, s) + } + sort.Strings(instanceNames) + + for _, s := range instanceNames { + serverName := data[s]["ServerName"] + if serverName == "" { + // Skip instances without a ServerName + continue + } + if s == "MSSQLSERVER" { + instances = append(instances, "(local)", serverName) + } else { + instances = append(instances, fmt.Sprintf(`%s\%s`, serverName, s)) + } + } + return instances, nil +} + +func parseInstances(msg []byte) msdsn.BrowserData { + results := msdsn.BrowserData{} + if len(msg) > 3 && msg[0] == 5 { + outStr := string(msg[3:]) + tokens := strings.Split(outStr, ";") + instanceDict := map[string]string{} + gotName := false + var name string + for _, token := range tokens { + if gotName { + instanceDict[name] = token + gotName = false + } else { + name = token + if len(name) == 0 { + if len(instanceDict) == 0 { + break + } + // Only add if InstanceName key exists and is non-empty + if instName, ok := instanceDict["InstanceName"]; ok && instName != "" { + results[strings.ToUpper(instName)] = instanceDict + } + instanceDict = map[string]string{} + continue + } + gotName = true + } + } + } + return results +} diff --git a/pkg/sqlcmd/serverlist_test.go b/pkg/sqlcmd/serverlist_test.go new file mode 100644 index 00000000..6c50ee72 --- /dev/null +++ b/pkg/sqlcmd/serverlist_test.go @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestListLocalServers(t *testing.T) { + // Test that ListLocalServers writes to the provided writer without error + // Note: actual server discovery depends on SQL Browser service availability + var buf bytes.Buffer + ListLocalServers(&buf) + // We can't assert specific content since it depends on environment, + // but we verify it doesn't panic and writes valid output + t.Logf("ListLocalServers output: %q", buf.String()) +} + +func TestGetLocalServerInstances(t *testing.T) { + // Test that GetLocalServerInstances returns a slice (may be empty if no servers) + instances, err := GetLocalServerInstances() + // instances may be nil or empty if no SQL Browser is running, that's OK + // err may be non-nil for non-timeout network errors + if err != nil { + t.Logf("GetLocalServerInstances returned error (expected in some environments): %v", err) + } + t.Logf("Found %d instances", len(instances)) + for _, inst := range instances { + assert.NotEmpty(t, inst, "Instance name should not be empty") + } +} + +func TestParseInstances(t *testing.T) { + // Test parsing of SQL Browser response + // Format: 0x05 (response type), 2 bytes length, then semicolon-separated key=value pairs + // Each instance ends with two semicolons + + t.Run("empty response", func(t *testing.T) { + result := parseInstances([]byte{}) + assert.Empty(t, result) + }) + + t.Run("invalid header", func(t *testing.T) { + result := parseInstances([]byte{1, 0, 0}) + assert.Empty(t, result) + }) + + t.Run("valid single instance", func(t *testing.T) { + // Simulating SQL Browser response format + // Header: 0x05 followed by 2 length bytes, then the instance data + data := []byte{5, 0, 0} + instanceData := "ServerName;MYSERVER;InstanceName;MSSQLSERVER;IsClustered;No;Version;15.0.2000.5;tcp;1433;;" + data = append(data, []byte(instanceData)...) + + result := parseInstances(data) + assert.Len(t, result, 1) + assert.Contains(t, result, "MSSQLSERVER") + assert.Equal(t, "MYSERVER", result["MSSQLSERVER"]["ServerName"]) + assert.Equal(t, "1433", result["MSSQLSERVER"]["tcp"]) + }) + + t.Run("valid multiple instances", func(t *testing.T) { + data := []byte{5, 0, 0} + instanceData := "ServerName;MYSERVER;InstanceName;MSSQLSERVER;tcp;1433;;ServerName;MYSERVER;InstanceName;SQLEXPRESS;tcp;1434;;" + data = append(data, []byte(instanceData)...) + + result := parseInstances(data) + assert.Len(t, result, 2) + assert.Contains(t, result, "MSSQLSERVER") + assert.Contains(t, result, "SQLEXPRESS") + }) +} + +func TestServerlistCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer func() { _ = buf.Close() }() + + // Run the serverlist command + c := []string{":serverlist"} + err := runSqlCmd(t, s, c) + + // The command should not raise an error even if no servers are found + assert.NoError(t, err, ":serverlist should not raise error") + // Output may be empty if no SQL Browser is running + t.Logf("Serverlist output: %q", buf.buf.String()) +}