Integrate SSH Library with Waveshell Installation/Auto-update (#322)

* refactor launch code to integrate install easier

The previous set up of launch was difficult to navigate. This makes it
much clearer which will make the auto install flow easier to manage.

* feat: integrate auto install into new ssh setup

This change makes it possible to auto install using the ssh library
instead of making a call to the ssh cli command. This will auto install
if the installed waveshell version is incorrect or cannot be found.

* chore: clean up some lints for sshclient

There was a context that didn't have it's cancel function deferred and
an error that wasn't being handle. They're fixed now.

* fix: disconnect client if requested or launch fail

A recent commit made it so a client remained part of the MShellProc
after being disconnected. This is undesireable since a manual
disconnection indicates that the user will need to enter their
credentials again if required. Similarly, if the launch fails with an
error, the expectation is that credentials will have to be entered
again.

* fix: use legacy timer for the time being

The legacy timer frustrates me because it adds a lot of state to the
MShellProc struct that is complicated to manage. But it currently works,
so I will be keeping it for the time being.

* fix: change separator between remoteref and name

With the inclusion of the port number in the canonical id, the :
separator between the remoteref and remote name causes problems if the
port is parsed instead. This changes it to a # in order to avoid this
conflict.

* fix: check for null when closing extra files

It is possible for the list of extra files to contain null files. This
change ensures the null files will not be erroneously closed.

* fix: change connecting method to show port once

With port added to the canonicalname, it no longer makes sense to append
the port afterward.

* feat: use user input modal for sudo connection

The sudo connection used to have a unique way of entering a password.
This change provides an alternative method using the user input modal
that the other connection methods use. It does not work perfectly with
this revision, but the basic building blocks are in place. It needs a
few timer updates to be complete.

* fix: remove old timer to prevent conflicts with it

With this change the old timer is no longer needed. It is not fully
removed yet, but it is disabled so as to not get in the way.
Additionally, error handling has been slightly improved.

There is still a bug where an incorrect password prints a new password
prompt after the error message. That needs to be fixed in the future.
This commit is contained in:
Sylvie Crowe 2024-02-29 11:37:03 -08:00 committed by GitHub
parent 75be66bada
commit 6c115716b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 451 additions and 236 deletions

View File

@ -1,9 +1,7 @@
github.com/aws/aws-sdk-go-v2/service/s3 v1.27.11 h1:3/gm/JTX9bX8CpzTgIlrtYpB3EVBDxyg/GY/QdcIEZw=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-github/v57 v57.0.0 h1:L+Y3UPTY8ALM8x+TV0lg+IEBI+upibemtBD8Q9u7zHs=
github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4=
github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
github.com/wavetermdev/ssh_config v0.0.0-20240109090616-36c8da3d7376 h1:tFhJgTu7lgd+hldLfPSzDCoWUpXI8wHKR3rxq5jTLkQ=
github.com/wavetermdev/ssh_config v0.0.0-20240109090616-36c8da3d7376/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@ -652,7 +652,7 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
return
}
cproc, _, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd})
cproc, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd})
if err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err))
return

View File

@ -57,9 +57,28 @@ func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser,
}
func (cw CmdWrap) Start() error {
defer func() {
for _, extraFile := range cw.Cmd.ExtraFiles {
if extraFile != nil {
extraFile.Close()
}
}
}()
return cw.Cmd.Start()
}
func (cw CmdWrap) StdinPipe() (io.WriteCloser, error) {
return cw.Cmd.StdinPipe()
}
func (cw CmdWrap) StdoutPipe() (io.ReadCloser, error) {
return cw.Cmd.StdoutPipe()
}
func (cw CmdWrap) StderrPipe() (io.ReadCloser, error) {
return cw.Cmd.StderrPipe()
}
type SessionWrap struct {
Session *ssh.Session
StartCmd string
@ -101,12 +120,35 @@ func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadClos
return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil
}
func (sw SessionWrap) StdinPipe() (io.WriteCloser, error) {
return sw.Session.StdinPipe()
}
func (sw SessionWrap) StdoutPipe() (io.ReadCloser, error) {
stdoutReader, err := sw.Session.StdoutPipe()
if err != nil {
return nil, err
}
return io.NopCloser(stdoutReader), nil
}
func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) {
stderrReader, err := sw.Session.StderrPipe()
if err != nil {
return nil, err
}
return io.NopCloser(stderrReader), nil
}
type ConnInterface interface {
Kill()
Wait() error
Sender() (*packet.PacketSender, io.WriteCloser, error)
Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error)
Start() error
StdinPipe() (io.WriteCloser, error)
StdoutPipe() (io.ReadCloser, error)
StderrPipe() (io.ReadCloser, error)
}
type ClientProc struct {
@ -120,20 +162,44 @@ type ClientProc struct {
Output *packet.PacketParser
}
type WaveshellLaunchError struct {
InitPk *packet.InitPacketType
}
func (wle WaveshellLaunchError) Error() string {
if wle.InitPk.NotFound {
return "waveshell client not found"
} else if semver.MajorMinor(wle.InitPk.Version) != semver.MajorMinor(base.MShellVersion) {
return fmt.Sprintf("invalid remote waveshell version '%s', must be '=%s'", wle.InitPk.Version, semver.MajorMinor(base.MShellVersion))
}
return fmt.Sprintf("invalid waveshell: init packet=%v", *wle.InitPk)
}
type InvalidPacketError struct {
InvalidPk *packet.PacketType
}
func (ipe InvalidPacketError) Error() string {
if ipe.InvalidPk == nil {
return "no init packet received from waveshell client"
}
return fmt.Sprintf("invalid packet received from waveshell client: %s", packet.AsString(*ipe.InvalidPk))
}
// returns (clientproc, initpk, error)
func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *packet.InitPacketType, error) {
func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, error) {
startTs := time.Now()
sender, inputWriter, err := ecmd.Sender()
if err != nil {
return nil, nil, err
return nil, err
}
packetParser, stdoutReader, stderrReader, err := ecmd.Parser()
if err != nil {
return nil, nil, err
return nil, err
}
err = ecmd.Start()
if err != nil {
return nil, nil, fmt.Errorf("running local client: %w", err)
return nil, fmt.Errorf("running local client: %w", err)
}
cproc := &ClientProc{
Cmd: ecmd,
@ -150,29 +216,27 @@ func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *pack
case pk = <-packetParser.MainCh:
case <-ctx.Done():
cproc.Close()
return nil, nil, ctx.Err()
return nil, ctx.Err()
}
if pk != nil {
if pk.GetType() != packet.InitPacketStr {
cproc.Close()
return nil, nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk))
}
initPk := pk.(*packet.InitPacketType)
if initPk.NotFound {
cproc.Close()
return nil, initPk, fmt.Errorf("mshell client not found")
}
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
cproc.Close()
return nil, initPk, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
}
cproc.InitPk = initPk
}
if cproc.InitPk == nil {
if pk == nil {
cproc.Close()
return nil, nil, fmt.Errorf("no init packet received from mshell client")
return nil, InvalidPacketError{}
}
return cproc, cproc.InitPk, nil
if pk.GetType() != packet.InitPacketStr {
cproc.Close()
return nil, InvalidPacketError{InvalidPk: &pk}
}
initPk := pk.(*packet.InitPacketType)
if initPk.NotFound {
cproc.Close()
return nil, WaveshellLaunchError{InitPk: initPk}
}
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
cproc.Close()
return nil, WaveshellLaunchError{InitPk: initPk}
}
cproc.InitPk = initPk
return cproc, nil
}
func (cproc *ClientProc) Close() {

View File

@ -437,6 +437,16 @@ func MakeMShellSingleCmd() (*exec.Cmd, error) {
return ecmd, nil
}
func MakeLocalExecCmd(cmdStr string, sapi shellapi.ShellApi) *exec.Cmd {
homeDir, _ := os.UserHomeDir() // ignore error
if homeDir == "" {
homeDir = "/"
}
ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", cmdStr)
ecmd.Dir = homeDir
return ecmd
}
func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string, sapi shellapi.ShellApi) *exec.Cmd {
remoteCommand = strings.TrimSpace(remoteCommand)
if opts.SSHHost == "" {
@ -587,7 +597,7 @@ func sendMShellBinary(input io.WriteCloser, mshellStream io.Reader) {
}()
}
func RunInstallFromCmd(ctx context.Context, ecmd *exec.Cmd, tryDetect bool, mshellStream io.Reader, mshellReaderFn MShellBinaryReaderFn, msgFn func(string)) error {
func RunInstallFromCmd(ctx context.Context, ecmd ConnInterface, tryDetect bool, mshellStream io.Reader, mshellReaderFn MShellBinaryReaderFn, msgFn func(string)) error {
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return fmt.Errorf("creating stdin pipe: %v", err)

View File

@ -443,7 +443,7 @@ func parseFullRemoteRef(fullRemoteRef string) (string, string, string, error) {
if strings.HasPrefix(fullRemoteRef, "[") && strings.HasSuffix(fullRemoteRef, "]") {
fullRemoteRef = fullRemoteRef[1 : len(fullRemoteRef)-1]
}
fields := strings.Split(fullRemoteRef, ":")
fields := strings.Split(fullRemoteRef, "#")
if len(fields) > 3 {
return "", "", "", fmt.Errorf("invalid remote format '%s'", fullRemoteRef)
}

View File

@ -37,6 +37,7 @@ import (
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/userinput"
"golang.org/x/crypto/ssh"
"golang.org/x/mod/semver"
@ -83,12 +84,6 @@ else
fi
`
const WaveshellServerRunOnlyFmt = `
PATH=$PATH:~/.mshell;
[%PINGPACKET%]
mshell-[%VERSION%] --server
`
func MakeLocalMShellCommandStr(isSudo bool) (string, error) {
mshellPath, err := scbase.LocalMShellBinaryPath()
if err != nil {
@ -107,13 +102,6 @@ func MakeServerCommandStr() string {
return rtn
}
func MakeServerRunOnlyCommandStr() string {
rtn := strings.ReplaceAll(WaveshellServerRunOnlyFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion))
rtn = strings.ReplaceAll(rtn, "[%PINGPACKET%]", PrintPingPacket)
return rtn
}
const (
StatusConnected = sstore.RemoteStatus_Connected
StatusConnecting = sstore.RemoteStatus_Connecting
@ -175,6 +163,7 @@ type MShellProc struct {
RunningCmds map[base.CommandKey]RunCmdType
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name]
launcher Launcher // for conditional launch method based on ssh library in use. remove once ssh library is stabilized
Client *ssh.Client
}
type RunCmdType struct {
@ -602,7 +591,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
if msh.Status == StatusConnecting {
state.WaitingForPassword = msh.isWaitingForPassword_nolock()
if msh.MakeClientDeadline != nil {
state.ConnectTimeout = int((*msh.MakeClientDeadline).Sub(time.Now()) / time.Second)
state.ConnectTimeout = int(time.Until(*msh.MakeClientDeadline) / time.Second)
if state.ConnectTimeout < 0 {
state.ConnectTimeout = 0
}
@ -734,7 +723,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
func SendRemoteInput(pk *scpacket.RemoteInputPacketType) error {
data, err := base64.StdEncoding.DecodeString(pk.InputData64)
if err != nil {
return fmt.Errorf("cannot decode base64: %v\n", err)
return fmt.Errorf("cannot decode base64: %v", err)
}
msh := GetRemoteById(pk.RemoteId)
if msh == nil {
@ -892,6 +881,7 @@ func (msh *MShellProc) Disconnect(force bool) {
defer msh.Lock.Unlock()
if msh.ServerProc != nil {
msh.ServerProc.Close()
msh.Client = nil
}
if msh.MakeClientCancelFn != nil {
msh.MakeClientCancelFn()
@ -989,6 +979,45 @@ func (msh *MShellProc) isWaitingForPassphrase_nolock() bool {
return pwIdx != -1
}
func (msh *MShellProc) RunPasswordReadLoop(cmdPty *os.File) {
buf := make([]byte, PtyReadBufSize)
for {
_, readErr := cmdPty.Read(buf)
if readErr == io.EOF {
return
}
if readErr != nil {
msh.WriteToPtyBuffer("*error reading from controlling-pty: %v\n", readErr)
return
}
var newIsWaiting bool
msh.WithLock(func() {
newIsWaiting = msh.isWaitingForPassword_nolock()
})
if newIsWaiting {
break
}
}
request := &userinput.UserInputRequestType{
QueryText: "Please enter your password",
ResponseType: "text",
Title: "Sudo Password",
Markdown: false,
}
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil {
msh.WriteToPtyBuffer("*error timed out waiting for password: %v\n", err)
return
}
msh.WithLock(func() {
curOffset := msh.PtyBuffer.TotalWritten()
msh.PtyBuffer.Write([]byte(response.Text))
sendRemotePtyUpdate(msh.Remote.RemoteId, curOffset, []byte(response.Text))
})
}
func (msh *MShellProc) RunPtyReadLoop(cmdPty *os.File) {
buf := make([]byte, PtyReadBufSize)
var isWaiting bool
@ -1015,6 +1044,141 @@ func (msh *MShellProc) RunPtyReadLoop(cmdPty *os.File) {
}
}
func (msh *MShellProc) CheckPasswordRequested(ctx context.Context, requiresPassword chan bool) {
for {
msh.WithLock(func() {
if msh.isWaitingForPassword_nolock() {
select {
case requiresPassword <- true:
default:
}
return
}
if msh.Status != StatusConnecting {
select {
case requiresPassword <- false:
default:
}
return
}
})
select {
case <-ctx.Done():
return
default:
}
time.Sleep(100 * time.Millisecond)
}
}
func (msh *MShellProc) SendPassword(pw string) {
msh.WithLock(func() {
if msh.ControllingPty == nil {
return
}
pwBytes := []byte(pw + "\r")
msh.writeToPtyBuffer_nolock("~[sent password]\r\n")
_, err := msh.ControllingPty.Write(pwBytes)
if err != nil {
msh.writeToPtyBuffer_nolock("*cannot write password to controlling pty: %v\n", err)
}
})
}
func (msh *MShellProc) WaitAndSendPasswordNew(pw string) {
requiresPassword := make(chan bool, 1)
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
if pw != "" {
// do an extra check with the saved password if it is provided
go msh.CheckPasswordRequested(ctx, requiresPassword)
select {
case <-ctx.Done():
err := ctx.Err()
var errMsg error
if err == context.Canceled {
errMsg = fmt.Errorf("canceled by the user: %v", err)
} else {
errMsg = fmt.Errorf("timed out waiting for password prompt: %v", err)
}
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
return
case required := <-requiresPassword:
if !required {
// we don't need user input in this case, so we exit early
return
}
}
msh.SendPassword(pw)
}
// ask for user input once
go msh.CheckPasswordRequested(ctx, requiresPassword)
select {
case <-ctx.Done():
err := ctx.Err()
var errMsg error
if err == context.Canceled {
errMsg = fmt.Errorf("canceled by the user")
} else {
errMsg = fmt.Errorf("timed out waiting for password prompt")
}
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
return
case required := <-requiresPassword:
if !required {
// we don't need user input in this case, so we exit early
return
}
}
request := &userinput.UserInputRequestType{
QueryText: "Please enter your password",
ResponseType: "text",
Title: "Sudo Password",
Markdown: false,
}
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil {
var errMsg error
if err == context.Canceled {
errMsg = fmt.Errorf("canceled by the user")
} else {
errMsg = fmt.Errorf("timed out waiting for user input")
}
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
return
}
msh.SendPassword(response.Text)
//error out if requested again
go msh.CheckPasswordRequested(ctx, requiresPassword)
select {
case <-ctx.Done():
err := ctx.Err()
var errMsg error
if err == context.Canceled {
errMsg = fmt.Errorf("canceled by the user")
} else {
errMsg = fmt.Errorf("timed out waiting for password prompt")
}
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
return
case required := <-requiresPassword:
if !required {
// we don't need user input in this case, so we exit early
return
}
}
errMsg := fmt.Errorf("*error, incorrect password")
msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error())
msh.setErrorStatus(errMsg)
}
func (msh *MShellProc) WaitAndSendPassword(pw string) {
var numWaits int
for {
@ -1073,29 +1237,34 @@ func (msh *MShellProc) RunInstall() {
msh.WriteToPtyBuffer("*error: cannot install on remote that is already trying to install, cancel current install to try again\n")
return
}
sapi, err := shellapi.MakeShellApi(packet.ShellType_bash)
if remoteCopy.Local {
msh.WriteToPtyBuffer("*error: cannot install on a local remote\n")
return
}
_, err := shellapi.MakeShellApi(packet.ShellType_bash)
if err != nil {
msh.WriteToPtyBuffer("*error: %v\n", err)
return
}
msh.WriteToPtyBuffer("installing mshell %s to %s...\n", scbase.MShellVersion, remoteCopy.RemoteCanonicalName)
sshOpts := convertSSHOpts(remoteCopy.SSHOpts)
sshOpts.SSHErrorsToTty = true
cmdStr := shexec.MakeInstallCommandStr()
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
cmdPty, err := msh.addControllingTty(ecmd)
if msh.Client == nil {
client, err := ConnectToClient(remoteCopy.SSHOpts)
if err != nil {
statusErr := fmt.Errorf("ssh cannot connect to client: %w", err)
msh.setInstallErrorStatus(statusErr)
return
}
msh.WithLock(func() {
msh.Client = client
})
}
session, err := msh.Client.NewSession()
if err != nil {
statusErr := fmt.Errorf("cannot attach controlling tty to mshell install command: %w", err)
statusErr := fmt.Errorf("ssh cannot connect to client: %w", err)
msh.setInstallErrorStatus(statusErr)
return
}
defer func() {
if len(ecmd.ExtraFiles) > 0 {
ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close()
}
cmdPty.Close()
}()
go msh.RunPtyReadLoop(cmdPty)
installSession := shexec.SessionWrap{Session: session, StartCmd: shexec.MakeInstallCommandStr()}
msh.WriteToPtyBuffer("installing waveshell %s to %s...\n", scbase.MShellVersion, remoteCopy.RemoteCanonicalName)
clientCtx, clientCancelFn := context.WithCancel(context.Background())
defer clientCancelFn()
msh.WithLock(func() {
@ -1107,7 +1276,7 @@ func (msh *MShellProc) RunInstall() {
msgFn := func(msg string) {
msh.WriteToPtyBuffer("%s", msg)
}
err = shexec.RunInstallFromCmd(clientCtx, ecmd, true, nil, scbase.MShellBinaryReader, msgFn)
err = shexec.RunInstallFromCmd(clientCtx, installSession, true, nil, scbase.MShellBinaryReader, msgFn)
if err == context.Canceled {
msh.WriteToPtyBuffer("*install canceled\n")
msh.WithLock(func() {
@ -1130,13 +1299,12 @@ func (msh *MShellProc) RunInstall() {
msh.Err = nil
connectMode = msh.Remote.ConnectMode
})
msh.WriteToPtyBuffer("successfully installed mshell %s to ~/.mshell\n", scbase.MShellVersion)
msh.WriteToPtyBuffer("successfully installed waveshell %s to ~/.mshell\n", scbase.MShellVersion)
go msh.NotifyRemoteUpdate()
if connectMode == sstore.ConnectModeStartup || connectMode == sstore.ConnectModeAuto {
// the install was successful, and we don't have a manual connect mode, try to connect
go msh.Launch(true)
}
return
}
func (msh *MShellProc) updateRemoteStateVars(ctx context.Context, remoteId string, initPk *packet.InitPacketType) {
@ -1286,6 +1454,56 @@ func (msh *MShellProc) getActiveShellTypes(ctx context.Context) ([]string, error
return utilfn.CombineStrArrays(rtn, activeShells), nil
}
func (msh *MShellProc) createWaveshellSession(remoteCopy sstore.RemoteType) (shexec.ConnInterface, error) {
msh.WithLock(func() {
msh.Err = nil
msh.ErrNoInitPk = false
msh.Status = StatusConnecting
msh.MakeClientDeadline = nil
go msh.NotifyRemoteUpdate()
})
sapi, err := shellapi.MakeShellApi(msh.GetShellType())
if err != nil {
return nil, err
}
var wsSession shexec.ConnInterface
if remoteCopy.SSHOpts.SSHHost == "" && remoteCopy.Local {
cmdStr, err := MakeLocalMShellCommandStr(remoteCopy.IsSudo())
if err != nil {
return nil, fmt.Errorf("cannot find local mshell binary: %v", err)
}
ecmd := shexec.MakeLocalExecCmd(cmdStr, sapi)
var cmdPty *os.File
cmdPty, err = msh.addControllingTty(ecmd)
if err != nil {
return nil, fmt.Errorf("cannot attach controlling tty to mshell command: %v", err)
}
go msh.RunPtyReadLoop(cmdPty)
go msh.WaitAndSendPasswordNew(remoteCopy.SSHOpts.SSHPassword)
wsSession = shexec.CmdWrap{Cmd: ecmd}
} else if msh.Client == nil {
client, err := ConnectToClient(remoteCopy.SSHOpts)
if err != nil {
return nil, fmt.Errorf("ssh cannot connect to client: %w", err)
}
msh.WithLock(func() {
msh.Client = client
})
session, err := client.NewSession()
if err != nil {
return nil, fmt.Errorf("ssh cannot create session: %w", err)
}
wsSession = shexec.SessionWrap{Session: session, StartCmd: MakeServerCommandStr()}
} else {
session, err := msh.Client.NewSession()
if err != nil {
return nil, fmt.Errorf("ssh cannot create session: %w", err)
}
wsSession = shexec.SessionWrap{Session: session, StartCmd: MakeServerCommandStr()}
}
return wsSession, nil
}
// for conditional launch method based on ssh library in use
// remove once ssh library is stabilized
type NewLauncher struct{}
@ -1306,147 +1524,79 @@ func (NewLauncher) Launch(msh *MShellProc, interactive bool) {
msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n")
return
}
sapi, err := shellapi.MakeShellApi(msh.GetShellType())
if err != nil {
msh.WriteToPtyBuffer("*error, %v\n", err)
return
}
istatus := msh.GetInstallStatus()
if istatus == StatusConnecting {
msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n")
return
}
if remoteCopy.SSHOpts.SSHPort != 0 && remoteCopy.SSHOpts.SSHPort != 22 {
msh.WriteToPtyBuffer("connecting to %s (port %d)...\n", remoteCopy.RemoteCanonicalName, remoteCopy.SSHOpts.SSHPort)
} else {
msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName)
}
sshOpts := convertSSHOpts(remoteCopy.SSHOpts)
sshOpts.SSHErrorsToTty = true
if remoteCopy.ConnectMode != sstore.ConnectModeManual && remoteCopy.SSHOpts.SSHPassword == "" && !interactive {
sshOpts.BatchMode = true
}
var cproc *shexec.ClientProc
var initPk *packet.InitPacketType
if sshOpts.SSHHost == "" && remoteCopy.Local {
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background())
defer makeClientCancelFn()
msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName)
wsSession, err := msh.createWaveshellSession(remoteCopy)
if err != nil {
msh.WriteToPtyBuffer("*error, %s\n", err.Error())
msh.setErrorStatus(err)
msh.WithLock(func() {
msh.Err = nil
msh.ErrNoInitPk = false
msh.Status = StatusConnecting
msh.MakeClientCancelFn = makeClientCancelFn
deadlineTime := time.Now().Add(RemoteConnectTimeout)
msh.MakeClientDeadline = &deadlineTime
go msh.NotifyRemoteUpdate()
msh.Client = nil
})
go msh.watchClientDeadlineTime()
cmdStr, err := MakeLocalMShellCommandStr(remoteCopy.IsSudo())
if err != nil {
msh.WriteToPtyBuffer("*error, cannot find local mshell binary: %v\n", err)
return
}
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
var cmdPty *os.File
cmdPty, err = msh.addControllingTty(ecmd)
if err != nil {
statusErr := fmt.Errorf("cannot attach controlling tty to mshell command: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
defer func() {
if len(ecmd.ExtraFiles) > 0 {
ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close()
}
}()
go msh.RunPtyReadLoop(cmdPty)
if remoteCopy.SSHOpts.SSHPassword != "" {
go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword)
}
cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
} else {
msh.WithLock(func() {
msh.Err = nil
msh.ErrNoInitPk = false
msh.Status = StatusConnecting
msh.MakeClientDeadline = nil
go msh.NotifyRemoteUpdate()
})
var client *ssh.Client
client, err = ConnectToClient(remoteCopy.SSHOpts)
if err != nil {
statusErr := fmt.Errorf("ssh cannot connect to client: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
var session *ssh.Session
session, err = client.NewSession()
if err != nil {
statusErr := fmt.Errorf("ssh cannot create session: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background())
defer makeClientCancelFn()
msh.WithLock(func() {
msh.MakeClientCancelFn = makeClientCancelFn
deadlineTime := time.Now().Add(RemoteConnectTimeout)
msh.MakeClientDeadline = &deadlineTime
go msh.NotifyRemoteUpdate()
})
go msh.watchClientDeadlineTime()
cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.SessionWrap{Session: session, StartCmd: MakeServerRunOnlyCommandStr()})
return
}
// TODO check if initPk.State is not nil
var mshellVersion string
var hitDeadline bool
var makeClientCtx context.Context
var makeClientCancelFn context.CancelFunc
msh.WithLock(func() {
makeClientCtx, makeClientCancelFn = context.WithCancel(context.Background())
msh.MakeClientCancelFn = makeClientCancelFn
msh.MakeClientDeadline = nil
go msh.NotifyRemoteUpdate()
})
defer makeClientCancelFn()
cproc, err := shexec.MakeClientProc(makeClientCtx, wsSession)
msh.WithLock(func() {
msh.MakeClientCancelFn = nil
if time.Now().After(*msh.MakeClientDeadline) {
hitDeadline = true
}
msh.MakeClientDeadline = nil
if initPk == nil {
msh.ErrNoInitPk = true
}
if initPk != nil {
msh.UName = initPk.UName
mshellVersion = initPk.Version
if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 {
// only set NeedsMShellUpgrade if we got an InitPk
msh.NeedsMShellUpgrade = true
}
msh.InitPkShellType = initPk.Shell
}
})
if err == context.DeadlineExceeded {
msh.WriteToPtyBuffer("*connect timeout\n")
msh.setErrorStatus(errors.New("connect timeout"))
msh.WithLock(func() {
msh.Client = nil
})
return
} else if err == context.Canceled {
msh.WriteToPtyBuffer("*forced disconnection\n")
msh.WithLock(func() {
msh.Status = StatusDisconnected
go msh.NotifyRemoteUpdate()
})
msh.WithLock(func() {
msh.Client = nil
})
return
} else if serr, ok := err.(shexec.WaveshellLaunchError); ok {
msh.WithLock(func() {
msh.UName = serr.InitPk.UName
msh.NeedsMShellUpgrade = true
msh.InitPkShellType = serr.InitPk.Shell
})
msh.StateMap.Clear()
msh.WriteToPtyBuffer("*error, %s\n", serr.Error())
msh.setErrorStatus(serr)
go msh.tryAutoInstall()
return
} else if err != nil {
msh.WriteToPtyBuffer("*error, %s\n", serr.Error())
msh.setErrorStatus(err)
msh.WithLock(func() {
msh.Client = nil
})
return
}
msh.WithLock(func() {
msh.UName = cproc.InitPk.UName
msh.InitPkShellType = cproc.InitPk.Shell
msh.StateMap.Clear()
// no notify here, because we'll call notify in either case below
})
if err == context.Canceled {
if hitDeadline {
msh.WriteToPtyBuffer("*connect timeout\n")
msh.setErrorStatus(errors.New("connect timeout"))
} else {
msh.WriteToPtyBuffer("*forced disconnection\n")
msh.WithLock(func() {
msh.Status = StatusDisconnected
go msh.NotifyRemoteUpdate()
})
}
return
}
if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) {
err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion)
}
if err != nil {
msh.setErrorStatus(err)
msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err)
go msh.tryAutoInstall()
return
}
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk)
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, cproc.InitPk)
msh.WithLock(func() {
msh.ServerProc = cproc
msh.Status = StatusConnected
@ -1465,7 +1615,6 @@ func (NewLauncher) Launch(msh *MShellProc, interactive bool) {
go msh.ProcessPackets()
msh.initActiveShells()
go msh.NotifyRemoteUpdate()
return
}
// for conditional launch method based on ssh library in use
@ -1536,66 +1685,58 @@ func (LegacyLauncher) Launch(msh *MShellProc, interactive bool) {
if remoteCopy.SSHOpts.SSHPassword != "" {
go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword)
}
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background())
defer makeClientCancelFn()
var makeClientCtx context.Context
var makeClientCancelFn context.CancelFunc
msh.WithLock(func() {
deadlineTime := time.Now().Add(RemoteConnectTimeout)
makeClientCtx, makeClientCancelFn = context.WithDeadline(context.Background(), deadlineTime)
defer makeClientCancelFn()
msh.Err = nil
msh.ErrNoInitPk = false
msh.Status = StatusConnecting
msh.MakeClientCancelFn = makeClientCancelFn
deadlineTime := time.Now().Add(RemoteConnectTimeout)
msh.MakeClientDeadline = &deadlineTime
go msh.NotifyRemoteUpdate()
})
go msh.watchClientDeadlineTime()
cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
// TODO check if initPk.State is not nil
var mshellVersion string
var hitDeadline bool
cproc, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
msh.WithLock(func() {
msh.MakeClientCancelFn = nil
if time.Now().After(*msh.MakeClientDeadline) {
hitDeadline = true
}
msh.MakeClientDeadline = nil
if initPk == nil {
msh.ErrNoInitPk = true
}
if initPk != nil {
msh.UName = initPk.UName
mshellVersion = initPk.Version
if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 {
// only set NeedsMShellUpgrade if we got an InitPk
msh.NeedsMShellUpgrade = true
}
msh.InitPkShellType = initPk.Shell
}
msh.StateMap.Clear()
// no notify here, because we'll call notify in either case below
})
if err == context.Canceled {
if hitDeadline {
msh.WriteToPtyBuffer("*connect timeout\n")
msh.setErrorStatus(errors.New("connect timeout"))
} else {
msh.WriteToPtyBuffer("*forced disconnection\n")
msh.WithLock(func() {
msh.Status = StatusDisconnected
go msh.NotifyRemoteUpdate()
})
}
if err == context.DeadlineExceeded {
msh.WriteToPtyBuffer("*connect timeout\n")
msh.setErrorStatus(errors.New("connect timeout"))
return
}
if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) {
err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion)
}
if err != nil {
msh.setErrorStatus(err)
msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err)
} else if err == context.Canceled {
msh.WriteToPtyBuffer("*forced disconnection\n")
msh.WithLock(func() {
msh.Status = StatusDisconnected
go msh.NotifyRemoteUpdate()
})
return
} else if serr, ok := err.(shexec.WaveshellLaunchError); ok {
msh.WithLock(func() {
msh.UName = serr.InitPk.UName
if semver.Compare(serr.InitPk.Version, scbase.MShellVersion) < 0 {
// only set NeedsMShellUpgrade if we got an InitPk
msh.NeedsMShellUpgrade = true
}
msh.InitPkShellType = serr.InitPk.Shell
})
msh.WriteToPtyBuffer("*error, %s\n", serr.Error())
msh.setErrorStatus(serr)
go msh.tryAutoInstall()
return
} else if err != nil {
msh.WriteToPtyBuffer("*error, %s\n", serr.Error())
msh.setErrorStatus(err)
return
}
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk)
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, cproc.InitPk)
msh.WithLock(func() {
msh.ServerProc = cproc
msh.Status = StatusConnected
@ -1614,7 +1755,6 @@ func (LegacyLauncher) Launch(msh *MShellProc, interactive bool) {
go msh.ProcessPackets()
msh.initActiveShells()
go msh.NotifyRemoteUpdate()
return
}
func (msh *MShellProc) initActiveShells() {
@ -2099,7 +2239,6 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
}
}
scbus.MainUpdateBus.DoUpdate(update)
return
}
func (msh *MShellProc) handleCmdFinalPacket(finalPk *packet.CmdFinalPacketType) {
@ -2143,7 +2282,6 @@ func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
msh.WriteToPtyBuffer("cmderr> [remote %s] [error] adding cmderr: %v\n", msh.GetRemoteName(), err)
return
}
return
}
func (msh *MShellProc) ResetDataPos(ck base.CommandKey) {

View File

@ -110,7 +110,8 @@ func createPublicKeyCallback(sshKeywords *SshKeywords, passphrase string) func()
QueryText: fmt.Sprintf("Enter passphrase for the SSH key: %s", identityFile),
Title: "Publickey Auth + Passphrase",
}
ctx, _ := context.WithTimeout(context.Background(), 60*time.Second)
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil {
// this is an error where we actually do want to stop
@ -267,6 +268,10 @@ func writeToKnownHosts(knownHostsFile string, newLine string, getUserVerificatio
}
_, err = f.WriteString(newLine)
if err != nil {
f.Close()
return err
}
return f.Close()
}