From b37f7f722e5d386e434341fbda8be800b598f8f4 Mon Sep 17 00:00:00 2001 From: Cole Lashley Date: Thu, 8 Feb 2024 12:37:23 -0800 Subject: [PATCH] Command to copy file from remote to local (#231) * first pass of copy file * first pass fixing up function * fleshed out copy function, still working on display and parameters * implemented scp like syntax * finished implemententation of copy file - there are still issues * more bug fixes, still running into error * pushing waveshell concurrency and channel fixes - still need to do some qol fixes before merge * aesthetic fixes and removed logs * fixed bug in GetRemoteRuntimeState * formatting small fix * fixed pretty print bytes * added local to local command * small fix removing workaround * added workaround back * added some logs for debug * added some more logs * quick bug fix for update cmd race condition * added fix for race condition * added some more logs for debugging * fixed up logs * added proper fe state for dest parameter * implemented setting status indicator output * first pass at updating status indicators * removed logs and small fix ups * removed whitespace * addressed review comments --- waveshell/pkg/packet/parser.go | 8 +- wavesrv/pkg/cmdrunner/cmdrunner.go | 555 +++++++++++++++++++++++++++++ wavesrv/pkg/cmdrunner/resolver.go | 5 + wavesrv/pkg/remote/remote.go | 17 +- wavesrv/pkg/sstore/dbops.go | 1 + wavesrv/pkg/sstore/sstore.go | 4 + 6 files changed, 579 insertions(+), 11 deletions(-) diff --git a/waveshell/pkg/packet/parser.go b/waveshell/pkg/packet/parser.go index 72631a4ae..a41608663 100644 --- a/waveshell/pkg/packet/parser.go +++ b/waveshell/pkg/packet/parser.go @@ -156,16 +156,12 @@ func (p *PacketParser) trySendRpcResponse(pk PacketType) bool { return false } p.Lock.Lock() - defer p.Lock.Unlock() entry := p.RpcMap[respId] + p.Lock.Unlock() if entry == nil { return false } - // nonblocking send - select { - case entry.RespCh <- respPk: - default: - } + entry.RespCh <- respPk return true } diff --git a/wavesrv/pkg/cmdrunner/cmdrunner.go b/wavesrv/pkg/cmdrunner/cmdrunner.go index 718c83fdc..e4c4d9df9 100644 --- a/wavesrv/pkg/cmdrunner/cmdrunner.go +++ b/wavesrv/pkg/cmdrunner/cmdrunner.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/fs" "log" "net/url" @@ -27,6 +28,7 @@ import ( "github.com/kevinburke/ssh_config" "github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/packet" + "github.com/wavetermdev/waveterm/waveshell/pkg/server" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" "github.com/wavetermdev/waveterm/waveshell/pkg/shexec" @@ -199,6 +201,8 @@ func init() { registerCmdFn("remote:reset", RemoteResetCommand) registerCmdFn("remote:parse", RemoteConfigParseCommand) + registerCmdFn("copyfile", CopyFileCommand) + registerCmdFn("screen:resize", ScreenResizeCommand) registerCmdFn("line", LineCommand) @@ -697,6 +701,8 @@ func EvalCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore. newPk, rtnErr := EvalMetaCommand(ctxWithHistory, pk) if rtnErr == nil { update, rtnErr = HandleCommand(ctxWithHistory, newPk) + } else { + return nil, fmt.Errorf("error in Eval Meta Command: %v", rtnErr) } if !resolveBool(pk.Kwargs["nohist"], false) { // TODO should this be "pk" or "newPk" (2nd arg) @@ -1102,6 +1108,553 @@ func SidebarRemoveCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) return &sstore.ModelUpdate{Screens: []*sstore.ScreenType{screen}}, nil } +func prettyPrintByteSize(size int64) string { + gbSize := float64(size) / float64(1073741824) + if gbSize > 1 { + return fmt.Sprintf("%.2f Gigabytes", gbSize) + } + mbSize := float64(size) / float64(1048576) + if mbSize > 1 { + return fmt.Sprintf("%.2f Megabytes", mbSize) + } + kbSize := float64(size) / float64(1024) + if kbSize > 1 { + return fmt.Sprintf("%.2f Kilobytes", kbSize) + } + return fmt.Sprintf("%v Bytes", size) +} + +// this can only be called in a defer func, because recover() only works inside of a defe +func deferWriteCmdStatus(ctx context.Context, cmd *sstore.CmdType, startTime time.Time, exitSuccess bool, outputPos int64) { + r := recover() + if r != nil { + panicMsg := fmt.Sprintf("panic: %v", r) + log.Printf("panic: %v\n", panicMsg) + writeStringToPty(ctx, cmd, panicMsg, &outputPos) + } + duration := time.Since(startTime) + cmdStatus := sstore.CmdStatusDone + var exitCode int + if !exitSuccess { + cmdStatus = sstore.CmdStatusError + exitCode = 1 + } + ck := base.MakeCommandKey(cmd.ScreenId, cmd.LineId) + donePk := packet.MakeCmdDonePacket(ck) + donePk.Ts = time.Now().UnixMilli() + donePk.ExitCode = exitCode + donePk.DurationMs = duration.Milliseconds() + update, err := sstore.UpdateCmdDoneInfo(context.Background(), ck, donePk, cmdStatus) + if err != nil { + // nothing to do + log.Printf("error updating cmddoneinfo (in openai): %v\n", err) + return + } + sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update) +} + +func checkForWriteReady(ctx context.Context, iter *packet.RpcResponseIter) (string, error) { + readyIf, err := iter.Next(ctx) + if err != nil { + return "", fmt.Errorf("error getting write ready response: %v\r\n", err) + } + readyPk, ok := readyIf.(*packet.WriteFileReadyPacketType) + if !ok { + return "", fmt.Errorf("bad write ready packet received %v", readyIf) + } + if readyPk.Error != "" { + return "", fmt.Errorf("ready error: %v", readyPk.Error) + } + return readyPk.RespId, nil +} + +func checkForWriteFinished(ctx context.Context, iter *packet.RpcResponseIter) error { + doneIf, err := iter.Next(ctx) + if err != nil { + return fmt.Errorf("error while getting done response: %v", err) + } + writeDonePk, ok := doneIf.(*packet.WriteFileDonePacketType) + if !ok { + return fmt.Errorf("bad done packet received: %T", doneIf) + } + if writeDonePk.Error != "" { + return fmt.Errorf("done error: %v", writeDonePk.Error) + } + return nil +} + +func doCopyLocalFileToRemote(ctx context.Context, cmd *sstore.CmdType, remote_msh *remote.MShellProc, localPath string, destPath string, outputPos int64) { + var exitSuccess bool + startTime := time.Now() + defer func() { + deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos) + }() + localFile, err := os.Open(localPath) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error, unable to open file %v: %v\r\n", localFile, localPath), &outputPos) + return + } + defer localFile.Close() + writePk := packet.MakeWriteFilePacket() + writePk.ReqId = uuid.New().String() + writePk.Path = destPath + iter, err := remote_msh.WriteFile(ctx, writePk) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error starting file write: %v\r\n", err), &outputPos) + return + } + defer iter.Close() + _, err = checkForWriteReady(ctx, iter) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Write ready packet error: %v\r\n", err), &outputPos) + return + } + fileStat, err := localFile.Stat() + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("error: could not get file stat: %v", err), &outputPos) + return + } + fileSizeBytes := fileStat.Size() + bytesWritten := int64(0) + lastFileTransferPercentage := float64(0) + fileTransferPercentage := float64(0) + writeStringToPty(ctx, cmd, fmt.Sprintf("Source File Size: %s\r\n", prettyPrintByteSize(fileSizeBytes)), &outputPos) + writeStringToPty(ctx, cmd, "[", &outputPos) + var buffer [server.MaxFileDataPacketSize]byte + bufSlice := buffer[:] + for { + dataPk := packet.MakeFileDataPacket(writePk.ReqId) + bytesRead, err := io.ReadFull(localFile, bufSlice) + if err == io.ErrUnexpectedEOF || err == io.EOF { + dataPk.Eof = true + } else if err != nil { + dataErr := fmt.Sprintf("error reading file data: %v", err) + dataPk.Error = dataErr + remote_msh.SendFileData(dataPk) + writeStringToPty(ctx, cmd, dataErr, &outputPos) + return + } + if bytesRead > 0 { + dataPk.Data = make([]byte, bytesRead) + copy(dataPk.Data, bufSlice[0:bytesRead]) + bytesWritten += int64(len(dataPk.Data)) + fileTransferPercentage = float64(bytesWritten) / float64(fileSizeBytes) + + if fileTransferPercentage-lastFileTransferPercentage > float64(0.05) { + writeStringToPty(ctx, cmd, "-", &outputPos) + lastFileTransferPercentage = fileTransferPercentage + } + } + remote_msh.SendFileData(dataPk) + if dataPk.Eof { + break + } + } + err = checkForWriteFinished(ctx, iter) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Write finished packet error %v", err), &outputPos) + return + } + writeStringToPty(ctx, cmd, "] done. \r\n", &outputPos) + writeStringToPty(ctx, cmd, fmt.Sprintf("Finished transferring. Transferred %v bytes\r\n", fileSizeBytes), &outputPos) + exitSuccess = true +} + +func getStatusBarString(filePercentageInt int) string { + statusBarString := "\x1b[2k\r[" + for count := 0; count < 20; count++ { + if (filePercentageInt - count*5) > 0 { + statusBarString += "-" + } else { + statusBarString += " " + } + } + if filePercentageInt < 100 { + statusBarString += fmt.Sprintf("] %v%%", filePercentageInt) + } else { + statusBarString += "]" + } + return statusBarString +} + +func doCopyRemoteFileToRemote(ctx context.Context, cmd *sstore.CmdType, sourceMsh *remote.MShellProc, destMsh *remote.MShellProc, sourcePath string, destPath string, outputPos int64) { + var exitSuccess bool + startTime := time.Now() + defer func() { + deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos) + }() + streamPk := packet.MakeStreamFilePacket() + streamPk.ReqId = uuid.New().String() + streamPk.Path = sourcePath + sourceStreamIter, err := sourceMsh.StreamFile(ctx, streamPk) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error getting file data packet: %v\r\n", err), &outputPos) + return + } + defer sourceStreamIter.Close() + respIf, err := sourceStreamIter.Next(ctx) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error getting next packet: %v\r\n", err), &outputPos) + return + } + resp, ok := respIf.(*packet.StreamFileResponseType) + if !ok { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error in getting packet response: %v\r\n", err), &outputPos) + return + } + if resp == nil || resp.Error != "" { + writeStringToPty(ctx, cmd, fmt.Sprintf("Response packet has error: %v\r\n", err), &outputPos) + return + } + fileSizeBytes := resp.Info.Size + if fileSizeBytes == 0 { + writeStringToPty(ctx, cmd, "Source file does not exist or is empty - exiting\r\n", &outputPos) + return + } + writeStringToPty(ctx, cmd, fmt.Sprintf("Source File Size: %v\r\n", prettyPrintByteSize(fileSizeBytes)), &outputPos) + writePk := packet.MakeWriteFilePacket() + writePk.ReqId = uuid.New().String() + writePk.Path = destPath + destWriteIter, err := destMsh.WriteFile(ctx, writePk) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error starting file write: %v\r\n", err), &outputPos) + return + } + defer destWriteIter.Close() + _, err = checkForWriteReady(ctx, destWriteIter) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Write ready packet error: %v\r\n", err), &outputPos) + return + } + bytesWritten := int64(0) + lastFilePercentageInt := int(0) + fileTransferPercentage := float64(0) + writeStringToPty(ctx, cmd, "[", &outputPos) + for { + dataPkIf, err := sourceStreamIter.Next(ctx) + if err != nil { + log.Printf("error in read-file while getting data: %v\n", err) + return + } + if dataPkIf == nil { + break + } + dataPk, ok := dataPkIf.(*packet.FileDataPacketType) + if !ok { + writeStringToPty(ctx, cmd, fmt.Sprintf("error in read-file, invalid data packet type: %T\r\n", dataPkIf), &outputPos) + return + } + if dataPk.Error != "" { + writeStringToPty(ctx, cmd, fmt.Sprintf("in read-file, data packet error: %s\r\n", dataPk.Error), &outputPos) + return + } + writeDataPk := packet.MakeFileDataPacket(writePk.ReqId) + writeDataPk.Eof = dataPk.Eof + writeDataPk.Error = dataPk.Error + writeDataPk.Type = dataPk.Type + writeDataPk.Data = make([]byte, int64(len(dataPk.Data))) + copy(writeDataPk.Data, dataPk.Data) + err = destMsh.SendFileData(writeDataPk) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("error sending file to dest: %v\r\n", err), &outputPos) + return + } + bytesWritten += int64(len(dataPk.Data)) + fileTransferPercentage = float64(bytesWritten) / float64(fileSizeBytes) + filePercentageInt := int(fileTransferPercentage * 100) + if filePercentageInt-lastFilePercentageInt > 5 { + statusBarString := getStatusBarString(filePercentageInt) + writeStringToPty(ctx, cmd, statusBarString, &outputPos) + lastFilePercentageInt = filePercentageInt + } + } + err = checkForWriteFinished(ctx, destWriteIter) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("\r\nWrite finished packet error %v", err), &outputPos) + return + } + writeStringToPty(ctx, cmd, getStatusBarString(100), &outputPos) + writeStringToPty(ctx, cmd, " done. \r\n", &outputPos) + writeStringToPty(ctx, cmd, fmt.Sprintf("Finished transferring. Transferred %v bytes\r\n", bytesWritten), &outputPos) + exitSuccess = true +} + +func doCopyLocalFileToLocal(ctx context.Context, cmd *sstore.CmdType, sourcePath string, destPath string, outputPos int64) { + var exitSuccess bool + var bytesWritten int64 + startTime := time.Now() + defer func() { + deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos) + }() + sourceFile, err := os.Open(sourcePath) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("error opening source file %v", err), &outputPos) + return + } + defer sourceFile.Close() + sourceFileStat, err := sourceFile.Stat() + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("error getting filestat %v", err), &outputPos) + return + } + fileSizeBytes := sourceFileStat.Size() + writeStringToPty(ctx, cmd, fmt.Sprintf("Source File Size: %v\r\n", prettyPrintByteSize(fileSizeBytes)), &outputPos) + destFile, err := os.Create(destPath) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("error creating dest file %v", err), &outputPos) + return + } + defer destFile.Close() + bytesWritten, err = io.Copy(destFile, sourceFile) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("error copying files %v", err), &outputPos) + return + } + writeStringToPty(ctx, cmd, fmt.Sprintf("Finished transferring. Transferred %v bytes\r\n", bytesWritten), &outputPos) + exitSuccess = true +} + +func doCopyRemoteFileToLocal(ctx context.Context, cmd *sstore.CmdType, remote_msh *remote.MShellProc, sourcePath string, localPath string, outputPos int64) { + var exitSuccess bool + startTime := time.Now() + defer func() { + deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos) + }() + streamPk := packet.MakeStreamFilePacket() + streamPk.ReqId = uuid.New().String() + streamPk.Path = sourcePath + iter, err := remote_msh.StreamFile(ctx, streamPk) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error getting file data packet: %v\r\n", err), &outputPos) + return + } + defer iter.Close() + respIf, err := iter.Next(ctx) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error getting next packet: %v\r\n", err), &outputPos) + return + } + resp, ok := respIf.(*packet.StreamFileResponseType) + if !ok { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error in getting packet response: %v\r\n", err), &outputPos) + return + } + if resp == nil || resp.Error != "" { + writeStringToPty(ctx, cmd, fmt.Sprintf("Response packet has error: %v\r\n", err), &outputPos) + return + } + fileSizeBytes := resp.Info.Size + if fileSizeBytes == 0 { + writeStringToPty(ctx, cmd, "Source file doesn't exist or file is empty - exiting\r\n", &outputPos) + return + } + writeStringToPty(ctx, cmd, fmt.Sprintf("Source File Size: %s\r\n", prettyPrintByteSize(fileSizeBytes)), &outputPos) + localFile, err := os.Create(localPath) + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Error creating file on local %v\r\n", err), &outputPos) + return + } + defer localFile.Close() + bytesWritten := int64(0) + lastFileTransferPercentage := float64(0) + fileTransferPercentage := float64(0) + writeStringToPty(ctx, cmd, "[", &outputPos) + for { + dataPkIf, err := iter.Next(ctx) + if err != nil { + log.Printf("error in read-file while getting data: %v\n", err) + return + } + if dataPkIf == nil { + break + } + dataPk, ok := dataPkIf.(*packet.FileDataPacketType) + if !ok { + writeStringToPty(ctx, cmd, fmt.Sprintf("error in read-file, invalid data packet type: %T\r\n", dataPkIf), &outputPos) + return + } + if dataPk.Error != "" { + writeStringToPty(ctx, cmd, fmt.Sprintf("in read-file, data packet error: %s", dataPk.Error), &outputPos) + return + } + localFile.Write(dataPk.Data) + bytesWritten += int64(len(dataPk.Data)) + fileTransferPercentage = float64(bytesWritten) / float64(fileSizeBytes) + + if fileTransferPercentage-lastFileTransferPercentage > float64(0.05) { + writeStringToPty(ctx, cmd, "-", &outputPos) + lastFileTransferPercentage = fileTransferPercentage + } + } + writeStringToPty(ctx, cmd, "] done. \r\n", &outputPos) + writeStringToPty(ctx, cmd, fmt.Sprintf("Finished transferring. Transferred %v bytes\n", fileSizeBytes), &outputPos) + exitSuccess = true +} + +func writeStringToPty(ctx context.Context, cmd *sstore.CmdType, outputString string, outputPos *int64) { + outBytes := []byte(outputString) + update, err := sstore.AppendToCmdPtyBlob(ctx, cmd.ScreenId, cmd.LineId, outBytes, *outputPos) + *outputPos += int64(len(outBytes)) + if err != nil { + log.Printf("error writing to pty: %v", err) + } + sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update) + err = sstore.SetStatusIndicatorLevel(ctx, cmd.ScreenId, sstore.StatusIndicatorLevel_Output, false) + if err != nil { + // This is not a fatal error, so just log it + log.Printf("error setting status indicator level to output in writeStringToPty: %v\n", err) + } +} + +func parseCopyFileParam(info string) (remote string, path string, err error) { + stringsList := strings.Split(info, ":") + if len(stringsList) == 1 { + // use cur remote + return "", stringsList[0], nil + } else if len(stringsList) == 2 { + remote := strings.Trim(stringsList[0], "[] ") + return remote, stringsList[1], nil + } else { + return "error", "error", fmt.Errorf("malformed arguments") + } +} + +func CopyFileCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) { + if len(pk.Args) == 0 { + return nil, fmt.Errorf("usage: /copyfile [file to copy] local=[path to copy to on local]") + } + ids, err := resolveUiIds(ctx, pk, R_Screen|R_Session|R_RemoteConnected) + if err != nil { + return nil, fmt.Errorf("failed to resolve connected remote id: %v", err) + } + sourceInfo := pk.Args[0] + sourceRemote, sourcePath, err := parseCopyFileParam(sourceInfo) + var sourceRemoteId *ResolvedRemote + var destRemoteId *ResolvedRemote + if err != nil { + return nil, fmt.Errorf("error: malformed arguments - usage: [remote]:path ") + } else if sourceRemote == "" { + // use cur remote + sourceRemote = ConnectedRemote + sourceRemoteId = ids.Remote + if ids.Remote.RemoteCopy.IsLocal() { + sourceRemote = LocalRemote + } + } else { + pk.Kwargs["remote"] = sourceRemote + sourceIds, err := resolveUiIds(ctx, pk, R_Remote) + if err != nil { + return nil, fmt.Errorf("error resolving remote id %v", err) + } + sourceRemoteId = sourceIds.Remote + } + destInfo := pk.Args[1] + destRemote, destPath, err := parseCopyFileParam(destInfo) + if err != nil { + return nil, fmt.Errorf("error: malformed arguments - usage: [remote]:path ") + } else if destRemote == "" { + destRemote = ConnectedRemote + destRemoteId = ids.Remote + if ids.Remote.RemoteCopy.IsLocal() { + destRemote = LocalRemote + } + } else { + pk.Kwargs["remote"] = destRemote + destIds, err := resolveUiIds(ctx, pk, R_Remote) + if err != nil { + return nil, fmt.Errorf("error resolving remote id %v", err) + } + destRemoteId = destIds.Remote + } + if destPath == "" { + return nil, fmt.Errorf("error: malformed arguments - usage: [remote]:path ") + } + + var sourceFullPath string + var destFullPath string + sourceMsh := sourceRemoteId.MShell + if sourceMsh == nil { + return nil, fmt.Errorf("failure getting source remote mshell") + } + sourceRRState := sourceMsh.GetRemoteRuntimeState() + sourcePathWithHome, err := sourceRRState.ExpandHomeDir(sourcePath) + if err != nil { + return nil, fmt.Errorf("expand home dir err: %v", err) + } + sourceFullPath = sourcePathWithHome + if (sourceRemote == ConnectedRemote || sourceRemote == LocalRemote) && !filepath.IsAbs(sourcePathWithHome) && sourceRemoteId.FeState != nil { + sourceCwd := sourceRemoteId.FeState["cwd"] + if sourceCwd != "" { + sourceFullPath = filepath.Join(sourceCwd, sourcePathWithHome) + } + } + if destPath[len(destPath)-1:] == "/" { + sourceFileName := filepath.Base(sourceFullPath) + destPath = filepath.Join(destPath, sourceFileName) + } + destMsh := destRemoteId.MShell + if destMsh == nil { + return nil, fmt.Errorf("failure getting dest remote mshell") + } + destRRState := destMsh.GetRemoteRuntimeState() + destPathWithHome, err := destRRState.ExpandHomeDir(destPath) + if err != nil { + return nil, fmt.Errorf("expand home dir err: %v", err) + } + destFullPath = destPathWithHome + if (destRemote == ConnectedRemote || destRemote == LocalRemote) && !filepath.IsAbs(destPathWithHome) && destRemoteId.FeState != nil { + destCwd := destRemoteId.FeState["cwd"] + if destCwd != "" { + destFullPath = filepath.Join(destCwd, destPathWithHome) + } + } + var outputPos int64 + outputStr := fmt.Sprintf("Copying [%v]:%v to [%v]:%v\r\n", sourceRemoteId.DisplayName, sourceFullPath, destRemoteId.DisplayName, destFullPath) + termopts := sstore.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize} + cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), termopts) + writeStringToPty(ctx, cmd, outputStr, &outputPos) + if err != nil { + // TODO tricky error since the command was a success, but we can't show the output + return nil, err + } + update, err := addLineForCmd(ctx, "/copy file", false, ids, cmd, "", nil) + if err != nil { + // TODO tricky error since the command was a success, but we can't show the output + return nil, err + } + update.Interactive = pk.Interactive + if destRemote != ConnectedRemote && destRemoteId != nil && !destRemoteId.RState.IsConnected() { + writeStringToPty(ctx, cmd, fmt.Sprintf("Attempting to autoconnect to remote %v\r\n", destRemote), &outputPos) + err = destRemoteId.MShell.TryAutoConnect() + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Couldn't connect to remote %v\r\n", sourceRemote), &outputPos) + } else { + writeStringToPty(ctx, cmd, "Auto connect successful\r\n", &outputPos) + } + } + if sourceRemote != LocalRemote && sourceRemoteId != nil && !sourceRemoteId.RState.IsConnected() { + writeStringToPty(ctx, cmd, fmt.Sprintf("Attempting to autoconnect to remote %v\r\n", sourceRemote), &outputPos) + err = sourceRemoteId.MShell.TryAutoConnect() + if err != nil { + writeStringToPty(ctx, cmd, fmt.Sprintf("Couldn't connect to remote %v\r\n", sourceRemote), &outputPos) + } else { + writeStringToPty(ctx, cmd, "Auto connect successful\r\n", &outputPos) + } + } + sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update) + update = &sstore.ModelUpdate{} + if destRemote == LocalRemote && sourceRemote == LocalRemote { + go doCopyLocalFileToLocal(context.Background(), cmd, sourceFullPath, destFullPath, outputPos) + } else if destRemote == LocalRemote && sourceRemote != LocalRemote { + go doCopyRemoteFileToLocal(context.Background(), cmd, sourceMsh, sourceFullPath, destFullPath, outputPos) + } else if destRemote != LocalRemote && sourceRemote == LocalRemote { + go doCopyLocalFileToRemote(context.Background(), cmd, destMsh, sourceFullPath, destFullPath, outputPos) + } else if destRemote != LocalRemote && sourceRemote != LocalRemote { + go doCopyRemoteFileToRemote(context.Background(), cmd, sourceMsh, destMsh, sourceFullPath, destFullPath, outputPos) + } + return update, nil +} + func RemoteInstallCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) { ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote) if err != nil { @@ -2463,6 +3016,7 @@ func addLineForCmd(ctx context.Context, metaCmd string, shouldFocus bool, ids re Cmd: cmd, Screens: []*sstore.ScreenType{screen}, } + sstore.IncrementNumRunningCmds_Update(update, cmd.ScreenId, 1) updateHistoryContext(ctx, rtnLine, cmd, cmd.FeState) return update, nil } @@ -3422,6 +3976,7 @@ func LineRestartCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) ( NoCreateCmdPtyFile: true, } cmd, callback, err := remote.RunCommand(ctx, rcOpts, runPacket) + sstore.IncrementNumRunningCmds(cmd.ScreenId, 1) if callback != nil { defer callback() } diff --git a/wavesrv/pkg/cmdrunner/resolver.go b/wavesrv/pkg/cmdrunner/resolver.go index ee62436b2..265cf1936 100644 --- a/wavesrv/pkg/cmdrunner/resolver.go +++ b/wavesrv/pkg/cmdrunner/resolver.go @@ -24,6 +24,11 @@ const ( R_RemoteConnected = 16 ) +const ( + ConnectedRemote = "connected" + LocalRemote = "local" +) + type resolvedIds struct { SessionId string ScreenId string diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index 3e5bf06d2..2860f8de8 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -50,6 +50,7 @@ const RemoteTermRows = 8 const RemoteTermCols = 80 const PtyReadBufSize = 100 const RemoteConnectTimeout = 15 * time.Second +const RpcIterChannelSize = 100 var envVarsToStrip map[string]bool = map[string]bool{ "PROMPT": true, @@ -665,7 +666,12 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState { if vars["remoteuser"] == "root" || vars["sudo"] == "1" { vars["isroot"] = "1" } - state.RemoteVars = vars + varsCopy := make(map[string]string) + // deep copy so that concurrent calls don't collide on this data + for key, value := range vars { + varsCopy[key] = value + } + state.RemoteVars = varsCopy state.ActiveShells = msh.StateMap.GetShells() return state } @@ -1203,6 +1209,10 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.Sh return ssPk, nil } +func (msh *MShellProc) WriteFile(ctx context.Context, writePk *packet.WriteFilePacketType) (*packet.RpcResponseIter, error) { + return msh.PacketRpcIter(ctx, writePk) +} + func (msh *MShellProc) StreamFile(ctx context.Context, streamPk *packet.StreamFilePacketType) (*packet.RpcResponseIter, error) { return msh.PacketRpcIter(ctx, streamPk) } @@ -1886,7 +1896,6 @@ func RunCommand(ctx context.Context, rcOpts RunCommandOpts, runPacket *packet.Ru RunPacket: runPacket, }) - go pushNumRunningCmdsUpdate(&runPacket.CK, 1) return cmd, func() { removeCmdWait(runPacket.CK) }, nil } @@ -1925,7 +1934,7 @@ func (msh *MShellProc) PacketRpcIter(ctx context.Context, pk packet.RpcPacketTyp return nil, fmt.Errorf("PacketRpc passed nil packet") } reqId := pk.GetReqId() - msh.ServerProc.Output.RegisterRpc(reqId) + msh.ServerProc.Output.RegisterRpcSz(reqId, RpcIterChannelSize) err := msh.ServerProc.Input.SendPacketCtx(ctx, pk) if err != nil { return nil, err @@ -2064,8 +2073,6 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) { // fall-through (nothing to do) } } - - go pushNumRunningCmdsUpdate(&donePk.CK, -1) sstore.MainBus.SendUpdate(update) return } diff --git a/wavesrv/pkg/sstore/dbops.go b/wavesrv/pkg/sstore/dbops.go index 5e394275c..0b8ca3d79 100644 --- a/wavesrv/pkg/sstore/dbops.go +++ b/wavesrv/pkg/sstore/dbops.go @@ -950,6 +950,7 @@ func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, donePk *packet.C // This is not a fatal error, so just log it log.Printf("error setting status indicator level after done packet: %v\n", err) } + IncrementNumRunningCmds_Update(update, screenId, -1) return update, nil } diff --git a/wavesrv/pkg/sstore/sstore.go b/wavesrv/pkg/sstore/sstore.go index 9cd62e50a..b9f2c5cf9 100644 --- a/wavesrv/pkg/sstore/sstore.go +++ b/wavesrv/pkg/sstore/sstore.go @@ -1100,6 +1100,10 @@ type RemoteType struct { OpenAIOpts *OpenAIOptsType `json:"openaiopts,omitempty"` } +func (r *RemoteType) IsLocal() bool { + return r.Local && !r.IsSudo() +} + func (r *RemoteType) IsSudo() bool { return r.SSHOpts != nil && r.SSHOpts.IsSudo }