Integrate SSH Library with Waveshell Installation/Auto-update (#322)

* refactor launch code to integrate install easier

The previous set up of launch was difficult to navigate. This makes it
much clearer which will make the auto install flow easier to manage.

* feat: integrate auto install into new ssh setup

This change makes it possible to auto install using the ssh library
instead of making a call to the ssh cli command. This will auto install
if the installed waveshell version is incorrect or cannot be found.

* chore: clean up some lints for sshclient

There was a context that didn't have it's cancel function deferred and
an error that wasn't being handle. They're fixed now.

* fix: disconnect client if requested or launch fail

A recent commit made it so a client remained part of the MShellProc
after being disconnected. This is undesireable since a manual
disconnection indicates that the user will need to enter their
credentials again if required. Similarly, if the launch fails with an
error, the expectation is that credentials will have to be entered
again.

* fix: use legacy timer for the time being

The legacy timer frustrates me because it adds a lot of state to the
MShellProc struct that is complicated to manage. But it currently works,
so I will be keeping it for the time being.

* fix: change separator between remoteref and name

With the inclusion of the port number in the canonical id, the :
separator between the remoteref and remote name causes problems if the
port is parsed instead. This changes it to a # in order to avoid this
conflict.

* fix: check for null when closing extra files

It is possible for the list of extra files to contain null files. This
change ensures the null files will not be erroneously closed.

* fix: change connecting method to show port once

With port added to the canonicalname, it no longer makes sense to append
the port afterward.

* feat: use user input modal for sudo connection

The sudo connection used to have a unique way of entering a password.
This change provides an alternative method using the user input modal
that the other connection methods use. It does not work perfectly with
this revision, but the basic building blocks are in place. It needs a
few timer updates to be complete.

* fix: remove old timer to prevent conflicts with it

With this change the old timer is no longer needed. It is not fully
removed yet, but it is disabled so as to not get in the way.
Additionally, error handling has been slightly improved.

There is still a bug where an incorrect password prints a new password
prompt after the error message. That needs to be fixed in the future.
This commit is contained in:
Sylvie Crowe 2024-02-29 11:37:03 -08:00 committed by GitHub
parent 75be66bada
commit 6c115716b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 451 additions and 236 deletions

View File

@ -1,9 +1,7 @@
github.com/aws/aws-sdk-go-v2/service/s3 v1.27.11 h1:3/gm/JTX9bX8CpzTgIlrtYpB3EVBDxyg/GY/QdcIEZw= github.com/aws/aws-sdk-go-v2/service/s3 v1.27.11 h1:3/gm/JTX9bX8CpzTgIlrtYpB3EVBDxyg/GY/QdcIEZw=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-github/v57 v57.0.0 h1:L+Y3UPTY8ALM8x+TV0lg+IEBI+upibemtBD8Q9u7zHs=
github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4=
github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
github.com/wavetermdev/ssh_config v0.0.0-20240109090616-36c8da3d7376 h1:tFhJgTu7lgd+hldLfPSzDCoWUpXI8wHKR3rxq5jTLkQ=
github.com/wavetermdev/ssh_config v0.0.0-20240109090616-36c8da3d7376/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@ -652,7 +652,7 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err)) m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
return return
} }
cproc, _, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd}) cproc, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd})
if err != nil { if err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err)) m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err))
return return

View File

@ -57,9 +57,28 @@ func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser,
} }
func (cw CmdWrap) Start() error { func (cw CmdWrap) Start() error {
defer func() {
for _, extraFile := range cw.Cmd.ExtraFiles {
if extraFile != nil {
extraFile.Close()
}
}
}()
return cw.Cmd.Start() return cw.Cmd.Start()
} }
func (cw CmdWrap) StdinPipe() (io.WriteCloser, error) {
return cw.Cmd.StdinPipe()
}
func (cw CmdWrap) StdoutPipe() (io.ReadCloser, error) {
return cw.Cmd.StdoutPipe()
}
func (cw CmdWrap) StderrPipe() (io.ReadCloser, error) {
return cw.Cmd.StderrPipe()
}
type SessionWrap struct { type SessionWrap struct {
Session *ssh.Session Session *ssh.Session
StartCmd string StartCmd string
@ -101,12 +120,35 @@ func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadClos
return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil
} }
func (sw SessionWrap) StdinPipe() (io.WriteCloser, error) {
return sw.Session.StdinPipe()
}
func (sw SessionWrap) StdoutPipe() (io.ReadCloser, error) {
stdoutReader, err := sw.Session.StdoutPipe()
if err != nil {
return nil, err
}
return io.NopCloser(stdoutReader), nil
}
func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) {
stderrReader, err := sw.Session.StderrPipe()
if err != nil {
return nil, err
}
return io.NopCloser(stderrReader), nil
}
type ConnInterface interface { type ConnInterface interface {
Kill() Kill()
Wait() error Wait() error
Sender() (*packet.PacketSender, io.WriteCloser, error) Sender() (*packet.PacketSender, io.WriteCloser, error)
Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error)
Start() error Start() error
StdinPipe() (io.WriteCloser, error)
StdoutPipe() (io.ReadCloser, error)
StderrPipe() (io.ReadCloser, error)
} }
type ClientProc struct { type ClientProc struct {
@ -120,20 +162,44 @@ type ClientProc struct {
Output *packet.PacketParser Output *packet.PacketParser
} }
type WaveshellLaunchError struct {
InitPk *packet.InitPacketType
}
func (wle WaveshellLaunchError) Error() string {
if wle.InitPk.NotFound {
return "waveshell client not found"
} else if semver.MajorMinor(wle.InitPk.Version) != semver.MajorMinor(base.MShellVersion) {
return fmt.Sprintf("invalid remote waveshell version '%s', must be '=%s'", wle.InitPk.Version, semver.MajorMinor(base.MShellVersion))
}
return fmt.Sprintf("invalid waveshell: init packet=%v", *wle.InitPk)
}
type InvalidPacketError struct {
InvalidPk *packet.PacketType
}
func (ipe InvalidPacketError) Error() string {
if ipe.InvalidPk == nil {
return "no init packet received from waveshell client"
}
return fmt.Sprintf("invalid packet received from waveshell client: %s", packet.AsString(*ipe.InvalidPk))
}
// returns (clientproc, initpk, error) // returns (clientproc, initpk, error)
func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *packet.InitPacketType, error) { func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, error) {
startTs := time.Now() startTs := time.Now()
sender, inputWriter, err := ecmd.Sender() sender, inputWriter, err := ecmd.Sender()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
packetParser, stdoutReader, stderrReader, err := ecmd.Parser() packetParser, stdoutReader, stderrReader, err := ecmd.Parser()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
err = ecmd.Start() err = ecmd.Start()
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("running local client: %w", err) return nil, fmt.Errorf("running local client: %w", err)
} }
cproc := &ClientProc{ cproc := &ClientProc{
Cmd: ecmd, Cmd: ecmd,
@ -150,29 +216,27 @@ func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *pack
case pk = <-packetParser.MainCh: case pk = <-packetParser.MainCh:
case <-ctx.Done(): case <-ctx.Done():
cproc.Close() cproc.Close()
return nil, nil, ctx.Err() return nil, ctx.Err()
}
if pk == nil {
cproc.Close()
return nil, InvalidPacketError{}
} }
if pk != nil {
if pk.GetType() != packet.InitPacketStr { if pk.GetType() != packet.InitPacketStr {
cproc.Close() cproc.Close()
return nil, nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk)) return nil, InvalidPacketError{InvalidPk: &pk}
} }
initPk := pk.(*packet.InitPacketType) initPk := pk.(*packet.InitPacketType)
if initPk.NotFound { if initPk.NotFound {
cproc.Close() cproc.Close()
return nil, initPk, fmt.Errorf("mshell client not found") return nil, WaveshellLaunchError{InitPk: initPk}
} }
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) { if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
cproc.Close() cproc.Close()
return nil, initPk, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion)) return nil, WaveshellLaunchError{InitPk: initPk}
} }
cproc.InitPk = initPk cproc.InitPk = initPk
} return cproc, nil
if cproc.InitPk == nil {
cproc.Close()
return nil, nil, fmt.Errorf("no init packet received from mshell client")
}
return cproc, cproc.InitPk, nil
} }
func (cproc *ClientProc) Close() { func (cproc *ClientProc) Close() {

View File

@ -437,6 +437,16 @@ func MakeMShellSingleCmd() (*exec.Cmd, error) {
return ecmd, nil return ecmd, nil
} }
func MakeLocalExecCmd(cmdStr string, sapi shellapi.ShellApi) *exec.Cmd {
homeDir, _ := os.UserHomeDir() // ignore error
if homeDir == "" {
homeDir = "/"
}
ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", cmdStr)
ecmd.Dir = homeDir
return ecmd
}
func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string, sapi shellapi.ShellApi) *exec.Cmd { func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string, sapi shellapi.ShellApi) *exec.Cmd {
remoteCommand = strings.TrimSpace(remoteCommand) remoteCommand = strings.TrimSpace(remoteCommand)
if opts.SSHHost == "" { if opts.SSHHost == "" {
@ -587,7 +597,7 @@ func sendMShellBinary(input io.WriteCloser, mshellStream io.Reader) {
}() }()
} }
func RunInstallFromCmd(ctx context.Context, ecmd *exec.Cmd, tryDetect bool, mshellStream io.Reader, mshellReaderFn MShellBinaryReaderFn, msgFn func(string)) error { func RunInstallFromCmd(ctx context.Context, ecmd ConnInterface, tryDetect bool, mshellStream io.Reader, mshellReaderFn MShellBinaryReaderFn, msgFn func(string)) error {
inputWriter, err := ecmd.StdinPipe() inputWriter, err := ecmd.StdinPipe()
if err != nil { if err != nil {
return fmt.Errorf("creating stdin pipe: %v", err) return fmt.Errorf("creating stdin pipe: %v", err)

View File

@ -443,7 +443,7 @@ func parseFullRemoteRef(fullRemoteRef string) (string, string, string, error) {
if strings.HasPrefix(fullRemoteRef, "[") && strings.HasSuffix(fullRemoteRef, "]") { if strings.HasPrefix(fullRemoteRef, "[") && strings.HasSuffix(fullRemoteRef, "]") {
fullRemoteRef = fullRemoteRef[1 : len(fullRemoteRef)-1] fullRemoteRef = fullRemoteRef[1 : len(fullRemoteRef)-1]
} }
fields := strings.Split(fullRemoteRef, ":") fields := strings.Split(fullRemoteRef, "#")
if len(fields) > 3 { if len(fields) > 3 {
return "", "", "", fmt.Errorf("invalid remote format '%s'", fullRemoteRef) return "", "", "", fmt.Errorf("invalid remote format '%s'", fullRemoteRef)
} }

View File

@ -37,6 +37,7 @@ import (
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus" "github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket" "github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore" "github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/userinput"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
@ -83,12 +84,6 @@ else
fi fi
` `
const WaveshellServerRunOnlyFmt = `
PATH=$PATH:~/.mshell;
[%PINGPACKET%]
mshell-[%VERSION%] --server
`
func MakeLocalMShellCommandStr(isSudo bool) (string, error) { func MakeLocalMShellCommandStr(isSudo bool) (string, error) {
mshellPath, err := scbase.LocalMShellBinaryPath() mshellPath, err := scbase.LocalMShellBinaryPath()
if err != nil { if err != nil {
@ -107,13 +102,6 @@ func MakeServerCommandStr() string {
return rtn return rtn
} }
func MakeServerRunOnlyCommandStr() string {
rtn := strings.ReplaceAll(WaveshellServerRunOnlyFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion))
rtn = strings.ReplaceAll(rtn, "[%PINGPACKET%]", PrintPingPacket)
return rtn
}
const ( const (
StatusConnected = sstore.RemoteStatus_Connected StatusConnected = sstore.RemoteStatus_Connected
StatusConnecting = sstore.RemoteStatus_Connecting StatusConnecting = sstore.RemoteStatus_Connecting
@ -175,6 +163,7 @@ type MShellProc struct {
RunningCmds map[base.CommandKey]RunCmdType RunningCmds map[base.CommandKey]RunCmdType
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name]
launcher Launcher // for conditional launch method based on ssh library in use. remove once ssh library is stabilized launcher Launcher // for conditional launch method based on ssh library in use. remove once ssh library is stabilized
Client *ssh.Client
} }
type RunCmdType struct { type RunCmdType struct {
@ -602,7 +591,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
if msh.Status == StatusConnecting { if msh.Status == StatusConnecting {
state.WaitingForPassword = msh.isWaitingForPassword_nolock() state.WaitingForPassword = msh.isWaitingForPassword_nolock()
if msh.MakeClientDeadline != nil { if msh.MakeClientDeadline != nil {
state.ConnectTimeout = int((*msh.MakeClientDeadline).Sub(time.Now()) / time.Second) state.ConnectTimeout = int(time.Until(*msh.MakeClientDeadline) / time.Second)
if state.ConnectTimeout < 0 { if state.ConnectTimeout < 0 {
state.ConnectTimeout = 0 state.ConnectTimeout = 0
} }
@ -734,7 +723,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
func SendRemoteInput(pk *scpacket.RemoteInputPacketType) error { func SendRemoteInput(pk *scpacket.RemoteInputPacketType) error {
data, err := base64.StdEncoding.DecodeString(pk.InputData64) data, err := base64.StdEncoding.DecodeString(pk.InputData64)
if err != nil { if err != nil {
return fmt.Errorf("cannot decode base64: %v\n", err) return fmt.Errorf("cannot decode base64: %v", err)
} }
msh := GetRemoteById(pk.RemoteId) msh := GetRemoteById(pk.RemoteId)
if msh == nil { if msh == nil {
@ -892,6 +881,7 @@ func (msh *MShellProc) Disconnect(force bool) {
defer msh.Lock.Unlock() defer msh.Lock.Unlock()
if msh.ServerProc != nil { if msh.ServerProc != nil {
msh.ServerProc.Close() msh.ServerProc.Close()
msh.Client = nil
} }
if msh.MakeClientCancelFn != nil { if msh.MakeClientCancelFn != nil {
msh.MakeClientCancelFn() msh.MakeClientCancelFn()
@ -989,6 +979,45 @@ func (msh *MShellProc) isWaitingForPassphrase_nolock() bool {
return pwIdx != -1 return pwIdx != -1
} }
func (msh *MShellProc) RunPasswordReadLoop(cmdPty *os.File) {
buf := make([]byte, PtyReadBufSize)
for {
_, readErr := cmdPty.Read(buf)
if readErr == io.EOF {
return
}
if readErr != nil {
msh.WriteToPtyBuffer("*error reading from controlling-pty: %v\n", readErr)
return
}
var newIsWaiting bool
msh.WithLock(func() {
newIsWaiting = msh.isWaitingForPassword_nolock()
})
if newIsWaiting {
break
}
}
request := &userinput.UserInputRequestType{
QueryText: "Please enter your password",
ResponseType: "text",
Title: "Sudo Password",
Markdown: false,
}
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil {
msh.WriteToPtyBuffer("*error timed out waiting for password: %v\n", err)
return
}
msh.WithLock(func() {
curOffset := msh.PtyBuffer.TotalWritten()
msh.PtyBuffer.Write([]byte(response.Text))
sendRemotePtyUpdate(msh.Remote.RemoteId, curOffset, []byte(response.Text))
})
}
func (msh *MShellProc) RunPtyReadLoop(cmdPty *os.File) { func (msh *MShellProc) RunPtyReadLoop(cmdPty *os.File) {
buf := make([]byte, PtyReadBufSize) buf := make([]byte, PtyReadBufSize)
var isWaiting bool var isWaiting bool
@ -1015,6 +1044,141 @@ func (msh *MShellProc) RunPtyReadLoop(cmdPty *os.File) {
} }
} }
func (msh *MShellProc) CheckPasswordRequested(ctx context.Context, requiresPassword chan bool) {
for {
msh.WithLock(func() {
if msh.isWaitingForPassword_nolock() {
select {
case requiresPassword <- true:
default:
}
return
}
if msh.Status != StatusConnecting {
select {
case requiresPassword <- false:
default:
}
return
}
})
select {
case <-ctx.Done():
return
default:
}
time.Sleep(100 * time.Millisecond)
}
}
func (msh *MShellProc) SendPassword(pw string) {
msh.WithLock(func() {
if msh.ControllingPty == nil {
return
}
pwBytes := []byte(pw + "\r")
msh.writeToPtyBuffer_nolock("~[sent password]\r\n")
_, err := msh.ControllingPty.Write(pwBytes)
if err != nil {
msh.writeToPtyBuffer_nolock("*cannot write password to controlling pty: %v\n", err)
}
})
}
func (msh *MShellProc) WaitAndSendPasswordNew(pw string) {
requiresPassword := make(chan bool, 1)
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
if pw != "" {
// do an extra check with the saved password if it is provided
go msh.CheckPasswordRequested(ctx, requiresPassword)
select {
case <-ctx.Done():
err := ctx.Err()
var errMsg error
if err == context.Canceled {
errMsg = fmt.Errorf("canceled by the user: %v", err)
} else {
errMsg = fmt.Errorf("timed out waiting for password prompt: %v", err)
}
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
return
case required := <-requiresPassword:
if !required {
// we don't need user input in this case, so we exit early
return
}
}
msh.SendPassword(pw)
}
// ask for user input once
go msh.CheckPasswordRequested(ctx, requiresPassword)
select {
case <-ctx.Done():
err := ctx.Err()
var errMsg error
if err == context.Canceled {
errMsg = fmt.Errorf("canceled by the user")
} else {
errMsg = fmt.Errorf("timed out waiting for password prompt")
}
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
return
case required := <-requiresPassword:
if !required {
// we don't need user input in this case, so we exit early
return
}
}
request := &userinput.UserInputRequestType{
QueryText: "Please enter your password",
ResponseType: "text",
Title: "Sudo Password",
Markdown: false,
}
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil {
var errMsg error
if err == context.Canceled {
errMsg = fmt.Errorf("canceled by the user")
} else {
errMsg = fmt.Errorf("timed out waiting for user input")
}
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
return
}
msh.SendPassword(response.Text)
//error out if requested again
go msh.CheckPasswordRequested(ctx, requiresPassword)
select {
case <-ctx.Done():
err := ctx.Err()
var errMsg error
if err == context.Canceled {
errMsg = fmt.Errorf("canceled by the user")
} else {
errMsg = fmt.Errorf("timed out waiting for password prompt")
}
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
return
case required := <-requiresPassword:
if !required {
// we don't need user input in this case, so we exit early
return
}
}
errMsg := fmt.Errorf("*error, incorrect password")
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
}
func (msh *MShellProc) WaitAndSendPassword(pw string) { func (msh *MShellProc) WaitAndSendPassword(pw string) {
var numWaits int var numWaits int
for { for {
@ -1073,29 +1237,34 @@ func (msh *MShellProc) RunInstall() {
msh.WriteToPtyBuffer("*error: cannot install on remote that is already trying to install, cancel current install to try again\n") msh.WriteToPtyBuffer("*error: cannot install on remote that is already trying to install, cancel current install to try again\n")
return return
} }
sapi, err := shellapi.MakeShellApi(packet.ShellType_bash) if remoteCopy.Local {
msh.WriteToPtyBuffer("*error: cannot install on a local remote\n")
return
}
_, err := shellapi.MakeShellApi(packet.ShellType_bash)
if err != nil { if err != nil {
msh.WriteToPtyBuffer("*error: %v\n", err) msh.WriteToPtyBuffer("*error: %v\n", err)
return return
} }
msh.WriteToPtyBuffer("installing mshell %s to %s...\n", scbase.MShellVersion, remoteCopy.RemoteCanonicalName) if msh.Client == nil {
sshOpts := convertSSHOpts(remoteCopy.SSHOpts) client, err := ConnectToClient(remoteCopy.SSHOpts)
sshOpts.SSHErrorsToTty = true
cmdStr := shexec.MakeInstallCommandStr()
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
cmdPty, err := msh.addControllingTty(ecmd)
if err != nil { if err != nil {
statusErr := fmt.Errorf("cannot attach controlling tty to mshell install command: %w", err) statusErr := fmt.Errorf("ssh cannot connect to client: %w", err)
msh.setInstallErrorStatus(statusErr) msh.setInstallErrorStatus(statusErr)
return return
} }
defer func() { msh.WithLock(func() {
if len(ecmd.ExtraFiles) > 0 { msh.Client = client
ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close() })
} }
cmdPty.Close() session, err := msh.Client.NewSession()
}() if err != nil {
go msh.RunPtyReadLoop(cmdPty) statusErr := fmt.Errorf("ssh cannot connect to client: %w", err)
msh.setInstallErrorStatus(statusErr)
return
}
installSession := shexec.SessionWrap{Session: session, StartCmd: shexec.MakeInstallCommandStr()}
msh.WriteToPtyBuffer("installing waveshell %s to %s...\n", scbase.MShellVersion, remoteCopy.RemoteCanonicalName)
clientCtx, clientCancelFn := context.WithCancel(context.Background()) clientCtx, clientCancelFn := context.WithCancel(context.Background())
defer clientCancelFn() defer clientCancelFn()
msh.WithLock(func() { msh.WithLock(func() {
@ -1107,7 +1276,7 @@ func (msh *MShellProc) RunInstall() {
msgFn := func(msg string) { msgFn := func(msg string) {
msh.WriteToPtyBuffer("%s", msg) msh.WriteToPtyBuffer("%s", msg)
} }
err = shexec.RunInstallFromCmd(clientCtx, ecmd, true, nil, scbase.MShellBinaryReader, msgFn) err = shexec.RunInstallFromCmd(clientCtx, installSession, true, nil, scbase.MShellBinaryReader, msgFn)
if err == context.Canceled { if err == context.Canceled {
msh.WriteToPtyBuffer("*install canceled\n") msh.WriteToPtyBuffer("*install canceled\n")
msh.WithLock(func() { msh.WithLock(func() {
@ -1130,13 +1299,12 @@ func (msh *MShellProc) RunInstall() {
msh.Err = nil msh.Err = nil
connectMode = msh.Remote.ConnectMode connectMode = msh.Remote.ConnectMode
}) })
msh.WriteToPtyBuffer("successfully installed mshell %s to ~/.mshell\n", scbase.MShellVersion) msh.WriteToPtyBuffer("successfully installed waveshell %s to ~/.mshell\n", scbase.MShellVersion)
go msh.NotifyRemoteUpdate() go msh.NotifyRemoteUpdate()
if connectMode == sstore.ConnectModeStartup || connectMode == sstore.ConnectModeAuto { if connectMode == sstore.ConnectModeStartup || connectMode == sstore.ConnectModeAuto {
// the install was successful, and we don't have a manual connect mode, try to connect // the install was successful, and we don't have a manual connect mode, try to connect
go msh.Launch(true) go msh.Launch(true)
} }
return
} }
func (msh *MShellProc) updateRemoteStateVars(ctx context.Context, remoteId string, initPk *packet.InitPacketType) { func (msh *MShellProc) updateRemoteStateVars(ctx context.Context, remoteId string, initPk *packet.InitPacketType) {
@ -1286,6 +1454,56 @@ func (msh *MShellProc) getActiveShellTypes(ctx context.Context) ([]string, error
return utilfn.CombineStrArrays(rtn, activeShells), nil return utilfn.CombineStrArrays(rtn, activeShells), nil
} }
func (msh *MShellProc) createWaveshellSession(remoteCopy sstore.RemoteType) (shexec.ConnInterface, error) {
msh.WithLock(func() {
msh.Err = nil
msh.ErrNoInitPk = false
msh.Status = StatusConnecting
msh.MakeClientDeadline = nil
go msh.NotifyRemoteUpdate()
})
sapi, err := shellapi.MakeShellApi(msh.GetShellType())
if err != nil {
return nil, err
}
var wsSession shexec.ConnInterface
if remoteCopy.SSHOpts.SSHHost == "" && remoteCopy.Local {
cmdStr, err := MakeLocalMShellCommandStr(remoteCopy.IsSudo())
if err != nil {
return nil, fmt.Errorf("cannot find local mshell binary: %v", err)
}
ecmd := shexec.MakeLocalExecCmd(cmdStr, sapi)
var cmdPty *os.File
cmdPty, err = msh.addControllingTty(ecmd)
if err != nil {
return nil, fmt.Errorf("cannot attach controlling tty to mshell command: %v", err)
}
go msh.RunPtyReadLoop(cmdPty)
go msh.WaitAndSendPasswordNew(remoteCopy.SSHOpts.SSHPassword)
wsSession = shexec.CmdWrap{Cmd: ecmd}
} else if msh.Client == nil {
client, err := ConnectToClient(remoteCopy.SSHOpts)
if err != nil {
return nil, fmt.Errorf("ssh cannot connect to client: %w", err)
}
msh.WithLock(func() {
msh.Client = client
})
session, err := client.NewSession()
if err != nil {
return nil, fmt.Errorf("ssh cannot create session: %w", err)
}
wsSession = shexec.SessionWrap{Session: session, StartCmd: MakeServerCommandStr()}
} else {
session, err := msh.Client.NewSession()
if err != nil {
return nil, fmt.Errorf("ssh cannot create session: %w", err)
}
wsSession = shexec.SessionWrap{Session: session, StartCmd: MakeServerCommandStr()}
}
return wsSession, nil
}
// for conditional launch method based on ssh library in use // for conditional launch method based on ssh library in use
// remove once ssh library is stabilized // remove once ssh library is stabilized
type NewLauncher struct{} type NewLauncher struct{}
@ -1306,147 +1524,79 @@ func (NewLauncher) Launch(msh *MShellProc, interactive bool) {
msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n") msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n")
return return
} }
sapi, err := shellapi.MakeShellApi(msh.GetShellType())
if err != nil {
msh.WriteToPtyBuffer("*error, %v\n", err)
return
}
istatus := msh.GetInstallStatus() istatus := msh.GetInstallStatus()
if istatus == StatusConnecting { if istatus == StatusConnecting {
msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n") msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n")
return return
} }
if remoteCopy.SSHOpts.SSHPort != 0 && remoteCopy.SSHOpts.SSHPort != 22 {
msh.WriteToPtyBuffer("connecting to %s (port %d)...\n", remoteCopy.RemoteCanonicalName, remoteCopy.SSHOpts.SSHPort)
} else {
msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName) msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName)
} wsSession, err := msh.createWaveshellSession(remoteCopy)
sshOpts := convertSSHOpts(remoteCopy.SSHOpts) if err != nil {
sshOpts.SSHErrorsToTty = true msh.WriteToPtyBuffer("*error, %s\n", err.Error())
if remoteCopy.ConnectMode != sstore.ConnectModeManual && remoteCopy.SSHOpts.SSHPassword == "" && !interactive { msh.setErrorStatus(err)
sshOpts.BatchMode = true
}
var cproc *shexec.ClientProc
var initPk *packet.InitPacketType
if sshOpts.SSHHost == "" && remoteCopy.Local {
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background())
defer makeClientCancelFn()
msh.WithLock(func() { msh.WithLock(func() {
msh.Err = nil msh.Client = nil
msh.ErrNoInitPk = false
msh.Status = StatusConnecting
msh.MakeClientCancelFn = makeClientCancelFn
deadlineTime := time.Now().Add(RemoteConnectTimeout)
msh.MakeClientDeadline = &deadlineTime
go msh.NotifyRemoteUpdate()
}) })
go msh.watchClientDeadlineTime()
cmdStr, err := MakeLocalMShellCommandStr(remoteCopy.IsSudo())
if err != nil {
msh.WriteToPtyBuffer("*error, cannot find local mshell binary: %v\n", err)
return return
} }
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi) var makeClientCtx context.Context
var cmdPty *os.File var makeClientCancelFn context.CancelFunc
cmdPty, err = msh.addControllingTty(ecmd)
if err != nil {
statusErr := fmt.Errorf("cannot attach controlling tty to mshell command: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
defer func() {
if len(ecmd.ExtraFiles) > 0 {
ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close()
}
}()
go msh.RunPtyReadLoop(cmdPty)
if remoteCopy.SSHOpts.SSHPassword != "" {
go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword)
}
cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
} else {
msh.WithLock(func() { msh.WithLock(func() {
msh.Err = nil makeClientCtx, makeClientCancelFn = context.WithCancel(context.Background())
msh.ErrNoInitPk = false msh.MakeClientCancelFn = makeClientCancelFn
msh.Status = StatusConnecting
msh.MakeClientDeadline = nil msh.MakeClientDeadline = nil
go msh.NotifyRemoteUpdate() go msh.NotifyRemoteUpdate()
}) })
var client *ssh.Client
client, err = ConnectToClient(remoteCopy.SSHOpts)
if err != nil {
statusErr := fmt.Errorf("ssh cannot connect to client: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
var session *ssh.Session
session, err = client.NewSession()
if err != nil {
statusErr := fmt.Errorf("ssh cannot create session: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background())
defer makeClientCancelFn() defer makeClientCancelFn()
msh.WithLock(func() { cproc, err := shexec.MakeClientProc(makeClientCtx, wsSession)
msh.MakeClientCancelFn = makeClientCancelFn
deadlineTime := time.Now().Add(RemoteConnectTimeout)
msh.MakeClientDeadline = &deadlineTime
go msh.NotifyRemoteUpdate()
})
go msh.watchClientDeadlineTime()
cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.SessionWrap{Session: session, StartCmd: MakeServerRunOnlyCommandStr()})
}
// TODO check if initPk.State is not nil
var mshellVersion string
var hitDeadline bool
msh.WithLock(func() { msh.WithLock(func() {
msh.MakeClientCancelFn = nil msh.MakeClientCancelFn = nil
if time.Now().After(*msh.MakeClientDeadline) {
hitDeadline = true
}
msh.MakeClientDeadline = nil msh.MakeClientDeadline = nil
if initPk == nil {
msh.ErrNoInitPk = true
}
if initPk != nil {
msh.UName = initPk.UName
mshellVersion = initPk.Version
if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 {
// only set NeedsMShellUpgrade if we got an InitPk
msh.NeedsMShellUpgrade = true
}
msh.InitPkShellType = initPk.Shell
}
msh.StateMap.Clear()
// no notify here, because we'll call notify in either case below
}) })
if err == context.Canceled { if err == context.DeadlineExceeded {
if hitDeadline {
msh.WriteToPtyBuffer("*connect timeout\n") msh.WriteToPtyBuffer("*connect timeout\n")
msh.setErrorStatus(errors.New("connect timeout")) msh.setErrorStatus(errors.New("connect timeout"))
} else { msh.WithLock(func() {
msh.Client = nil
})
return
} else if err == context.Canceled {
msh.WriteToPtyBuffer("*forced disconnection\n") msh.WriteToPtyBuffer("*forced disconnection\n")
msh.WithLock(func() { msh.WithLock(func() {
msh.Status = StatusDisconnected msh.Status = StatusDisconnected
go msh.NotifyRemoteUpdate() go msh.NotifyRemoteUpdate()
}) })
} msh.WithLock(func() {
msh.Client = nil
})
return return
} } else if serr, ok := err.(shexec.WaveshellLaunchError); ok {
if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) { msh.WithLock(func() {
err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion) msh.UName = serr.InitPk.UName
} msh.NeedsMShellUpgrade = true
if err != nil { msh.InitPkShellType = serr.InitPk.Shell
msh.setErrorStatus(err) })
msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err) msh.StateMap.Clear()
msh.WriteToPtyBuffer("*error, %s\n", serr.Error())
msh.setErrorStatus(serr)
go msh.tryAutoInstall() go msh.tryAutoInstall()
return return
} else if err != nil {
msh.WriteToPtyBuffer("*error, %s\n", serr.Error())
msh.setErrorStatus(err)
msh.WithLock(func() {
msh.Client = nil
})
return
} }
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk) msh.WithLock(func() {
msh.UName = cproc.InitPk.UName
msh.InitPkShellType = cproc.InitPk.Shell
msh.StateMap.Clear()
// no notify here, because we'll call notify in either case below
})
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, cproc.InitPk)
msh.WithLock(func() { msh.WithLock(func() {
msh.ServerProc = cproc msh.ServerProc = cproc
msh.Status = StatusConnected msh.Status = StatusConnected
@ -1465,7 +1615,6 @@ func (NewLauncher) Launch(msh *MShellProc, interactive bool) {
go msh.ProcessPackets() go msh.ProcessPackets()
msh.initActiveShells() msh.initActiveShells()
go msh.NotifyRemoteUpdate() go msh.NotifyRemoteUpdate()
return
} }
// for conditional launch method based on ssh library in use // for conditional launch method based on ssh library in use
@ -1536,66 +1685,58 @@ func (LegacyLauncher) Launch(msh *MShellProc, interactive bool) {
if remoteCopy.SSHOpts.SSHPassword != "" { if remoteCopy.SSHOpts.SSHPassword != "" {
go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword) go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword)
} }
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background()) var makeClientCtx context.Context
defer makeClientCancelFn() var makeClientCancelFn context.CancelFunc
msh.WithLock(func() { msh.WithLock(func() {
deadlineTime := time.Now().Add(RemoteConnectTimeout)
makeClientCtx, makeClientCancelFn = context.WithDeadline(context.Background(), deadlineTime)
defer makeClientCancelFn()
msh.Err = nil msh.Err = nil
msh.ErrNoInitPk = false msh.ErrNoInitPk = false
msh.Status = StatusConnecting msh.Status = StatusConnecting
msh.MakeClientCancelFn = makeClientCancelFn msh.MakeClientCancelFn = makeClientCancelFn
deadlineTime := time.Now().Add(RemoteConnectTimeout)
msh.MakeClientDeadline = &deadlineTime msh.MakeClientDeadline = &deadlineTime
go msh.NotifyRemoteUpdate() go msh.NotifyRemoteUpdate()
}) })
go msh.watchClientDeadlineTime() go msh.watchClientDeadlineTime()
cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd}) cproc, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
// TODO check if initPk.State is not nil
var mshellVersion string
var hitDeadline bool
msh.WithLock(func() { msh.WithLock(func() {
msh.MakeClientCancelFn = nil msh.MakeClientCancelFn = nil
if time.Now().After(*msh.MakeClientDeadline) {
hitDeadline = true
}
msh.MakeClientDeadline = nil msh.MakeClientDeadline = nil
if initPk == nil {
msh.ErrNoInitPk = true
}
if initPk != nil {
msh.UName = initPk.UName
mshellVersion = initPk.Version
if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 {
// only set NeedsMShellUpgrade if we got an InitPk
msh.NeedsMShellUpgrade = true
}
msh.InitPkShellType = initPk.Shell
}
msh.StateMap.Clear() msh.StateMap.Clear()
// no notify here, because we'll call notify in either case below // no notify here, because we'll call notify in either case below
}) })
if err == context.Canceled { if err == context.DeadlineExceeded {
if hitDeadline {
msh.WriteToPtyBuffer("*connect timeout\n") msh.WriteToPtyBuffer("*connect timeout\n")
msh.setErrorStatus(errors.New("connect timeout")) msh.setErrorStatus(errors.New("connect timeout"))
} else { return
} else if err == context.Canceled {
msh.WriteToPtyBuffer("*forced disconnection\n") msh.WriteToPtyBuffer("*forced disconnection\n")
msh.WithLock(func() { msh.WithLock(func() {
msh.Status = StatusDisconnected msh.Status = StatusDisconnected
go msh.NotifyRemoteUpdate() go msh.NotifyRemoteUpdate()
}) })
}
return return
} else if serr, ok := err.(shexec.WaveshellLaunchError); ok {
msh.WithLock(func() {
msh.UName = serr.InitPk.UName
if semver.Compare(serr.InitPk.Version, scbase.MShellVersion) < 0 {
// only set NeedsMShellUpgrade if we got an InitPk
msh.NeedsMShellUpgrade = true
} }
if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) { msh.InitPkShellType = serr.InitPk.Shell
err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion) })
} msh.WriteToPtyBuffer("*error, %s\n", serr.Error())
if err != nil { msh.setErrorStatus(serr)
msh.setErrorStatus(err)
msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err)
go msh.tryAutoInstall() go msh.tryAutoInstall()
return return
} else if err != nil {
msh.WriteToPtyBuffer("*error, %s\n", serr.Error())
msh.setErrorStatus(err)
return
} }
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk)
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, cproc.InitPk)
msh.WithLock(func() { msh.WithLock(func() {
msh.ServerProc = cproc msh.ServerProc = cproc
msh.Status = StatusConnected msh.Status = StatusConnected
@ -1614,7 +1755,6 @@ func (LegacyLauncher) Launch(msh *MShellProc, interactive bool) {
go msh.ProcessPackets() go msh.ProcessPackets()
msh.initActiveShells() msh.initActiveShells()
go msh.NotifyRemoteUpdate() go msh.NotifyRemoteUpdate()
return
} }
func (msh *MShellProc) initActiveShells() { func (msh *MShellProc) initActiveShells() {
@ -2099,7 +2239,6 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
} }
} }
scbus.MainUpdateBus.DoUpdate(update) scbus.MainUpdateBus.DoUpdate(update)
return
} }
func (msh *MShellProc) handleCmdFinalPacket(finalPk *packet.CmdFinalPacketType) { func (msh *MShellProc) handleCmdFinalPacket(finalPk *packet.CmdFinalPacketType) {
@ -2143,7 +2282,6 @@ func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
msh.WriteToPtyBuffer("cmderr> [remote %s] [error] adding cmderr: %v\n", msh.GetRemoteName(), err) msh.WriteToPtyBuffer("cmderr> [remote %s] [error] adding cmderr: %v\n", msh.GetRemoteName(), err)
return return
} }
return
} }
func (msh *MShellProc) ResetDataPos(ck base.CommandKey) { func (msh *MShellProc) ResetDataPos(ck base.CommandKey) {

View File

@ -110,7 +110,8 @@ func createPublicKeyCallback(sshKeywords *SshKeywords, passphrase string) func()
QueryText: fmt.Sprintf("Enter passphrase for the SSH key: %s", identityFile), QueryText: fmt.Sprintf("Enter passphrase for the SSH key: %s", identityFile),
Title: "Publickey Auth + Passphrase", Title: "Publickey Auth + Passphrase",
} }
ctx, _ := context.WithTimeout(context.Background(), 60*time.Second) ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request) response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil { if err != nil {
// this is an error where we actually do want to stop // this is an error where we actually do want to stop
@ -267,6 +268,10 @@ func writeToKnownHosts(knownHostsFile string, newLine string, getUserVerificatio
} }
_, err = f.WriteString(newLine) _, err = f.WriteString(newLine)
if err != nil {
f.Close()
return err
}
return f.Close() return f.Close()
} }