diff --git a/waveshell/pkg/server/server.go b/waveshell/pkg/server/server.go index 89e588a54..192b39dc7 100644 --- a/waveshell/pkg/server/server.go +++ b/waveshell/pkg/server/server.go @@ -748,6 +748,10 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) { m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("invalid shellstate version: %w", err)) return } + if runPacket.Command == "wave:testerror" { + m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("test error")) + return + } ecmd, err := shexec.MakeMShellSingleCmd() if err != nil { m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err)) diff --git a/waveshell/pkg/utilfn/ansi.go b/waveshell/pkg/utilfn/ansi.go index 276989816..34d897f6b 100644 --- a/waveshell/pkg/utilfn/ansi.go +++ b/waveshell/pkg/utilfn/ansi.go @@ -10,3 +10,7 @@ func AnsiResetColor() string { func AnsiGreenColor() string { return "\033[32m" } + +func AnsiRedColor() string { + return "\033[31m" +} diff --git a/wavesrv/cmd/main-server.go b/wavesrv/cmd/main-server.go index 41e37de71..d62a93a11 100644 --- a/wavesrv/cmd/main-server.go +++ b/wavesrv/cmd/main-server.go @@ -1018,6 +1018,8 @@ func main() { wlog.GlobalSubsystem = base.ProcessType_WaveSrv wlog.LogConsumer = wlog.LogWithLogger + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + if len(os.Args) >= 2 && os.Args[1] == "--test" { log.Printf("running test fn\n") err := test() diff --git a/wavesrv/pkg/cmdrunner/cmdrunner.go b/wavesrv/pkg/cmdrunner/cmdrunner.go index 62056e560..c16e65027 100644 --- a/wavesrv/pkg/cmdrunner/cmdrunner.go +++ b/wavesrv/pkg/cmdrunner/cmdrunner.go @@ -1242,12 +1242,13 @@ func deferWriteCmdStatus(ctx context.Context, cmd *sstore.CmdType, startTime tim exitCode = 1 } ck := base.MakeCommandKey(cmd.ScreenId, cmd.LineId) - donePk := packet.MakeCmdDonePacket(ck) - donePk.Ts = time.Now().UnixMilli() - donePk.ExitCode = exitCode - donePk.DurationMs = duration.Milliseconds() + doneInfo := sstore.CmdDoneDataValues{ + Ts: time.Now().UnixMilli(), + ExitCode: exitCode, + DurationMs: duration.Milliseconds(), + } update := scbus.MakeUpdatePacket() - err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus) + err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, doneInfo, cmdStatus) if err != nil { // nothing to do log.Printf("error updating cmddoneinfo: %v\n", err) @@ -2623,12 +2624,13 @@ func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt exitCode = 1 } ck := base.MakeCommandKey(cmd.ScreenId, cmd.LineId) - donePk := packet.MakeCmdDonePacket(ck) - donePk.Ts = time.Now().UnixMilli() - donePk.ExitCode = exitCode - donePk.DurationMs = duration.Milliseconds() + doneInfo := sstore.CmdDoneDataValues{ + Ts: time.Now().UnixMilli(), + ExitCode: exitCode, + DurationMs: duration.Milliseconds(), + } update := scbus.MakeUpdatePacket() - err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus) + err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, doneInfo, cmdStatus) if err != nil { // nothing to do log.Printf("error updating cmddoneinfo (in openai): %v\n", err) @@ -2783,12 +2785,13 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, clientId string, opts *sstore exitCode = 1 } ck := base.MakeCommandKey(cmd.ScreenId, cmd.LineId) - donePk := packet.MakeCmdDonePacket(ck) - donePk.Ts = time.Now().UnixMilli() - donePk.ExitCode = exitCode - donePk.DurationMs = duration.Milliseconds() + doneInfo := sstore.CmdDoneDataValues{ + Ts: time.Now().UnixMilli(), + ExitCode: exitCode, + DurationMs: duration.Milliseconds(), + } update := scbus.MakeUpdatePacket() - err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus) + err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, doneInfo, cmdStatus) if err != nil { // nothing to do log.Printf("error updating cmddoneinfo (in openai): %v\n", err) diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index 0328905cd..d01b7dc43 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -2012,30 +2012,33 @@ func RunCommand(ctx context.Context, rcOpts RunCommandOpts, runPacket *packet.Ru removeCmdWait(runPacket.CK) } }() - + runningCmdType := &RunCmdType{ + CK: runPacket.CK, + SessionId: sessionId, + ScreenId: screenId, + RemotePtr: remotePtr, + RunPacket: runPacket, + EphemeralOpts: rcOpts.EphemeralOpts, + } // RegisterRpc + WaitForResponse is used to get any waveshell side errors // waveshell will either return an error (in a ResponsePacketType) or a CmdStartPacketType msh.ServerProc.Output.RegisterRpc(runPacket.ReqId) - err = shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket) - if err != nil { - return nil, nil, fmt.Errorf("sending run packet to remote: %w", err) - } - rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId) - if rtnPk == nil { - return nil, nil, ctx.Err() - } - startPk, ok := rtnPk.(*packet.CmdStartPacketType) - if !ok { - respPk, ok := rtnPk.(*packet.ResponsePacketType) - if !ok { - return nil, nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) - } - if respPk.Error != "" { - return nil, nil, respPk.Err() - } - return nil, nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) - } - + go func() { + startPk, err := msh.sendRunPacketAndReturnResponse(runPacket) + runCmdUpdateFn(runPacket.CK, func() { + if err != nil { + // the cmd failed (never started) + msh.handleCmdStartError(runningCmdType, err) + return + } + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + err = sstore.UpdateCmdStartInfo(ctx, runPacket.CK, startPk.Pid, startPk.MShellPid) + if err != nil { + log.Printf("error updating cmd start info (in remote.RunCommand): %v\n", err) + } + }) + }() // command is now successfully runnning status := sstore.CmdStatusRunning if runPacket.Detached { @@ -2051,8 +2054,6 @@ func RunCommand(ctx context.Context, rcOpts RunCommandOpts, runPacket *packet.Ru StatePtr: *statePtr, TermOpts: makeTermOpts(runPacket), Status: status, - CmdPid: startPk.Pid, - RemotePid: startPk.MShellPid, ExitCode: 0, DurationMs: 0, RunOut: nil, @@ -2065,18 +2066,36 @@ func RunCommand(ctx context.Context, rcOpts RunCommandOpts, runPacket *packet.Ru return nil, nil, fmt.Errorf("cannot create local ptyout file for running command: %v", err) } } - runningCmdType := &RunCmdType{ - CK: runPacket.CK, - SessionId: sessionId, - ScreenId: screenId, - RemotePtr: remotePtr, - RunPacket: runPacket, - EphemeralOpts: rcOpts.EphemeralOpts} msh.AddRunningCmd(runningCmdType) - return cmd, func() { removeCmdWait(runPacket.CK) }, nil } +// no context because it is called as a goroutine +func (msh *MShellProc) sendRunPacketAndReturnResponse(runPacket *packet.RunPacketType) (*packet.CmdStartPacketType, error) { + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + err := shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket) + if err != nil { + return nil, fmt.Errorf("sending run packet to remote: %w", err) + } + rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId) + if rtnPk == nil { + return nil, ctx.Err() + } + startPk, ok := rtnPk.(*packet.CmdStartPacketType) + if !ok { + respPk, ok := rtnPk.(*packet.ResponsePacketType) + if !ok { + return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) + } + if respPk.Error != "" { + return nil, respPk.Err() + } + return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) + } + return startPk, nil +} + // helper func to construct the proper error given what information we have func makePSCLineError(existingPSC base.CommandKey, line *sstore.LineType, lineErr error) error { if lineErr != nil { @@ -2342,6 +2361,42 @@ func (msh *MShellProc) updateRIWithFinalState(ctx context.Context, rct *RunCmdTy return sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, nil, newStateDiff) } +func (msh *MShellProc) handleCmdStartError(rct *RunCmdType, startErr error) { + if rct == nil { + log.Printf("handleCmdStartError, no rct\n") + return + } + defer msh.RemoveRunningCmd(rct.CK) + if rct.EphemeralOpts != nil { + // nothing to do for ephemeral commands besides remove the running command + return + } + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + update := scbus.MakeUpdatePacket() + errOutputStr := fmt.Sprintf("%serror: %v%s\n", utilfn.AnsiRedColor(), startErr, utilfn.AnsiResetColor()) + msh.writeToCmdPtyOut(ctx, rct.ScreenId, rct.CK.GetCmdId(), []byte(errOutputStr)) + doneInfo := sstore.CmdDoneDataValues{ + Ts: time.Now().UnixMilli(), + ExitCode: 1, + DurationMs: 0, + } + err := sstore.UpdateCmdDoneInfo(ctx, update, rct.CK, doneInfo, sstore.CmdStatusError) + if err != nil { + log.Printf("error updating cmddone info (in handleCmdStartError): %v\n", err) + return + } + screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, rct.CK.GetGroupId(), rct.CK.GetCmdId()) + if err != nil { + log.Printf("error trying to update screen focus type (in handleCmdDonePacket): %v\n", err) + // fall-through (nothing to do) + } + if screen != nil { + update.AddUpdate(*screen) + } + scbus.MainUpdateBus.DoUpdate(update) +} + func (msh *MShellProc) handleCmdDonePacket(rct *RunCmdType, donePk *packet.CmdDonePacketType) { if rct == nil { log.Printf("cmddone packet received, but no running command found for it %q\n", donePk.CK) @@ -2359,7 +2414,12 @@ func (msh *MShellProc) handleCmdDonePacket(rct *RunCmdType, donePk *packet.CmdDo update := scbus.MakeUpdatePacket() if rct.EphemeralOpts == nil { // only update DB for non-ephemeral commands - err := sstore.UpdateCmdDoneInfo(ctx, update, donePk.CK, donePk, sstore.CmdStatusDone) + cmdDoneInfo := sstore.CmdDoneDataValues{ + Ts: donePk.Ts, + ExitCode: donePk.ExitCode, + DurationMs: donePk.DurationMs, + } + err := sstore.UpdateCmdDoneInfo(ctx, update, donePk.CK, cmdDoneInfo, sstore.CmdStatusDone) if err != nil { log.Printf("error updating cmddone info (in handleCmdDonePacket): %v\n", err) return @@ -2453,6 +2513,19 @@ func (msh *MShellProc) ResetDataPos(ck base.CommandKey) { msh.DataPosMap.Delete(ck) } +func (msh *MShellProc) writeToCmdPtyOut(ctx context.Context, screenId string, lineId string, data []byte) error { + dataPos := msh.DataPosMap.Get(base.MakeCommandKey(screenId, lineId)) + update, err := sstore.AppendToCmdPtyBlob(ctx, screenId, lineId, data, dataPos) + if err != nil { + return err + } + utilfn.IncSyncMap(msh.DataPosMap, base.MakeCommandKey(screenId, lineId), int64(len(data))) + if update != nil { + scbus.MainUpdateBus.DoScreenUpdate(screenId, update) + } + return nil +} + func (msh *MShellProc) handleDataPacket(rct *RunCmdType, dataPk *packet.DataPacketType, dataPosMap *utilfn.SyncMap[base.CommandKey, int64]) { if rct == nil { log.Printf("error handling data packet: no running cmd found %s\n", dataPk.CK) diff --git a/wavesrv/pkg/sstore/dbops.go b/wavesrv/pkg/sstore/dbops.go index ca35f87a7..f28a3da61 100644 --- a/wavesrv/pkg/sstore/dbops.go +++ b/wavesrv/pkg/sstore/dbops.go @@ -743,10 +743,21 @@ func UpdateCmdForRestart(ctx context.Context, ck base.CommandKey, ts int64, cmdP }) } -func UpdateCmdDoneInfo(ctx context.Context, update *scbus.ModelUpdatePacketType, ck base.CommandKey, donePk *packet.CmdDonePacketType, status string) error { - if donePk == nil { - return fmt.Errorf("invalid cmddone packet") - } +func UpdateCmdStartInfo(ctx context.Context, ck base.CommandKey, cmdPid int, mshellPid int) error { + return WithTx(ctx, func(tx *TxWrap) error { + query := `UPDATE cmd SET cmdpid = ?, remotepid = ? WHERE screenid = ? AND lineid = ?` + tx.Exec(query, cmdPid, mshellPid, ck.GetGroupId(), lineIdFromCK(ck)) + return nil + }) +} + +type CmdDoneDataValues struct { + Ts int64 + ExitCode int + DurationMs int64 +} + +func UpdateCmdDoneInfo(ctx context.Context, update *scbus.ModelUpdatePacketType, ck base.CommandKey, donePk CmdDoneDataValues, status string) error { if ck.IsEmpty() { return fmt.Errorf("cannot update cmddoneinfo, empty ck") }