From 3385608b4a054d559a101ed517a054db841635a7 Mon Sep 17 00:00:00 2001 From: Sylvie Crowe <107814465+oneirocosm@users.noreply.github.com> Date: Mon, 15 Jul 2024 18:00:10 -0700 Subject: [PATCH] SSH Port (#111) This enables basic ssh for connections using publickey auth without a passphrase. It can be established by creating a widget with the "meta" property set to ``` { "connection": "@:" } ``` where the : is optional. --------- Co-authored-by: sawka --- go.mod | 8 +- go.sum | 12 +- pkg/blockcontroller/blockcontroller.go | 17 +- pkg/remote/sshclient.go | 725 +++++++++++++++++++++++++ pkg/shellexec/conninterface.go | 92 ++++ pkg/shellexec/shellexec.go | 90 ++- 6 files changed, 931 insertions(+), 13 deletions(-) create mode 100644 pkg/remote/sshclient.go create mode 100644 pkg/shellexec/conninterface.go diff --git a/go.mod b/go.mod index 30369e5d1..2017e7d09 100644 --- a/go.mod +++ b/go.mod @@ -14,12 +14,14 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/jmoiron/sqlx v1.4.0 + github.com/kevinburke/ssh_config v1.2.0 github.com/mattn/go-sqlite3 v1.14.22 github.com/mitchellh/mapstructure v1.5.0 github.com/sawka/txwrap v0.2.0 github.com/spf13/cobra v1.8.0 github.com/wavetermdev/waveterm/wavesrv v0.0.0-20240508181017-d07068c09d94 - golang.org/x/term v0.17.0 + golang.org/x/crypto v0.25.0 + golang.org/x/term v0.22.0 ) require ( @@ -30,5 +32,7 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/testify v1.8.4 // indirect go.uber.org/atomic v1.7.0 // indirect - golang.org/x/sys v0.20.0 // indirect + golang.org/x/sys v0.22.0 // indirect ) + +replace github.com/kevinburke/ssh_config => github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 diff --git a/go.sum b/go.sum index 187bd8cef..0b4a525f5 100644 --- a/go.sum +++ b/go.sum @@ -53,15 +53,19 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 h1:onqZrJVap1sm15AiIGTfWzdr6cEF0KdtddeuuOVhzyY= +github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= github.com/wavetermdev/waveterm/wavesrv v0.0.0-20240508181017-d07068c09d94 h1:/SPCxd4KHlS4eRTreYEXWFRr8WfRFBcChlV5cgkaO58= github.com/wavetermdev/waveterm/wavesrv v0.0.0-20240508181017-d07068c09d94/go.mod h1:ywoo7DXdYueQ0tTPhVoB+wzRTgERSE19EA3mR6KGRaI= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= +golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index d5349b382..be003c5a3 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -285,9 +285,20 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta map[str } else { return fmt.Errorf("unknown controller type %q", bc.ControllerType) } - shellProc, err := shellexec.StartShellProc(rc.TermSize, cmdStr, cmdOpts) - if err != nil { - return err + // pty buffer equivalent for ssh? i think if i have the ecmd or session i can manage it with output + // pty write needs stdin, so if i provide that, i might be able to write that way + // need a way to handle setsize??? + var shellProc *shellexec.ShellProc + if remoteName, ok := blockMeta["connection"].(string); ok && remoteName != "" { + shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, remoteName) + if err != nil { + return err + } + } else { + shellProc, err = shellexec.StartShellProc(rc.TermSize, cmdStr, cmdOpts) + if err != nil { + return err + } } bc.UpdateControllerAndSendUpdate(func() bool { bc.ShellProc = shellProc diff --git a/pkg/remote/sshclient.go b/pkg/remote/sshclient.go new file mode 100644 index 000000000..7d4578add --- /dev/null +++ b/pkg/remote/sshclient.go @@ -0,0 +1,725 @@ +// Copyright 2023-2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "fmt" + "log" + "net" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "sync" + + "github.com/kevinburke/ssh_config" + "github.com/wavetermdev/thenextwave/pkg/wavebase" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +type UserInputCancelError struct { + Err error +} + +func (uice UserInputCancelError) Error() string { + return uice.Err.Error() +} + +// This exists to trick the ssh library into continuing to try +// different public keys even when the current key cannot be +// properly parsed +func createDummySigner() ([]ssh.Signer, error) { + dummyKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + dummySigner, err := ssh.NewSignerFromKey(dummyKey) + if err != nil { + return nil, err + } + return []ssh.Signer{dummySigner}, nil + +} + +// This is a workaround to only process one identity file at a time, +// even if they have passphrases. It must be combined with retryable +// authentication to work properly +// +// Despite returning an array of signers, we only ever provide one since +// it allows proper user interaction in between attempts +// +// A significant number of errors end up returning dummy values as if +// they were successes. An error in this function prevents any other +// keys from being attempted. But if there's an error because of a dummy +// file, the library can still try again with a new key. +func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, passphrase string) func() ([]ssh.Signer, error) { + var identityFiles []string + existingKeys := make(map[string][]byte) + + // checking the file early prevents us from needing to send a + // dummy signer if there's a problem with the signer + for _, identityFile := range sshKeywords.IdentityFile { + privateKey, err := os.ReadFile(wavebase.ExpandHomeDir(identityFile)) + if err != nil { + // skip this key and try with the next + continue + } + existingKeys[identityFile] = privateKey + identityFiles = append(identityFiles, identityFile) + } + // require pointer to modify list in closure + identityFilesPtr := &identityFiles + + return func() ([]ssh.Signer, error) { + if len(*identityFilesPtr) == 0 { + return nil, fmt.Errorf("no identity files remaining") + } + identityFile := (*identityFilesPtr)[0] + *identityFilesPtr = (*identityFilesPtr)[1:] + privateKey, ok := existingKeys[identityFile] + if !ok { + log.Printf("error with existingKeys, this should never happen") + // skip this key and try with the next + return createDummySigner() + } + signer, err := ssh.ParsePrivateKey(privateKey) + if err == nil { + return []ssh.Signer{signer}, err + } + if _, ok := err.(*ssh.PassphraseMissingError); !ok { + // skip this key and try with the next + return createDummySigner() + } + + signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase)) + if err == nil { + return []ssh.Signer{signer}, err + } + if err != x509.IncorrectPasswordError && err.Error() != "bcrypt_pbkdf: empty password" { + // skip this key and try with the next + return createDummySigner() + } + + // batch mode deactivates user input + if sshKeywords.BatchMode { + // skip this key and try with the next + return createDummySigner() + } + + return nil, fmt.Errorf("unimplemented: userinput createPublicKeyCallback") //todo + /* + request := &userinput.UserInputRequestType{ + ResponseType: "text", + QueryText: fmt.Sprintf("Enter passphrase for the SSH key: %s", identityFile), + Title: "Publickey Auth + Passphrase", + } + ctx, cancelFn := context.WithTimeout(connCtx, 60*time.Second) + defer cancelFn() + response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request) + if err != nil { + // this is an error where we actually do want to stop + // trying keys + return nil, UserInputCancelError{Err: err} + } + signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(response.Text)) + if err != nil { + // skip this key and try with the next + return createDummySigner() + } + return []ssh.Signer{signer}, err + */ + } +} + +func createDefaultPasswordCallbackPrompt(password string) func() (secret string, err error) { + return func() (secret string, err error) { + // this should be modified to return an error if no password is stored + // but an empty password is not sufficient because some systems allow + // empty passwords + return password, nil + } +} + +func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisplayName string) func() (secret string, err error) { + return func() (secret string, err error) { + return "", fmt.Errorf("unimplemented: userinput createInteractivePasswordCallbackPrompt") //todo + /* + ctx, cancelFn := context.WithTimeout(connCtx, 60*time.Second) + defer cancelFn() + queryText := fmt.Sprintf( + "Password Authentication requested from connection \n"+ + "%s\n\n"+ + "Password:", remoteDisplayName) + request := &userinput.UserInputRequestType{ + ResponseType: "text", + QueryText: queryText, + Markdown: true, + Title: "Password Authentication", + } + response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request) + if err != nil { + return "", err + } + return response.Text, nil + */ + } +} + +func createCombinedPasswordCallbackPrompt(connCtx context.Context, password string, remoteDisplayName string) func() (secret string, err error) { + var once sync.Once + return func() (secret string, err error) { + var prompt func() (secret string, err error) + once.Do(func() { prompt = createDefaultPasswordCallbackPrompt(password) }) + if prompt == nil { + prompt = createInteractivePasswordCallbackPrompt(connCtx, remoteDisplayName) + } + return prompt() + } +} + +func createNaiveKbdInteractiveChallenge(password string) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + for _, q := range questions { + if strings.Contains(strings.ToLower(q), "password") { + answers = append(answers, password) + } else { + answers = append(answers, "") + } + } + return answers, nil + } +} + +func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteName string) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) != len(echos) { + return nil, fmt.Errorf("bad response from server: questions has len %d, echos has len %d", len(questions), len(echos)) + } + for i, question := range questions { + echo := echos[i] + answer, err := promptChallengeQuestion(connCtx, question, echo, remoteName) + if err != nil { + return nil, err + } + answers = append(answers, answer) + } + return answers, nil + } +} + +func promptChallengeQuestion(connCtx context.Context, question string, echo bool, remoteName string) (answer string, err error) { + // limited to 15 seconds for some reason. this should be investigated more + // in the future + return "", fmt.Errorf("unimplemented: userinput promptChallengeQuestion") //todo + /* + ctx, cancelFn := context.WithTimeout(connCtx, 60*time.Second) + defer cancelFn() + queryText := fmt.Sprintf( + "Keyboard Interactive Authentication requested from connection \n"+ + "%s\n\n"+ + "%s", remoteName, question) + request := &userinput.UserInputRequestType{ + ResponseType: "text", + QueryText: queryText, + Markdown: true, + Title: "Keyboard Interactive Authentication", + PublicText: echo, + } + response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request) + if err != nil { + return "", err + } + return response.Text, nil + */ +} + +func createCombinedKbdInteractiveChallenge(connCtx context.Context, password string, remoteName string) ssh.KeyboardInteractiveChallenge { + var once sync.Once + return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + var challenge ssh.KeyboardInteractiveChallenge + once.Do(func() { challenge = createNaiveKbdInteractiveChallenge(password) }) + if challenge == nil { + challenge = createInteractiveKbdInteractiveChallenge(connCtx, remoteName) + } + return challenge(name, instruction, questions, echos) + } +} + +func openKnownHostsForEdit(knownHostsFilename string) (*os.File, error) { + path, _ := filepath.Split(knownHostsFilename) + err := os.MkdirAll(path, 0700) + if err != nil { + return nil, err + } + return os.OpenFile(knownHostsFilename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) +} + +/* +func writeToKnownHosts(knownHostsFile string, newLine string, getUserVerification func() (*userinput.UserInputResponsePacketType, error)) error { + if getUserVerification == nil { + getUserVerification = func() (*userinput.UserInputResponsePacketType, error) { + return &userinput.UserInputResponsePacketType{ + Type: "confirm", + Confirm: true, + }, nil + } + } + + path, _ := filepath.Split(knownHostsFile) + err := os.MkdirAll(path, 0700) + if err != nil { + return err + } + f, err := os.OpenFile(knownHostsFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return err + } + // do not close writeable files with defer + + // this file works, so let's ask the user for permission + response, err := getUserVerification() + if err != nil { + f.Close() + return UserInputCancelError{Err: err} + } + if !response.Confirm { + f.Close() + return UserInputCancelError{Err: fmt.Errorf("canceled by the user")} + } + + _, err = f.WriteString(newLine + "\n") + if err != nil { + f.Close() + return err + } + return f.Close() +} +*/ + +/* todo +func createUnknownKeyVerifier(knownHostsFile string, hostname string, remote string, key ssh.PublicKey) func() (*userinput.UserInputResponsePacketType, error) { + base64Key := base64.StdEncoding.EncodeToString(key.Marshal()) + queryText := fmt.Sprintf( + "The authenticity of host '%s (%s)' can't be established "+ + "as it **does not exist in any checked known_hosts files**. "+ + "The host you are attempting to connect to provides this %s key: \n"+ + "%s.\n\n"+ + "**Would you like to continue connecting?** If so, the key will be permanently "+ + "added to the file %s "+ + "to protect from future man-in-the-middle attacks.", hostname, remote, key.Type(), base64Key, knownHostsFile) + request := &userinput.UserInputRequestType{ + ResponseType: "confirm", + QueryText: queryText, + Markdown: true, + Title: "Known Hosts Key Missing", + } + return func() (*userinput.UserInputResponsePacketType, error) { + ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelFn() + return userinput.GetUserInput(ctx, scbus.MainRpcBus, request) + } +} +*/ + +/* +func createMissingKnownHostsVerifier(knownHostsFile string, hostname string, remote string, key ssh.PublicKey) func() (*userinput.UserInputResponsePacketType, error) { + base64Key := base64.StdEncoding.EncodeToString(key.Marshal()) + queryText := fmt.Sprintf( + "The authenticity of host '%s (%s)' can't be established "+ + "as **no known_hosts files could be found**. "+ + "The host you are attempting to connect to provides this %s key: \n"+ + "%s.\n\n"+ + "**Would you like to continue connecting?** If so: \n"+ + "- %s will be created \n"+ + "- the key will be added to %s\n\n"+ + "This will protect from future man-in-the-middle attacks.", hostname, remote, key.Type(), base64Key, knownHostsFile, knownHostsFile) + request := &userinput.UserInputRequestType{ + ResponseType: "confirm", + QueryText: queryText, + Markdown: true, + Title: "Known Hosts File Missing", + } + return func() (*userinput.UserInputResponsePacketType, error) { + ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelFn() + return userinput.GetUserInput(ctx, scbus.MainRpcBus, request) + } +} +*/ + +func lineContainsMatch(line []byte, matches [][]byte) bool { + for _, match := range matches { + if bytes.Contains(line, match) { + return true + } + } + return false +} + +func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { + ssh_config.ReloadConfigs() + rawUserKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "UserKnownHostsFile") + userKnownHostsFiles := strings.Fields(rawUserKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes + rawGlobalKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "GlobalKnownHostsFile") + globalKnownHostsFiles := strings.Fields(rawGlobalKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes + + osUser, err := user.Current() + if err != nil { + return nil, err + } + var unexpandedKnownHostsFiles []string + if osUser.Username == "root" { + unexpandedKnownHostsFiles = globalKnownHostsFiles + } else { + unexpandedKnownHostsFiles = append(userKnownHostsFiles, globalKnownHostsFiles...) + } + + var knownHostsFiles []string + for _, filename := range unexpandedKnownHostsFiles { + knownHostsFiles = append(knownHostsFiles, wavebase.ExpandHomeDir(filename)) + } + + // there are no good known hosts files + if len(knownHostsFiles) == 0 { + return nil, fmt.Errorf("no known_hosts files provided by ssh. defaults are overridden") + } + + var unreadableFiles []string + + // the library we use isn't very forgiving about files that are formatted + // incorrectly. if a problem file is found, it is removed from our list + // and we try again + var basicCallback ssh.HostKeyCallback + for basicCallback == nil && len(knownHostsFiles) > 0 { + var err error + basicCallback, err = knownhosts.New(knownHostsFiles...) + if serr, ok := err.(*os.PathError); ok { + badFile := serr.Path + unreadableFiles = append(unreadableFiles, badFile) + var okFiles []string + for _, filename := range knownHostsFiles { + if filename != badFile { + okFiles = append(okFiles, filename) + } + } + if len(okFiles) >= len(knownHostsFiles) { + return nil, fmt.Errorf("problem file (%s) doesn't exist. this should not be possible", badFile) + } + knownHostsFiles = okFiles + } else if err != nil { + // TODO handle obscure problems if possible + return nil, fmt.Errorf("known_hosts formatting error: %+v", err) + } + } + + if basicCallback == nil { + basicCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return &knownhosts.KeyError{} + } + } + + waveHostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error { + err := basicCallback(hostname, remote, key) + if err == nil { + // success + return nil + } else if _, ok := err.(*knownhosts.RevokedError); ok { + // revoked credentials are refused outright + return err + } else if _, ok := err.(*knownhosts.KeyError); !ok { + // this is an unknown error (note the !ok is opposite of usual) + return err + } + serr, _ := err.(*knownhosts.KeyError) + if len(serr.Want) == 0 { + // the key was not found + + // try to write to a file that could be read + //err := fmt.Errorf("placeholder, should not be returned") // a null value here can cause problems with empty slice + return fmt.Errorf("unimplemented: waveHostKeyCallback key not found") //todo + /* + for _, filename := range knownHostsFiles { + newLine := knownhosts.Line([]string{knownhosts.Normalize(hostname)}, key) + getUserVerification := createUnknownKeyVerifier(filename, hostname, remote.String(), key) + err = writeToKnownHosts(filename, newLine, getUserVerification) + if err == nil { + break + } + if serr, ok := err.(UserInputCancelError); ok { + return serr + } + } + + // try to write to a file that could not be read (file likely doesn't exist) + // should catch cases where there is no known_hosts file + if err != nil { + for _, filename := range unreadableFiles { + newLine := knownhosts.Line([]string{knownhosts.Normalize(hostname)}, key) + getUserVerification := createMissingKnownHostsVerifier(filename, hostname, remote.String(), key) + err = writeToKnownHosts(filename, newLine, getUserVerification) + if err == nil { + knownHostsFiles = []string{filename} + break + } + if serr, ok := err.(UserInputCancelError); ok { + return serr + } + } + } + if err != nil { + return fmt.Errorf("unable to create new knownhost key: %e", err) + } + */ + } else { + // the key changed + correctKeyFingerprint := base64.StdEncoding.EncodeToString(key.Marshal()) + var bulletListKnownHosts []string + for _, knownHostName := range knownHostsFiles { + withBulletPoint := "- " + knownHostName + bulletListKnownHosts = append(bulletListKnownHosts, withBulletPoint) + } + var offendingKeysFmt []string + for _, badKey := range serr.Want { + formattedKey := "- " + base64.StdEncoding.EncodeToString(badKey.Key.Marshal()) + offendingKeysFmt = append(offendingKeysFmt, formattedKey) + } + // todo + _ = fmt.Sprintf("**WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!**\n\n"+ + "If this is not expected, it is possible that someone could be trying to "+ + "eavesdrop on you via a man-in-the-middle attack. "+ + "Alternatively, the host you are connecting to may have changed its key. "+ + "The %s key sent by the remote hist has the fingerprint: \n"+ + "%s\n\n"+ + "If you are sure this is correct, please update your known_hosts files to "+ + "remove the lines with the offending before trying to connect again. \n"+ + "**Known Hosts Files** \n"+ + "%s\n\n"+ + "**Offending Keys** \n"+ + "%s", key.Type(), correctKeyFingerprint, strings.Join(bulletListKnownHosts, " \n"), strings.Join(offendingKeysFmt, " \n")) + //update := scbus.MakeUpdatePacket() + // create update into alert message + + //send update via bus? + return fmt.Errorf("remote host identification has changed") + } + + updatedCallback, err := knownhosts.New(knownHostsFiles...) + if err != nil { + return err + } + // try one final time + return updatedCallback(hostname, remote, key) + } + + return waveHostKeyCallback, nil +} + +func DialContext(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + d := net.Dialer{Timeout: config.Timeout} + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return ssh.NewClient(c, chans, reqs), nil +} + +func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error) { + sshConfigKeywords, err := findSshConfigKeywords(opts.SSHHost) + if err != nil { + return nil, err + } + + sshKeywords, err := combineSshKeywords(opts, sshConfigKeywords) + if err != nil { + return nil, err + } + remoteName := sshKeywords.User + "@" + knownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port) + + publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, "")) + keyboardInteractive := ssh.KeyboardInteractive(createCombinedKbdInteractiveChallenge(connCtx, "", remoteName)) + passwordCallback := ssh.PasswordCallback(createCombinedPasswordCallbackPrompt(connCtx, "", remoteName)) + + // batch mode turns off interactive input. this means the number of + // attemtps must drop to 1 with this setup + var attemptsAllowed int + if sshKeywords.BatchMode { + attemptsAllowed = 1 + } else { + attemptsAllowed = 2 + } + + // exclude gssapi-with-mic and hostbased until implemented + authMethodMap := map[string]ssh.AuthMethod{ + "publickey": ssh.RetryableAuthMethod(publicKeyCallback, len(sshKeywords.IdentityFile)), + "keyboard-interactive": ssh.RetryableAuthMethod(keyboardInteractive, attemptsAllowed), + "password": ssh.RetryableAuthMethod(passwordCallback, attemptsAllowed), + } + + authMethodActiveMap := map[string]bool{ + "publickey": sshKeywords.PubkeyAuthentication, + "keyboard-interactive": sshKeywords.KbdInteractiveAuthentication, + "password": sshKeywords.PasswordAuthentication, + } + + var authMethods []ssh.AuthMethod + for _, authMethodName := range sshKeywords.PreferredAuthentications { + authMethodActive, ok := authMethodActiveMap[authMethodName] + if !ok || !authMethodActive { + continue + } + authMethod, ok := authMethodMap[authMethodName] + if !ok { + continue + } + authMethods = append(authMethods, authMethod) + } + + hostKeyCallback, err := createHostKeyCallback(opts) + if err != nil { + return nil, err + } + + clientConfig := &ssh.ClientConfig{ + User: sshKeywords.User, + Auth: authMethods, + HostKeyCallback: hostKeyCallback, + } + networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port + return DialContext(connCtx, "tcp", networkAddr, clientConfig) +} + +type SshKeywords struct { + User string + HostName string + Port string + IdentityFile []string + BatchMode bool + PubkeyAuthentication bool + PasswordAuthentication bool + KbdInteractiveAuthentication bool + PreferredAuthentications []string +} + +func combineSshKeywords(opts *SSHOpts, configKeywords *SshKeywords) (*SshKeywords, error) { + sshKeywords := &SshKeywords{} + + if opts.SSHUser != "" { + sshKeywords.User = opts.SSHUser + } else if configKeywords.User != "" { + sshKeywords.User = configKeywords.User + } else { + user, err := user.Current() + if err != nil { + return nil, fmt.Errorf("failed to get user for ssh: %+v", err) + } + sshKeywords.User = user.Username + } + + // we have to check the host value because of the weird way + // we store the pattern as the hostname for imported remotes + if configKeywords.HostName != "" { + sshKeywords.HostName = configKeywords.HostName + } else { + sshKeywords.HostName = opts.SSHHost + } + + if opts.SSHPort != 0 && opts.SSHPort != 22 { + sshKeywords.Port = strconv.Itoa(opts.SSHPort) + } else if configKeywords.Port != "" && configKeywords.Port != "22" { + sshKeywords.Port = configKeywords.Port + } else { + sshKeywords.Port = "22" + } + + sshKeywords.IdentityFile = configKeywords.IdentityFile + + // these are not officially supported in the waveterm frontend but can be configured + // in ssh config files + sshKeywords.BatchMode = configKeywords.BatchMode + sshKeywords.PubkeyAuthentication = configKeywords.PubkeyAuthentication + sshKeywords.PasswordAuthentication = configKeywords.PasswordAuthentication + sshKeywords.KbdInteractiveAuthentication = configKeywords.KbdInteractiveAuthentication + sshKeywords.PreferredAuthentications = configKeywords.PreferredAuthentications + + return sshKeywords, nil +} + +// note that a `var == "yes"` will default to false +// but `var != "no"` will default to true +// when given unexpected strings +func findSshConfigKeywords(hostPattern string) (*SshKeywords, error) { + ssh_config.ReloadConfigs() + sshKeywords := &SshKeywords{} + var err error + + sshKeywords.User, err = ssh_config.GetStrict(hostPattern, "User") + if err != nil { + return nil, err + } + + sshKeywords.HostName, err = ssh_config.GetStrict(hostPattern, "HostName") + if err != nil { + return nil, err + } + + sshKeywords.Port, err = ssh_config.GetStrict(hostPattern, "Port") + if err != nil { + return nil, err + } + + sshKeywords.IdentityFile = ssh_config.GetAll(hostPattern, "IdentityFile") + + batchModeRaw, err := ssh_config.GetStrict(hostPattern, "BatchMode") + if err != nil { + return nil, err + } + sshKeywords.BatchMode = (strings.ToLower(batchModeRaw) == "yes") + + // we currently do not support host-bound or unbound but will use yes when they are selected + pubkeyAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "PubkeyAuthentication") + if err != nil { + return nil, err + } + sshKeywords.PubkeyAuthentication = (strings.ToLower(pubkeyAuthenticationRaw) != "no") + + passwordAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "PasswordAuthentication") + if err != nil { + return nil, err + } + sshKeywords.PasswordAuthentication = (strings.ToLower(passwordAuthenticationRaw) != "no") + + kbdInteractiveAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "KbdInteractiveAuthentication") + if err != nil { + return nil, err + } + sshKeywords.KbdInteractiveAuthentication = (strings.ToLower(kbdInteractiveAuthenticationRaw) != "no") + + // these are parsed as a single string and must be separated + // these are case sensitive in openssh so they are here too + preferredAuthenticationsRaw, err := ssh_config.GetStrict(hostPattern, "PreferredAuthentications") + if err != nil { + return nil, err + } + sshKeywords.PreferredAuthentications = strings.Split(preferredAuthenticationsRaw, ",") + + return sshKeywords, nil +} + +type SSHOpts struct { + SSHHost string `json:"sshhost"` + SSHUser string `json:"sshuser"` + SSHPort int `json:"sshport,omitempty"` +} diff --git a/pkg/shellexec/conninterface.go b/pkg/shellexec/conninterface.go new file mode 100644 index 000000000..178af27ee --- /dev/null +++ b/pkg/shellexec/conninterface.go @@ -0,0 +1,92 @@ +package shellexec + +import ( + "io" + "os" + "os/exec" + + "golang.org/x/crypto/ssh" +) + +type ConnInterface interface { + Kill() + Wait() error + Start() error + StdinPipe() (io.WriteCloser, error) + StdoutPipe() (io.ReadCloser, error) + StderrPipe() (io.ReadCloser, error) +} + +type CmdWrap struct { + Cmd *exec.Cmd +} + +func (cw CmdWrap) Kill() { + cw.Cmd.Process.Kill() +} + +func (cw CmdWrap) Wait() error { + return cw.Cmd.Wait() +} + +func (cw CmdWrap) Start() error { + defer func() { + for _, extraFile := range cw.Cmd.ExtraFiles { + if extraFile != nil { + extraFile.Close() + } + } + }() + return cw.Cmd.Start() +} + +func (cw CmdWrap) StdinPipe() (io.WriteCloser, error) { + return cw.Cmd.StdinPipe() +} + +func (cw CmdWrap) StdoutPipe() (io.ReadCloser, error) { + return cw.Cmd.StdoutPipe() +} + +func (cw CmdWrap) StderrPipe() (io.ReadCloser, error) { + return cw.Cmd.StderrPipe() +} + +type SessionWrap struct { + Session *ssh.Session + StartCmd string + Tty *os.File +} + +func (sw SessionWrap) Kill() { + sw.Tty.Close() + sw.Session.Close() +} + +func (sw SessionWrap) Wait() error { + return sw.Session.Wait() +} + +func (sw SessionWrap) Start() error { + return sw.Session.Start(sw.StartCmd) +} + +func (sw SessionWrap) StdinPipe() (io.WriteCloser, error) { + return sw.Session.StdinPipe() +} + +func (sw SessionWrap) StdoutPipe() (io.ReadCloser, error) { + stdoutReader, err := sw.Session.StdoutPipe() + if err != nil { + return nil, err + } + return io.NopCloser(stdoutReader), nil +} + +func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) { + stderrReader, err := sw.Session.StderrPipe() + if err != nil { + return nil, err + } + return io.NopCloser(stderrReader), nil +} diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index 450d7681d..db97b56c5 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -5,17 +5,24 @@ package shellexec import ( "bytes" + "context" "fmt" "io" + "log" "os" "os/exec" "reflect" + "regexp" + "strconv" + "strings" "sync" "syscall" "github.com/creack/pty" + "github.com/wavetermdev/thenextwave/pkg/remote" "github.com/wavetermdev/thenextwave/pkg/util/shellutil" "github.com/wavetermdev/thenextwave/pkg/wavebase" + "golang.org/x/term" ) type TermSize struct { @@ -31,7 +38,7 @@ type CommandOptsType struct { } type ShellProc struct { - Cmd *exec.Cmd + Cmd ConnInterface Pty *os.File CloseOnce *sync.Once DoneCh chan any // closed after proc.Wait() returns @@ -39,9 +46,9 @@ type ShellProc struct { } func (sp *ShellProc) Close() { - sp.Cmd.Process.Kill() + sp.Cmd.Kill() go func() { - _, waitErr := sp.Cmd.Process.Wait() + waitErr := sp.Cmd.Wait() sp.SetWaitErrorAndSignalDone(waitErr) sp.Pty.Close() }() @@ -104,6 +111,81 @@ func checkCwd(cwd string) error { return nil } +var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`) + +func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType, remoteName string) (*ShellProc, error) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var shellPath string + if cmdStr == "" { + shellPath = "/bin/bash" + } else { + shellPath = cmdStr + } + + var shellOpts []string + if cmdOpts.Login { + shellOpts = append(shellOpts, "-l") + } + if cmdOpts.Interactive { + shellOpts = append(shellOpts, "-i") + } + cmdCombined := fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " ")) + log.Print(cmdCombined) + m := userHostRe.FindStringSubmatch(remoteName) + if m == nil { + return nil, fmt.Errorf("invalid format of user@host argument") + } + remoteUser, remoteHost, remotePortStr := m[1], m[2], m[3] + remoteUser = strings.Trim(remoteUser, "@") + var remotePort int + if remotePortStr != "" { + var err error + remotePort, err = strconv.Atoi(remotePortStr) + if err != nil { + return nil, fmt.Errorf("invalid port specified on user@host argument") + } + } + + client, err := remote.ConnectToClient(ctx, &remote.SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}) //todo specify or remove opts + if err != nil { + return nil, err + } + session, err := client.NewSession() + if err != nil { + return nil, err + } + // todo: connect pty output, etc + // redirect to fake pty??? + + cmdPty, cmdTty, err := pty.Open() + if err != nil { + return nil, fmt.Errorf("opening new pty: %w", err) + } + term.MakeRaw(int(cmdTty.Fd())) + if termSize.Rows == 0 || termSize.Cols == 0 { + termSize.Rows = shellutil.DefaultTermRows + termSize.Cols = shellutil.DefaultTermCols + } + if termSize.Rows <= 0 || termSize.Cols <= 0 { + return nil, fmt.Errorf("invalid term size: %v", termSize) + } + pty.Setsize(cmdPty, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)}) + session.Stdin = cmdTty + session.Stdout = cmdTty + session.Stderr = cmdTty + session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil) + + sessionWrap := SessionWrap{session, cmdCombined, cmdTty} + err = sessionWrap.Start() + if err != nil { + cmdPty.Close() + return nil, err + } + return &ShellProc{Cmd: sessionWrap, Pty: cmdPty, CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil +} + func StartShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) { var ecmd *exec.Cmd var shellOpts []string @@ -156,7 +238,7 @@ func StartShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType) ( cmdPty.Close() return nil, err } - return &ShellProc{Cmd: ecmd, Pty: cmdPty, CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil + return &ShellProc{Cmd: CmdWrap{ecmd}, Pty: cmdPty, CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil } func RunSimpleCmdInPty(ecmd *exec.Cmd, termSize TermSize) ([]byte, error) {