standardize error reporting, rpc gets resp, command get cmderr, other errors are just sent as messages

This commit is contained in:
sawka 2022-07-05 17:45:46 -07:00
parent ef362e5ee9
commit 0c204e8b2b
5 changed files with 190 additions and 167 deletions

View File

@ -13,132 +13,131 @@ import (
"strings"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/server"
"github.com/scripthaus-dev/mshell/pkg/shexec"
"golang.org/x/sys/unix"
)
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
err := shexec.ValidateRunPacket(pk)
if err != nil {
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
return
}
fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil {
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
return
}
cmd, err := shexec.MakeRunnerExec(pk.CK)
if err != nil {
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
return
}
cmdStdin, err := cmd.StdinPipe()
if err != nil {
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
return
}
// touch ptyout file (should exist for tailer to work correctly)
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
return
}
ptyOutFd.Close() // just opened to create the file, can close right after
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
return
}
defer runnerOutFd.Close()
cmd.Stdout = runnerOutFd
cmd.Stderr = runnerOutFd
err = cmd.Start()
if err != nil {
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
return
}
go func() {
err = packet.SendPacket(cmdStdin, pk)
if err != nil {
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
return
}
cmdStdin.Close()
// func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
// err := shexec.ValidateRunPacket(pk)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
// return
// }
// fileNames, err := base.GetCommandFileNames(pk.CK)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
// return
// }
// cmd, err := shexec.MakeRunnerExec(pk.CK)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
// return
// }
// cmdStdin, err := cmd.StdinPipe()
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
// return
// }
// // touch ptyout file (should exist for tailer to work correctly)
// ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
// return
// }
// ptyOutFd.Close() // just opened to create the file, can close right after
// runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
// return
// }
// defer runnerOutFd.Close()
// cmd.Stdout = runnerOutFd
// cmd.Stderr = runnerOutFd
// err = cmd.Start()
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
// return
// }
// go func() {
// err = packet.SendPacket(cmdStdin, pk)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
// return
// }
// cmdStdin.Close()
// clean up zombies
cmd.Wait()
}()
}
// // clean up zombies
// cmd.Wait()
// }()
// }
func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error {
err := tailer.AddWatch(pk)
if err != nil {
return err
}
return nil
}
// func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error {
// err := tailer.AddWatch(pk)
// if err != nil {
// return err
// }
// return nil
// }
func doMain() {
homeDir := base.GetHomeDir()
err := os.Chdir(homeDir)
if err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to $HOME '%s': %v", homeDir, err))
return
}
_, err = base.GetMShellPath()
if err != nil {
packet.SendErrorPacket(os.Stdout, err.Error())
return
}
packetParser := packet.MakePacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout)
tailer, err := cmdtail.MakeTailer(sender)
if err != nil {
packet.SendErrorPacket(os.Stdout, err.Error())
return
}
go tailer.Run()
initPacket := shexec.MakeInitPacket()
sender.SendPacket(initPacket)
for pk := range packetParser.MainCh {
if pk.GetType() == packet.RunPacketStr {
doMainRun(pk.(*packet.RunPacketType), sender)
continue
}
if pk.GetType() == packet.GetCmdPacketStr {
err = doGetCmd(tailer, pk.(*packet.GetCmdPacketType), sender)
if err != nil {
errPk := packet.MakeErrorPacket(err.Error())
sender.SendPacket(errPk)
continue
}
continue
}
if pk.GetType() == packet.CdPacketStr {
cdPacket := pk.(*packet.CdPacketType)
err := os.Chdir(cdPacket.Dir)
resp := packet.MakeResponsePacket(cdPacket.ReqId)
if err != nil {
resp.Error = err.Error()
} else {
resp.Success = true
}
sender.SendPacket(resp)
continue
}
if pk.GetType() == packet.ErrorPacketStr {
errPk := pk.(*packet.ErrorPacketType)
errPk.Error = "invalid packet sent to mshell: " + errPk.Error
sender.SendPacket(errPk)
continue
}
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
}
}
// func doMain() {
// homeDir := base.GetHomeDir()
// err := os.Chdir(homeDir)
// if err != nil {
// packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to $HOME '%s': %v", homeDir, err))
// return
// }
// _, err = base.GetMShellPath()
// if err != nil {
// packet.SendErrorPacket(os.Stdout, err.Error())
// return
// }
// packetParser := packet.MakePacketParser(os.Stdin)
// sender := packet.MakePacketSender(os.Stdout)
// tailer, err := cmdtail.MakeTailer(sender)
// if err != nil {
// packet.SendErrorPacket(os.Stdout, err.Error())
// return
// }
// go tailer.Run()
// initPacket := shexec.MakeInitPacket()
// sender.SendPacket(initPacket)
// for pk := range packetParser.MainCh {
// if pk.GetType() == packet.RunPacketStr {
// doMainRun(pk.(*packet.RunPacketType), sender)
// continue
// }
// if pk.GetType() == packet.GetCmdPacketStr {
// err = doGetCmd(tailer, pk.(*packet.GetCmdPacketType), sender)
// if err != nil {
// errPk := packet.MakeErrorPacket(err.Error())
// sender.SendPacket(errPk)
// continue
// }
// continue
// }
// if pk.GetType() == packet.CdPacketStr {
// cdPacket := pk.(*packet.CdPacketType)
// err := os.Chdir(cdPacket.Dir)
// resp := packet.MakeResponsePacket(cdPacket.ReqId)
// if err != nil {
// resp.Error = err.Error()
// } else {
// resp.Success = true
// }
// sender.SendPacket(resp)
// continue
// }
// if pk.GetType() == packet.ErrorPacketStr {
// errPk := pk.(*packet.ErrorPacketType)
// errPk.Error = "invalid packet sent to mshell: " + errPk.Error
// sender.SendPacket(errPk)
// continue
// }
// sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
// }
// }
func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType, error) {
rpb := packet.MakeRunPacketBuilder()
@ -168,28 +167,24 @@ func handleSingle() {
}
runPacket, err := readFullRunPacket(packetParser)
if err != nil {
ck := base.CommandKey("")
if runPacket != nil {
ck = runPacket.CK
}
sender.SendCKErrorPacket(ck, err.Error())
sender.SendErrorResponse(runPacket.ReqId, err)
return
}
err = shexec.ValidateRunPacket(runPacket)
if err != nil {
sender.SendCKErrorPacket(runPacket.CK, err.Error())
sender.SendErrorResponse(runPacket.ReqId, err)
return
}
if runPacket.Detached {
err := shexec.RunCommandDetached(runPacket, sender)
if err != nil {
sender.SendCKErrorPacket(runPacket.CK, err.Error())
sender.SendErrorResponse(runPacket.ReqId, err)
return
}
} else {
cmd, err := shexec.RunCommandSimple(runPacket, sender)
if err != nil {
sender.SendCKErrorPacket(runPacket.CK, fmt.Sprintf("error running command: %v", err))
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("error running command: %w", err))
return
}
defer cmd.Close()

View File

@ -19,10 +19,11 @@ import (
"github.com/scripthaus-dev/mshell/pkg/base"
)
// remote: init, run, ping, data, cmdstart, cmddone
// remote(detached): init, run, cmdstart
// server: init, run, ping, cmdstart, cmddone, cd, resp, getcmd, untailcmd, cmddata, input, data, [comp]
// all: error, message
// single : <init, >run, >cmddata, >cmddone, <cmdstart, <>data, <>dataack, <cmddone
// single(detached): <init, >run, >cmddata, >cmddone, <cmdstart
// server : <init, >run, >cmddata, >cmddone, <cmdstart, <>data, <>dataack, <cmddone
// >cd, >getcmd, >untailcmd, >input, <resp
// all : <>error, <>message, <>ping, <raw
var GlobalDebug = false
@ -37,7 +38,7 @@ const (
DataEndPacketStr = "dataend"
ResponsePacketStr = "resp" // rpc-response
DonePacketStr = "done"
ErrorPacketStr = "error"
CmdErrorPacketStr = "cmderror"
MessagePacketStr = "message"
GetCmdPacketStr = "getcmd" // rpc
UntailCmdPacketStr = "untailcmd" // rpc
@ -57,7 +58,7 @@ func init() {
TypeStrToFactory[PingPacketStr] = reflect.TypeOf(PingPacketType{})
TypeStrToFactory[ResponsePacketStr] = reflect.TypeOf(ResponsePacketType{})
TypeStrToFactory[DonePacketStr] = reflect.TypeOf(DonePacketType{})
TypeStrToFactory[ErrorPacketStr] = reflect.TypeOf(ErrorPacketType{})
TypeStrToFactory[CmdErrorPacketStr] = reflect.TypeOf(CmdErrorPacketType{})
TypeStrToFactory[MessagePacketStr] = reflect.TypeOf(MessagePacketType{})
TypeStrToFactory[CmdStartPacketStr] = reflect.TypeOf(CmdStartPacketType{})
TypeStrToFactory[CmdDonePacketStr] = reflect.TypeOf(CmdDonePacketType{})
@ -325,8 +326,12 @@ func (p *ResponsePacketType) GetResponseId() string {
return p.RespId
}
func MakeResponsePacket(reqId string) *ResponsePacketType {
return &ResponsePacketType{Type: ResponsePacketStr, RespId: reqId}
func MakeErrorResponsePacket(reqId string, err error) *ResponsePacketType {
return &ResponsePacketType{Type: ResponsePacketStr, RespId: reqId, Error: err.Error()}
}
func MakeResponsePacket(reqId string, data interface{}) *ResponsePacketType {
return &ResponsePacketType{Type: ResponsePacketStr, RespId: reqId, Success: true, Data: data}
}
type RawPacketType struct {
@ -488,21 +493,26 @@ type BarePacketType struct {
Type string `json:"type"`
}
type ErrorPacketType struct {
Type string `json:"type"`
Error string `json:"error"`
type CmdErrorPacketType struct {
Type string `json:"type"`
CK base.CommandKey `json:"ck"`
Error string `json:"error"`
}
func (*ErrorPacketType) GetType() string {
return ErrorPacketStr
func (*CmdErrorPacketType) GetType() string {
return CmdErrorPacketStr
}
func (p *ErrorPacketType) String() string {
func (p *CmdErrorPacketType) GetCK() base.CommandKey {
return p.CK
}
func (p *CmdErrorPacketType) String() string {
return fmt.Sprintf("error[%s]", p.Error)
}
func MakeErrorPacket(errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr}
func MakeCmdErrorPacket(ck base.CommandKey, err error) *CmdErrorPacketType {
return &CmdErrorPacketType{Type: CmdErrorPacketStr, CK: ck, Error: err.Error()}
}
type PacketType interface {
@ -594,8 +604,8 @@ func SendPacket(w io.Writer, packet PacketType) error {
return nil
}
func SendErrorPacket(w io.Writer, errorStr string) error {
return SendPacket(w, MakeErrorPacket(errorStr))
func SendCmdError(w io.Writer, ck base.CommandKey, err error) error {
return SendPacket(w, MakeCmdErrorPacket(ck, err))
}
type PacketSender struct {
@ -679,12 +689,18 @@ func (sender *PacketSender) SendPacket(pk PacketType) error {
return nil
}
func (sender *PacketSender) SendErrorPacket(errVal string) error {
return sender.SendPacket(MakeErrorPacket(errVal))
func (sender *PacketSender) SendCmdError(ck base.CommandKey, err error) error {
return sender.SendPacket(MakeCmdErrorPacket(ck, err))
}
func (sender *PacketSender) SendCKErrorPacket(ck base.CommandKey, errVal string) error {
return sender.SendPacket(MakeErrorPacket(errVal))
func (sender *PacketSender) SendErrorResponse(reqId string, err error) error {
pk := MakeErrorResponsePacket(reqId, err)
return sender.SendPacket(pk)
}
func (sender *PacketSender) SendResponse(reqId string, data interface{}) error {
pk := MakeResponsePacket(reqId, data)
return sender.SendPacket(pk)
}
func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error {
@ -698,8 +714,8 @@ type UnknownPacketReporter interface {
type DefaultUPR struct{}
func (DefaultUPR) UnknownPacket(pk PacketType) {
if pk.GetType() == ErrorPacketStr {
errPacket := pk.(*ErrorPacketType)
if pk.GetType() == CmdErrorPacketStr {
errPacket := pk.(*CmdErrorPacketType)
// at this point, just send the error packet to stderr rather than try to do something special
fmt.Fprintf(os.Stderr, "[error] %s\n", errPacket.Error)
} else if pk.GetType() == RawPacketStr {

View File

@ -8,7 +8,6 @@ package packet
import (
"bufio"
"fmt"
"io"
"strconv"
"strings"
@ -18,6 +17,7 @@ import (
type PacketParser struct {
Lock *sync.Mutex
MainCh chan PacketType
Err error
}
func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser {
@ -46,6 +46,20 @@ func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser {
return rtnParser
}
func (p *PacketParser) GetErr() error {
p.Lock.Lock()
defer p.Lock.Unlock()
return p.Err
}
func (p *PacketParser) SetErr(err error) {
p.Lock.Lock()
defer p.Lock.Unlock()
if p.Err == nil {
p.Err = err
}
}
func MakePacketParser(input io.Reader) *PacketParser {
parser := &PacketParser{
Lock: &sync.Mutex{},
@ -62,8 +76,7 @@ func MakePacketParser(input io.Reader) *PacketParser {
return
}
if err != nil {
errPacket := MakeErrorPacket(fmt.Sprintf("reading packets from input: %v", err))
parser.MainCh <- errPacket
parser.SetErr(err)
return
}
if line == "\n" {
@ -86,9 +99,8 @@ func MakePacketParser(input io.Reader) *PacketParser {
}
pk, err := ParseJsonPacket([]byte(line[bracePos:]))
if err != nil {
errPk := MakeErrorPacket(fmt.Sprintf("parsing packet json from input: %v", err))
parser.MainCh <- errPk
return
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
continue
}
if pk.GetType() == DonePacketStr {
return

View File

@ -71,14 +71,14 @@ func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
ck := pk.GetCK()
if ck == "" {
m.Sender.SendErrorPacket(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
m.Sender.SendMessage(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
return
}
m.Lock.Lock()
fdContext := m.FdContextMap[ck]
m.Lock.Unlock()
if fdContext == nil {
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("no server context for ck '%s'", ck))
m.Sender.SendCmdError(ck, fmt.Errorf("no server context for ck '%s'", ck))
return
}
if pk.GetType() == packet.DataPacketStr {
@ -89,7 +89,7 @@ func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
m.Sender.SendPacket(pk)
return
} else {
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("invalid packet '%s' received", packet.AsExtType(pk)))
m.Sender.SendCmdError(ck, fmt.Errorf("invalid packet '%s' received", packet.AsExtType(pk)))
return
}
}
@ -114,7 +114,7 @@ func (m *MServer) RemoveFdContext(ck base.CommandKey) {
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
if err := runPacket.CK.Validate("packet"); err != nil {
m.Sender.SendErrorPacket(fmt.Sprintf("server run packets require valid ck: %s", err))
m.Sender.SendResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
return
}
fdContext := m.MakeServerFdContext(runPacket.CK)
@ -125,7 +125,7 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
m.Sender.SendPacket(donePk)
}
if err != nil {
m.Sender.SendCKErrorPacket(runPacket.CK, err.Error())
m.Sender.SendErrorResponse(runPacket.ReqId, err)
}
}()
}
@ -183,7 +183,7 @@ func RunServer() (int, error) {
server.ProcessCommandPacket(cmdPk)
continue
}
server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsString(pk)))
server.Sender.SendMessage(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsString(pk)))
continue
}
return 0, nil

View File

@ -906,35 +906,35 @@ func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) e
fmt.Printf("sender done! start: %v\n", startPacket)
err = unix.Dup2(int(nullFd.Fd()), int(os.Stdin.Fd()))
if err != nil {
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot dup2 stdin to /dev/null: %w", err))
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
err = unix.Dup2(int(nullFd.Fd()), int(os.Stdout.Fd()))
if err != nil {
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot dup2 stdin to /dev/null: %w", err))
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
err = unix.Dup2(int(nullFd.Fd()), int(os.Stderr.Fd()))
if err != nil {
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot dup2 stdin to /dev/null: %w", err))
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
cmd.DetachedOutput.SendPacket(startPacket)
}()
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open ptyout file '%s': %v", fileNames.PtyOutFile, err))
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("cannot open ptyout file '%s': %w", fileNames.PtyOutFile, err))
// don't return (command is already running)
}
go func() {
// copy pty output to .ptyout file
_, copyErr := io.Copy(ptyOutFd, cmdPty)
if copyErr != nil {
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("copying pty output to ptyout file: %w", copyErr))
}
}()
go func() {
// copy .stdin fifo contents to pty input
copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo)
if copyFifoErr != nil {
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("reading from stdin fifo: %w", copyFifoErr))
}
}()
donePacket := cmd.WaitForCommand()