fixup connect/disconnect to deal with connecting state. use context to cancel remote that is in connecting state

This commit is contained in:
sawka 2022-09-16 12:28:09 -07:00
parent 06e3a86f53
commit fad718d571
2 changed files with 35 additions and 8 deletions

View File

@ -421,6 +421,9 @@ func RemoteConnectCommand(ctx context.Context, pk *scpacket.FeCommandPacketType)
if ids.Remote.RState.IsConnected() { if ids.Remote.RState.IsConnected() {
return sstore.InfoMsgUpdate("remote %q already connected (no action taken)", ids.Remote.DisplayName), nil return sstore.InfoMsgUpdate("remote %q already connected (no action taken)", ids.Remote.DisplayName), nil
} }
if ids.Remote.RState.Status == remote.StatusConnecting {
return sstore.InfoMsgUpdate("remote %q is already trying to connect (no action taken)", ids.Remote.DisplayName), nil
}
go ids.Remote.MShell.Launch() go ids.Remote.MShell.Launch()
return sstore.InfoMsgUpdate("remote %q reconnecting", ids.Remote.DisplayName), nil return sstore.InfoMsgUpdate("remote %q reconnecting", ids.Remote.DisplayName), nil
} }
@ -431,7 +434,8 @@ func RemoteDisconnectCommand(ctx context.Context, pk *scpacket.FeCommandPacketTy
return nil, err return nil, err
} }
force := resolveBool(pk.Kwargs["force"], false) force := resolveBool(pk.Kwargs["force"], false)
if !ids.Remote.RState.IsConnected() && !force { status := ids.Remote.MShell.GetStatus()
if status != remote.StatusConnected && status != remote.StatusConnecting {
return sstore.InfoMsgUpdate("remote %q already disconnected (no action taken)", ids.Remote.DisplayName), nil return sstore.InfoMsgUpdate("remote %q already disconnected (no action taken)", ids.Remote.DisplayName), nil
} }
numCommands := ids.Remote.MShell.GetNumRunningCommands() numCommands := ids.Remote.MShell.GetNumRunningCommands()

View File

@ -72,6 +72,7 @@ type MShellProc struct {
Err error Err error
ControllingPty *os.File ControllingPty *os.File
PtyBuffer *circbuf.Buffer PtyBuffer *circbuf.Buffer
MakeClientCancelFn context.CancelFunc
RunningCmds []base.CommandKey RunningCmds []base.CommandKey
} }
@ -95,6 +96,12 @@ func (state RemoteRuntimeState) IsConnected() bool {
return state.Status == StatusConnected return state.Status == StatusConnected
} }
func (msh *MShellProc) GetStatus() string {
msh.Lock.Lock()
defer msh.Lock.Unlock()
return msh.Status
}
func (state RemoteRuntimeState) GetBaseDisplayName() string { func (state RemoteRuntimeState) GetBaseDisplayName() string {
if state.RemoteAlias != "" { if state.RemoteAlias != "" {
return state.RemoteAlias return state.RemoteAlias
@ -522,6 +529,10 @@ func (msh *MShellProc) Disconnect() {
if msh.ServerProc != nil { if msh.ServerProc != nil {
msh.ServerProc.Close() msh.ServerProc.Close()
} }
if msh.MakeClientCancelFn != nil {
msh.MakeClientCancelFn()
msh.MakeClientCancelFn = nil
}
} }
func (msh *MShellProc) GetRemoteName() string { func (msh *MShellProc) GetRemoteName() string {
@ -598,6 +609,11 @@ func (msh *MShellProc) Launch() {
msh.WriteToPtyBuffer("cannot launch archived remote\n") msh.WriteToPtyBuffer("cannot launch archived remote\n")
return return
} }
curStatus := msh.GetStatus()
if curStatus == StatusConnecting {
msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n")
return
}
msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName) msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName)
sshOpts := convertSSHOpts(remoteCopy.SSHOpts) sshOpts := convertSSHOpts(remoteCopy.SSHOpts)
sshOpts.SSHErrorsToTty = true sshOpts.SSHErrorsToTty = true
@ -615,15 +631,22 @@ func (msh *MShellProc) Launch() {
} }
}() }()
go msh.RunPtyReadLoop(cmdPty) go msh.RunPtyReadLoop(cmdPty)
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background())
defer makeClientCancelFn()
msh.WithLock(func() { msh.WithLock(func() {
msh.Status = StatusConnecting msh.Status = StatusConnecting
msh.MakeClientCancelFn = makeClientCancelFn
go msh.NotifyRemoteUpdate() go msh.NotifyRemoteUpdate()
}) })
cproc, uname, err := shexec.MakeClientProc(ecmd) cproc, uname, err := shexec.MakeClientProc(makeClientCtx, ecmd)
msh.WithLock(func() { msh.WithLock(func() {
msh.UName = uname msh.UName = uname
msh.MakeClientCancelFn = nil
// no notify here, because we'll call notify in either case below // no notify here, because we'll call notify in either case below
}) })
if err == context.Canceled {
err = fmt.Errorf("forced disconnection")
}
if err != nil { if err != nil {
msh.setErrorStatus(err) msh.setErrorStatus(err)
msh.WriteToPtyBuffer("*error connecting to remote (uname=%q): %v\n", msh.UName, err) msh.WriteToPtyBuffer("*error connecting to remote (uname=%q): %v\n", msh.UName, err)