waveterm/pkg/cmdrunner/resolver.go
2022-08-26 16:24:07 -07:00

294 lines
8.0 KiB
Go

package cmdrunner
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/scripthaus-dev/sh2-server/pkg/remote"
"github.com/scripthaus-dev/sh2-server/pkg/scpacket"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
)
const (
R_Session = 1
R_Screen = 2
R_Window = 4
R_Remote = 8
R_SessionOpt = 16
R_ScreenOpt = 32
R_WindowOpt = 64
R_RemoteOpt = 128
)
type resolvedIds struct {
SessionId string
ScreenId string
WindowId string
RemotePtr sstore.RemotePtrType
RemoteState *sstore.RemoteState
RemoteDisplayName string
RState remote.RemoteState
}
func resolveByPosition(ids []string, curId string, posStr string) string {
if len(ids) == 0 {
return ""
}
if !positionRe.MatchString(posStr) {
return ""
}
curIdx := 1 // if no match, curIdx will be first item
for idx, id := range ids {
if id == curId {
curIdx = idx + 1
break
}
}
isRelative := strings.HasPrefix(posStr, "+") || strings.HasPrefix(posStr, "-")
isWrap := posStr == "+" || posStr == "-"
var pos int
if isWrap && posStr == "+" {
pos = 1
} else if isWrap && posStr == "-" {
pos = -1
} else {
pos, _ = strconv.Atoi(posStr)
}
if isRelative {
pos = curIdx + pos
}
if pos < 1 {
if isWrap {
pos = len(ids)
} else {
pos = 1
}
}
if pos > len(ids) {
if isWrap {
pos = 1
} else {
pos = len(ids)
}
}
return ids[pos-1]
}
func resolveIds(ctx context.Context, pk *scpacket.FeCommandPacketType, rtype int) (resolvedIds, error) {
rtn := resolvedIds{}
if rtype == 0 {
return rtn, nil
}
var err error
if (rtype&R_Session)+(rtype&R_SessionOpt) > 0 {
rtn.SessionId, err = resolveSessionId(pk)
if err != nil {
return rtn, err
}
if rtn.SessionId == "" && (rtype&R_Session) > 0 {
return rtn, fmt.Errorf("no session")
}
}
if (rtype&R_Window)+(rtype&R_WindowOpt) > 0 {
rtn.WindowId, err = resolveWindowId(pk, rtn.SessionId)
if err != nil {
return rtn, err
}
if rtn.WindowId == "" && (rtype&R_Window) > 0 {
return rtn, fmt.Errorf("no window")
}
}
if (rtype&R_Screen)+(rtype&R_ScreenOpt) > 0 {
rtn.ScreenId, err = resolveScreenId(ctx, pk, rtn.SessionId)
if err != nil {
return rtn, err
}
if rtn.ScreenId == "" && (rtype&R_Screen) > 0 {
return rtn, fmt.Errorf("no screen")
}
}
if (rtype&R_Remote)+(rtype&R_RemoteOpt) > 0 {
rname, rptr, state, rstate, err := resolveRemote(ctx, pk.Kwargs["remote"], rtn.SessionId, rtn.WindowId)
if err != nil {
return rtn, err
}
if rptr == nil && (rtype&R_Remote) > 0 {
return rtn, fmt.Errorf("no remote")
}
rtn.RemoteDisplayName = rname
rtn.RemotePtr = *rptr
rtn.RemoteState = state
rtn.RState = *rstate
}
return rtn, nil
}
func resolveSessionScreen(ctx context.Context, sessionId string, screenArg string) (string, error) {
screens, err := sstore.GetSessionScreens(ctx, sessionId)
if err != nil {
return "", fmt.Errorf("could not retreive screens for session=%s", sessionId)
}
screenNum, err := strconv.Atoi(screenArg)
if err == nil {
if screenNum < 1 || screenNum > len(screens) {
return "", fmt.Errorf("could not resolve screen #%d (out of range), valid screens 1-%d", screenNum, len(screens))
}
return screens[screenNum-1].ScreenId, nil
}
for _, screen := range screens {
if screen.ScreenId == screenArg || screen.Name == screenArg {
return screen.ScreenId, nil
}
}
return "", fmt.Errorf("could not resolve screen '%s' (name/id not found)", screenArg)
}
func getSessionIds(sarr []*sstore.SessionType) []string {
rtn := make([]string, len(sarr))
for idx, s := range sarr {
rtn[idx] = s.SessionId
}
return rtn
}
var partialUUIDRe = regexp.MustCompile("^[0-9a-f]{8}$")
func isPartialUUID(s string) bool {
return partialUUIDRe.MatchString(s)
}
func resolveSession(ctx context.Context, sessionArg string, curSession string, bareSessions []*sstore.SessionType) (string, error) {
if bareSessions == nil {
var err error
bareSessions, err = sstore.GetBareSessions(ctx)
if err != nil {
return "", fmt.Errorf("could not retrive bare sessions")
}
}
var curSessionId string
if curSession != "" {
curSessionId, _ = resolveSession(ctx, curSession, "", bareSessions)
}
sids := getSessionIds(bareSessions)
rtnId := resolveByPosition(sids, curSessionId, sessionArg)
if rtnId != "" {
return rtnId, nil
}
tryPuid := isPartialUUID(sessionArg)
var prefixMatches []string
var lastPrefixMatchId string
for _, session := range bareSessions {
if session.SessionId == sessionArg || session.Name == sessionArg || (tryPuid && strings.HasPrefix(session.SessionId, sessionArg)) {
return session.SessionId, nil
}
if strings.HasPrefix(session.Name, sessionArg) {
prefixMatches = append(prefixMatches, session.Name)
lastPrefixMatchId = session.SessionId
}
}
if len(prefixMatches) == 1 {
return lastPrefixMatchId, nil
}
if len(prefixMatches) > 1 {
return "", fmt.Errorf("could not resolve session '%s', ambiguious prefix matched multiple sessions: %s", sessionArg, formatStrs(prefixMatches, "and", true))
}
return "", fmt.Errorf("could not resolve sesssion '%s' (name/id/pos not found)", sessionArg)
}
func resolveSessionId(pk *scpacket.FeCommandPacketType) (string, error) {
sessionId := pk.Kwargs["session"]
if sessionId == "" {
return "", nil
}
if _, err := uuid.Parse(sessionId); err != nil {
return "", fmt.Errorf("invalid sessionid '%s'", sessionId)
}
return sessionId, nil
}
func resolveWindowId(pk *scpacket.FeCommandPacketType, sessionId string) (string, error) {
windowId := pk.Kwargs["window"]
if windowId == "" {
return "", nil
}
if _, err := uuid.Parse(windowId); err != nil {
return "", fmt.Errorf("invalid windowid '%s'", windowId)
}
return windowId, nil
}
func resolveScreenId(ctx context.Context, pk *scpacket.FeCommandPacketType, sessionId string) (string, error) {
screenArg := pk.Kwargs["screen"]
if screenArg == "" {
return "", nil
}
if _, err := uuid.Parse(screenArg); err == nil {
return screenArg, nil
}
if sessionId == "" {
return "", fmt.Errorf("cannot resolve screen without session")
}
return resolveSessionScreen(ctx, sessionId, screenArg)
}
// returns (remoteuserref, remoteref, name, error)
func parseFullRemoteRef(fullRemoteRef string) (string, string, string, error) {
if strings.HasPrefix(fullRemoteRef, "[") && strings.HasSuffix(fullRemoteRef, "]") {
fullRemoteRef = fullRemoteRef[1 : len(fullRemoteRef)-1]
}
fields := strings.Split(fullRemoteRef, ":")
if len(fields) > 3 {
return "", "", "", fmt.Errorf("invalid remote format '%s'", fullRemoteRef)
}
if len(fields) == 1 {
return "", fields[0], "", nil
}
if len(fields) == 2 {
if strings.HasPrefix(fields[0], "@") {
return fields[0], fields[1], "", nil
}
return "", fields[0], fields[1], nil
}
return fields[0], fields[1], fields[2], nil
}
// returns (remoteDisplayName, remoteptr, state, rstate, err)
func resolveRemote(ctx context.Context, fullRemoteRef string, sessionId string, windowId string) (string, *sstore.RemotePtrType, *sstore.RemoteState, *remote.RemoteState, error) {
if fullRemoteRef == "" {
return "", nil, nil, nil, nil
}
userRef, remoteRef, remoteName, err := parseFullRemoteRef(fullRemoteRef)
if err != nil {
return "", nil, nil, nil, err
}
if userRef != "" {
return "", nil, nil, nil, fmt.Errorf("invalid remote '%s', cannot resolve remote userid '%s'", fullRemoteRef, userRef)
}
rstate := remote.ResolveRemoteRef(remoteRef)
if rstate == nil {
return "", nil, nil, nil, fmt.Errorf("cannot resolve remote '%s': not found", fullRemoteRef)
}
rptr := sstore.RemotePtrType{RemoteId: rstate.RemoteId, Name: remoteName}
state, err := sstore.GetRemoteState(ctx, sessionId, windowId, rptr)
if err != nil {
return "", nil, nil, nil, fmt.Errorf("cannot resolve remote state '%s': %w", fullRemoteRef, err)
}
rname := rstate.RemoteCanonicalName
if rstate.RemoteAlias != "" {
rname = rstate.RemoteAlias
}
if rptr.Name != "" {
rname = fmt.Sprintf("%s:%s", rname, rptr.Name)
}
if state == nil {
return rname, &rptr, rstate.DefaultState, rstate, nil
}
return rname, &rptr, state, rstate, nil
}