From 18f000174b291b76f28fef064f7c24b532fd2805 Mon Sep 17 00:00:00 2001 From: David Shiflet Date: Fri, 1 Dec 2023 21:42:47 -0600 Subject: [PATCH] implement -L for local instances (#483) --- cmd/sqlcmd/sqlcmd.go | 79 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 77891cb1..1b9d7b20 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -5,15 +5,19 @@ package sqlcmd import ( + "context" "errors" "fmt" + "net" "os" "regexp" "strconv" "strings" + "time" mssql "github.com/microsoft/go-mssqldb" "github.com/microsoft/go-mssqldb/azuread" + "github.com/microsoft/go-mssqldb/msdsn" "github.com/microsoft/go-sqlcmd/internal/localizer" "github.com/microsoft/go-sqlcmd/pkg/console" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" @@ -216,7 +220,7 @@ func Execute(version string) { fmt.Println() fmt.Println(localizer.Sprintf("Servers:")) } - fmt.Println(" ;UID:Login ID=?;PWD:Password=?;Trusted_Connection:Use Integrated Security=?;*APP:AppName=?;*WSID:WorkStation ID=?;") + listLocalServers() os.Exit(0) } if len(argss) > 0 { @@ -820,3 +824,76 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { s.SetError(nil) return s.Exitcode, err } + +func listLocalServers() { + bmsg := []byte{byte(msdsn.BrowserAllInstances)} + resp := make([]byte, 16*1024-1) + dialer := &net.Dialer{} + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + conn, err := dialer.DialContext(ctx, "udp", ":1434") + // silently ignore failures to connect, same as ODBC + if err != nil { + return + } + defer conn.Close() + dl, _ := ctx.Deadline() + _ = conn.SetDeadline(dl) + _, err = conn.Write(bmsg) + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + fmt.Println(err) + } + return + } + read, err := conn.Read(resp) + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + fmt.Println(err) + } + return + } + + data := parseInstances(resp[:read]) + instances := make([]string, 0, len(data)) + for s := range data { + if s == "MSSQLSERVER" { + + instances = append(instances, "(local)", data[s]["ServerName"]) + } else { + instances = append(instances, fmt.Sprintf(`%s\%s`, data[s]["ServerName"], s)) + } + } + for _, s := range instances { + fmt.Println(" ", s) + } +} + +func parseInstances(msg []byte) msdsn.BrowserData { + results := msdsn.BrowserData{} + if len(msg) > 3 && msg[0] == 5 { + out_s := string(msg[3:]) + tokens := strings.Split(out_s, ";") + instdict := map[string]string{} + got_name := false + var name string + for _, token := range tokens { + if got_name { + instdict[name] = token + got_name = false + } else { + name = token + if len(name) == 0 { + if len(instdict) == 0 { + break + } + results[strings.ToUpper(instdict["InstanceName"])] = instdict + instdict = map[string]string{} + continue + } + got_name = true + } + } + } + return results +}