mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-02 18:39:05 +01:00
Merging wave server code into mono-repo
This commit is contained in:
commit
a4c0128c89
887
wavesrv/cmd/main-server.go
Normal file
887
wavesrv/cmd/main-server.go
Normal file
@ -0,0 +1,887 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/apishell/pkg/server"
|
||||
"github.com/commandlinedev/prompt-server/pkg/cmdrunner"
|
||||
"github.com/commandlinedev/prompt-server/pkg/pcloud"
|
||||
"github.com/commandlinedev/prompt-server/pkg/remote"
|
||||
"github.com/commandlinedev/prompt-server/pkg/rtnstate"
|
||||
"github.com/commandlinedev/prompt-server/pkg/scbase"
|
||||
"github.com/commandlinedev/prompt-server/pkg/scpacket"
|
||||
"github.com/commandlinedev/prompt-server/pkg/scws"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
"github.com/commandlinedev/prompt-server/pkg/wsshell"
|
||||
)
|
||||
|
||||
type WebFnType = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
const HttpReadTimeout = 5 * time.Second
|
||||
const HttpWriteTimeout = 21 * time.Second
|
||||
const HttpMaxHeaderBytes = 60000
|
||||
const HttpTimeoutDuration = 21 * time.Second
|
||||
|
||||
const MainServerAddr = "127.0.0.1:1619" // PromptServer, P=16, S=19, PS=1619
|
||||
const WebSocketServerAddr = "127.0.0.1:1623" // PromptWebsock, P=16, W=23, PW=1623
|
||||
const MainServerDevAddr = "127.0.0.1:8090"
|
||||
const WebSocketServerDevAddr = "127.0.0.1:8091"
|
||||
const WSStateReconnectTime = 30 * time.Second
|
||||
const WSStatePacketChSize = 20
|
||||
|
||||
const InitialTelemetryWait = 30 * time.Second
|
||||
const TelemetryTick = 30 * time.Minute
|
||||
const TelemetryInterval = 8 * time.Hour
|
||||
|
||||
const MaxWriteFileMemSize = 20 * (1024 * 1024) // 20M
|
||||
|
||||
var GlobalLock = &sync.Mutex{}
|
||||
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
|
||||
var GlobalAuthKey string
|
||||
var BuildTime = "0"
|
||||
var shutdownOnce sync.Once
|
||||
var ContentTypeHeaderValidRe = regexp.MustCompile(`^\w+/[\w.+-]+$`)
|
||||
|
||||
type ClientActiveState struct {
|
||||
Fg bool `json:"fg"`
|
||||
Active bool `json:"active"`
|
||||
Open bool `json:"open"`
|
||||
}
|
||||
|
||||
func setWSState(state *scws.WSState) {
|
||||
GlobalLock.Lock()
|
||||
defer GlobalLock.Unlock()
|
||||
WSStateMap[state.ClientId] = state
|
||||
}
|
||||
|
||||
func getWSState(clientId string) *scws.WSState {
|
||||
GlobalLock.Lock()
|
||||
defer GlobalLock.Unlock()
|
||||
return WSStateMap[clientId]
|
||||
}
|
||||
|
||||
func removeWSStateAfterTimeout(clientId string, connectTime time.Time, waitDuration time.Duration) {
|
||||
go func() {
|
||||
time.Sleep(waitDuration)
|
||||
GlobalLock.Lock()
|
||||
defer GlobalLock.Unlock()
|
||||
state := WSStateMap[clientId]
|
||||
if state == nil || state.ConnectTime != connectTime {
|
||||
return
|
||||
}
|
||||
delete(WSStateMap, clientId)
|
||||
state.UnWatchScreen()
|
||||
}()
|
||||
}
|
||||
|
||||
func HandleWs(w http.ResponseWriter, r *http.Request) {
|
||||
shell, err := wsshell.StartWS(w, r)
|
||||
if err != nil {
|
||||
log.Printf("WebSocket Upgrade Failed %T: %v\n", w, err)
|
||||
return
|
||||
}
|
||||
defer shell.Conn.Close()
|
||||
clientId := r.URL.Query().Get("clientid")
|
||||
if clientId == "" {
|
||||
close(shell.WriteChan)
|
||||
return
|
||||
}
|
||||
state := getWSState(clientId)
|
||||
if state == nil {
|
||||
state = scws.MakeWSState(clientId, GlobalAuthKey)
|
||||
state.ReplaceShell(shell)
|
||||
setWSState(state)
|
||||
} else {
|
||||
state.UpdateConnectTime()
|
||||
state.ReplaceShell(shell)
|
||||
}
|
||||
stateConnectTime := state.GetConnectTime()
|
||||
defer func() {
|
||||
removeWSStateAfterTimeout(clientId, stateConnectTime, WSStateReconnectTime)
|
||||
}()
|
||||
log.Printf("WebSocket opened %s %s\n", state.ClientId, shell.RemoteAddr)
|
||||
state.RunWSRead()
|
||||
}
|
||||
|
||||
// todo: sync multiple writes to the same fifoName into a single go-routine and do liveness checking on fifo
|
||||
// if this returns an error, likely the fifo is dead and the cmd should be marked as 'done'
|
||||
func writeToFifo(fifoName string, data []byte) error {
|
||||
rwfd, err := os.OpenFile(fifoName, os.O_RDWR, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rwfd.Close()
|
||||
fifoWriter, err := os.OpenFile(fifoName, os.O_WRONLY, 0600) // blocking open (open won't block because of rwfd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fifoWriter.Close()
|
||||
// this *could* block if the fifo buffer is full
|
||||
// unlikely because if the reader is dead, and len(data) < pipe size, then the buffer will be empty and will clear after rwfd is closed
|
||||
_, err = fifoWriter.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
|
||||
cdata, err := sstore.EnsureClientData(r.Context())
|
||||
if err != nil {
|
||||
WriteJsonError(w, err)
|
||||
return
|
||||
}
|
||||
cdata = cdata.Clean()
|
||||
WriteJsonSuccess(w, cdata)
|
||||
return
|
||||
}
|
||||
|
||||
func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var winSize sstore.ClientWinSizeType
|
||||
err := decoder.Decode(&winSize)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error decoding json: %w", err))
|
||||
return
|
||||
}
|
||||
err = sstore.SetWinSize(r.Context(), winSize)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error setting winsize: %w", err))
|
||||
return
|
||||
}
|
||||
WriteJsonSuccess(w, true)
|
||||
return
|
||||
}
|
||||
|
||||
// params: fg, active, open
|
||||
func HandleLogActiveState(w http.ResponseWriter, r *http.Request) {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var activeState ClientActiveState
|
||||
err := decoder.Decode(&activeState)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error decoding json: %w", err))
|
||||
return
|
||||
}
|
||||
activity := sstore.ActivityUpdate{}
|
||||
if activeState.Fg {
|
||||
activity.FgMinutes = 1
|
||||
}
|
||||
if activeState.Active {
|
||||
activity.ActiveMinutes = 1
|
||||
}
|
||||
if activeState.Open {
|
||||
activity.OpenMinutes = 1
|
||||
}
|
||||
activity.NumConns = remote.NumRemotes()
|
||||
err = sstore.UpdateCurrentActivity(r.Context(), activity)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error updating activity: %w", err))
|
||||
return
|
||||
}
|
||||
WriteJsonSuccess(w, true)
|
||||
return
|
||||
}
|
||||
|
||||
// params: screenid
|
||||
func HandleGetScreenLines(w http.ResponseWriter, r *http.Request) {
|
||||
qvals := r.URL.Query()
|
||||
screenId := qvals.Get("screenid")
|
||||
if _, err := uuid.Parse(screenId); err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("invalid screenid: %w", err))
|
||||
return
|
||||
}
|
||||
screenLines, err := sstore.GetScreenLinesById(r.Context(), screenId)
|
||||
if err != nil {
|
||||
WriteJsonError(w, err)
|
||||
return
|
||||
}
|
||||
WriteJsonSuccess(w, screenLines)
|
||||
return
|
||||
}
|
||||
|
||||
func HandleRtnState(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
log.Printf("[error] in handlertnstate: %v\n", r)
|
||||
debug.PrintStack()
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("panic: %v", r)))
|
||||
return
|
||||
}()
|
||||
qvals := r.URL.Query()
|
||||
screenId := qvals.Get("screenid")
|
||||
lineId := qvals.Get("lineid")
|
||||
if screenId == "" || lineId == "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("must specify screenid and lineid")))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(screenId); err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid screenid: %v", err)))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(lineId); err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
|
||||
return
|
||||
}
|
||||
data, err := rtnstate.GetRtnStateDiff(r.Context(), screenId, lineId)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("cannot get rtnstate diff: %v", err)))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
return
|
||||
}
|
||||
|
||||
func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
|
||||
qvals := r.URL.Query()
|
||||
remoteId := qvals.Get("remoteid")
|
||||
if remoteId == "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("must specify remoteid")))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(remoteId); err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid remoteid: %v", err)))
|
||||
return
|
||||
}
|
||||
realOffset, data, err := remote.ReadRemotePty(r.Context(), remoteId)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("error reading ptyout file: %v", err)))
|
||||
return
|
||||
}
|
||||
w.Header().Set("X-PtyDataOffset", strconv.FormatInt(realOffset, 10))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
return
|
||||
}
|
||||
|
||||
func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
|
||||
qvals := r.URL.Query()
|
||||
screenId := qvals.Get("screenid")
|
||||
lineId := qvals.Get("lineid")
|
||||
if screenId == "" || lineId == "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("must specify screenid and lineid")))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(screenId); err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid screenid: %v", err)))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(lineId); err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
|
||||
return
|
||||
}
|
||||
realOffset, data, err := sstore.ReadFullPtyOutFile(r.Context(), screenId, lineId)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("error reading ptyout file: %v", err)))
|
||||
return
|
||||
}
|
||||
w.Header().Set("X-PtyDataOffset", strconv.FormatInt(realOffset, 10))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
type writeFileParamsType struct {
|
||||
ScreenId string `json:"screenid"`
|
||||
LineId string `json:"lineid"`
|
||||
Path string `json:"path"`
|
||||
UseTemp bool `json:"usetemp,omitempty"`
|
||||
}
|
||||
|
||||
func parseWriteFileParams(r *http.Request) (*writeFileParamsType, multipart.File, error) {
|
||||
err := r.ParseMultipartForm(MaxWriteFileMemSize)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cannot parse multipart form data: %v", err)
|
||||
}
|
||||
form := r.MultipartForm
|
||||
if len(form.Value["params"]) == 0 {
|
||||
return nil, nil, fmt.Errorf("no params found")
|
||||
}
|
||||
paramsStr := form.Value["params"][0]
|
||||
var params writeFileParamsType
|
||||
err = json.Unmarshal([]byte(paramsStr), ¶ms)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("bad params json: %v", err)
|
||||
}
|
||||
if len(form.File["data"]) == 0 {
|
||||
return nil, nil, fmt.Errorf("no data found")
|
||||
}
|
||||
fileHeader := form.File["data"][0]
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error opening multipart data file: %v", err)
|
||||
}
|
||||
return ¶ms, file, nil
|
||||
}
|
||||
|
||||
func HandleWriteFile(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
log.Printf("[error] in write-file: %v\n", r)
|
||||
debug.PrintStack()
|
||||
WriteJsonError(w, fmt.Errorf("panic: %v", r))
|
||||
return
|
||||
}()
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
params, mpFile, err := parseWriteFileParams(r)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error parsing multipart form params: %w", err))
|
||||
return
|
||||
}
|
||||
if params.ScreenId == "" || params.LineId == "" || params.Path == "" {
|
||||
WriteJsonError(w, fmt.Errorf("invalid params, must set screenid, lineid, and path"))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(params.ScreenId); err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("invalid screenid: %v", err))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(params.LineId); err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("invalid lineid: %v", err))
|
||||
return
|
||||
}
|
||||
_, cmd, err := sstore.GetLineCmdByLineId(r.Context(), params.ScreenId, params.LineId)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("cannot retrieve line/cmd: %v", err))
|
||||
return
|
||||
}
|
||||
if cmd == nil {
|
||||
WriteJsonError(w, fmt.Errorf("line not found"))
|
||||
return
|
||||
}
|
||||
if cmd.Remote.RemoteId == "" {
|
||||
WriteJsonError(w, fmt.Errorf("invalid line, no remote"))
|
||||
return
|
||||
}
|
||||
msh := remote.GetRemoteById(cmd.Remote.RemoteId)
|
||||
if msh == nil {
|
||||
WriteJsonError(w, fmt.Errorf("invalid line, cannot resolve remote"))
|
||||
return
|
||||
}
|
||||
rrState := msh.GetRemoteRuntimeState()
|
||||
fullPath, err := rrState.ExpandHomeDir(params.Path)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error expanding homedir: %v", err))
|
||||
return
|
||||
}
|
||||
cwd := cmd.FeState["cwd"]
|
||||
writePk := packet.MakeWriteFilePacket()
|
||||
writePk.ReqId = uuid.New().String()
|
||||
writePk.UseTemp = params.UseTemp
|
||||
if filepath.IsAbs(fullPath) {
|
||||
writePk.Path = fullPath
|
||||
} else {
|
||||
writePk.Path = filepath.Join(cwd, fullPath)
|
||||
}
|
||||
iter, err := msh.PacketRpcIter(r.Context(), writePk)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error: %v", err))
|
||||
return
|
||||
}
|
||||
// first packet should be WriteFileReady
|
||||
readyIf, err := iter.Next(r.Context())
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error while getting ready response: %w", err))
|
||||
return
|
||||
}
|
||||
readyPk, ok := readyIf.(*packet.WriteFileReadyPacketType)
|
||||
if !ok {
|
||||
WriteJsonError(w, fmt.Errorf("bad ready packet received: %T", readyIf))
|
||||
return
|
||||
}
|
||||
if readyPk.Error != "" {
|
||||
WriteJsonError(w, fmt.Errorf("ready error: %s", readyPk.Error))
|
||||
return
|
||||
}
|
||||
var buffer [server.MaxFileDataPacketSize]byte
|
||||
bufSlice := buffer[:]
|
||||
for {
|
||||
dataPk := packet.MakeFileDataPacket(writePk.ReqId)
|
||||
nr, err := io.ReadFull(mpFile, bufSlice)
|
||||
if err == io.ErrUnexpectedEOF || err == io.EOF {
|
||||
dataPk.Eof = true
|
||||
} else if err != nil {
|
||||
dataErr := fmt.Errorf("error reading file data: %v", err)
|
||||
dataPk.Error = dataErr.Error()
|
||||
msh.SendFileData(dataPk)
|
||||
WriteJsonError(w, dataErr)
|
||||
return
|
||||
}
|
||||
if nr > 0 {
|
||||
dataPk.Data = make([]byte, nr)
|
||||
copy(dataPk.Data, bufSlice[0:nr])
|
||||
}
|
||||
msh.SendFileData(dataPk)
|
||||
if dataPk.Eof {
|
||||
break
|
||||
}
|
||||
// slight throttle for sending packets
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
doneIf, err := iter.Next(r.Context())
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error while getting done response: %w", err))
|
||||
return
|
||||
}
|
||||
donePk, ok := doneIf.(*packet.WriteFileDonePacketType)
|
||||
if !ok {
|
||||
WriteJsonError(w, fmt.Errorf("bad done packet received: %T", doneIf))
|
||||
return
|
||||
}
|
||||
if donePk.Error != "" {
|
||||
WriteJsonError(w, fmt.Errorf("dne error: %s", donePk.Error))
|
||||
return
|
||||
}
|
||||
WriteJsonSuccess(w, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func HandleReadFile(w http.ResponseWriter, r *http.Request) {
|
||||
qvals := r.URL.Query()
|
||||
screenId := qvals.Get("screenid")
|
||||
lineId := qvals.Get("lineid")
|
||||
path := qvals.Get("path") // validate path?
|
||||
contentType := qvals.Get("mimetype")
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
if screenId == "" || lineId == "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("must specify sessionid, screenid, and lineid")))
|
||||
return
|
||||
}
|
||||
if path == "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("must specify path")))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(screenId); err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid screenid: %v", err)))
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(lineId); err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
|
||||
return
|
||||
}
|
||||
if !ContentTypeHeaderValidRe.MatchString(contentType) {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid mimetype specified")))
|
||||
return
|
||||
}
|
||||
_, cmd, err := sstore.GetLineCmdByLineId(r.Context(), screenId, lineId)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
|
||||
return
|
||||
}
|
||||
if cmd == nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid line, no cmd")))
|
||||
return
|
||||
}
|
||||
if cmd.Remote.RemoteId == "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid line, no remote")))
|
||||
return
|
||||
}
|
||||
msh := remote.GetRemoteById(cmd.Remote.RemoteId)
|
||||
if msh == nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("invalid line, cannot resolve remote")))
|
||||
return
|
||||
}
|
||||
rrState := msh.GetRemoteRuntimeState()
|
||||
fullPath, err := rrState.ExpandHomeDir(path)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error expanding homedir: %v", err))
|
||||
return
|
||||
}
|
||||
streamPk := packet.MakeStreamFilePacket()
|
||||
streamPk.ReqId = uuid.New().String()
|
||||
cwd := cmd.FeState["cwd"]
|
||||
if filepath.IsAbs(fullPath) {
|
||||
streamPk.Path = fullPath
|
||||
} else {
|
||||
streamPk.Path = filepath.Join(cwd, fullPath)
|
||||
}
|
||||
iter, err := msh.StreamFile(r.Context(), streamPk)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("error trying to stream file: %v", err)))
|
||||
return
|
||||
}
|
||||
defer iter.Close()
|
||||
respIf, err := iter.Next(r.Context())
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("error getting streamfile response: %v", err)))
|
||||
return
|
||||
}
|
||||
resp, ok := respIf.(*packet.StreamFileResponseType)
|
||||
if !ok {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("bad response packet type: %T", respIf)))
|
||||
return
|
||||
}
|
||||
if resp.Error != "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(fmt.Sprintf("error response: %s", resp.Error)))
|
||||
return
|
||||
}
|
||||
infoJson, _ := json.Marshal(resp.Info)
|
||||
w.Header().Set("X-FileInfo", base64.StdEncoding.EncodeToString(infoJson))
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
for {
|
||||
dataPkIf, err := iter.Next(r.Context())
|
||||
if err != nil {
|
||||
log.Printf("error in read-file while getting data: %v\n", err)
|
||||
break
|
||||
}
|
||||
if dataPkIf == nil {
|
||||
break
|
||||
}
|
||||
dataPk, ok := dataPkIf.(*packet.FileDataPacketType)
|
||||
if !ok {
|
||||
log.Printf("error in read-file, invalid data packet type: %T", dataPkIf)
|
||||
break
|
||||
}
|
||||
if dataPk.Error != "" {
|
||||
log.Printf("in read-file, data packet error: %s", dataPk.Error)
|
||||
break
|
||||
}
|
||||
w.Write(dataPk.Data)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func WriteJsonError(w http.ResponseWriter, errVal error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
errMap := make(map[string]interface{})
|
||||
errMap["error"] = errVal.Error()
|
||||
barr, _ := json.Marshal(errMap)
|
||||
w.Write(barr)
|
||||
return
|
||||
}
|
||||
|
||||
func WriteJsonSuccess(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
rtnMap := make(map[string]interface{})
|
||||
rtnMap["success"] = true
|
||||
if data != nil {
|
||||
rtnMap["data"] = data
|
||||
}
|
||||
barr, err := json.Marshal(rtnMap)
|
||||
if err != nil {
|
||||
WriteJsonError(w, err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
w.Write(barr)
|
||||
return
|
||||
}
|
||||
|
||||
func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
log.Printf("[error] in run-command: %v\n", r)
|
||||
debug.PrintStack()
|
||||
WriteJsonError(w, fmt.Errorf("panic: %v", r))
|
||||
return
|
||||
}()
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var commandPk scpacket.FeCommandPacketType
|
||||
err := decoder.Decode(&commandPk)
|
||||
if err != nil {
|
||||
WriteJsonError(w, fmt.Errorf("error decoding json: %w", err))
|
||||
return
|
||||
}
|
||||
update, err := cmdrunner.HandleCommand(r.Context(), &commandPk)
|
||||
if err != nil {
|
||||
WriteJsonError(w, err)
|
||||
return
|
||||
}
|
||||
if update != nil {
|
||||
update.Clean()
|
||||
}
|
||||
WriteJsonSuccess(w, update)
|
||||
return
|
||||
}
|
||||
|
||||
func AuthKeyWrap(fn WebFnType) WebFnType {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reqAuthKey := r.Header.Get("X-AuthKey")
|
||||
if reqAuthKey == "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte("no x-authkey header"))
|
||||
return
|
||||
}
|
||||
if reqAuthKey != GlobalAuthKey {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte("x-authkey header is invalid"))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
fn(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func runWebSocketServer() {
|
||||
gr := mux.NewRouter()
|
||||
gr.HandleFunc("/ws", HandleWs)
|
||||
serverAddr := WebSocketServerAddr
|
||||
if scbase.IsDevMode() {
|
||||
serverAddr = WebSocketServerDevAddr
|
||||
}
|
||||
server := &http.Server{
|
||||
Addr: serverAddr,
|
||||
ReadTimeout: HttpReadTimeout,
|
||||
WriteTimeout: HttpWriteTimeout,
|
||||
MaxHeaderBytes: HttpMaxHeaderBytes,
|
||||
Handler: gr,
|
||||
}
|
||||
server.SetKeepAlivesEnabled(false)
|
||||
log.Printf("Running websocket server on %s\n", serverAddr)
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Printf("[error] trying to run websocket server: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func test() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendTelemetryWrapper() {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
log.Printf("[error] in sendTelemetryWrapper: %v\n", r)
|
||||
debug.PrintStack()
|
||||
return
|
||||
}()
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
err := pcloud.SendTelemetry(ctx, false)
|
||||
if err != nil {
|
||||
log.Printf("[error] sending telemetry: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func telemetryLoop() {
|
||||
var lastSent time.Time
|
||||
time.Sleep(InitialTelemetryWait)
|
||||
for {
|
||||
dur := time.Now().Sub(lastSent)
|
||||
if lastSent.IsZero() || dur >= TelemetryInterval {
|
||||
lastSent = time.Now()
|
||||
sendTelemetryWrapper()
|
||||
}
|
||||
time.Sleep(TelemetryTick)
|
||||
}
|
||||
}
|
||||
|
||||
// watch stdin, kill server if stdin is closed
|
||||
func stdinReadWatch() {
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
doShutdown(fmt.Sprintf("stdin closed/error (%v)", err))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ignore SIGHUP
|
||||
func installSignalHandlers() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGHUP)
|
||||
go func() {
|
||||
for sig := range sigCh {
|
||||
doShutdown(fmt.Sprintf("got signal %v", sig))
|
||||
break
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func doShutdown(reason string) {
|
||||
shutdownOnce.Do(func() {
|
||||
log.Printf("[prompt] local server %v, start shutdown\n", reason)
|
||||
sendTelemetryWrapper()
|
||||
log.Printf("[prompt] closing db connection\n")
|
||||
sstore.CloseDB()
|
||||
log.Printf("[prompt] *** shutting down local server\n")
|
||||
time.Sleep(1 * time.Second)
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
||||
time.Sleep(5 * time.Second)
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGKILL)
|
||||
})
|
||||
}
|
||||
|
||||
func main() {
|
||||
scbase.BuildTime = BuildTime
|
||||
|
||||
if len(os.Args) >= 2 && os.Args[1] == "--test" {
|
||||
log.Printf("running test fn\n")
|
||||
err := test()
|
||||
if err != nil {
|
||||
log.Printf("[error] %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
scHomeDir := scbase.GetPromptHomeDir()
|
||||
log.Printf("[prompt] *** starting local server\n")
|
||||
log.Printf("[prompt] local server version %s+%s\n", scbase.PromptVersion, scbase.BuildTime)
|
||||
log.Printf("[prompt] homedir = %q\n", scHomeDir)
|
||||
|
||||
scLock, err := scbase.AcquirePromptLock()
|
||||
if err != nil || scLock == nil {
|
||||
log.Printf("[error] cannot acquire prompt lock: %v\n", err)
|
||||
return
|
||||
}
|
||||
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
|
||||
err := sstore.MigrateCommandOpts(os.Args[1:])
|
||||
if err != nil {
|
||||
log.Printf("[error] migrate cmd: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
authKey, err := scbase.ReadPromptAuthKey()
|
||||
if err != nil {
|
||||
log.Printf("[error] %v\n", err)
|
||||
return
|
||||
}
|
||||
GlobalAuthKey = authKey
|
||||
err = sstore.TryMigrateUp()
|
||||
if err != nil {
|
||||
log.Printf("[error] migrate up: %v\n", err)
|
||||
return
|
||||
}
|
||||
clientData, err := sstore.EnsureClientData(context.Background())
|
||||
if err != nil {
|
||||
log.Printf("[error] ensuring client data: %v\n", err)
|
||||
return
|
||||
}
|
||||
log.Printf("userid = %s\n", clientData.UserId)
|
||||
err = sstore.EnsureLocalRemote(context.Background())
|
||||
if err != nil {
|
||||
log.Printf("[error] ensuring local remote: %v\n", err)
|
||||
return
|
||||
}
|
||||
_, err = sstore.EnsureDefaultSession(context.Background())
|
||||
if err != nil {
|
||||
log.Printf("[error] ensuring default session: %v\n", err)
|
||||
return
|
||||
}
|
||||
err = remote.LoadRemotes(context.Background())
|
||||
if err != nil {
|
||||
log.Printf("[error] loading remotes: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = sstore.HangupAllRunningCmds(context.Background())
|
||||
if err != nil {
|
||||
log.Printf("[error] calling HUP on all running commands: %v\n", err)
|
||||
}
|
||||
err = sstore.ReInitFocus(context.Background())
|
||||
if err != nil {
|
||||
log.Printf("[error] resetting screen focus: %v\n", err)
|
||||
}
|
||||
|
||||
log.Printf("PCLOUD_ENDPOINT=%s\n", pcloud.GetEndpoint())
|
||||
err = sstore.UpdateCurrentActivity(context.Background(), sstore.ActivityUpdate{NumConns: remote.NumRemotes()}) // set at least one record into activity
|
||||
if err != nil {
|
||||
log.Printf("[error] updating activity: %v\n", err)
|
||||
}
|
||||
installSignalHandlers()
|
||||
go telemetryLoop()
|
||||
go stdinReadWatch()
|
||||
go runWebSocketServer()
|
||||
go func() {
|
||||
time.Sleep(10 * time.Second)
|
||||
pcloud.StartUpdateWriter()
|
||||
}()
|
||||
gr := mux.NewRouter()
|
||||
gr.HandleFunc("/api/ptyout", AuthKeyWrap(HandleGetPtyOut))
|
||||
gr.HandleFunc("/api/remote-pty", AuthKeyWrap(HandleRemotePty))
|
||||
gr.HandleFunc("/api/rtnstate", AuthKeyWrap(HandleRtnState))
|
||||
gr.HandleFunc("/api/get-screen-lines", AuthKeyWrap(HandleGetScreenLines))
|
||||
gr.HandleFunc("/api/run-command", AuthKeyWrap(HandleRunCommand)).Methods("POST")
|
||||
gr.HandleFunc("/api/get-client-data", AuthKeyWrap(HandleGetClientData))
|
||||
gr.HandleFunc("/api/set-winsize", AuthKeyWrap(HandleSetWinSize))
|
||||
gr.HandleFunc("/api/log-active-state", AuthKeyWrap(HandleLogActiveState))
|
||||
gr.HandleFunc("/api/read-file", AuthKeyWrap(HandleReadFile))
|
||||
gr.HandleFunc("/api/write-file", AuthKeyWrap(HandleWriteFile)).Methods("POST")
|
||||
serverAddr := MainServerAddr
|
||||
if scbase.IsDevMode() {
|
||||
serverAddr = MainServerDevAddr
|
||||
}
|
||||
server := &http.Server{
|
||||
Addr: serverAddr,
|
||||
ReadTimeout: HttpReadTimeout,
|
||||
WriteTimeout: HttpWriteTimeout,
|
||||
MaxHeaderBytes: HttpMaxHeaderBytes,
|
||||
Handler: http.TimeoutHandler(gr, HttpTimeoutDuration, "Timeout"),
|
||||
}
|
||||
server.SetKeepAlivesEnabled(false)
|
||||
log.Printf("Running main server on %s\n", serverAddr)
|
||||
err = server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Printf("ERROR: %v\n", err)
|
||||
}
|
||||
}
|
9
wavesrv/db/db.go
Normal file
9
wavesrv/db/db.go
Normal file
@ -0,0 +1,9 @@
|
||||
// provides the io/fs for DB migrations
|
||||
package db
|
||||
|
||||
import "embed"
|
||||
|
||||
// since embeds must be relative to the package directory, this source file is required
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var MigrationFS embed.FS
|
13
wavesrv/db/migrations/000001_init.down.sql
Normal file
13
wavesrv/db/migrations/000001_init.down.sql
Normal file
@ -0,0 +1,13 @@
|
||||
DROP TABLE client;
|
||||
DROP TABLE session;
|
||||
DROP TABLE window;
|
||||
DROP TABLE screen;
|
||||
DROP TABLE screen_window;
|
||||
DROP TABLE remote_instance;
|
||||
DROP TABLE line;
|
||||
DROP TABLE remote;
|
||||
DROP TABLE cmd;
|
||||
DROP TABLE history;
|
||||
DROP TABLE state_base;
|
||||
DROP TABLE state_diff;
|
||||
|
167
wavesrv/db/migrations/000001_init.up.sql
Normal file
167
wavesrv/db/migrations/000001_init.up.sql
Normal file
@ -0,0 +1,167 @@
|
||||
CREATE TABLE client (
|
||||
clientid varchar(36) NOT NULL,
|
||||
userid varchar(36) NOT NULL,
|
||||
activesessionid varchar(36) NOT NULL,
|
||||
userpublickeybytes blob NOT NULL,
|
||||
userprivatekeybytes blob NOT NULL,
|
||||
winsize json NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE session (
|
||||
sessionid varchar(36) PRIMARY KEY,
|
||||
name varchar(50) NOT NULL,
|
||||
sessionidx int NOT NULL,
|
||||
activescreenid varchar(36) NOT NULL,
|
||||
notifynum int NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
archivedts bigint NOT NULL,
|
||||
ownerid varchar(36) NOT NULL,
|
||||
sharemode varchar(12) NOT NULL,
|
||||
accesskey varchar(36) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE window (
|
||||
sessionid varchar(36) NOT NULL,
|
||||
windowid varchar(36) NOT NULL,
|
||||
curremoteownerid varchar(36) NOT NULL,
|
||||
curremoteid varchar(36) NOT NULL,
|
||||
curremotename varchar(50) NOT NULL,
|
||||
nextlinenum int NOT NULL,
|
||||
winopts json NOT NULL,
|
||||
ownerid varchar(36) NOT NULL,
|
||||
sharemode varchar(12) NOT NULL,
|
||||
shareopts json NOT NULL,
|
||||
PRIMARY KEY (sessionid, windowid)
|
||||
);
|
||||
|
||||
CREATE TABLE screen (
|
||||
sessionid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
name varchar(50) NOT NULL,
|
||||
activewindowid varchar(36) NOT NULL,
|
||||
screenidx int NOT NULL,
|
||||
screenopts json NOT NULL,
|
||||
ownerid varchar(36) NOT NULL,
|
||||
sharemode varchar(12) NOT NULL,
|
||||
incognito boolean NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
archivedts bigint NOT NULL,
|
||||
PRIMARY KEY (sessionid, screenid)
|
||||
);
|
||||
|
||||
CREATE TABLE screen_window (
|
||||
sessionid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
windowid varchar(36) NOT NULL,
|
||||
name varchar(50) NOT NULL,
|
||||
layout json NOT NULL,
|
||||
selectedline int NOT NULL,
|
||||
anchor json NOT NULL,
|
||||
focustype varchar(12) NOT NULL,
|
||||
PRIMARY KEY (sessionid, screenid, windowid)
|
||||
);
|
||||
|
||||
CREATE TABLE remote_instance (
|
||||
riid varchar(36) PRIMARY KEY,
|
||||
name varchar(50) NOT NULL,
|
||||
sessionid varchar(36) NOT NULL,
|
||||
windowid varchar(36) NOT NULL,
|
||||
remoteownerid varchar(36) NOT NULL,
|
||||
remoteid varchar(36) NOT NULL,
|
||||
festate json NOT NULL,
|
||||
statebasehash varchar(36) NOT NULL,
|
||||
statediffhasharr json NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE state_base (
|
||||
basehash varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
version varchar(200) NOT NULL,
|
||||
data blob NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE state_diff (
|
||||
diffhash varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
basehash varchar(36) NOT NULL,
|
||||
diffhasharr json NOT NULL,
|
||||
data blob NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE line (
|
||||
sessionid varchar(36) NOT NULL,
|
||||
windowid varchar(36) NOT NULL,
|
||||
userid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
ts bigint NOT NULL,
|
||||
linenum int NOT NULL,
|
||||
linenumtemp boolean NOT NULL,
|
||||
linetype varchar(10) NOT NULL,
|
||||
linelocal boolean NOT NULL,
|
||||
text text NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
ephemeral boolean NOT NULL,
|
||||
contentheight int NOT NULL,
|
||||
star int NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
PRIMARY KEY (sessionid, windowid, lineid)
|
||||
);
|
||||
|
||||
CREATE TABLE remote (
|
||||
remoteid varchar(36) PRIMARY KEY,
|
||||
physicalid varchar(36) NOT NULL,
|
||||
remotetype varchar(10) NOT NULL,
|
||||
remotealias varchar(50) NOT NULL,
|
||||
remotecanonicalname varchar(200) NOT NULL,
|
||||
remotesudo boolean NOT NULL,
|
||||
remoteuser varchar(50) NOT NULL,
|
||||
remotehost varchar(200) NOT NULL,
|
||||
connectmode varchar(20) NOT NULL,
|
||||
autoinstall boolean NOT NULL,
|
||||
sshopts json NOT NULL,
|
||||
remoteopts json NOT NULL,
|
||||
lastconnectts bigint NOT NULL,
|
||||
local boolean NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
remoteidx int NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE cmd (
|
||||
sessionid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
remoteownerid varchar(36) NOT NULL,
|
||||
remoteid varchar(36) NOT NULL,
|
||||
remotename varchar(50) NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
festate json NOT NULL,
|
||||
statebasehash varchar(36) NOT NULL,
|
||||
statediffhasharr json NOT NULL,
|
||||
termopts json NOT NULL,
|
||||
origtermopts json NOT NULL,
|
||||
status varchar(10) NOT NULL,
|
||||
startpk json NOT NULL,
|
||||
doneinfo json NOT NULL,
|
||||
runout json NOT NULL,
|
||||
rtnstate boolean NOT NULL,
|
||||
rtnbasehash varchar(36) NOT NULL,
|
||||
rtndiffhasharr json NOT NULL,
|
||||
PRIMARY KEY (sessionid, cmdid)
|
||||
);
|
||||
|
||||
CREATE TABLE history (
|
||||
historyid varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
userid varchar(36) NOT NULL,
|
||||
sessionid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
windowid varchar(36) NOT NULL,
|
||||
lineid int NOT NULL,
|
||||
remoteownerid varchar(36) NOT NULL,
|
||||
remoteid varchar(36) NOT NULL,
|
||||
remotename varchar(50) NOT NULL,
|
||||
haderror boolean NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
ismetacmd boolean,
|
||||
incognito boolean
|
||||
);
|
3
wavesrv/db/migrations/000002_activity.down.sql
Normal file
3
wavesrv/db/migrations/000002_activity.down.sql
Normal file
@ -0,0 +1,3 @@
|
||||
DROP TABLE activity;
|
||||
|
||||
ALTER TABLE client DROP COLUMN clientopts;
|
11
wavesrv/db/migrations/000002_activity.up.sql
Normal file
11
wavesrv/db/migrations/000002_activity.up.sql
Normal file
@ -0,0 +1,11 @@
|
||||
CREATE TABLE activity (
|
||||
day varchar(20) PRIMARY KEY,
|
||||
uploaded boolean NOT NULL,
|
||||
tdata json NOT NULL,
|
||||
tzname varchar(50) NOT NULL,
|
||||
tzoffset int NOT NULL,
|
||||
clientversion varchar(20) NOT NULL,
|
||||
clientarch varchar(20) NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE client ADD COLUMN clientopts json NOT NULL DEFAULT '';
|
2
wavesrv/db/migrations/000003_renderer.down.sql
Normal file
2
wavesrv/db/migrations/000003_renderer.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE line DROP COLUMN renderer;
|
||||
|
2
wavesrv/db/migrations/000003_renderer.up.sql
Normal file
2
wavesrv/db/migrations/000003_renderer.up.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE line ADD COLUMN renderer varchar(50) NOT NULL DEFAULT '';
|
||||
|
6
wavesrv/db/migrations/000004_bookmarks.down.sql
Normal file
6
wavesrv/db/migrations/000004_bookmarks.down.sql
Normal file
@ -0,0 +1,6 @@
|
||||
DROP TABLE bookmark;
|
||||
DROP TABLE bookmark_order;
|
||||
DROP TABLE bookmark_cmd;
|
||||
|
||||
ALTER TABLE line DROP COLUMN bookmarked;
|
||||
ALTER TABLE line DROP COLUMN pinned;
|
26
wavesrv/db/migrations/000004_bookmarks.up.sql
Normal file
26
wavesrv/db/migrations/000004_bookmarks.up.sql
Normal file
@ -0,0 +1,26 @@
|
||||
CREATE TABLE bookmark (
|
||||
bookmarkid varchar(36) PRIMARY KEY,
|
||||
createdts bigint NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
alias varchar(50) NOT NULL,
|
||||
tags json NOT NULL,
|
||||
description text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE bookmark_order (
|
||||
tag varchar(50) NOT NULL,
|
||||
bookmarkid varchar(36) NOT NULL,
|
||||
orderidx int NOT NULL,
|
||||
PRIMARY KEY (tag, bookmarkid)
|
||||
);
|
||||
|
||||
CREATE TABLE bookmark_cmd (
|
||||
bookmarkid varchar(36) NOT NULL,
|
||||
sessionid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
PRIMARY KEY (bookmarkid, sessionid, cmdid)
|
||||
);
|
||||
|
||||
ALTER TABLE line ADD COLUMN bookmarked boolean NOT NULL DEFAULT 0;
|
||||
ALTER TABLE line ADD COLUMN pinned boolean NOT NULL DEFAULT 0;
|
||||
|
2
wavesrv/db/migrations/000005_buildtime.down.sql
Normal file
2
wavesrv/db/migrations/000005_buildtime.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE activity DROP COLUMN buildtime;
|
||||
ALTER TABLE activity DROP COLUMN osrelease;
|
2
wavesrv/db/migrations/000005_buildtime.up.sql
Normal file
2
wavesrv/db/migrations/000005_buildtime.up.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE activity ADD COLUMN buildtime varchar(20) NOT NULL DEFAULT '-';
|
||||
ALTER TABLE activity ADD COLUMN osrelease varchar(20) NOT NULL DEFAULT '-';
|
1
wavesrv/db/migrations/000006_feopts.down.sql
Normal file
1
wavesrv/db/migrations/000006_feopts.down.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE client DROP COLUMN feopts;
|
3
wavesrv/db/migrations/000006_feopts.up.sql
Normal file
3
wavesrv/db/migrations/000006_feopts.up.sql
Normal file
@ -0,0 +1,3 @@
|
||||
ALTER TABLE client ADD COLUMN feopts json NOT NULL DEFAULT '{}';
|
||||
|
||||
|
3
wavesrv/db/migrations/000007_playbooks.down.sql
Normal file
3
wavesrv/db/migrations/000007_playbooks.down.sql
Normal file
@ -0,0 +1,3 @@
|
||||
DROP TABLE playbook;
|
||||
|
||||
DROP TABLE playbook_entry;
|
16
wavesrv/db/migrations/000007_playbooks.up.sql
Normal file
16
wavesrv/db/migrations/000007_playbooks.up.sql
Normal file
@ -0,0 +1,16 @@
|
||||
CREATE TABLE playbook (
|
||||
playbookid varchar(36) PRIMARY KEY,
|
||||
playbookname varchar(100) NOT NULL,
|
||||
description text NOT NULL,
|
||||
entryids json NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE playbook_entry (
|
||||
entryid varchar(36) PRIMARY KEY,
|
||||
playbookid varchar(36) NOT NULL,
|
||||
description text NOT NULL,
|
||||
alias varchar(50) NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
createdts bigint NOT NULL,
|
||||
updatedts bigint NOT NULL
|
||||
);
|
5
wavesrv/db/migrations/000008_cloudsession.down.sql
Normal file
5
wavesrv/db/migrations/000008_cloudsession.down.sql
Normal file
@ -0,0 +1,5 @@
|
||||
ALTER TABLE session ADD COLUMN accesskey DEFAULT '';
|
||||
ALTER TABLE session ADD COLUMN ownerid DEFAULT '';
|
||||
|
||||
DROP TABLE cloud_session;
|
||||
DROP TABLE cloud_update;
|
20
wavesrv/db/migrations/000008_cloudsession.up.sql
Normal file
20
wavesrv/db/migrations/000008_cloudsession.up.sql
Normal file
@ -0,0 +1,20 @@
|
||||
ALTER TABLE session DROP COLUMN accesskey;
|
||||
ALTER TABLE session DROP COLUMN ownerid;
|
||||
|
||||
CREATE TABLE cloud_session (
|
||||
sessionid varchar(36) PRIMARY KEY,
|
||||
viewkey varchar(50) NOT NULL,
|
||||
writekey varchar(50) NOT NULL,
|
||||
enckey varchar(100) NOT NULL,
|
||||
enctype varchar(50) NOT NULL,
|
||||
vts bigint NOT NULL,
|
||||
acl json NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE cloud_update (
|
||||
updateid varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
updatetype varchar(50) NOT NULL,
|
||||
updatekeys json NOT NULL
|
||||
);
|
||||
|
3
wavesrv/db/migrations/000009_screenprimary.down.sql
Normal file
3
wavesrv/db/migrations/000009_screenprimary.down.sql
Normal file
@ -0,0 +1,3 @@
|
||||
-- invalid, will throw an error, cannot migrate down
|
||||
SELECT x;
|
||||
|
56
wavesrv/db/migrations/000009_screenprimary.up.sql
Normal file
56
wavesrv/db/migrations/000009_screenprimary.up.sql
Normal file
@ -0,0 +1,56 @@
|
||||
CREATE TABLE new_screen (
|
||||
sessionid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
windowid varchar(36) NOT NULL,
|
||||
name varchar(50) NOT NULL,
|
||||
screenidx int NOT NULL,
|
||||
screenopts json NOT NULL,
|
||||
ownerid varchar(36) NOT NULL,
|
||||
sharemode varchar(12) NOT NULL,
|
||||
curremoteownerid varchar(36) NOT NULL,
|
||||
curremoteid varchar(36) NOT NULL,
|
||||
curremotename varchar(50) NOT NULL,
|
||||
nextlinenum int NOT NULL,
|
||||
selectedline int NOT NULL,
|
||||
anchor json NOT NULL,
|
||||
focustype varchar(12) NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
archivedts bigint NOT NULL,
|
||||
PRIMARY KEY (sessionid, screenid)
|
||||
);
|
||||
|
||||
INSERT INTO new_screen
|
||||
SELECT
|
||||
s.sessionid,
|
||||
s.screenid,
|
||||
w.windowid,
|
||||
s.name,
|
||||
s.screenidx,
|
||||
json_patch(s.screenopts, w.winopts),
|
||||
s.ownerid,
|
||||
s.sharemode,
|
||||
w.curremoteownerid,
|
||||
w.curremoteid,
|
||||
w.curremotename,
|
||||
w.nextlinenum,
|
||||
sw.selectedline,
|
||||
sw.anchor,
|
||||
sw.focustype,
|
||||
s.archived,
|
||||
s.archivedts
|
||||
FROM
|
||||
screen s,
|
||||
screen_window sw,
|
||||
window w
|
||||
WHERE
|
||||
s.screenid = sw.screenid
|
||||
AND sw.windowid = w.windowid
|
||||
;
|
||||
|
||||
DROP TABLE screen;
|
||||
DROP TABLE screen_window;
|
||||
DROP TABLE window;
|
||||
|
||||
ALTER TABLE new_screen RENAME TO screen;
|
||||
|
||||
|
2
wavesrv/db/migrations/000010_removewindowid.down.sql
Normal file
2
wavesrv/db/migrations/000010_removewindowid.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- invalid, will throw an error, cannot migrate down
|
||||
SELECT x;
|
17
wavesrv/db/migrations/000010_removewindowid.up.sql
Normal file
17
wavesrv/db/migrations/000010_removewindowid.up.sql
Normal file
@ -0,0 +1,17 @@
|
||||
ALTER TABLE remote_instance RENAME COLUMN windowid TO screenid;
|
||||
ALTER TABLE line RENAME COLUMN windowid TO screenid;
|
||||
|
||||
UPDATE remote_instance
|
||||
SET screenid = COALESCE((SELECT screen.screenid FROM screen WHERE screen.windowid = remote_instance.screenid), '')
|
||||
WHERE screenid <> ''
|
||||
;
|
||||
|
||||
UPDATE line
|
||||
SET screenid = COALESCE((SELECT screen.screenid FROM screen WHERE screen.windowid = line.screenid), '')
|
||||
WHERE screenid <> ''
|
||||
;
|
||||
|
||||
ALTER TABLE history DROP COLUMN windowid;
|
||||
ALTER TABLE screen DROP COLUMN windowid;
|
||||
|
||||
|
2
wavesrv/db/migrations/000011_cmdscreenid.down.sql
Normal file
2
wavesrv/db/migrations/000011_cmdscreenid.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE cmd DROP COLUMN screenid;
|
||||
|
5
wavesrv/db/migrations/000011_cmdscreenid.up.sql
Normal file
5
wavesrv/db/migrations/000011_cmdscreenid.up.sql
Normal file
@ -0,0 +1,5 @@
|
||||
ALTER TABLE cmd ADD COLUMN screenid varchar(36) NOT NULL DEFAULT '';
|
||||
|
||||
UPDATE cmd
|
||||
SET screenid = (SELECT line.screenid FROM line WHERE line.cmdid = cmd.cmdid)
|
||||
;
|
1
wavesrv/db/migrations/000012_historylinenum.down.sql
Normal file
1
wavesrv/db/migrations/000012_historylinenum.down.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE history DROP COLUMN linenum;
|
6
wavesrv/db/migrations/000012_historylinenum.up.sql
Normal file
6
wavesrv/db/migrations/000012_historylinenum.up.sql
Normal file
@ -0,0 +1,6 @@
|
||||
ALTER TABLE history ADD COLUMN linenum int NOT NULL DEFAULT 0;
|
||||
|
||||
UPDATE history
|
||||
SET linenum = COALESCE((SELECT line.linenum FROM line WHERE line.lineid = history.lineid), 0)
|
||||
;
|
||||
|
2
wavesrv/db/migrations/000013_cmdmigration.down.sql
Normal file
2
wavesrv/db/migrations/000013_cmdmigration.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- invalid, will throw an error, cannot migrate down
|
||||
SELECT x;
|
123
wavesrv/db/migrations/000013_cmdmigration.up.sql
Normal file
123
wavesrv/db/migrations/000013_cmdmigration.up.sql
Normal file
@ -0,0 +1,123 @@
|
||||
DELETE FROM cmd
|
||||
WHERE screenid = '';
|
||||
|
||||
DELETE FROM line
|
||||
WHERE screenid = '';
|
||||
|
||||
DELETE FROM cmd
|
||||
WHERE cmdid NOT IN (SELECT cmdid FROM line);
|
||||
|
||||
DELETE FROM line
|
||||
WHERE cmdid <> '' AND cmdid NOT IN (SELECT cmdid FROM cmd);
|
||||
|
||||
CREATE TABLE new_bookmark_cmd (
|
||||
bookmarkid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
PRIMARY KEY (bookmarkid, screenid, cmdid)
|
||||
);
|
||||
INSERT INTO new_bookmark_cmd
|
||||
SELECT
|
||||
b.bookmarkid,
|
||||
c.screenid,
|
||||
c.cmdid
|
||||
FROM bookmark_cmd b, cmd c
|
||||
WHERE b.cmdid = c.cmdid;
|
||||
DROP TABLE bookmark_cmd;
|
||||
ALTER TABLE new_bookmark_cmd RENAME TO bookmark_cmd;
|
||||
|
||||
ALTER TABLE client ADD COLUMN cmdstoretype varchar(20) DEFAULT 'session';
|
||||
|
||||
CREATE TABLE cmd_migrate (
|
||||
sessionid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL
|
||||
);
|
||||
INSERT INTO cmd_migrate
|
||||
SELECT sessionid, screenid, cmdid
|
||||
FROM cmd;
|
||||
|
||||
-- update primary key for screen
|
||||
CREATE TABLE new_screen (
|
||||
screenid varchar(36) NOT NULL,
|
||||
sessionid varchar(36) NOT NULL,
|
||||
name varchar(50) NOT NULL,
|
||||
screenidx int NOT NULL,
|
||||
screenopts json NOT NULL,
|
||||
ownerid varchar(36) NOT NULL,
|
||||
sharemode varchar(12) NOT NULL,
|
||||
curremoteownerid varchar(36) NOT NULL,
|
||||
curremoteid varchar(36) NOT NULL,
|
||||
curremotename varchar(50) NOT NULL,
|
||||
nextlinenum int NOT NULL,
|
||||
selectedline int NOT NULL,
|
||||
anchor json NOT NULL,
|
||||
focustype varchar(12) NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
archivedts bigint NOT NULL,
|
||||
PRIMARY KEY (screenid)
|
||||
);
|
||||
INSERT INTO new_screen
|
||||
SELECT screenid, sessionid, name, screenidx, screenopts, ownerid, sharemode,
|
||||
curremoteownerid, curremoteid, curremotename, nextlinenum, selectedline,
|
||||
anchor, focustype, archived, archivedts
|
||||
FROM screen;
|
||||
DROP TABLE screen;
|
||||
ALTER TABLE new_screen RENAME TO screen;
|
||||
|
||||
-- drop sessionid from line
|
||||
CREATE TABLE new_line (
|
||||
screenid varchar(36) NOT NULL,
|
||||
userid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
ts bigint NOT NULL,
|
||||
linenum int NOT NULL,
|
||||
linenumtemp boolean NOT NULL,
|
||||
linetype varchar(10) NOT NULL,
|
||||
linelocal boolean NOT NULL,
|
||||
text text NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
ephemeral boolean NOT NULL,
|
||||
contentheight int NOT NULL,
|
||||
star int NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
renderer varchar(50) NOT NULL,
|
||||
bookmarked boolean NOT NULL,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
INSERT INTO new_line
|
||||
SELECT screenid, userid, lineid, ts, linenum, linenumtemp, linetype, linelocal,
|
||||
text, cmdid, ephemeral, contentheight, star, archived, renderer, bookmarked
|
||||
FROM line;
|
||||
DROP TABLE line;
|
||||
ALTER TABLE new_line RENAME TO line;
|
||||
|
||||
-- drop sessionid from cmd
|
||||
CREATE TABLE new_cmd (
|
||||
screenid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
remoteownerid varchar(36) NOT NULL,
|
||||
remoteid varchar(36) NOT NULL,
|
||||
remotename varchar(50) NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
rawcmdstr text NOT NULL,
|
||||
festate json NOT NULL,
|
||||
statebasehash varchar(36) NOT NULL,
|
||||
statediffhasharr json NOT NULL,
|
||||
termopts json NOT NULL,
|
||||
origtermopts json NOT NULL,
|
||||
status varchar(10) NOT NULL,
|
||||
startpk json NOT NULL,
|
||||
doneinfo json NOT NULL,
|
||||
runout json NOT NULL,
|
||||
rtnstate boolean NOT NULL,
|
||||
rtnbasehash varchar(36) NOT NULL,
|
||||
rtndiffhasharr json NOT NULL,
|
||||
PRIMARY KEY (screenid, cmdid)
|
||||
);
|
||||
INSERT INTO new_cmd
|
||||
SELECT screenid, cmdid, remoteownerid, remoteid, remotename, cmdstr, cmdstr,
|
||||
festate, statebasehash, statediffhasharr, termopts, origtermopts, status, startpk, doneinfo, runout, rtnstate, rtnbasehash, rtndiffhasharr
|
||||
FROM cmd;
|
||||
DROP TABLE cmd;
|
||||
ALTER TABLE new_cmd RENAME TO cmd;
|
9
wavesrv/db/migrations/000014_simplifybookmarks.down.sql
Normal file
9
wavesrv/db/migrations/000014_simplifybookmarks.down.sql
Normal file
@ -0,0 +1,9 @@
|
||||
CREATE TABLE IF NOT EXISTS "bookmark_cmd" (
|
||||
bookmarkid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
PRIMARY KEY (bookmarkid, screenid, cmdid)
|
||||
);
|
||||
|
||||
ALTER TABLE line ADD COLUMN bookmarked boolean NOT NULL DEFAULT 0;
|
||||
|
3
wavesrv/db/migrations/000014_simplifybookmarks.up.sql
Normal file
3
wavesrv/db/migrations/000014_simplifybookmarks.up.sql
Normal file
@ -0,0 +1,3 @@
|
||||
DROP TABLE bookmark_cmd;
|
||||
ALTER TABLE line DROP COLUMN bookmarked;
|
||||
|
4
wavesrv/db/migrations/000015_lineupdates.down.sql
Normal file
4
wavesrv/db/migrations/000015_lineupdates.down.sql
Normal file
@ -0,0 +1,4 @@
|
||||
DROP TABLE screenupdate;
|
||||
|
||||
ALTER TABLE screen DROP COLUMN webshareopts;
|
||||
|
10
wavesrv/db/migrations/000015_lineupdates.up.sql
Normal file
10
wavesrv/db/migrations/000015_lineupdates.up.sql
Normal file
@ -0,0 +1,10 @@
|
||||
CREATE TABLE screenupdate (
|
||||
updateid integer PRIMARY KEY,
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
updatetype varchar(50) NOT NULL,
|
||||
updatets bigint NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE screen ADD COLUMN webshareopts json NOT NULL DEFAULT 'null';
|
||||
|
3
wavesrv/db/migrations/000016_webptypos.down.sql
Normal file
3
wavesrv/db/migrations/000016_webptypos.down.sql
Normal file
@ -0,0 +1,3 @@
|
||||
DROP TABLE webptypos;
|
||||
|
||||
DROP INDEX idx_screenupdate_ids;
|
8
wavesrv/db/migrations/000016_webptypos.up.sql
Normal file
8
wavesrv/db/migrations/000016_webptypos.up.sql
Normal file
@ -0,0 +1,8 @@
|
||||
CREATE TABLE webptypos (
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
ptypos bigint NOT NULL,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_screenupdate_ids ON screenupdate (screenid, lineid);
|
2
wavesrv/db/migrations/000017_remotevars.down.sql
Normal file
2
wavesrv/db/migrations/000017_remotevars.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE remote DROP COLUMN statevars;
|
||||
|
2
wavesrv/db/migrations/000017_remotevars.up.sql
Normal file
2
wavesrv/db/migrations/000017_remotevars.up.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE remote ADD COLUMN statevars json NOT NULL DEFAULT '{}';
|
||||
|
14
wavesrv/db/migrations/000018_modremote.down.sql
Normal file
14
wavesrv/db/migrations/000018_modremote.down.sql
Normal file
@ -0,0 +1,14 @@
|
||||
ALTER TABLE remote ADD COLUMN remotesudo;
|
||||
|
||||
UPDATE remote
|
||||
SET remotesudo = 1
|
||||
WHERE json_extract(sshopts, '$.issudo')
|
||||
;
|
||||
|
||||
UPDATE remote
|
||||
SET sshopts = json_remove(sshopts, '$.issudo')
|
||||
;
|
||||
|
||||
ALTER TABLE remote ADD COLUMN physicalid varchar(36) NOT NULL DEFAULT '';
|
||||
|
||||
ALTER TABLE remote DROP COLUMN openaiopts;
|
11
wavesrv/db/migrations/000018_modremote.up.sql
Normal file
11
wavesrv/db/migrations/000018_modremote.up.sql
Normal file
@ -0,0 +1,11 @@
|
||||
UPDATE remote
|
||||
SET sshopts = json_set(sshopts, '$.issudo', json('true'))
|
||||
WHERE remotesudo
|
||||
;
|
||||
|
||||
ALTER TABLE remote DROP COLUMN remotesudo;
|
||||
|
||||
ALTER TABLE remote DROP COLUMN physicalid;
|
||||
|
||||
ALTER TABLE remote ADD COLUMN openaiopts json NOT NULL DEFAULT '{}';
|
||||
|
1
wavesrv/db/migrations/000019_clientopenai.down.sql
Normal file
1
wavesrv/db/migrations/000019_clientopenai.down.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE client DROP COLUMN openaiopts;
|
2
wavesrv/db/migrations/000019_clientopenai.up.sql
Normal file
2
wavesrv/db/migrations/000019_clientopenai.up.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE client ADD COLUMN openaiopts json NOT NULL DEFAULT '{}';
|
||||
|
2
wavesrv/db/migrations/000020_linecmd.down.sql
Normal file
2
wavesrv/db/migrations/000020_linecmd.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- invalid, will throw an error, cannot migrate down
|
||||
SELECT x;
|
74
wavesrv/db/migrations/000020_linecmd.up.sql
Normal file
74
wavesrv/db/migrations/000020_linecmd.up.sql
Normal file
@ -0,0 +1,74 @@
|
||||
-- remove cmdid from line, history, and cmd (use lineid everywhere)
|
||||
|
||||
CREATE TABLE cmd_new (
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
remoteownerid varchar(36) NOT NULL,
|
||||
remoteid varchar(36) NOT NULL,
|
||||
remotename varchar(50) NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
rawcmdstr text NOT NULL,
|
||||
festate json NOT NULL,
|
||||
statebasehash varchar(36) NOT NULL,
|
||||
statediffhasharr json NOT NULL,
|
||||
termopts json NOT NULL,
|
||||
origtermopts json NOT NULL,
|
||||
status varchar(10) NOT NULL,
|
||||
cmdpid int NOT NULL,
|
||||
remotepid int NOT NULL,
|
||||
donets bigint NOT NULL,
|
||||
exitcode int NOT NULL,
|
||||
durationms int NOT NULL,
|
||||
rtnstate boolean NOT NULL,
|
||||
rtnbasehash varchar(36) NOT NULL,
|
||||
rtndiffhasharr json NOT NULL,
|
||||
runout json NOT NULL,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
|
||||
CREATE TABLE cmd_migrate20 (
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
|
||||
INSERT INTO cmd_migrate20
|
||||
SELECT screenid, lineid, cmdid
|
||||
FROM line
|
||||
WHERE cmdid <> ''
|
||||
;
|
||||
|
||||
INSERT INTO cmd_new
|
||||
SELECT
|
||||
c.screenid,
|
||||
l.lineid,
|
||||
c.remoteownerid,
|
||||
c.remoteid,
|
||||
c.remotename,
|
||||
c.cmdstr,
|
||||
c.rawcmdstr,
|
||||
c.festate,
|
||||
c.statebasehash,
|
||||
c.statediffhasharr,
|
||||
c.termopts,
|
||||
c.origtermopts,
|
||||
c.status,
|
||||
coalesce(json_extract(startpk, '$.pid'), 0),
|
||||
coalesce(json_extract(startpk, '$.mshellpid'), 0),
|
||||
coalesce(json_extract(doneinfo, '$.ts'), 0),
|
||||
coalesce(json_extract(doneinfo, '$.exitcode'), 0),
|
||||
coalesce(json_extract(doneinfo, '$.durationms'), 0),
|
||||
c.rtnstate,
|
||||
c.rtnbasehash,
|
||||
c.rtndiffhasharr,
|
||||
c.runout
|
||||
FROM cmd c
|
||||
JOIN line l ON (l.cmdid = c.cmdid);
|
||||
|
||||
DROP TABLE cmd;
|
||||
|
||||
ALTER TABLE cmd_new RENAME TO cmd;
|
||||
|
||||
ALTER TABLE history DROP COLUMN cmdid;
|
||||
ALTER TABLE line DROP COLUMN cmdid;
|
1
wavesrv/db/migrations/000021_linestate.down.sql
Normal file
1
wavesrv/db/migrations/000021_linestate.down.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE line DROP COLUMN linestate;
|
1
wavesrv/db/migrations/000021_linestate.up.sql
Normal file
1
wavesrv/db/migrations/000021_linestate.up.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE line ADD COLUMN linestate json NOT NULL DEFAULT '{}';
|
1
wavesrv/db/migrations/000022_endwebshare.down.sql
Normal file
1
wavesrv/db/migrations/000022_endwebshare.down.sql
Normal file
@ -0,0 +1 @@
|
||||
-- no down migration
|
1
wavesrv/db/migrations/000022_endwebshare.up.sql
Normal file
1
wavesrv/db/migrations/000022_endwebshare.up.sql
Normal file
@ -0,0 +1 @@
|
||||
UPDATE screen SET sharemode = 'local' AND webshareopts = 'null';
|
219
wavesrv/db/schema.sql
Normal file
219
wavesrv/db/schema.sql
Normal file
@ -0,0 +1,219 @@
|
||||
CREATE TABLE schema_migrations (version uint64,dirty bool);
|
||||
CREATE UNIQUE INDEX version_unique ON schema_migrations (version);
|
||||
CREATE TABLE client (
|
||||
clientid varchar(36) NOT NULL,
|
||||
userid varchar(36) NOT NULL,
|
||||
activesessionid varchar(36) NOT NULL,
|
||||
userpublickeybytes blob NOT NULL,
|
||||
userprivatekeybytes blob NOT NULL,
|
||||
winsize json NOT NULL
|
||||
, clientopts json NOT NULL DEFAULT '', feopts json NOT NULL DEFAULT '{}', cmdstoretype varchar(20) DEFAULT 'session', openaiopts json NOT NULL DEFAULT '{}');
|
||||
CREATE TABLE session (
|
||||
sessionid varchar(36) PRIMARY KEY,
|
||||
name varchar(50) NOT NULL,
|
||||
sessionidx int NOT NULL,
|
||||
activescreenid varchar(36) NOT NULL,
|
||||
notifynum int NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
archivedts bigint NOT NULL,
|
||||
sharemode varchar(12) NOT NULL);
|
||||
CREATE TABLE remote_instance (
|
||||
riid varchar(36) PRIMARY KEY,
|
||||
name varchar(50) NOT NULL,
|
||||
sessionid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
remoteownerid varchar(36) NOT NULL,
|
||||
remoteid varchar(36) NOT NULL,
|
||||
festate json NOT NULL,
|
||||
statebasehash varchar(36) NOT NULL,
|
||||
statediffhasharr json NOT NULL
|
||||
);
|
||||
CREATE TABLE state_base (
|
||||
basehash varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
version varchar(200) NOT NULL,
|
||||
data blob NOT NULL
|
||||
);
|
||||
CREATE TABLE state_diff (
|
||||
diffhash varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
basehash varchar(36) NOT NULL,
|
||||
diffhasharr json NOT NULL,
|
||||
data blob NOT NULL
|
||||
);
|
||||
CREATE TABLE remote (
|
||||
remoteid varchar(36) PRIMARY KEY,
|
||||
remotetype varchar(10) NOT NULL,
|
||||
remotealias varchar(50) NOT NULL,
|
||||
remotecanonicalname varchar(200) NOT NULL,
|
||||
remoteuser varchar(50) NOT NULL,
|
||||
remotehost varchar(200) NOT NULL,
|
||||
connectmode varchar(20) NOT NULL,
|
||||
autoinstall boolean NOT NULL,
|
||||
sshopts json NOT NULL,
|
||||
remoteopts json NOT NULL,
|
||||
lastconnectts bigint NOT NULL,
|
||||
local boolean NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
remoteidx int NOT NULL
|
||||
, statevars json NOT NULL DEFAULT '{}', openaiopts json NOT NULL DEFAULT '{}');
|
||||
CREATE TABLE history (
|
||||
historyid varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
userid varchar(36) NOT NULL,
|
||||
sessionid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid int NOT NULL,
|
||||
remoteownerid varchar(36) NOT NULL,
|
||||
remoteid varchar(36) NOT NULL,
|
||||
remotename varchar(50) NOT NULL,
|
||||
haderror boolean NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
ismetacmd boolean,
|
||||
incognito boolean
|
||||
, linenum int NOT NULL DEFAULT 0);
|
||||
CREATE TABLE activity (
|
||||
day varchar(20) PRIMARY KEY,
|
||||
uploaded boolean NOT NULL,
|
||||
tdata json NOT NULL,
|
||||
tzname varchar(50) NOT NULL,
|
||||
tzoffset int NOT NULL,
|
||||
clientversion varchar(20) NOT NULL,
|
||||
clientarch varchar(20) NOT NULL
|
||||
, buildtime varchar(20) NOT NULL DEFAULT '-', osrelease varchar(20) NOT NULL DEFAULT '-');
|
||||
CREATE TABLE bookmark (
|
||||
bookmarkid varchar(36) PRIMARY KEY,
|
||||
createdts bigint NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
alias varchar(50) NOT NULL,
|
||||
tags json NOT NULL,
|
||||
description text NOT NULL
|
||||
);
|
||||
CREATE TABLE bookmark_order (
|
||||
tag varchar(50) NOT NULL,
|
||||
bookmarkid varchar(36) NOT NULL,
|
||||
orderidx int NOT NULL,
|
||||
PRIMARY KEY (tag, bookmarkid)
|
||||
);
|
||||
CREATE TABLE playbook (
|
||||
playbookid varchar(36) PRIMARY KEY,
|
||||
playbookname varchar(100) NOT NULL,
|
||||
description text NOT NULL,
|
||||
entryids json NOT NULL
|
||||
);
|
||||
CREATE TABLE playbook_entry (
|
||||
entryid varchar(36) PRIMARY KEY,
|
||||
playbookid varchar(36) NOT NULL,
|
||||
description text NOT NULL,
|
||||
alias varchar(50) NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
createdts bigint NOT NULL,
|
||||
updatedts bigint NOT NULL
|
||||
);
|
||||
CREATE TABLE cloud_session (
|
||||
sessionid varchar(36) PRIMARY KEY,
|
||||
viewkey varchar(50) NOT NULL,
|
||||
writekey varchar(50) NOT NULL,
|
||||
enckey varchar(100) NOT NULL,
|
||||
enctype varchar(50) NOT NULL,
|
||||
vts bigint NOT NULL,
|
||||
acl json NOT NULL
|
||||
);
|
||||
CREATE TABLE cloud_update (
|
||||
updateid varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
updatetype varchar(50) NOT NULL,
|
||||
updatekeys json NOT NULL
|
||||
);
|
||||
CREATE TABLE cmd_migrate (
|
||||
sessionid varchar(36) NOT NULL,
|
||||
screenid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "screen" (
|
||||
screenid varchar(36) NOT NULL,
|
||||
sessionid varchar(36) NOT NULL,
|
||||
name varchar(50) NOT NULL,
|
||||
screenidx int NOT NULL,
|
||||
screenopts json NOT NULL,
|
||||
ownerid varchar(36) NOT NULL,
|
||||
sharemode varchar(12) NOT NULL,
|
||||
curremoteownerid varchar(36) NOT NULL,
|
||||
curremoteid varchar(36) NOT NULL,
|
||||
curremotename varchar(50) NOT NULL,
|
||||
nextlinenum int NOT NULL,
|
||||
selectedline int NOT NULL,
|
||||
anchor json NOT NULL,
|
||||
focustype varchar(12) NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
archivedts bigint NOT NULL, webshareopts json NOT NULL DEFAULT 'null',
|
||||
PRIMARY KEY (screenid)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "line" (
|
||||
screenid varchar(36) NOT NULL,
|
||||
userid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
ts bigint NOT NULL,
|
||||
linenum int NOT NULL,
|
||||
linenumtemp boolean NOT NULL,
|
||||
linetype varchar(10) NOT NULL,
|
||||
linelocal boolean NOT NULL,
|
||||
text text NOT NULL,
|
||||
ephemeral boolean NOT NULL,
|
||||
contentheight int NOT NULL,
|
||||
star int NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
renderer varchar(50) NOT NULL, linestate json NOT NULL DEFAULT '{}',
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
CREATE TABLE screenupdate (
|
||||
updateid integer PRIMARY KEY,
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
updatetype varchar(50) NOT NULL,
|
||||
updatets bigint NOT NULL
|
||||
);
|
||||
CREATE TABLE webptypos (
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
ptypos bigint NOT NULL,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
CREATE INDEX idx_screenupdate_ids ON screenupdate (screenid, lineid);
|
||||
CREATE TABLE cmd_migration (
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "cmd" (
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
remoteownerid varchar(36) NOT NULL,
|
||||
remoteid varchar(36) NOT NULL,
|
||||
remotename varchar(50) NOT NULL,
|
||||
cmdstr text NOT NULL,
|
||||
rawcmdstr text NOT NULL,
|
||||
festate json NOT NULL,
|
||||
statebasehash varchar(36) NOT NULL,
|
||||
statediffhasharr json NOT NULL,
|
||||
termopts json NOT NULL,
|
||||
origtermopts json NOT NULL,
|
||||
status varchar(10) NOT NULL,
|
||||
cmdpid int NOT NULL,
|
||||
remotepid int NOT NULL,
|
||||
donets bigint NOT NULL,
|
||||
exitcode int NOT NULL,
|
||||
durationms int NOT NULL,
|
||||
rtnstate boolean NOT NULL,
|
||||
rtnbasehash varchar(36) NOT NULL,
|
||||
rtndiffhasharr json NOT NULL,
|
||||
runout json NOT NULL,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
CREATE TABLE cmd_migrate20 (
|
||||
screenid varchar(36) NOT NULL,
|
||||
lineid varchar(36) NOT NULL,
|
||||
cmdid varchar(36) NOT NULL,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
28
wavesrv/go.mod
Normal file
28
wavesrv/go.mod
Normal file
@ -0,0 +1,28 @@
|
||||
module github.com/commandlinedev/prompt-server
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alessio/shellescape v1.4.1
|
||||
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2
|
||||
github.com/commandlinedev/apishell v0.0.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/golang-migrate/migrate/v4 v4.16.2
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jmoiron/sqlx v1.3.5
|
||||
github.com/mattn/go-sqlite3 v1.14.16
|
||||
github.com/sashabaranov/go-openai v1.9.0
|
||||
github.com/sawka/txwrap v0.1.2
|
||||
golang.org/x/crypto v0.7.0
|
||||
golang.org/x/mod v0.10.0
|
||||
golang.org/x/sys v0.10.0
|
||||
mvdan.cc/sh/v3 v3.7.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
go.uber.org/atomic v1.7.0 // indirect
|
||||
)
|
64
wavesrv/go.sum
Normal file
64
wavesrv/go.sum
Normal file
@ -0,0 +1,64 @@
|
||||
github.com/alessio/shellescape v1.4.1 h1:V7yhSDDn8LP4lc4jS8pFkt0zCnzVJlG5JXy9BVKJUX0=
|
||||
github.com/alessio/shellescape v1.4.1/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30=
|
||||
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloDxZfhMm0xrLXZS8+COSu2bXmEQs=
|
||||
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
|
||||
github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/golang-migrate/migrate/v4 v4.16.2 h1:8coYbMKUyInrFk1lfGfRovTLAW7PhWp8qQDT2iKfuoA=
|
||||
github.com/golang-migrate/migrate/v4 v4.16.2/go.mod h1:pfcJX4nPHaVdc5nmdCikFBWtm+UBpiZjRNNsyBbp0/o=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
|
||||
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
|
||||
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.10.1-0.20230524175051-ec119421bb97 h1:3RPlVWzZ/PDqmVuf/FKHARG5EMid/tl7cv54Sw/QRVY=
|
||||
github.com/rogpeppe/go-internal v1.10.1-0.20230524175051-ec119421bb97/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/sashabaranov/go-openai v1.9.0 h1:NoiO++IISxxJ1pRc0n7uZvMGMake0G+FJ1XPwXtprsA=
|
||||
github.com/sashabaranov/go-openai v1.9.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/sawka/txwrap v0.1.2 h1:v8xS0Z1LE7/6vMZA81PYihI+0TSR6Zm1MalzzBIuXKc=
|
||||
github.com/sawka/txwrap v0.1.2/go.mod h1:T3nlw2gVpuolo6/XEetvBbk1oMXnY978YmBFy1UyHvw=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
|
||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
|
||||
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
mvdan.cc/sh/v3 v3.7.0 h1:lSTjdP/1xsddtaKfGg7Myu7DnlHItd3/M2tomOcNNBg=
|
||||
mvdan.cc/sh/v3 v3.7.0/go.mod h1:K2gwkaesF/D7av7Kxl0HbF5kGOd2ArupNTX3X44+8l8=
|
3828
wavesrv/pkg/cmdrunner/cmdrunner.go
Normal file
3828
wavesrv/pkg/cmdrunner/cmdrunner.go
Normal file
File diff suppressed because it is too large
Load Diff
1986
wavesrv/pkg/cmdrunner/linux-decls.txt
Normal file
1986
wavesrv/pkg/cmdrunner/linux-decls.txt
Normal file
File diff suppressed because it is too large
Load Diff
516
wavesrv/pkg/cmdrunner/resolver.go
Normal file
516
wavesrv/pkg/cmdrunner/resolver.go
Normal file
@ -0,0 +1,516 @@
|
||||
package cmdrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/commandlinedev/prompt-server/pkg/remote"
|
||||
"github.com/commandlinedev/prompt-server/pkg/scpacket"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
R_Session = 1
|
||||
R_Screen = 2
|
||||
R_Remote = 8
|
||||
R_RemoteConnected = 16
|
||||
)
|
||||
|
||||
type resolvedIds struct {
|
||||
SessionId string
|
||||
ScreenId string
|
||||
Remote *ResolvedRemote
|
||||
}
|
||||
|
||||
type ResolvedRemote struct {
|
||||
DisplayName string
|
||||
RemotePtr sstore.RemotePtrType
|
||||
MShell *remote.MShellProc
|
||||
RState remote.RemoteRuntimeState
|
||||
RemoteCopy *sstore.RemoteType
|
||||
StatePtr *sstore.ShellStatePtr
|
||||
FeState map[string]string
|
||||
}
|
||||
|
||||
type ResolveItem = sstore.ResolveItem
|
||||
|
||||
func itemNames(items []ResolveItem) []string {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := make([]string, len(items))
|
||||
for idx, item := range items {
|
||||
rtn[idx] = item.Name
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func sessionsToResolveItems(sessions []*sstore.SessionType) []ResolveItem {
|
||||
if len(sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := make([]ResolveItem, len(sessions))
|
||||
for idx, session := range sessions {
|
||||
rtn[idx] = ResolveItem{Name: session.Name, Id: session.SessionId, Hidden: session.Archived}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func screensToResolveItems(screens []*sstore.ScreenType) []ResolveItem {
|
||||
if len(screens) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := make([]ResolveItem, len(screens))
|
||||
for idx, screen := range screens {
|
||||
rtn[idx] = ResolveItem{Name: screen.Name, Id: screen.ScreenId, Hidden: screen.Archived}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
// 1-indexed
|
||||
func boundInt(ival int, maxVal int, wrap bool) int {
|
||||
if maxVal == 0 {
|
||||
return 0
|
||||
}
|
||||
if ival < 1 {
|
||||
if wrap {
|
||||
return maxVal
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
if ival > maxVal {
|
||||
if wrap {
|
||||
return 1
|
||||
} else {
|
||||
return maxVal
|
||||
}
|
||||
}
|
||||
return ival
|
||||
}
|
||||
|
||||
type posArgType struct {
|
||||
Pos int
|
||||
IsWrap bool
|
||||
IsRelative bool
|
||||
StartAnchor bool
|
||||
EndAnchor bool
|
||||
}
|
||||
|
||||
func parsePosArg(posStr string) *posArgType {
|
||||
if !positionRe.MatchString(posStr) {
|
||||
return nil
|
||||
}
|
||||
if posStr == "+" {
|
||||
return &posArgType{Pos: 1, IsWrap: true, IsRelative: true}
|
||||
} else if posStr == "-" {
|
||||
return &posArgType{Pos: -1, IsWrap: true, IsRelative: true}
|
||||
} else if posStr == "S" {
|
||||
return &posArgType{Pos: 0, IsRelative: true, StartAnchor: true}
|
||||
} else if posStr == "E" {
|
||||
return &posArgType{Pos: 0, IsRelative: true, EndAnchor: true}
|
||||
}
|
||||
if strings.HasPrefix(posStr, "S+") {
|
||||
pos, _ := strconv.Atoi(posStr[2:])
|
||||
return &posArgType{Pos: pos, IsRelative: true, StartAnchor: true}
|
||||
}
|
||||
if strings.HasPrefix(posStr, "E-") {
|
||||
pos, _ := strconv.Atoi(posStr[1:])
|
||||
return &posArgType{Pos: pos, IsRelative: true, EndAnchor: true}
|
||||
}
|
||||
if strings.HasPrefix(posStr, "+") || strings.HasPrefix(posStr, "-") {
|
||||
pos, _ := strconv.Atoi(posStr)
|
||||
return &posArgType{Pos: pos, IsRelative: true}
|
||||
}
|
||||
pos, _ := strconv.Atoi(posStr)
|
||||
return &posArgType{Pos: pos}
|
||||
}
|
||||
|
||||
func resolveByPosition(isNumeric bool, allItems []ResolveItem, curId string, posStr string) *ResolveItem {
|
||||
items := make([]ResolveItem, 0, len(allItems))
|
||||
for _, item := range allItems {
|
||||
if !item.Hidden {
|
||||
items = append(items, item)
|
||||
}
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
posArg := parsePosArg(posStr)
|
||||
if posArg == nil {
|
||||
return nil
|
||||
}
|
||||
var finalPos int
|
||||
if posArg.IsRelative {
|
||||
var curIdx int
|
||||
if posArg.StartAnchor {
|
||||
curIdx = 1
|
||||
} else if posArg.EndAnchor {
|
||||
curIdx = len(items)
|
||||
} else {
|
||||
curIdx = 1 // if no match, curIdx will be first item
|
||||
for idx, item := range items {
|
||||
if item.Id == curId {
|
||||
curIdx = idx + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
finalPos = curIdx + posArg.Pos
|
||||
finalPos = boundInt(finalPos, len(items), posArg.IsWrap)
|
||||
return &items[finalPos-1]
|
||||
} else if isNumeric {
|
||||
// these resolve items have a "Num" set that should be used to look up non-relative positions
|
||||
// use allItems for numeric resolve
|
||||
for _, item := range allItems {
|
||||
if item.Num == posArg.Pos {
|
||||
return &item
|
||||
}
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
// non-numeric means position is just the index
|
||||
finalPos = posArg.Pos
|
||||
if finalPos <= 0 || finalPos > len(items) {
|
||||
return nil
|
||||
}
|
||||
return &items[finalPos-1]
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRemoteArg(remoteArg string) (*sstore.RemotePtrType, error) {
|
||||
rrUser, rrRemote, rrName, err := parseFullRemoteRef(remoteArg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rrUser != "" {
|
||||
return nil, fmt.Errorf("remoteusers not supported")
|
||||
}
|
||||
msh := remote.GetRemoteByArg(rrRemote)
|
||||
if msh == nil {
|
||||
return nil, nil
|
||||
}
|
||||
rcopy := msh.GetRemoteCopy()
|
||||
return &sstore.RemotePtrType{RemoteId: rcopy.RemoteId, Name: rrName}, nil
|
||||
}
|
||||
|
||||
func resolveUiIds(ctx context.Context, pk *scpacket.FeCommandPacketType, rtype int) (resolvedIds, error) {
|
||||
rtn := resolvedIds{}
|
||||
uictx := pk.UIContext
|
||||
if uictx != nil {
|
||||
rtn.SessionId = uictx.SessionId
|
||||
rtn.ScreenId = uictx.ScreenId
|
||||
}
|
||||
if pk.Kwargs["session"] != "" {
|
||||
sessionId, err := resolveSessionArg(pk.Kwargs["session"])
|
||||
if err != nil {
|
||||
return rtn, err
|
||||
}
|
||||
if sessionId != "" {
|
||||
rtn.SessionId = sessionId
|
||||
}
|
||||
}
|
||||
if pk.Kwargs["screen"] != "" {
|
||||
screenId, err := resolveScreenArg(rtn.SessionId, pk.Kwargs["screen"])
|
||||
if err != nil {
|
||||
return rtn, err
|
||||
}
|
||||
if screenId != "" {
|
||||
rtn.ScreenId = screenId
|
||||
}
|
||||
}
|
||||
var rptr *sstore.RemotePtrType
|
||||
var err error
|
||||
if pk.Kwargs["remote"] != "" {
|
||||
rptr, err = resolveRemoteArg(pk.Kwargs["remote"])
|
||||
if err != nil {
|
||||
return rtn, err
|
||||
}
|
||||
if rptr == nil {
|
||||
return rtn, fmt.Errorf("invalid remote argument %q passed, remote not found", pk.Kwargs["remote"])
|
||||
}
|
||||
} else if uictx.Remote != nil {
|
||||
rptr = uictx.Remote
|
||||
}
|
||||
if rptr != nil {
|
||||
err = rptr.Validate()
|
||||
if err != nil {
|
||||
return rtn, fmt.Errorf("invalid resolved remote: %v", err)
|
||||
}
|
||||
rr, err := ResolveRemoteFromPtr(ctx, rptr, rtn.SessionId, rtn.ScreenId)
|
||||
if err != nil {
|
||||
return rtn, err
|
||||
}
|
||||
rtn.Remote = rr
|
||||
}
|
||||
if rtype&R_Session > 0 && rtn.SessionId == "" {
|
||||
return rtn, fmt.Errorf("no session")
|
||||
}
|
||||
if rtype&R_Screen > 0 && rtn.ScreenId == "" {
|
||||
return rtn, fmt.Errorf("no screen")
|
||||
}
|
||||
if (rtype&R_Remote > 0 || rtype&R_RemoteConnected > 0) && rtn.Remote == nil {
|
||||
return rtn, fmt.Errorf("no remote")
|
||||
}
|
||||
if rtype&R_RemoteConnected > 0 {
|
||||
if !rtn.Remote.RState.IsConnected() {
|
||||
err = rtn.Remote.MShell.TryAutoConnect()
|
||||
if err != nil {
|
||||
return rtn, fmt.Errorf("error trying to auto-connect remote [%s]: %w", rtn.Remote.DisplayName, err)
|
||||
}
|
||||
rrNew, err := ResolveRemoteFromPtr(ctx, rptr, rtn.SessionId, rtn.ScreenId)
|
||||
if err != nil {
|
||||
return rtn, err
|
||||
}
|
||||
rtn.Remote = rrNew
|
||||
}
|
||||
if !rtn.Remote.RState.IsConnected() {
|
||||
return rtn, fmt.Errorf("remote [%s] is not connected", rtn.Remote.DisplayName)
|
||||
}
|
||||
if rtn.Remote.StatePtr == nil || rtn.Remote.FeState == nil {
|
||||
return rtn, fmt.Errorf("remote [%s] state is not available", rtn.Remote.DisplayName)
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func resolveSessionScreen(ctx context.Context, sessionId string, screenArg string, curScreenArg string) (*ResolveItem, error) {
|
||||
screens, err := sstore.GetSessionScreens(ctx, sessionId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not retreive screens for session=%s: %v", sessionId, err)
|
||||
}
|
||||
ritems := screensToResolveItems(screens)
|
||||
return genericResolve(screenArg, curScreenArg, ritems, false, "screen")
|
||||
}
|
||||
|
||||
func resolveSession(ctx context.Context, sessionArg string, curSessionArg string) (*ResolveItem, error) {
|
||||
bareSessions, err := sstore.GetBareSessions(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ritems := sessionsToResolveItems(bareSessions)
|
||||
ritem, err := genericResolve(sessionArg, curSessionArg, ritems, false, "session")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ritem, nil
|
||||
}
|
||||
|
||||
func resolveLine(ctx context.Context, sessionId string, screenId string, lineArg string, curLineArg string) (*ResolveItem, error) {
|
||||
lines, err := sstore.GetLineResolveItems(ctx, screenId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get lines: %v", err)
|
||||
}
|
||||
return genericResolve(lineArg, curLineArg, lines, true, "line")
|
||||
}
|
||||
|
||||
func getSessionIds(sarr []*sstore.SessionType) []string {
|
||||
rtn := make([]string, len(sarr))
|
||||
for idx, s := range sarr {
|
||||
rtn[idx] = s.SessionId
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
var partialUUIDRe = regexp.MustCompile("^[0-9a-f]{8}$")
|
||||
|
||||
func isPartialUUID(s string) bool {
|
||||
return partialUUIDRe.MatchString(s)
|
||||
}
|
||||
|
||||
func isUUID(s string) bool {
|
||||
_, err := uuid.Parse(s)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func getResolveItemById(id string, items []ResolveItem) *ResolveItem {
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
for _, item := range items {
|
||||
if item.Id == id {
|
||||
return &item
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func genericResolve(arg string, curArg string, items []ResolveItem, isNumeric bool, typeStr string) (*ResolveItem, error) {
|
||||
if len(items) == 0 || arg == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var curId string
|
||||
if curArg != "" {
|
||||
curItem, _ := genericResolve(curArg, "", items, isNumeric, typeStr)
|
||||
if curItem != nil {
|
||||
curId = curItem.Id
|
||||
}
|
||||
}
|
||||
rtnItem := resolveByPosition(isNumeric, items, curId, arg)
|
||||
if rtnItem != nil {
|
||||
return rtnItem, nil
|
||||
}
|
||||
isUuid := isUUID(arg)
|
||||
tryPuid := isPartialUUID(arg)
|
||||
var prefixMatches []ResolveItem
|
||||
for _, item := range items {
|
||||
if (isUuid && item.Id == arg) || (tryPuid && strings.HasPrefix(item.Id, arg)) {
|
||||
return &item, nil
|
||||
}
|
||||
if item.Name != "" {
|
||||
if item.Name == arg {
|
||||
return &item, nil
|
||||
}
|
||||
if !item.Hidden && strings.HasPrefix(item.Name, arg) {
|
||||
prefixMatches = append(prefixMatches, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(prefixMatches) == 1 {
|
||||
return &prefixMatches[0], nil
|
||||
}
|
||||
if len(prefixMatches) > 1 {
|
||||
return nil, fmt.Errorf("could not resolve %s '%s', ambiguious prefix matched multiple %ss: %s", typeStr, arg, typeStr, formatStrs(itemNames(prefixMatches), "and", true))
|
||||
}
|
||||
return nil, fmt.Errorf("could not resolve %s '%s' (name/id/pos not found)", typeStr, arg)
|
||||
}
|
||||
|
||||
func resolveSessionId(pk *scpacket.FeCommandPacketType) (string, error) {
|
||||
sessionId := pk.Kwargs["session"]
|
||||
if sessionId == "" {
|
||||
return "", nil
|
||||
}
|
||||
if _, err := uuid.Parse(sessionId); err != nil {
|
||||
return "", fmt.Errorf("invalid sessionid '%s'", sessionId)
|
||||
}
|
||||
return sessionId, nil
|
||||
}
|
||||
|
||||
func resolveSessionArg(sessionArg string) (string, error) {
|
||||
if sessionArg == "" {
|
||||
return "", nil
|
||||
}
|
||||
if _, err := uuid.Parse(sessionArg); err != nil {
|
||||
return "", fmt.Errorf("invalid session arg specified (must be sessionid) '%s'", sessionArg)
|
||||
}
|
||||
return sessionArg, nil
|
||||
}
|
||||
|
||||
func resolveScreenArg(sessionId string, screenArg string) (string, error) {
|
||||
if screenArg == "" {
|
||||
return "", nil
|
||||
}
|
||||
if _, err := uuid.Parse(screenArg); err != nil {
|
||||
return "", fmt.Errorf("invalid screen arg specified (must be screenid) '%s'", screenArg)
|
||||
}
|
||||
return screenArg, nil
|
||||
}
|
||||
|
||||
func resolveScreenId(ctx context.Context, pk *scpacket.FeCommandPacketType, sessionId string) (string, error) {
|
||||
screenArg := pk.Kwargs["screen"]
|
||||
if screenArg == "" {
|
||||
return "", nil
|
||||
}
|
||||
if _, err := uuid.Parse(screenArg); err == nil {
|
||||
return screenArg, nil
|
||||
}
|
||||
if sessionId == "" {
|
||||
return "", fmt.Errorf("cannot resolve screen without session")
|
||||
}
|
||||
ritem, err := resolveSessionScreen(ctx, sessionId, screenArg, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ritem.Id, nil
|
||||
}
|
||||
|
||||
// returns (remoteuserref, remoteref, name, error)
|
||||
func parseFullRemoteRef(fullRemoteRef string) (string, string, string, error) {
|
||||
if strings.HasPrefix(fullRemoteRef, "[") && strings.HasSuffix(fullRemoteRef, "]") {
|
||||
fullRemoteRef = fullRemoteRef[1 : len(fullRemoteRef)-1]
|
||||
}
|
||||
fields := strings.Split(fullRemoteRef, ":")
|
||||
if len(fields) > 3 {
|
||||
return "", "", "", fmt.Errorf("invalid remote format '%s'", fullRemoteRef)
|
||||
}
|
||||
if len(fields) == 1 {
|
||||
return "", fields[0], "", nil
|
||||
}
|
||||
if len(fields) == 2 {
|
||||
if strings.HasPrefix(fields[0], "@") {
|
||||
return fields[0], fields[1], "", nil
|
||||
}
|
||||
return "", fields[0], fields[1], nil
|
||||
}
|
||||
return fields[0], fields[1], fields[2], nil
|
||||
}
|
||||
|
||||
func ResolveRemoteFromPtr(ctx context.Context, rptr *sstore.RemotePtrType, sessionId string, screenId string) (*ResolvedRemote, error) {
|
||||
if rptr == nil || rptr.RemoteId == "" {
|
||||
return nil, nil
|
||||
}
|
||||
msh := remote.GetRemoteById(rptr.RemoteId)
|
||||
if msh == nil {
|
||||
return nil, fmt.Errorf("invalid remote '%s', not found", rptr.RemoteId)
|
||||
}
|
||||
rstate := msh.GetRemoteRuntimeState()
|
||||
rcopy := msh.GetRemoteCopy()
|
||||
displayName := rstate.GetDisplayName(rptr)
|
||||
rtn := &ResolvedRemote{
|
||||
DisplayName: displayName,
|
||||
RemotePtr: *rptr,
|
||||
RState: rstate,
|
||||
MShell: msh,
|
||||
RemoteCopy: &rcopy,
|
||||
StatePtr: nil,
|
||||
FeState: nil,
|
||||
}
|
||||
if sessionId != "" && screenId != "" {
|
||||
ri, err := sstore.GetRemoteInstance(ctx, sessionId, screenId, *rptr)
|
||||
if err != nil {
|
||||
log.Printf("ERROR resolving remote state '%s': %v\n", displayName, err)
|
||||
// continue with state set to nil
|
||||
} else {
|
||||
if ri == nil {
|
||||
rtn.StatePtr = msh.GetDefaultStatePtr()
|
||||
rtn.FeState = msh.GetDefaultFeState()
|
||||
} else {
|
||||
rtn.StatePtr = &sstore.ShellStatePtr{BaseHash: ri.StateBaseHash, DiffHashArr: ri.StateDiffHashArr}
|
||||
rtn.FeState = ri.FeState
|
||||
}
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
// returns (remoteDisplayName, remoteptr, state, rstate, err)
|
||||
func resolveRemote(ctx context.Context, fullRemoteRef string, sessionId string, screenId string) (string, *sstore.RemotePtrType, *remote.RemoteRuntimeState, error) {
|
||||
if fullRemoteRef == "" {
|
||||
return "", nil, nil, nil
|
||||
}
|
||||
userRef, remoteRef, remoteName, err := parseFullRemoteRef(fullRemoteRef)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
if userRef != "" {
|
||||
return "", nil, nil, fmt.Errorf("invalid remote '%s', cannot resolve remote userid '%s'", fullRemoteRef, userRef)
|
||||
}
|
||||
rstate := remote.ResolveRemoteRef(remoteRef)
|
||||
if rstate == nil {
|
||||
return "", nil, nil, fmt.Errorf("cannot resolve remote '%s': not found", fullRemoteRef)
|
||||
}
|
||||
rptr := sstore.RemotePtrType{RemoteId: rstate.RemoteId, Name: remoteName}
|
||||
rname := rstate.RemoteCanonicalName
|
||||
if rstate.RemoteAlias != "" {
|
||||
rname = rstate.RemoteAlias
|
||||
}
|
||||
if rptr.Name != "" {
|
||||
rname = fmt.Sprintf("%s:%s", rname, rptr.Name)
|
||||
}
|
||||
return rname, &rptr, rstate, nil
|
||||
}
|
333
wavesrv/pkg/cmdrunner/shparse.go
Normal file
333
wavesrv/pkg/cmdrunner/shparse.go
Normal file
@ -0,0 +1,333 @@
|
||||
package cmdrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/shexec"
|
||||
"github.com/commandlinedev/apishell/pkg/simpleexpand"
|
||||
"github.com/commandlinedev/prompt-server/pkg/scpacket"
|
||||
"github.com/commandlinedev/prompt-server/pkg/utilfn"
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
var ValidMetaCmdRe = regexp.MustCompile("^/([a-z_][a-z0-9_-]*)(?::([a-z][a-z0-9_-]*))?$")
|
||||
|
||||
type BareMetaCmdDecl struct {
|
||||
CmdStr string
|
||||
MetaCmd string
|
||||
}
|
||||
|
||||
var BareMetaCmds = []BareMetaCmdDecl{
|
||||
BareMetaCmdDecl{"cr", "cr"},
|
||||
BareMetaCmdDecl{"connect", "cr"},
|
||||
BareMetaCmdDecl{"clear", "clear"},
|
||||
BareMetaCmdDecl{"reset", "reset"},
|
||||
BareMetaCmdDecl{"codeedit", "codeedit"},
|
||||
BareMetaCmdDecl{"codeview", "codeview"},
|
||||
BareMetaCmdDecl{"imageview", "imageview"},
|
||||
BareMetaCmdDecl{"markdownview", "markdownview"},
|
||||
BareMetaCmdDecl{"mdview", "markdownview"},
|
||||
BareMetaCmdDecl{"csvview", "csvview"},
|
||||
}
|
||||
|
||||
const (
|
||||
CmdParseTypePositional = "pos"
|
||||
CmdParseTypeRaw = "raw"
|
||||
)
|
||||
|
||||
var CmdParseOverrides map[string]string = map[string]string{
|
||||
"setenv": CmdParseTypePositional,
|
||||
"unset": CmdParseTypePositional,
|
||||
"set": CmdParseTypePositional,
|
||||
"run": CmdParseTypeRaw,
|
||||
"comment": CmdParseTypeRaw,
|
||||
"chat": CmdParseTypeRaw,
|
||||
}
|
||||
|
||||
func DumpPacket(pk *scpacket.FeCommandPacketType) {
|
||||
if pk == nil || pk.MetaCmd == "" {
|
||||
fmt.Printf("[no metacmd]\n")
|
||||
return
|
||||
}
|
||||
if pk.MetaSubCmd == "" {
|
||||
fmt.Printf("/%s\n", pk.MetaCmd)
|
||||
} else {
|
||||
fmt.Printf("/%s:%s\n", pk.MetaCmd, pk.MetaSubCmd)
|
||||
}
|
||||
for _, arg := range pk.Args {
|
||||
fmt.Printf(" %q\n", arg)
|
||||
}
|
||||
for key, val := range pk.Kwargs {
|
||||
fmt.Printf(" [%s]=%q\n", key, val)
|
||||
}
|
||||
}
|
||||
|
||||
func isQuoted(source string, w *syntax.Word) bool {
|
||||
if w == nil {
|
||||
return false
|
||||
}
|
||||
offset := w.Pos().Offset()
|
||||
if int(offset) >= len(source) {
|
||||
return false
|
||||
}
|
||||
return source[offset] == '"' || source[offset] == '\''
|
||||
}
|
||||
|
||||
func getSourceStr(source string, w *syntax.Word) string {
|
||||
if w == nil {
|
||||
return ""
|
||||
}
|
||||
offset := w.Pos().Offset()
|
||||
end := w.End().Offset()
|
||||
return source[offset:end]
|
||||
}
|
||||
|
||||
func SubMetaCmd(cmd string) string {
|
||||
switch cmd {
|
||||
case "s":
|
||||
return "screen"
|
||||
case "r":
|
||||
return "run"
|
||||
case "c":
|
||||
return "comment"
|
||||
case "e":
|
||||
return "eval"
|
||||
case "export":
|
||||
return "setenv"
|
||||
case "connection":
|
||||
return "remote"
|
||||
default:
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
// returns (metaCmd, metaSubCmd, rest)
|
||||
// if metaCmd is "" then this isn't a valid metacmd string
|
||||
func parseMetaCmd(origCommandStr string) (string, string, string) {
|
||||
commandStr := strings.TrimSpace(origCommandStr)
|
||||
if len(commandStr) < 2 {
|
||||
return "run", "", origCommandStr
|
||||
}
|
||||
fields := strings.SplitN(commandStr, " ", 2)
|
||||
firstArg := fields[0]
|
||||
rest := ""
|
||||
if len(fields) > 1 {
|
||||
rest = strings.TrimSpace(fields[1])
|
||||
}
|
||||
for _, decl := range BareMetaCmds {
|
||||
if firstArg == decl.CmdStr {
|
||||
return decl.MetaCmd, "", rest
|
||||
}
|
||||
}
|
||||
m := ValidMetaCmdRe.FindStringSubmatch(firstArg)
|
||||
if m == nil {
|
||||
return "run", "", origCommandStr
|
||||
}
|
||||
return SubMetaCmd(m[1]), m[2], rest
|
||||
}
|
||||
|
||||
func onlyPositionalArgs(metaCmd string, metaSubCmd string) bool {
|
||||
return (CmdParseOverrides[metaCmd] == CmdParseTypePositional) && metaSubCmd == ""
|
||||
}
|
||||
|
||||
func onlyRawArgs(metaCmd string, metaSubCmd string) bool {
|
||||
return CmdParseOverrides[metaCmd] == CmdParseTypeRaw
|
||||
}
|
||||
|
||||
func setBracketArgs(argMap map[string]string, bracketStr string) error {
|
||||
bracketStr = strings.TrimSpace(bracketStr)
|
||||
if bracketStr == "" {
|
||||
return nil
|
||||
}
|
||||
strReader := strings.NewReader(bracketStr)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
var wordErr error
|
||||
var ectx simpleexpand.SimpleExpandContext // do not set HomeDir (we don't expand ~ in bracket args)
|
||||
err := parser.Words(strReader, func(w *syntax.Word) bool {
|
||||
litStr, _ := simpleexpand.SimpleExpandWord(ectx, w, bracketStr)
|
||||
eqIdx := strings.Index(litStr, "=")
|
||||
var varName, varVal string
|
||||
if eqIdx == -1 {
|
||||
varName = litStr
|
||||
} else {
|
||||
varName = litStr[0:eqIdx]
|
||||
varVal = litStr[eqIdx+1:]
|
||||
}
|
||||
if !shexec.IsValidBashIdentifier(varName) {
|
||||
wordErr = fmt.Errorf("invalid identifier %s in bracket args", utilfn.ShellQuote(varName, true, 20))
|
||||
return false
|
||||
}
|
||||
if varVal == "" {
|
||||
varVal = "1"
|
||||
}
|
||||
argMap[varName] = varVal
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if wordErr != nil {
|
||||
return wordErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var literalRtnStateCommands = []string{".", "source", "unset", "cd", "alias", "unalias", "deactivate"}
|
||||
|
||||
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {
|
||||
if len(callExpr.Args) <= argNum {
|
||||
return ""
|
||||
}
|
||||
arg := callExpr.Args[argNum]
|
||||
if len(arg.Parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
lit, ok := arg.Parts[0].(*syntax.Lit)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return lit.Value
|
||||
}
|
||||
|
||||
// detects: export, declare, ., source, X=1, unset
|
||||
func IsReturnStateCommand(cmdStr string) bool {
|
||||
cmdReader := strings.NewReader(cmdStr)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(cmdReader, "cmd")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, stmt := range file.Stmts {
|
||||
if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok {
|
||||
if len(callExpr.Assigns) > 0 && len(callExpr.Args) == 0 {
|
||||
return true
|
||||
}
|
||||
arg0 := getCallExprLitArg(callExpr, 0)
|
||||
if arg0 != "" && utilfn.ContainsStr(literalRtnStateCommands, arg0) {
|
||||
return true
|
||||
}
|
||||
if arg0 == "git" {
|
||||
arg1 := getCallExprLitArg(callExpr, 1)
|
||||
if arg1 == "checkout" || arg1 == "switch" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else if _, ok := stmt.Cmd.(*syntax.DeclClause); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func EvalBracketArgs(origCmdStr string) (map[string]string, string, error) {
|
||||
rtn := make(map[string]string)
|
||||
if strings.HasPrefix(origCmdStr, " ") {
|
||||
rtn["nohist"] = "1"
|
||||
}
|
||||
cmdStr := strings.TrimSpace(origCmdStr)
|
||||
if !strings.HasPrefix(cmdStr, "[") {
|
||||
return rtn, origCmdStr, nil
|
||||
}
|
||||
rbIdx := strings.Index(cmdStr, "]")
|
||||
if rbIdx == -1 {
|
||||
return nil, "", fmt.Errorf("unmatched '[' found in command")
|
||||
}
|
||||
bracketStr := cmdStr[1:rbIdx]
|
||||
restStr := strings.TrimSpace(cmdStr[rbIdx+1:])
|
||||
err := setBracketArgs(rtn, bracketStr)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return rtn, restStr, nil
|
||||
}
|
||||
|
||||
func unescapeBackSlashes(s string) string {
|
||||
if strings.Index(s, "\\") == -1 {
|
||||
return s
|
||||
}
|
||||
var newStr []rune
|
||||
var lastSlash bool
|
||||
for _, r := range s {
|
||||
if lastSlash {
|
||||
lastSlash = false
|
||||
newStr = append(newStr, r)
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
lastSlash = true
|
||||
continue
|
||||
}
|
||||
newStr = append(newStr, r)
|
||||
}
|
||||
return string(newStr)
|
||||
}
|
||||
|
||||
func EvalMetaCommand(ctx context.Context, origPk *scpacket.FeCommandPacketType) (*scpacket.FeCommandPacketType, error) {
|
||||
if len(origPk.Args) == 0 {
|
||||
return nil, fmt.Errorf("empty command (no fields)")
|
||||
}
|
||||
if strings.TrimSpace(origPk.Args[0]) == "" {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
bracketArgs, cmdStr, err := EvalBracketArgs(origPk.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metaCmd, metaSubCmd, commandArgs := parseMetaCmd(cmdStr)
|
||||
rtnPk := scpacket.MakeFeCommandPacket()
|
||||
rtnPk.MetaCmd = metaCmd
|
||||
rtnPk.MetaSubCmd = metaSubCmd
|
||||
rtnPk.Kwargs = make(map[string]string)
|
||||
rtnPk.UIContext = origPk.UIContext
|
||||
rtnPk.RawStr = origPk.RawStr
|
||||
for key, val := range origPk.Kwargs {
|
||||
rtnPk.Kwargs[key] = val
|
||||
}
|
||||
for key, val := range bracketArgs {
|
||||
rtnPk.Kwargs[key] = val
|
||||
}
|
||||
if onlyRawArgs(metaCmd, metaSubCmd) {
|
||||
// don't evaluate arguments for /run or /comment
|
||||
rtnPk.Args = []string{commandArgs}
|
||||
return rtnPk, nil
|
||||
}
|
||||
commandReader := strings.NewReader(commandArgs)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
var words []*syntax.Word
|
||||
err = parser.Words(commandReader, func(w *syntax.Word) bool {
|
||||
words = append(words, w)
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing metacmd, position %v", err)
|
||||
}
|
||||
envMap := make(map[string]string) // later we can add vars like session, screen, remote, and user
|
||||
cfg := shexec.GetParserConfig(envMap)
|
||||
// process arguments
|
||||
for idx, w := range words {
|
||||
literalVal, err := expand.Literal(cfg, w)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error evaluating metacmd argument %d [%s]: %v", idx+1, getSourceStr(commandArgs, w), err)
|
||||
}
|
||||
if isQuoted(commandArgs, w) || onlyPositionalArgs(metaCmd, metaSubCmd) {
|
||||
rtnPk.Args = append(rtnPk.Args, literalVal)
|
||||
continue
|
||||
}
|
||||
eqIdx := strings.Index(literalVal, "=")
|
||||
if eqIdx != -1 && eqIdx != 0 {
|
||||
varName := literalVal[:eqIdx]
|
||||
varVal := literalVal[eqIdx+1:]
|
||||
rtnPk.Kwargs[varName] = varVal
|
||||
continue
|
||||
}
|
||||
rtnPk.Args = append(rtnPk.Args, unescapeBackSlashes(literalVal))
|
||||
}
|
||||
if resolveBool(rtnPk.Kwargs["dump"], false) {
|
||||
DumpPacket(rtnPk)
|
||||
}
|
||||
return rtnPk, nil
|
||||
}
|
57
wavesrv/pkg/cmdrunner/shparse_test.go
Normal file
57
wavesrv/pkg/cmdrunner/shparse_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
package cmdrunner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func xTestParseAliases(t *testing.T) {
|
||||
m, err := ParseAliases(`
|
||||
alias cdg='cd work/gopath/src/github.com/sawka'
|
||||
alias s='scripthaus'
|
||||
alias x='ls;ls"'
|
||||
alias foo="bar \"hello\""
|
||||
alias x=y
|
||||
`)
|
||||
if err != nil {
|
||||
fmt.Printf("err: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("m: %#v\n", m)
|
||||
}
|
||||
|
||||
func xTestParseFuncs(t *testing.T) {
|
||||
file, err := os.ReadFile("./linux-decls.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("error reading linux-decls: %v", err)
|
||||
}
|
||||
m, err := ParseFuncs(string(file))
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing funcs: %v", err)
|
||||
}
|
||||
for key, val := range m {
|
||||
fmt.Printf("func: %s %d\n", key, len(val))
|
||||
}
|
||||
}
|
||||
|
||||
func testRSC(t *testing.T, cmd string, expected bool) {
|
||||
rtn := IsReturnStateCommand(cmd)
|
||||
if rtn != expected {
|
||||
t.Errorf("cmd [%s], rtn=%v, expected=%v", cmd, rtn, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReturnStateCommand(t *testing.T) {
|
||||
testRSC(t, "FOO=1", true)
|
||||
testRSC(t, "FOO=1 X=2", true)
|
||||
testRSC(t, "ls", false)
|
||||
testRSC(t, "export X", true)
|
||||
testRSC(t, "export X=1", true)
|
||||
testRSC(t, "declare -x FOO=1", true)
|
||||
testRSC(t, "source ./test", true)
|
||||
testRSC(t, "unset FOO BAR", true)
|
||||
testRSC(t, "FOO=1; ls", true)
|
||||
testRSC(t, ". ./test", true)
|
||||
testRSC(t, "{ FOO=6; }", false)
|
||||
}
|
120
wavesrv/pkg/cmdrunner/termopts.go
Normal file
120
wavesrv/pkg/cmdrunner/termopts.go
Normal file
@ -0,0 +1,120 @@
|
||||
package cmdrunner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/apishell/pkg/shexec"
|
||||
"github.com/commandlinedev/prompt-server/pkg/remote"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
)
|
||||
|
||||
// PTERM=MxM,Mx25
|
||||
// PTERM="Mx25!"
|
||||
// PTERM=80x25,80x35
|
||||
|
||||
type PTermOptsType struct {
|
||||
Rows string
|
||||
RowsFlex bool
|
||||
Cols string
|
||||
ColsFlex bool
|
||||
}
|
||||
|
||||
const PTermMax = "M"
|
||||
|
||||
func isDigits(s string) bool {
|
||||
for _, ch := range s {
|
||||
if ch < '0' || ch > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func atoiDefault(s string, def int) int {
|
||||
ival, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return ival
|
||||
}
|
||||
|
||||
func parseTermPart(part string, partType string) (string, bool, error) {
|
||||
flex := true
|
||||
if strings.HasSuffix(part, "!") {
|
||||
part = part[:len(part)-1]
|
||||
flex = false
|
||||
}
|
||||
if part == "" {
|
||||
return PTermMax, flex, nil
|
||||
}
|
||||
if part == PTermMax {
|
||||
return PTermMax, flex, nil
|
||||
}
|
||||
if !isDigits(part) {
|
||||
return "", false, fmt.Errorf("invalid PTERM %s: must be '%s' or [number]", partType, PTermMax)
|
||||
}
|
||||
return part, flex, nil
|
||||
}
|
||||
|
||||
func parseSingleTermStr(s string) (*PTermOptsType, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
xIdx := strings.Index(s, "x")
|
||||
if xIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid PTERM, must include 'x' to separate width and height (e.g. WxH)")
|
||||
}
|
||||
rowsPart := s[0:xIdx]
|
||||
colsPart := s[xIdx+1:]
|
||||
rows, rowsFlex, err := parseTermPart(rowsPart, "rows")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cols, colsFlex, err := parseTermPart(colsPart, "cols")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PTermOptsType{Rows: rows, RowsFlex: rowsFlex, Cols: cols, ColsFlex: colsFlex}, nil
|
||||
}
|
||||
|
||||
func GetUITermOpts(winSize *packet.WinSize, ptermStr string) (*packet.TermOpts, error) {
|
||||
opts, err := parseSingleTermStr(ptermStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
termOpts := &packet.TermOpts{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols, Term: remote.DefaultTerm, MaxPtySize: shexec.DefaultMaxPtySize}
|
||||
if winSize == nil {
|
||||
winSize = &packet.WinSize{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols}
|
||||
}
|
||||
if winSize.Rows == 0 {
|
||||
winSize.Rows = shexec.DefaultTermRows
|
||||
}
|
||||
if winSize.Cols == 0 {
|
||||
winSize.Cols = shexec.DefaultTermCols
|
||||
}
|
||||
if opts.Rows == PTermMax {
|
||||
termOpts.Rows = winSize.Rows
|
||||
} else {
|
||||
termOpts.Rows = atoiDefault(opts.Rows, termOpts.Rows)
|
||||
}
|
||||
if opts.Cols == PTermMax {
|
||||
termOpts.Cols = winSize.Cols
|
||||
} else {
|
||||
termOpts.Cols = atoiDefault(opts.Cols, termOpts.Cols)
|
||||
}
|
||||
termOpts.MaxPtySize = base.BoundInt64(termOpts.MaxPtySize, shexec.MinMaxPtySize, shexec.MaxMaxPtySize)
|
||||
termOpts.Cols = base.BoundInt(termOpts.Cols, shexec.MinTermCols, shexec.MaxTermCols)
|
||||
termOpts.Rows = base.BoundInt(termOpts.Rows, shexec.MinTermRows, shexec.MaxTermRows)
|
||||
return termOpts, nil
|
||||
}
|
||||
|
||||
func convertTermOpts(pkto *packet.TermOpts) *sstore.TermOpts {
|
||||
return &sstore.TermOpts{
|
||||
Rows: int64(pkto.Rows),
|
||||
Cols: int64(pkto.Cols),
|
||||
FlexRows: true,
|
||||
MaxPtySize: pkto.MaxPtySize,
|
||||
}
|
||||
}
|
640
wavesrv/pkg/comp/comp.go
Normal file
640
wavesrv/pkg/comp/comp.go
Normal file
@ -0,0 +1,640 @@
|
||||
// scripthaus completion
|
||||
package comp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/simpleexpand"
|
||||
"github.com/commandlinedev/prompt-server/pkg/shparse"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
"github.com/commandlinedev/prompt-server/pkg/utilfn"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
const MaxCompQuoteLen = 5000
|
||||
|
||||
const (
|
||||
// local to simplecomp
|
||||
CGTypeCommand = "command"
|
||||
CGTypeFile = "file"
|
||||
CGTypeDir = "directory"
|
||||
CGTypeVariable = "variable"
|
||||
|
||||
// implemented in cmdrunner
|
||||
CGTypeMeta = "metacmd"
|
||||
CGTypeCommandMeta = "command+meta"
|
||||
|
||||
CGTypeRemote = "remote"
|
||||
CGTypeRemoteInstance = "remoteinstance"
|
||||
CGTypeGlobalCmd = "globalcmd"
|
||||
)
|
||||
|
||||
const (
|
||||
QuoteTypeLiteral = ""
|
||||
QuoteTypeDQ = "\""
|
||||
QuoteTypeANSI = "$'"
|
||||
QuoteTypeSQ = "'"
|
||||
)
|
||||
|
||||
type CompContext struct {
|
||||
RemotePtr *sstore.RemotePtrType
|
||||
Cwd string
|
||||
ForDisplay bool
|
||||
}
|
||||
|
||||
type ParsedWord struct {
|
||||
Offset int
|
||||
Word *syntax.Word
|
||||
PartialWord string
|
||||
Prefix string
|
||||
}
|
||||
|
||||
type CompPoint struct {
|
||||
StmtStr string
|
||||
Words []ParsedWord
|
||||
CompWord int
|
||||
CompWordPos int
|
||||
Prefix string
|
||||
Suffix string
|
||||
}
|
||||
|
||||
// directories will have a trailing "/"
|
||||
type CompEntry struct {
|
||||
Word string
|
||||
IsMetaCmd bool
|
||||
}
|
||||
|
||||
type CompReturn struct {
|
||||
CompType string
|
||||
Entries []CompEntry
|
||||
HasMore bool
|
||||
}
|
||||
|
||||
var noEscChars []bool
|
||||
var specialEsc []string
|
||||
|
||||
func init() {
|
||||
noEscChars = make([]bool, 256)
|
||||
for ch := 0; ch < 256; ch++ {
|
||||
if (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
|
||||
ch == '-' || ch == '.' || ch == '/' || ch == ':' || ch == '=' {
|
||||
noEscChars[byte(ch)] = true
|
||||
}
|
||||
}
|
||||
specialEsc = make([]string, 256)
|
||||
specialEsc[0x7] = "\\a"
|
||||
specialEsc[0x8] = "\\b"
|
||||
specialEsc[0x9] = "\\t"
|
||||
specialEsc[0xa] = "\\n"
|
||||
specialEsc[0xb] = "\\v"
|
||||
specialEsc[0xc] = "\\f"
|
||||
specialEsc[0xd] = "\\r"
|
||||
specialEsc[0x1b] = "\\E"
|
||||
}
|
||||
|
||||
func compQuoteDQString(s string, close bool) string {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('"')
|
||||
for _, ch := range s {
|
||||
if ch == '"' || ch == '\\' || ch == '$' || ch == '`' {
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
buf.WriteRune(ch)
|
||||
}
|
||||
if close {
|
||||
buf.WriteByte('"')
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func hasGlob(s string) bool {
|
||||
var lastExtGlob bool
|
||||
for _, ch := range s {
|
||||
if ch == '*' || ch == '?' || ch == '[' || ch == '{' {
|
||||
return true
|
||||
}
|
||||
if ch == '+' || ch == '@' || ch == '!' {
|
||||
lastExtGlob = true
|
||||
continue
|
||||
}
|
||||
if lastExtGlob && ch == '(' {
|
||||
return true
|
||||
}
|
||||
lastExtGlob = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeUtf8Literal(buf *bytes.Buffer, ch rune) {
|
||||
var runeArr [utf8.UTFMax]byte
|
||||
buf.WriteString("$'")
|
||||
barr := runeArr[:]
|
||||
byteLen := utf8.EncodeRune(barr, ch)
|
||||
for i := 0; i < byteLen; i++ {
|
||||
buf.WriteString("\\x")
|
||||
buf.WriteByte(utilfn.HexDigits[barr[i]/16])
|
||||
buf.WriteByte(utilfn.HexDigits[barr[i]%16])
|
||||
}
|
||||
buf.WriteByte('\'')
|
||||
}
|
||||
|
||||
func compQuoteLiteralString(s string) string {
|
||||
var buf bytes.Buffer
|
||||
for idx, ch := range s {
|
||||
if ch == 0 {
|
||||
break
|
||||
}
|
||||
if idx == 0 && ch == '~' {
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
if ch > unicode.MaxASCII {
|
||||
writeUtf8Literal(&buf, ch)
|
||||
continue
|
||||
}
|
||||
var bch = byte(ch)
|
||||
if noEscChars[bch] {
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
if specialEsc[bch] != "" {
|
||||
buf.WriteString(specialEsc[bch])
|
||||
continue
|
||||
}
|
||||
if !unicode.IsPrint(ch) {
|
||||
writeUtf8Literal(&buf, ch)
|
||||
continue
|
||||
}
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte(bch)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func compQuoteSQString(s string) string {
|
||||
var buf bytes.Buffer
|
||||
for _, ch := range s {
|
||||
if ch == 0 {
|
||||
break
|
||||
}
|
||||
if ch == '\'' {
|
||||
buf.WriteString("'\\''")
|
||||
continue
|
||||
}
|
||||
var bch byte
|
||||
if ch <= unicode.MaxASCII {
|
||||
bch = byte(ch)
|
||||
}
|
||||
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
|
||||
buf.WriteByte('\'')
|
||||
if bch != 0 && specialEsc[bch] != "" {
|
||||
buf.WriteString(specialEsc[bch])
|
||||
} else {
|
||||
writeUtf8Literal(&buf, ch)
|
||||
}
|
||||
buf.WriteByte('\'')
|
||||
continue
|
||||
}
|
||||
buf.WriteByte(bch)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func compQuoteString(s string, quoteType string, close bool) string {
|
||||
if quoteType != QuoteTypeANSI && quoteType != QuoteTypeLiteral {
|
||||
for _, ch := range s {
|
||||
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) || ch == '!' {
|
||||
quoteType = QuoteTypeANSI
|
||||
break
|
||||
}
|
||||
if ch == '\'' {
|
||||
if quoteType == QuoteTypeSQ {
|
||||
quoteType = QuoteTypeANSI
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if quoteType == QuoteTypeANSI {
|
||||
rtn := strconv.QuoteToASCII(s)
|
||||
rtn = "$'" + strings.ReplaceAll(rtn[1:len(rtn)-1], "'", "\\'")
|
||||
if close {
|
||||
rtn = rtn + "'"
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
if quoteType == QuoteTypeLiteral {
|
||||
return compQuoteLiteralString(s)
|
||||
}
|
||||
if quoteType == QuoteTypeSQ {
|
||||
rtn := utilfn.ShellQuote(s, false, MaxCompQuoteLen)
|
||||
if len(rtn) > 0 && rtn[0] != '\'' {
|
||||
rtn = "'" + rtn + "'"
|
||||
}
|
||||
if !close {
|
||||
rtn = rtn[0 : len(rtn)-1]
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
// QuoteTypeDQ
|
||||
return compQuoteDQString(s, close)
|
||||
}
|
||||
|
||||
func (p *CompPoint) wordAsStr(w ParsedWord) string {
|
||||
if w.Word != nil {
|
||||
return p.StmtStr[w.Word.Pos().Offset():w.Word.End().Offset()]
|
||||
}
|
||||
return w.PartialWord
|
||||
}
|
||||
|
||||
func (p *CompPoint) simpleExpandWord(w ParsedWord) (string, simpleexpand.SimpleExpandInfo) {
|
||||
ectx := simpleexpand.SimpleExpandContext{}
|
||||
if w.Word != nil {
|
||||
return simpleexpand.SimpleExpandWord(ectx, w.Word, p.StmtStr)
|
||||
}
|
||||
return simpleexpand.SimpleExpandPartialWord(ectx, w.PartialWord, false)
|
||||
}
|
||||
|
||||
func getQuoteTypePref(str string) string {
|
||||
if strings.HasPrefix(str, QuoteTypeANSI) {
|
||||
return QuoteTypeANSI
|
||||
}
|
||||
if strings.HasPrefix(str, QuoteTypeDQ) {
|
||||
return QuoteTypeDQ
|
||||
}
|
||||
if strings.HasPrefix(str, QuoteTypeSQ) {
|
||||
return QuoteTypeSQ
|
||||
}
|
||||
return QuoteTypeLiteral
|
||||
}
|
||||
|
||||
func (p *CompPoint) getCompPrefix() (string, simpleexpand.SimpleExpandInfo) {
|
||||
if p.CompWordPos == 0 {
|
||||
return "", simpleexpand.SimpleExpandInfo{}
|
||||
}
|
||||
pword := p.Words[p.CompWord]
|
||||
wordStr := p.wordAsStr(pword)
|
||||
if p.CompWordPos == len(wordStr) {
|
||||
return p.simpleExpandWord(pword)
|
||||
}
|
||||
// TODO we can do better, if p.Word is not nil, we can look for which WordPart
|
||||
// our pos is in. we can then do a normal word expand on the previous parts
|
||||
// and a partial on just the current part. this is an uncommon case though
|
||||
// and has very little upside (even bash does not expand multipart words correctly)
|
||||
partialWordStr := wordStr[:p.CompWordPos]
|
||||
return simpleexpand.SimpleExpandPartialWord(simpleexpand.SimpleExpandContext{}, partialWordStr, false)
|
||||
}
|
||||
|
||||
func (p *CompPoint) extendWord(newWord string, newWordComplete bool) utilfn.StrWithPos {
|
||||
pword := p.Words[p.CompWord]
|
||||
wordStr := p.wordAsStr(pword)
|
||||
quotePref := getQuoteTypePref(wordStr)
|
||||
needsClose := newWordComplete && (len(wordStr) == p.CompWordPos)
|
||||
wordSuffix := wordStr[p.CompWordPos:]
|
||||
newQuotedStr := compQuoteString(newWord, quotePref, needsClose)
|
||||
if needsClose && wordSuffix == "" && !strings.HasSuffix(newWord, "/") {
|
||||
newQuotedStr = newQuotedStr + " "
|
||||
}
|
||||
newPos := len(newQuotedStr)
|
||||
return utilfn.StrWithPos{Str: newQuotedStr + wordSuffix, Pos: newPos}
|
||||
}
|
||||
|
||||
// returns (extension, complete)
|
||||
func computeCompExtension(compPrefix string, crtn *CompReturn) (string, bool) {
|
||||
if crtn == nil || crtn.HasMore {
|
||||
return "", false
|
||||
}
|
||||
compStrs := crtn.GetCompStrs()
|
||||
lcp := utilfn.LongestPrefix(compPrefix, compStrs)
|
||||
if lcp == compPrefix || len(lcp) < len(compPrefix) || !strings.HasPrefix(lcp, compPrefix) {
|
||||
return "", false
|
||||
}
|
||||
return lcp[len(compPrefix):], (utilfn.ContainsStr(compStrs, lcp) && !utilfn.IsPrefix(compStrs, lcp))
|
||||
}
|
||||
|
||||
func (p *CompPoint) FullyExtend(crtn *CompReturn) utilfn.StrWithPos {
|
||||
if crtn == nil || crtn.HasMore {
|
||||
return utilfn.StrWithPos{Str: p.getOrigStr(), Pos: p.getOrigPos()}
|
||||
}
|
||||
compStrs := crtn.GetCompStrs()
|
||||
compPrefix, _ := p.getCompPrefix()
|
||||
lcp := utilfn.LongestPrefix(compPrefix, compStrs)
|
||||
if lcp == compPrefix || len(lcp) < len(compPrefix) || !strings.HasPrefix(lcp, compPrefix) {
|
||||
return utilfn.StrWithPos{Str: p.getOrigStr(), Pos: p.getOrigPos()}
|
||||
}
|
||||
newStr := p.extendWord(lcp, utilfn.ContainsStr(compStrs, lcp))
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(p.Prefix)
|
||||
for idx, w := range p.Words {
|
||||
if idx == p.CompWord {
|
||||
buf.WriteString(w.Prefix)
|
||||
buf.WriteString(newStr.Str)
|
||||
} else {
|
||||
buf.WriteString(w.Prefix)
|
||||
buf.WriteString(p.wordAsStr(w))
|
||||
}
|
||||
}
|
||||
buf.WriteString(p.Suffix)
|
||||
compWord := p.Words[p.CompWord]
|
||||
newPos := len(p.Prefix) + compWord.Offset + len(compWord.Prefix) + newStr.Pos
|
||||
return utilfn.StrWithPos{Str: buf.String(), Pos: newPos}
|
||||
}
|
||||
|
||||
func (p *CompPoint) dump() {
|
||||
if p.Prefix != "" {
|
||||
fmt.Printf("prefix: %s\n", p.Prefix)
|
||||
}
|
||||
fmt.Printf("cpos: %d %d\n", p.CompWord, p.CompWordPos)
|
||||
for idx, w := range p.Words {
|
||||
fmt.Printf("w[%d]: ", idx)
|
||||
if w.Prefix != "" {
|
||||
fmt.Printf("{%s}", w.Prefix)
|
||||
}
|
||||
if idx == p.CompWord {
|
||||
fmt.Printf("%s\n", utilfn.StrWithPos{Str: p.wordAsStr(w), Pos: p.CompWordPos})
|
||||
} else {
|
||||
fmt.Printf("%s\n", p.wordAsStr(w))
|
||||
}
|
||||
}
|
||||
if p.Suffix != "" {
|
||||
fmt.Printf("suffix: %s\n", p.Suffix)
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
}
|
||||
|
||||
var SimpleCompGenFns map[string]SimpleCompGenFnType
|
||||
|
||||
func isWhitespace(str string) bool {
|
||||
return strings.TrimSpace(str) == ""
|
||||
}
|
||||
|
||||
func splitInitialWhitespace(str string) (string, string) {
|
||||
for pos, ch := range str { // rune iteration :/
|
||||
if !unicode.IsSpace(ch) {
|
||||
return str[:pos], str[pos:]
|
||||
}
|
||||
}
|
||||
return str, ""
|
||||
}
|
||||
|
||||
func ParseCompPoint(cmdStr utilfn.StrWithPos) *CompPoint {
|
||||
fullCmdStr := cmdStr.Str
|
||||
pos := cmdStr.Pos
|
||||
// fmt.Printf("---\n")
|
||||
// fmt.Printf("cmd: %s\n", strWithCursor(fullCmdStr, pos))
|
||||
|
||||
// first, find the stmt that the pos appears in
|
||||
cmdReader := strings.NewReader(fullCmdStr)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
var foundStmt *syntax.Stmt
|
||||
var lastStmt *syntax.Stmt
|
||||
var restStartPos int
|
||||
parser.Stmts(cmdReader, func(stmt *syntax.Stmt) bool { // ignore parse errors (since stmtStr will be the unparsed part)
|
||||
restStartPos = int(stmt.End().Offset())
|
||||
lastStmt = stmt
|
||||
if uint(pos) >= stmt.Pos().Offset() && uint(pos) < stmt.End().Offset() {
|
||||
foundStmt = stmt
|
||||
return false
|
||||
}
|
||||
// fmt.Printf("stmt: [[%s]] %d:%d (%d)\n", fullCmdStr[stmt.Pos().Offset():stmt.End().Offset()], stmt.Pos().Offset(), stmt.End().Offset(), stmt.Semicolon.Offset())
|
||||
return true
|
||||
})
|
||||
restStr := fullCmdStr[restStartPos:]
|
||||
if foundStmt == nil && lastStmt != nil && isWhitespace(restStr) && lastStmt.Semicolon.Offset() == 0 {
|
||||
foundStmt = lastStmt
|
||||
}
|
||||
var rtnPoint CompPoint
|
||||
var stmtStr string
|
||||
var stmtPos int
|
||||
if foundStmt != nil {
|
||||
stmtPos = pos - int(foundStmt.Pos().Offset())
|
||||
rtnPoint.Prefix = fullCmdStr[:foundStmt.Pos().Offset()]
|
||||
if isWhitespace(fullCmdStr[foundStmt.End().Offset():]) {
|
||||
stmtStr = fullCmdStr[foundStmt.Pos().Offset():]
|
||||
rtnPoint.Suffix = ""
|
||||
} else {
|
||||
stmtStr = fullCmdStr[foundStmt.Pos().Offset():foundStmt.End().Offset()]
|
||||
rtnPoint.Suffix = fullCmdStr[foundStmt.End().Offset():]
|
||||
}
|
||||
} else {
|
||||
stmtStr = restStr
|
||||
stmtPos = pos - restStartPos
|
||||
rtnPoint.Prefix = fullCmdStr[:restStartPos]
|
||||
rtnPoint.Suffix = fullCmdStr[restStartPos+len(stmtStr):]
|
||||
}
|
||||
if stmtPos > len(stmtStr) {
|
||||
// this should not happen and will cause a jump in completed strings
|
||||
stmtPos = len(stmtStr)
|
||||
}
|
||||
// fmt.Printf("found: ((%s))%s((%s))\n", rtnPoint.Prefix, strWithCursor(stmtStr, stmtPos), rtnPoint.Suffix)
|
||||
|
||||
// now, find the word that the pos appears in within the stmt above
|
||||
rtnPoint.StmtStr = stmtStr
|
||||
stmtReader := strings.NewReader(stmtStr)
|
||||
lastWordPos := 0
|
||||
parser.Words(stmtReader, func(w *syntax.Word) bool {
|
||||
var pword ParsedWord
|
||||
pword.Offset = lastWordPos
|
||||
if int(w.Pos().Offset()) > lastWordPos {
|
||||
pword.Prefix = stmtStr[lastWordPos:w.Pos().Offset()]
|
||||
}
|
||||
pword.Word = w
|
||||
rtnPoint.Words = append(rtnPoint.Words, pword)
|
||||
lastWordPos = int(w.End().Offset())
|
||||
return true
|
||||
})
|
||||
if lastWordPos < len(stmtStr) {
|
||||
pword := ParsedWord{Offset: lastWordPos}
|
||||
pword.Prefix, pword.PartialWord = splitInitialWhitespace(stmtStr[lastWordPos:])
|
||||
rtnPoint.Words = append(rtnPoint.Words, pword)
|
||||
}
|
||||
if len(rtnPoint.Words) == 0 {
|
||||
rtnPoint.Words = append(rtnPoint.Words, ParsedWord{})
|
||||
}
|
||||
for idx, w := range rtnPoint.Words {
|
||||
wordLen := len(rtnPoint.wordAsStr(w))
|
||||
if stmtPos > w.Offset && stmtPos <= w.Offset+len(w.Prefix)+wordLen {
|
||||
rtnPoint.CompWord = idx
|
||||
rtnPoint.CompWordPos = stmtPos - w.Offset - len(w.Prefix)
|
||||
if rtnPoint.CompWordPos < 0 {
|
||||
splitCompWord(&rtnPoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &rtnPoint
|
||||
}
|
||||
|
||||
func splitCompWord(p *CompPoint) {
|
||||
w := p.Words[p.CompWord]
|
||||
prefixPos := p.CompWordPos + len(w.Prefix)
|
||||
|
||||
w1 := ParsedWord{Offset: w.Offset, Prefix: w.Prefix[:prefixPos]}
|
||||
w2 := ParsedWord{Offset: w.Offset + prefixPos, Prefix: w.Prefix[prefixPos:], Word: w.Word, PartialWord: w.PartialWord}
|
||||
p.CompWord = p.CompWord // the same (w1)
|
||||
p.CompWordPos = 0 // will be at 0 since w1 has a word length of 0
|
||||
var newWords []ParsedWord
|
||||
if p.CompWord > 0 {
|
||||
newWords = append(newWords, p.Words[0:p.CompWord]...)
|
||||
}
|
||||
newWords = append(newWords, w1, w2)
|
||||
newWords = append(newWords, p.Words[p.CompWord+1:]...)
|
||||
p.Words = newWords
|
||||
}
|
||||
|
||||
func getCompType(compPos shparse.CompletionPos) string {
|
||||
switch compPos.CompType {
|
||||
case shparse.CompTypeCommandMeta:
|
||||
return CGTypeCommandMeta
|
||||
|
||||
case shparse.CompTypeCommand:
|
||||
return CGTypeCommand
|
||||
|
||||
case shparse.CompTypeVar:
|
||||
return CGTypeVariable
|
||||
|
||||
case shparse.CompTypeArg, shparse.CompTypeBasic, shparse.CompTypeAssignment:
|
||||
return CGTypeFile
|
||||
|
||||
default:
|
||||
return CGTypeFile
|
||||
}
|
||||
}
|
||||
|
||||
func fixupVarPrefix(varPrefix string) string {
|
||||
if strings.HasPrefix(varPrefix, "${") {
|
||||
varPrefix = varPrefix[2:]
|
||||
if strings.HasSuffix(varPrefix, "}") {
|
||||
varPrefix = varPrefix[:len(varPrefix)-1]
|
||||
}
|
||||
} else if strings.HasPrefix(varPrefix, "$") {
|
||||
varPrefix = varPrefix[1:]
|
||||
}
|
||||
return varPrefix
|
||||
}
|
||||
|
||||
func DoCompGen(ctx context.Context, cmdStr utilfn.StrWithPos, compCtx CompContext) (*CompReturn, *utilfn.StrWithPos, error) {
|
||||
words := shparse.Tokenize(cmdStr.Str)
|
||||
cmds := shparse.ParseCommands(words)
|
||||
compPos := shparse.FindCompletionPos(cmds, cmdStr.Pos)
|
||||
if compPos.CompType == shparse.CompTypeInvalid {
|
||||
return nil, nil, nil
|
||||
}
|
||||
var compPrefix string
|
||||
if compPos.CompWord != nil {
|
||||
var info shparse.ExpandInfo
|
||||
compPrefix, info = shparse.SimpleExpandPrefix(shparse.ExpandContext{}, compPos.CompWord, compPos.CompWordOffset)
|
||||
if info.HasGlob || info.HasExtGlob || info.HasHistory || info.HasSpecial {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if compPos.CompType != shparse.CompTypeVar && info.HasVar {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if compPos.CompType == shparse.CompTypeVar {
|
||||
compPrefix = fixupVarPrefix(compPrefix)
|
||||
}
|
||||
}
|
||||
scType := getCompType(compPos)
|
||||
crtn, err := DoSimpleComp(ctx, scType, compPrefix, compCtx, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if compCtx.ForDisplay {
|
||||
return crtn, nil, nil
|
||||
}
|
||||
extensionStr, extensionComplete := computeCompExtension(compPrefix, crtn)
|
||||
if extensionStr == "" {
|
||||
return crtn, nil, nil
|
||||
}
|
||||
rtnSP := compPos.Extend(cmdStr, extensionStr, extensionComplete)
|
||||
return crtn, &rtnSP, nil
|
||||
}
|
||||
|
||||
func DoCompGenOld(ctx context.Context, sp utilfn.StrWithPos, compCtx CompContext) (*CompReturn, *utilfn.StrWithPos, error) {
|
||||
compPoint := ParseCompPoint(sp)
|
||||
compType := CGTypeFile
|
||||
if compPoint.CompWord == 0 {
|
||||
compType = CGTypeCommandMeta
|
||||
}
|
||||
// TODO lookup special types
|
||||
compPrefix, info := compPoint.getCompPrefix()
|
||||
if info.HasVar || info.HasGlob || info.HasExtGlob || info.HasHistory || info.HasSpecial {
|
||||
return nil, nil, nil
|
||||
}
|
||||
crtn, err := DoSimpleComp(ctx, compType, compPrefix, compCtx, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if compCtx.ForDisplay {
|
||||
return crtn, nil, nil
|
||||
}
|
||||
rtnSP := compPoint.FullyExtend(crtn)
|
||||
return crtn, &rtnSP, nil
|
||||
}
|
||||
|
||||
func SortCompReturnEntries(c *CompReturn) {
|
||||
sort.Slice(c.Entries, func(i int, j int) bool {
|
||||
e1 := c.Entries[i]
|
||||
e2 := c.Entries[j]
|
||||
if e1.Word < e2.Word {
|
||||
return true
|
||||
}
|
||||
if e1.Word == e2.Word && e1.IsMetaCmd && !e2.IsMetaCmd {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func CombineCompReturn(compType string, c1 *CompReturn, c2 *CompReturn) *CompReturn {
|
||||
if c1 == nil {
|
||||
return c2
|
||||
}
|
||||
if c2 == nil {
|
||||
return c1
|
||||
}
|
||||
var rtn CompReturn
|
||||
rtn.CompType = compType
|
||||
rtn.HasMore = c1.HasMore || c2.HasMore
|
||||
rtn.Entries = append([]CompEntry{}, c1.Entries...)
|
||||
rtn.Entries = append(rtn.Entries, c2.Entries...)
|
||||
SortCompReturnEntries(&rtn)
|
||||
return &rtn
|
||||
}
|
||||
|
||||
func (c *CompReturn) GetCompStrs() []string {
|
||||
rtn := make([]string, len(c.Entries))
|
||||
for idx, entry := range c.Entries {
|
||||
rtn[idx] = entry.Word
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (c *CompReturn) GetCompDisplayStrs() []string {
|
||||
rtn := make([]string, len(c.Entries))
|
||||
for idx, entry := range c.Entries {
|
||||
if entry.IsMetaCmd {
|
||||
rtn[idx] = "^" + entry.Word
|
||||
} else {
|
||||
rtn[idx] = entry.Word
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (p CompPoint) getOrigPos() int {
|
||||
pword := p.Words[p.CompWord]
|
||||
return len(p.Prefix) + pword.Offset + len(pword.Prefix) + p.CompWordPos
|
||||
}
|
||||
|
||||
func (p CompPoint) getOrigStr() string {
|
||||
return p.Prefix + p.StmtStr + p.Suffix
|
||||
}
|
106
wavesrv/pkg/comp/comp_test.go
Normal file
106
wavesrv/pkg/comp/comp_test.go
Normal file
@ -0,0 +1,106 @@
|
||||
package comp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func parseToSP(s string) StrWithPos {
|
||||
idx := strings.Index(s, "[*]")
|
||||
if idx == -1 {
|
||||
return StrWithPos{Str: s}
|
||||
}
|
||||
return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: idx}
|
||||
}
|
||||
|
||||
func testParse(cmdStr string, pos int) {
|
||||
fmt.Printf("cmd: %s\n", strWithCursor(cmdStr, pos))
|
||||
p := ParseCompPoint(StrWithPos{Str: cmdStr, Pos: pos})
|
||||
p.dump()
|
||||
}
|
||||
|
||||
func _Test1(t *testing.T) {
|
||||
testParse("ls ", 3)
|
||||
testParse("ls ", 4)
|
||||
testParse("ls -l foo", 4)
|
||||
testParse("ls foo; cd h", 12)
|
||||
testParse("ls foo; cd h;", 13)
|
||||
testParse("ls & foo; cd h", 12)
|
||||
testParse("ls \"he", 6)
|
||||
testParse("ls;", 3)
|
||||
testParse("ls;", 2)
|
||||
testParse("ls; cd x; ls", 8)
|
||||
testParse("cd \"foo ", 8)
|
||||
testParse("ls; { ls f", 10)
|
||||
testParse("ls; { ls -l; ls f", 17)
|
||||
testParse("ls $(ls f", 9)
|
||||
}
|
||||
|
||||
func testMiniExtend(t *testing.T, p *CompPoint, newWord string, complete bool, expectedStr string) {
|
||||
newSP := p.extendWord(newWord, complete)
|
||||
expectedSP := parseToSP(expectedStr)
|
||||
if newSP != expectedSP {
|
||||
t.Fatalf("not equal: [%s] != [%s]", newSP, expectedSP)
|
||||
} else {
|
||||
fmt.Printf("extend: %s\n", newSP)
|
||||
}
|
||||
}
|
||||
|
||||
func Test2(t *testing.T) {
|
||||
p := ParseCompPoint(parseToSP("ls f[*]"))
|
||||
testMiniExtend(t, p, "foo", false, "foo[*]")
|
||||
testMiniExtend(t, p, "foo", true, "foo [*]")
|
||||
testMiniExtend(t, p, "foo bar", true, "'foo bar' [*]")
|
||||
testMiniExtend(t, p, "foo'bar", true, `$'foo\'bar' [*]`)
|
||||
|
||||
p = ParseCompPoint(parseToSP("ls f[*]more"))
|
||||
testMiniExtend(t, p, "foo", false, "foo[*]more")
|
||||
testMiniExtend(t, p, "foo bar", false, `'foo bar[*]more`)
|
||||
testMiniExtend(t, p, "foo bar", true, `'foo bar[*]more`)
|
||||
testMiniExtend(t, p, "foo's", true, `$'foo\'s[*]more`)
|
||||
}
|
||||
|
||||
func testParseRT(t *testing.T, origSP StrWithPos) {
|
||||
p := ParseCompPoint(origSP)
|
||||
newSP := StrWithPos{Str: p.getOrigStr(), Pos: p.getOrigPos()}
|
||||
if origSP != newSP {
|
||||
t.Fatalf("not equal: [%s] != [%s]", origSP, newSP)
|
||||
}
|
||||
}
|
||||
|
||||
func Test3(t *testing.T) {
|
||||
testParseRT(t, parseToSP("ls f[*]"))
|
||||
testParseRT(t, parseToSP("ls f[*]; more $FOO"))
|
||||
testParseRT(t, parseToSP("hello; ls [*]f"))
|
||||
testParseRT(t, parseToSP("ls -l; ./foo he[*]ll more; touch foo &"))
|
||||
}
|
||||
|
||||
func testExtend(t *testing.T, origStr string, compStrs []string, expectedStr string) {
|
||||
origSP := parseToSP(origStr)
|
||||
expectedSP := parseToSP(expectedStr)
|
||||
p := ParseCompPoint(origSP)
|
||||
crtn := compsToCompReturn(compStrs, false)
|
||||
newSP := p.FullyExtend(crtn)
|
||||
if newSP != expectedSP {
|
||||
t.Fatalf("comp-fail: %s + %v => [%s] expected[%s]", origSP, compStrs, newSP, expectedSP)
|
||||
} else {
|
||||
fmt.Printf("comp: %s + %v => [%s]\n", origSP, compStrs, newSP)
|
||||
}
|
||||
}
|
||||
|
||||
func Test4(t *testing.T) {
|
||||
testExtend(t, "ls f[*]", []string{"foo"}, "ls foo [*]")
|
||||
testExtend(t, "ls f[*]", []string{"foox", "fooy"}, "ls foo[*]")
|
||||
testExtend(t, "w; ls f[*]; touch x", []string{"foo"}, "w; ls foo [*]; touch x")
|
||||
testExtend(t, "w; ls f[*] more; touch x", []string{"foo"}, "w; ls foo [*] more; touch x")
|
||||
testExtend(t, "w; ls f[*]oo; touch x", []string{"foo"}, "w; ls foo[*]oo; touch x")
|
||||
testExtend(t, `ls "f[*]`, []string{"foo"}, `ls "foo" [*]`)
|
||||
testExtend(t, `ls 'f[*]`, []string{"foo"}, `ls 'foo' [*]`)
|
||||
testExtend(t, `ls $'f[*]`, []string{"foo"}, `ls $'foo' [*]`)
|
||||
testExtend(t, `ls f[*]`, []string{"foo/"}, `ls foo/[*]`)
|
||||
testExtend(t, `ls f[*]`, []string{"foo bar"}, `ls 'foo bar' [*]`)
|
||||
testExtend(t, `ls f[*]`, []string{"f\x01\x02"}, `ls $'f\x01\x02' [*]`)
|
||||
testExtend(t, `ls "foo [*]`, []string{"foo bar"}, `ls "foo bar" [*]`)
|
||||
testExtend(t, `ls f[*]`, []string{"foo's"}, `ls $'foo\'s' [*]`)
|
||||
}
|
100
wavesrv/pkg/comp/simplecomp.go
Normal file
100
wavesrv/pkg/comp/simplecomp.go
Normal file
@ -0,0 +1,100 @@
|
||||
package comp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/prompt-server/pkg/remote"
|
||||
"github.com/commandlinedev/prompt-server/pkg/utilfn"
|
||||
)
|
||||
|
||||
var globalLock = &sync.Mutex{}
|
||||
var simpleCompMap = map[string]SimpleCompGenFnType{
|
||||
CGTypeCommand: simpleCompCommand,
|
||||
CGTypeFile: simpleCompFile,
|
||||
CGTypeDir: simpleCompDir,
|
||||
CGTypeVariable: simpleCompVar,
|
||||
}
|
||||
|
||||
type SimpleCompGenFnType = func(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error)
|
||||
|
||||
func RegisterSimpleCompFn(compType string, fn SimpleCompGenFnType) {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
if _, ok := simpleCompMap[compType]; ok {
|
||||
panic(fmt.Sprintf("simpleCompFn %q already registered", compType))
|
||||
}
|
||||
simpleCompMap[compType] = fn
|
||||
}
|
||||
|
||||
func getSimpleCompFn(compType string) SimpleCompGenFnType {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
return simpleCompMap[compType]
|
||||
}
|
||||
|
||||
func DoSimpleComp(ctx context.Context, compType string, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
|
||||
compFn := getSimpleCompFn(compType)
|
||||
if compFn == nil {
|
||||
return nil, fmt.Errorf("no simple comp fn for %q", compType)
|
||||
}
|
||||
crtn, err := compFn(ctx, prefix, compCtx, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
crtn.CompType = compType
|
||||
return crtn, nil
|
||||
}
|
||||
|
||||
func compsToCompReturn(comps []string, hasMore bool) *CompReturn {
|
||||
var rtn CompReturn
|
||||
rtn.HasMore = hasMore
|
||||
for _, comp := range comps {
|
||||
rtn.Entries = append(rtn.Entries, CompEntry{Word: comp})
|
||||
}
|
||||
return &rtn
|
||||
}
|
||||
|
||||
func doCompGen(ctx context.Context, prefix string, compType string, compCtx CompContext) (*CompReturn, error) {
|
||||
if !packet.IsValidCompGenType(compType) {
|
||||
return nil, fmt.Errorf("/_compgen invalid type '%s'", compType)
|
||||
}
|
||||
msh := remote.GetRemoteById(compCtx.RemotePtr.RemoteId)
|
||||
if msh == nil {
|
||||
return nil, fmt.Errorf("invalid remote '%s', not found", compCtx.RemotePtr)
|
||||
}
|
||||
cgPacket := packet.MakeCompGenPacket()
|
||||
cgPacket.ReqId = uuid.New().String()
|
||||
cgPacket.CompType = compType
|
||||
cgPacket.Prefix = prefix
|
||||
cgPacket.Cwd = compCtx.Cwd
|
||||
resp, err := msh.PacketRpc(ctx, cgPacket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = resp.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
comps := utilfn.GetStrArr(resp.Data, "comps")
|
||||
hasMore := utilfn.GetBool(resp.Data, "hasmore")
|
||||
return compsToCompReturn(comps, hasMore), nil
|
||||
}
|
||||
|
||||
func simpleCompFile(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
|
||||
return doCompGen(ctx, prefix, CGTypeFile, compCtx)
|
||||
}
|
||||
|
||||
func simpleCompDir(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
|
||||
return doCompGen(ctx, prefix, CGTypeDir, compCtx)
|
||||
}
|
||||
|
||||
func simpleCompVar(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
|
||||
return doCompGen(ctx, prefix, CGTypeVariable, compCtx)
|
||||
}
|
||||
|
||||
func simpleCompCommand(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
|
||||
return doCompGen(ctx, prefix, CGTypeCommand, compCtx)
|
||||
}
|
212
wavesrv/pkg/dbutil/dbutil.go
Normal file
212
wavesrv/pkg/dbutil/dbutil.go
Normal file
@ -0,0 +1,212 @@
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func QuickSetStr(strVal *string, m map[string]interface{}, name string) {
|
||||
v, ok := m[name]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ival, ok := v.(int64)
|
||||
if ok {
|
||||
*strVal = strconv.FormatInt(ival, 10)
|
||||
return
|
||||
}
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
*strVal = str
|
||||
}
|
||||
|
||||
func QuickSetInt(ival *int, m map[string]interface{}, name string) {
|
||||
v, ok := m[name]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sqlInt, ok := v.(int)
|
||||
if ok {
|
||||
*ival = sqlInt
|
||||
return
|
||||
}
|
||||
sqlInt64, ok := v.(int64)
|
||||
if ok {
|
||||
*ival = int(sqlInt64)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func QuickSetInt64(ival *int64, m map[string]interface{}, name string) {
|
||||
v, ok := m[name]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sqlInt64, ok := v.(int64)
|
||||
if ok {
|
||||
*ival = sqlInt64
|
||||
return
|
||||
}
|
||||
sqlInt, ok := v.(int)
|
||||
if ok {
|
||||
*ival = int64(sqlInt)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func QuickSetBool(bval *bool, m map[string]interface{}, name string) {
|
||||
v, ok := m[name]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sqlInt, ok := v.(int64)
|
||||
if ok {
|
||||
if sqlInt > 0 {
|
||||
*bval = true
|
||||
}
|
||||
return
|
||||
}
|
||||
sqlBool, ok := v.(bool)
|
||||
if ok {
|
||||
*bval = sqlBool
|
||||
}
|
||||
}
|
||||
|
||||
func QuickSetBytes(bval *[]byte, m map[string]interface{}, name string) {
|
||||
v, ok := m[name]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sqlBytes, ok := v.([]byte)
|
||||
if ok {
|
||||
*bval = sqlBytes
|
||||
}
|
||||
}
|
||||
|
||||
func getByteArr(m map[string]any, name string, def string) ([]byte, bool) {
|
||||
v, ok := m[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
barr, ok := v.([]byte)
|
||||
if !ok {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
barr = []byte(str)
|
||||
}
|
||||
if len(barr) == 0 {
|
||||
barr = []byte(def)
|
||||
}
|
||||
return barr, true
|
||||
}
|
||||
|
||||
func QuickSetJson(ptr interface{}, m map[string]interface{}, name string) {
|
||||
barr, ok := getByteArr(m, name, "{}")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
json.Unmarshal(barr, ptr)
|
||||
}
|
||||
|
||||
func QuickSetNullableJson(ptr interface{}, m map[string]interface{}, name string) {
|
||||
barr, ok := getByteArr(m, name, "null")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
json.Unmarshal(barr, ptr)
|
||||
}
|
||||
|
||||
func QuickSetJsonArr(ptr interface{}, m map[string]interface{}, name string) {
|
||||
barr, ok := getByteArr(m, name, "[]")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
json.Unmarshal(barr, ptr)
|
||||
}
|
||||
|
||||
func CheckNil(v interface{}) bool {
|
||||
rv := reflect.ValueOf(v)
|
||||
if !rv.IsValid() {
|
||||
return true
|
||||
}
|
||||
switch rv.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
|
||||
return rv.IsNil()
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func QuickNullableJson(v interface{}) string {
|
||||
if CheckNil(v) {
|
||||
return "null"
|
||||
}
|
||||
barr, _ := json.Marshal(v)
|
||||
return string(barr)
|
||||
}
|
||||
|
||||
func QuickJson(v interface{}) string {
|
||||
if CheckNil(v) {
|
||||
return "{}"
|
||||
}
|
||||
barr, _ := json.Marshal(v)
|
||||
return string(barr)
|
||||
}
|
||||
|
||||
func QuickJsonBytes(v interface{}) []byte {
|
||||
if CheckNil(v) {
|
||||
return []byte("{}")
|
||||
}
|
||||
barr, _ := json.Marshal(v)
|
||||
return barr
|
||||
}
|
||||
|
||||
func QuickJsonArr(v interface{}) string {
|
||||
if CheckNil(v) {
|
||||
return "[]"
|
||||
}
|
||||
barr, _ := json.Marshal(v)
|
||||
return string(barr)
|
||||
}
|
||||
|
||||
func QuickJsonArrBytes(v interface{}) []byte {
|
||||
if CheckNil(v) {
|
||||
return []byte("[]")
|
||||
}
|
||||
barr, _ := json.Marshal(v)
|
||||
return barr
|
||||
}
|
||||
|
||||
func QuickScanJson(ptr interface{}, val interface{}) error {
|
||||
barrVal, ok := val.([]byte)
|
||||
if !ok {
|
||||
strVal, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot scan '%T' into '%T'", val, ptr)
|
||||
}
|
||||
barrVal = []byte(strVal)
|
||||
}
|
||||
if len(barrVal) == 0 {
|
||||
barrVal = []byte("{}")
|
||||
}
|
||||
return json.Unmarshal(barrVal, ptr)
|
||||
}
|
||||
|
||||
func QuickValueJson(v interface{}) (driver.Value, error) {
|
||||
if CheckNil(v) {
|
||||
return "{}", nil
|
||||
}
|
||||
barr, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return string(barr), nil
|
||||
}
|
234
wavesrv/pkg/dbutil/map.go
Normal file
234
wavesrv/pkg/dbutil/map.go
Normal file
@ -0,0 +1,234 @@
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/sawka/txwrap"
|
||||
)
|
||||
|
||||
type DBMappable interface {
|
||||
UseDBMap()
|
||||
}
|
||||
|
||||
type MapEntry[T any] struct {
|
||||
Key string
|
||||
Val T
|
||||
}
|
||||
|
||||
type MapConverter interface {
|
||||
ToMap() map[string]interface{}
|
||||
FromMap(map[string]interface{}) bool
|
||||
}
|
||||
|
||||
type HasSimpleKey interface {
|
||||
GetSimpleKey() string
|
||||
}
|
||||
|
||||
type HasSimpleInt64Key interface {
|
||||
GetSimpleKey() int64
|
||||
}
|
||||
|
||||
type MapConverterPtr[T any] interface {
|
||||
MapConverter
|
||||
*T
|
||||
}
|
||||
|
||||
type DBMappablePtr[T any] interface {
|
||||
DBMappable
|
||||
*T
|
||||
}
|
||||
|
||||
func FromMap[PT MapConverterPtr[T], T any](m map[string]any) PT {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := PT(new(T))
|
||||
ok := rtn.FromMap(m)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func GetMapGen[PT MapConverterPtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) PT {
|
||||
m := tx.GetMap(query, args...)
|
||||
return FromMap[PT](m)
|
||||
}
|
||||
|
||||
func GetMappable[PT DBMappablePtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) PT {
|
||||
m := tx.GetMap(query, args...)
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := PT(new(T))
|
||||
FromDBMap(rtn, m)
|
||||
return rtn
|
||||
}
|
||||
|
||||
func SelectMappable[PT DBMappablePtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) []PT {
|
||||
var rtn []PT
|
||||
marr := tx.SelectMaps(query, args...)
|
||||
for _, m := range marr {
|
||||
if len(m) == 0 {
|
||||
continue
|
||||
}
|
||||
val := PT(new(T))
|
||||
FromDBMap(val, m)
|
||||
rtn = append(rtn, val)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func SelectMapsGen[PT MapConverterPtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) []PT {
|
||||
var rtn []PT
|
||||
marr := tx.SelectMaps(query, args...)
|
||||
for _, m := range marr {
|
||||
val := FromMap[PT](m)
|
||||
if val != nil {
|
||||
rtn = append(rtn, val)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func SelectSimpleMap[T any](tx *txwrap.TxWrap, query string, args ...interface{}) map[string]T {
|
||||
var rtn []MapEntry[T]
|
||||
tx.Select(&rtn, query, args...)
|
||||
if len(rtn) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtnMap := make(map[string]T)
|
||||
for _, entry := range rtn {
|
||||
rtnMap[entry.Key] = entry.Val
|
||||
}
|
||||
return rtnMap
|
||||
}
|
||||
|
||||
func MakeGenMap[T HasSimpleKey](arr []T) map[string]T {
|
||||
rtn := make(map[string]T)
|
||||
for _, val := range arr {
|
||||
rtn[val.GetSimpleKey()] = val
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func MakeGenMapInt64[T HasSimpleInt64Key](arr []T) map[int64]T {
|
||||
rtn := make(map[int64]T)
|
||||
for _, val := range arr {
|
||||
rtn[val.GetSimpleKey()] = val
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func isStructType(rt reflect.Type) bool {
|
||||
if rt.Kind() == reflect.Struct {
|
||||
return true
|
||||
}
|
||||
if rt.Kind() == reflect.Pointer && rt.Elem().Kind() == reflect.Struct {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isByteArrayType(t reflect.Type) bool {
|
||||
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
|
||||
}
|
||||
|
||||
func isStringMapType(t reflect.Type) bool {
|
||||
return t.Kind() == reflect.Map && t.Key().Kind() == reflect.String
|
||||
}
|
||||
|
||||
func ToDBMap(v DBMappable, useBytes bool) map[string]interface{} {
|
||||
if CheckNil(v) {
|
||||
return nil
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Pointer {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
if rv.Kind() != reflect.Struct {
|
||||
panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructToDBMap", v))
|
||||
}
|
||||
rt := rv.Type()
|
||||
m := make(map[string]interface{})
|
||||
numFields := rt.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := rt.Field(i)
|
||||
fieldVal := rv.FieldByIndex(field.Index)
|
||||
dbName := field.Tag.Get("dbmap")
|
||||
if dbName == "" {
|
||||
dbName = strings.ToLower(field.Name)
|
||||
}
|
||||
if dbName == "-" {
|
||||
continue
|
||||
}
|
||||
if isByteArrayType(field.Type) {
|
||||
m[dbName] = fieldVal.Interface()
|
||||
} else if field.Type.Kind() == reflect.Slice {
|
||||
if useBytes {
|
||||
m[dbName] = QuickJsonArrBytes(fieldVal.Interface())
|
||||
} else {
|
||||
m[dbName] = QuickJsonArr(fieldVal.Interface())
|
||||
}
|
||||
} else if isStructType(field.Type) || isStringMapType(field.Type) {
|
||||
if useBytes {
|
||||
m[dbName] = QuickJsonBytes(fieldVal.Interface())
|
||||
} else {
|
||||
m[dbName] = QuickJson(fieldVal.Interface())
|
||||
}
|
||||
} else {
|
||||
m[dbName] = fieldVal.Interface()
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func FromDBMap(v DBMappable, m map[string]interface{}) {
|
||||
if CheckNil(v) {
|
||||
panic("StructFromDBMap, v cannot be nil")
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Pointer {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
if rv.Kind() != reflect.Struct {
|
||||
panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructFromDBMap", v))
|
||||
}
|
||||
rt := rv.Type()
|
||||
numFields := rt.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := rt.Field(i)
|
||||
fieldVal := rv.FieldByIndex(field.Index)
|
||||
dbName := field.Tag.Get("dbmap")
|
||||
if dbName == "" {
|
||||
dbName = strings.ToLower(field.Name)
|
||||
}
|
||||
if dbName == "-" {
|
||||
continue
|
||||
}
|
||||
if isByteArrayType(field.Type) {
|
||||
barrVal := fieldVal.Addr().Interface()
|
||||
QuickSetBytes(barrVal.(*[]byte), m, dbName)
|
||||
} else if field.Type.Kind() == reflect.Slice {
|
||||
QuickSetJsonArr(fieldVal.Addr().Interface(), m, dbName)
|
||||
} else if isStructType(field.Type) || isStringMapType(field.Type) {
|
||||
QuickSetJson(fieldVal.Addr().Interface(), m, dbName)
|
||||
} else if field.Type.Kind() == reflect.String {
|
||||
strVal := fieldVal.Addr().Interface()
|
||||
QuickSetStr(strVal.(*string), m, dbName)
|
||||
} else if field.Type.Kind() == reflect.Int64 {
|
||||
intVal := fieldVal.Addr().Interface()
|
||||
QuickSetInt64(intVal.(*int64), m, dbName)
|
||||
} else if field.Type.Kind() == reflect.Int {
|
||||
intVal := fieldVal.Addr().Interface()
|
||||
QuickSetInt(intVal.(*int), m, dbName)
|
||||
} else if field.Type.Kind() == reflect.Bool {
|
||||
boolVal := fieldVal.Addr().Interface()
|
||||
QuickSetBool(boolVal.(*bool), m, dbName)
|
||||
} else {
|
||||
panic(fmt.Sprintf("StructFromDBMap invalid field type %v in %T", fieldVal.Type(), v))
|
||||
}
|
||||
}
|
||||
}
|
112
wavesrv/pkg/keygen/keygen.go
Normal file
112
wavesrv/pkg/keygen/keygen.go
Normal file
@ -0,0 +1,112 @@
|
||||
// Utility functions for generating and reading public/private keypairs.
|
||||
package keygen
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const p384Params = "BgUrgQQAIg=="
|
||||
|
||||
// Creates a keypair with CN=[id], private key at keyFileName, and
|
||||
// public key certificate at certFileName.
|
||||
func CreateKeyPair(keyFileName string, certFileName string, id string) error {
|
||||
privateKey, err := CreatePrivateKey(keyFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = CreateCertificate(certFileName, privateKey, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a private key at keyFileName (ECDSA, secp384r1 (P-384)), PEM format
|
||||
func CreatePrivateKey() (*ecdsa.PrivateKey, error) {
|
||||
curve := elliptic.P384() // secp384r1
|
||||
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error generating P-384 key err:%w", err)
|
||||
}
|
||||
keyFile, err := os.Create(keyFileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening file:%s err:%w", keyFileName, err)
|
||||
}
|
||||
defer keyFile.Close()
|
||||
pkBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error MarshalPKCS8PrivateKey err:%w", err)
|
||||
}
|
||||
paramsBytes, err := base64.StdEncoding.DecodeString(p384Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error decoding bytes for P-384 EC PARAMETERS err:%w", err)
|
||||
}
|
||||
var pemParamsBlock = &pem.Block{
|
||||
Type: "EC PARAMETERS",
|
||||
Bytes: paramsBytes,
|
||||
}
|
||||
err = pem.Encode(keyFile, pemParamsBlock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error writing EC PARAMETERS pem block err:%w", err)
|
||||
}
|
||||
var pemPrivateBlock = &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: pkBytes,
|
||||
}
|
||||
err = pem.Encode(keyFile, pemPrivateBlock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error writing EC PRIVATE KEY pem block err:%w", err)
|
||||
}
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
// Creates a public key certificate at certFileName using privateKey with CN=[id].
|
||||
func CreateCertificate(certFileName string, privateKey *ecdsa.PrivateKey, id string) error {
|
||||
serialNumber, err := rand.Int(rand.Reader, big.NewInt(1000000000000))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Cannot generate serial number err:%w", err)
|
||||
}
|
||||
notBefore, err := time.Parse("Jan 2 15:04:05 2006", "Jan 1 00:00:00 2020")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Cannot Parse Date err:%w", err)
|
||||
}
|
||||
notAfter, err := time.Parse("Jan 2 15:04:05 2006", "Jan 1 00:00:00 2030")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Cannot Parse Date err:%w", err)
|
||||
}
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: id,
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error running x509.CreateCertificate err:%v\n", err)
|
||||
}
|
||||
certFile, err := os.Create(certFileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error opening file:%s err:%w", certFileName, err)
|
||||
}
|
||||
defer certFile.Close()
|
||||
err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error writing CERTIFICATE pem block err:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
99
wavesrv/pkg/mapqueue/mapqueue.go
Normal file
99
wavesrv/pkg/mapqueue/mapqueue.go
Normal file
@ -0,0 +1,99 @@
|
||||
package mapqueue
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type MQEntry struct {
|
||||
Lock *sync.Mutex
|
||||
Running bool
|
||||
Queue chan func()
|
||||
}
|
||||
|
||||
type MapQueue struct {
|
||||
Lock *sync.Mutex
|
||||
M map[string]*MQEntry
|
||||
QueueSize int
|
||||
}
|
||||
|
||||
func MakeMapQueue(queueSize int) *MapQueue {
|
||||
rtn := &MapQueue{
|
||||
Lock: &sync.Mutex{},
|
||||
M: make(map[string]*MQEntry),
|
||||
QueueSize: queueSize,
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (mq *MapQueue) getEntry(id string) *MQEntry {
|
||||
mq.Lock.Lock()
|
||||
defer mq.Lock.Unlock()
|
||||
entry := mq.M[id]
|
||||
if entry == nil {
|
||||
entry = &MQEntry{
|
||||
Lock: &sync.Mutex{},
|
||||
Running: false,
|
||||
Queue: make(chan func(), mq.QueueSize),
|
||||
}
|
||||
mq.M[id] = entry
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func (entry *MQEntry) add(fn func()) error {
|
||||
select {
|
||||
case entry.Queue <- fn:
|
||||
break
|
||||
default:
|
||||
return fmt.Errorf("input queue full")
|
||||
}
|
||||
entry.tryRun()
|
||||
return nil
|
||||
}
|
||||
|
||||
func runFn(fn func()) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
log.Printf("[error] panic in MQEntry runFn: %v\n", r)
|
||||
debug.PrintStack()
|
||||
return
|
||||
}()
|
||||
fn()
|
||||
}
|
||||
|
||||
func (entry *MQEntry) tryRun() {
|
||||
entry.Lock.Lock()
|
||||
defer entry.Lock.Unlock()
|
||||
if entry.Running {
|
||||
return
|
||||
}
|
||||
if len(entry.Queue) > 0 {
|
||||
entry.Running = true
|
||||
go entry.run()
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *MQEntry) run() {
|
||||
for fn := range entry.Queue {
|
||||
runFn(fn)
|
||||
}
|
||||
entry.Lock.Lock()
|
||||
entry.Running = false
|
||||
entry.Lock.Unlock()
|
||||
entry.tryRun()
|
||||
}
|
||||
|
||||
func (mq *MapQueue) Enqueue(id string, fn func()) error {
|
||||
entry := mq.getEntry(id)
|
||||
err := entry.add(fn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot enqueue: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
628
wavesrv/pkg/pcloud/pcloud.go
Normal file
628
wavesrv/pkg/pcloud/pcloud.go
Normal file
@ -0,0 +1,628 @@
|
||||
package pcloud
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/commandlinedev/prompt-server/pkg/dbutil"
|
||||
"github.com/commandlinedev/prompt-server/pkg/rtnstate"
|
||||
"github.com/commandlinedev/prompt-server/pkg/scbase"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
)
|
||||
|
||||
const PCloudEndpoint = "https://api.getprompt.dev/central"
|
||||
const PCloudEndpointVarName = "PCLOUD_ENDPOINT"
|
||||
const APIVersion = 1
|
||||
const MaxPtyUpdateSize = (128 * 1024)
|
||||
const MaxUpdatesPerReq = 10
|
||||
const MaxUpdatesToDeDup = 1000
|
||||
const MaxUpdateWriterErrors = 3
|
||||
const PCloudDefaultTimeout = 5 * time.Second
|
||||
const PCloudWebShareUpdateTimeout = 15 * time.Second
|
||||
|
||||
// setting to 1M to be safe (max is 6M for API-GW + Lambda, but there is base64 encoding and upload time)
|
||||
// we allow one extra update past this estimated size
|
||||
const MaxUpdatePayloadSize = 1 * (1024 * 1024)
|
||||
|
||||
const TelemetryUrl = "/telemetry"
|
||||
const NoTelemetryUrl = "/no-telemetry"
|
||||
const WebShareUpdateUrl = "/auth/web-share-update"
|
||||
|
||||
var updateWriterLock = &sync.Mutex{}
|
||||
var updateWriterRunning = false
|
||||
var updateWriterNumFailures = 0
|
||||
|
||||
type AuthInfo struct {
|
||||
UserId string `json:"userid"`
|
||||
ClientId string `json:"clientid"`
|
||||
AuthKey string `json:"authkey"`
|
||||
}
|
||||
|
||||
func GetEndpoint() string {
|
||||
if !scbase.IsDevMode() {
|
||||
return PCloudEndpoint
|
||||
}
|
||||
endpoint := os.Getenv(PCloudEndpointVarName)
|
||||
if endpoint == "" || !strings.HasPrefix(endpoint, "https://") {
|
||||
panic("Invalid PCloud dev endpoint, PCLOUD_ENDPOINT not set or invalid")
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func makeAuthPostReq(ctx context.Context, apiUrl string, authInfo AuthInfo, data interface{}) (*http.Request, error) {
|
||||
var dataReader io.Reader
|
||||
if data != nil {
|
||||
byteArr, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling json for %s request: %v", apiUrl, err)
|
||||
}
|
||||
dataReader = bytes.NewReader(byteArr)
|
||||
}
|
||||
fullUrl := GetEndpoint() + apiUrl
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullUrl, dataReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating %s request: %v", apiUrl, err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-PromptAPIVersion", strconv.Itoa(APIVersion))
|
||||
req.Header.Set("X-PromptAPIUrl", apiUrl)
|
||||
req.Header.Set("X-PromptUserId", authInfo.UserId)
|
||||
req.Header.Set("X-PromptClientId", authInfo.ClientId)
|
||||
req.Header.Set("X-PromptAuthKey", authInfo.AuthKey)
|
||||
req.Close = true
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func makeAnonPostReq(ctx context.Context, apiUrl string, data interface{}) (*http.Request, error) {
|
||||
var dataReader io.Reader
|
||||
if data != nil {
|
||||
byteArr, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling json for %s request: %v", apiUrl, err)
|
||||
}
|
||||
dataReader = bytes.NewReader(byteArr)
|
||||
}
|
||||
fullUrl := GetEndpoint() + apiUrl
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullUrl, dataReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating %s request: %v", apiUrl, err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-PromptAPIVersion", strconv.Itoa(APIVersion))
|
||||
req.Header.Set("X-PromptAPIUrl", apiUrl)
|
||||
req.Close = true
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func doRequest(req *http.Request, outputObj interface{}) (*http.Response, error) {
|
||||
apiUrl := req.Header.Get("X-PromptAPIUrl")
|
||||
log.Printf("[pcloud] sending request %s %v\n", req.Method, req.URL)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error contacting pcloud %q service: %v", apiUrl, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("error reading %q response body: %v", apiUrl, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return resp, fmt.Errorf("error contacting pcloud %q service: %s", apiUrl, resp.Status)
|
||||
}
|
||||
if outputObj != nil && resp.Header.Get("Content-Type") == "application/json" {
|
||||
err = json.Unmarshal(bodyBytes, outputObj)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("error decoding json: %v", err)
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func SendTelemetry(ctx context.Context, force bool) error {
|
||||
clientData, err := sstore.EnsureClientData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot retrieve client data: %v", err)
|
||||
}
|
||||
if !force && clientData.ClientOpts.NoTelemetry {
|
||||
return nil
|
||||
}
|
||||
activity, err := sstore.GetNonUploadedActivity(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get activity: %v", err)
|
||||
}
|
||||
if len(activity) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Printf("[pcloud] sending telemetry data\n")
|
||||
dayStr := sstore.GetCurDayStr()
|
||||
input := TelemetryInputType{UserId: clientData.UserId, ClientId: clientData.ClientId, CurDay: dayStr, Activity: activity}
|
||||
req, err := makeAnonPostReq(ctx, TelemetryUrl, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = doRequest(req, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sstore.MarkActivityAsUploaded(ctx, activity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marking activity as uploaded: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SendNoTelemetryUpdate(ctx context.Context, noTelemetryVal bool) error {
|
||||
clientData, err := sstore.EnsureClientData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot retrieve client data: %v", err)
|
||||
}
|
||||
req, err := makeAnonPostReq(ctx, NoTelemetryUrl, NoTelemetryInputType{ClientId: clientData.ClientId, Value: noTelemetryVal})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = doRequest(req, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAuthInfo(ctx context.Context) (AuthInfo, error) {
|
||||
clientData, err := sstore.EnsureClientData(ctx)
|
||||
if err != nil {
|
||||
return AuthInfo{}, fmt.Errorf("cannot retrieve client data: %v", err)
|
||||
}
|
||||
return AuthInfo{UserId: clientData.UserId, ClientId: clientData.ClientId}, nil
|
||||
}
|
||||
|
||||
func defaultError(err error, estr string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New(estr)
|
||||
}
|
||||
|
||||
func MakeScreenNewUpdate(screen *sstore.ScreenType, webShareOpts sstore.ScreenWebShareOpts) *WebShareUpdateType {
|
||||
rtn := &WebShareUpdateType{
|
||||
ScreenId: screen.ScreenId,
|
||||
UpdateId: -1,
|
||||
UpdateType: sstore.UpdateType_ScreenNew,
|
||||
UpdateTs: time.Now().UnixMilli(),
|
||||
}
|
||||
rtn.Screen = &WebShareScreenType{
|
||||
ScreenId: screen.ScreenId,
|
||||
SelectedLine: int(screen.SelectedLine),
|
||||
ShareName: webShareOpts.ShareName,
|
||||
ViewKey: webShareOpts.ViewKey,
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func MakeScreenDelUpdate(screen *sstore.ScreenType, screenId string) *WebShareUpdateType {
|
||||
rtn := &WebShareUpdateType{
|
||||
ScreenId: screenId,
|
||||
UpdateId: -1,
|
||||
UpdateType: sstore.UpdateType_ScreenDel,
|
||||
UpdateTs: time.Now().UnixMilli(),
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func makeWebShareUpdate(ctx context.Context, update *sstore.ScreenUpdateType) (*WebShareUpdateType, error) {
|
||||
rtn := &WebShareUpdateType{
|
||||
ScreenId: update.ScreenId,
|
||||
LineId: update.LineId,
|
||||
UpdateId: update.UpdateId,
|
||||
UpdateType: update.UpdateType,
|
||||
UpdateTs: update.UpdateTs,
|
||||
}
|
||||
switch update.UpdateType {
|
||||
case sstore.UpdateType_ScreenNew:
|
||||
screen, err := sstore.GetScreenById(ctx, update.ScreenId)
|
||||
if err != nil || screen == nil {
|
||||
return nil, fmt.Errorf("error getting screen: %v", defaultError(err, "not found"))
|
||||
}
|
||||
rtn.Screen, err = webScreenFromScreen(screen)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting screen to web-screen: %v", err)
|
||||
}
|
||||
|
||||
case sstore.UpdateType_ScreenDel:
|
||||
break
|
||||
|
||||
case sstore.UpdateType_ScreenName, sstore.UpdateType_ScreenSelectedLine:
|
||||
screen, err := sstore.GetScreenById(ctx, update.ScreenId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting screen: %v", err)
|
||||
}
|
||||
if screen == nil || screen.WebShareOpts == nil {
|
||||
return nil, fmt.Errorf("invalid screen, not webshared (makeWebScreenUpdate)")
|
||||
}
|
||||
if update.UpdateType == sstore.UpdateType_ScreenName {
|
||||
rtn.SVal = screen.WebShareOpts.ShareName
|
||||
} else if update.UpdateType == sstore.UpdateType_ScreenSelectedLine {
|
||||
rtn.IVal = int64(screen.SelectedLine)
|
||||
}
|
||||
|
||||
case sstore.UpdateType_LineNew:
|
||||
line, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
|
||||
if err != nil || line == nil {
|
||||
return nil, fmt.Errorf("error getting line/cmd: %v", defaultError(err, "not found"))
|
||||
}
|
||||
rtn.Line, err = webLineFromLine(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting line to web-line: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
rtn.Cmd, err = webCmdFromCmd(update.LineId, cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting cmd to web-cmd: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
case sstore.UpdateType_LineDel:
|
||||
break
|
||||
|
||||
case sstore.UpdateType_LineRenderer, sstore.UpdateType_LineContentHeight:
|
||||
line, err := sstore.GetLineById(ctx, update.ScreenId, update.LineId)
|
||||
if err != nil || line == nil {
|
||||
return nil, fmt.Errorf("error getting line: %v", defaultError(err, "not found"))
|
||||
}
|
||||
if update.UpdateType == sstore.UpdateType_LineRenderer {
|
||||
rtn.SVal = line.Renderer
|
||||
} else if update.UpdateType == sstore.UpdateType_LineContentHeight {
|
||||
rtn.IVal = line.ContentHeight
|
||||
}
|
||||
|
||||
case sstore.UpdateType_CmdStatus:
|
||||
_, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
|
||||
if err != nil || cmd == nil {
|
||||
return nil, fmt.Errorf("error getting cmd: %v", defaultError(err, "not found"))
|
||||
}
|
||||
rtn.SVal = cmd.Status
|
||||
|
||||
case sstore.UpdateType_CmdTermOpts:
|
||||
_, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
|
||||
if err != nil || cmd == nil {
|
||||
return nil, fmt.Errorf("error getting cmd: %v", defaultError(err, "not found"))
|
||||
}
|
||||
rtn.TermOpts = &cmd.TermOpts
|
||||
|
||||
case sstore.UpdateType_CmdExitCode, sstore.UpdateType_CmdDurationMs:
|
||||
_, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
|
||||
if err != nil || cmd == nil {
|
||||
return nil, fmt.Errorf("error getting cmd: %v", defaultError(err, "not found"))
|
||||
}
|
||||
if update.UpdateType == sstore.UpdateType_CmdExitCode {
|
||||
rtn.IVal = int64(cmd.ExitCode)
|
||||
} else if update.UpdateType == sstore.UpdateType_CmdDurationMs {
|
||||
rtn.IVal = int64(cmd.DurationMs)
|
||||
}
|
||||
|
||||
case sstore.UpdateType_CmdRtnState:
|
||||
_, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
|
||||
if err != nil || cmd == nil {
|
||||
return nil, fmt.Errorf("error getting cmd: %v", defaultError(err, "not found"))
|
||||
}
|
||||
data, err := rtnstate.GetRtnStateDiff(ctx, update.ScreenId, cmd.LineId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot compute rtnstate: %v", err)
|
||||
}
|
||||
rtn.SVal = string(data)
|
||||
|
||||
case sstore.UpdateType_PtyPos:
|
||||
ptyPos, err := sstore.GetWebPtyPos(ctx, update.ScreenId, update.LineId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting ptypos: %v", err)
|
||||
}
|
||||
realOffset, data, err := sstore.ReadPtyOutFile(ctx, update.ScreenId, update.LineId, ptyPos, MaxPtyUpdateSize+1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting ptydata: %v", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if len(data) > MaxPtyUpdateSize {
|
||||
rtn.PtyData = &WebSharePtyData{PtyPos: realOffset, Data: data[0:MaxPtyUpdateSize], Eof: false}
|
||||
} else {
|
||||
rtn.PtyData = &WebSharePtyData{PtyPos: realOffset, Data: data, Eof: true}
|
||||
}
|
||||
|
||||
case sstore.UpdateType_LineState:
|
||||
// TODO implement!
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported update type (pcloud/makeWebScreenUpdate): %s\n", update.UpdateType)
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func finalizeWebScreenUpdate(ctx context.Context, webUpdate *WebShareUpdateType) error {
|
||||
switch webUpdate.UpdateType {
|
||||
case sstore.UpdateType_PtyPos:
|
||||
newPos := webUpdate.PtyData.PtyPos + int64(len(webUpdate.PtyData.Data))
|
||||
err := sstore.SetWebPtyPos(ctx, webUpdate.ScreenId, webUpdate.LineId, newPos)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case sstore.UpdateType_LineDel:
|
||||
err := sstore.DeleteWebPtyPos(ctx, webUpdate.ScreenId, webUpdate.LineId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := sstore.RemoveScreenUpdate(ctx, webUpdate.UpdateId)
|
||||
if err != nil {
|
||||
// this is not great, this *should* never fail and is not easy to recover from
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type webShareResponseType struct {
|
||||
Success bool `json:"success"`
|
||||
Data []*WebShareUpdateResponseType `json:"data"`
|
||||
}
|
||||
|
||||
func convertUpdate(update *sstore.ScreenUpdateType) *WebShareUpdateType {
|
||||
webUpdate, err := makeWebShareUpdate(context.Background(), update)
|
||||
if err != nil || webUpdate == nil {
|
||||
if err != nil {
|
||||
log.Printf("[pcloud] error create web-share update updateid:%d: %v", update.UpdateId, err)
|
||||
}
|
||||
// if err, or no web update created, remove the screenupdate
|
||||
removeErr := sstore.RemoveScreenUpdate(context.Background(), update.UpdateId)
|
||||
if removeErr != nil {
|
||||
// ignore this error too (although this is really problematic, there is nothing to do)
|
||||
log.Printf("[pcloud] error removing screen update updateid:%d: %v", update.UpdateId, removeErr)
|
||||
}
|
||||
}
|
||||
return webUpdate
|
||||
}
|
||||
|
||||
func DoSyncWebUpdate(webUpdate *WebShareUpdateType) error {
|
||||
authInfo, err := getAuthInfo(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get authinfo for request: %v", err)
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), PCloudDefaultTimeout)
|
||||
defer cancelFn()
|
||||
req, err := makeAuthPostReq(ctx, WebShareUpdateUrl, authInfo, []*WebShareUpdateType{webUpdate})
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create auth-post-req for %s: %v", WebShareUpdateUrl, err)
|
||||
}
|
||||
var resp webShareResponseType
|
||||
_, err = doRequest(req, &resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.Data) == 0 {
|
||||
return fmt.Errorf("invalid response received from server")
|
||||
}
|
||||
urt := resp.Data[0]
|
||||
if urt.Error != "" {
|
||||
return errors.New(urt.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DoWebUpdates(webUpdates []*WebShareUpdateType) error {
|
||||
if len(webUpdates) == 0 {
|
||||
return nil
|
||||
}
|
||||
authInfo, err := getAuthInfo(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get authinfo for request: %v", err)
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), PCloudWebShareUpdateTimeout)
|
||||
defer cancelFn()
|
||||
req, err := makeAuthPostReq(ctx, WebShareUpdateUrl, authInfo, webUpdates)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create auth-post-req for %s: %v", WebShareUpdateUrl, err)
|
||||
}
|
||||
var resp webShareResponseType
|
||||
_, err = doRequest(req, &resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
respMap := dbutil.MakeGenMapInt64(resp.Data)
|
||||
for _, update := range webUpdates {
|
||||
err = finalizeWebScreenUpdate(context.Background(), update)
|
||||
if err != nil {
|
||||
// ignore this error (nothing to do)
|
||||
log.Printf("[pcloud] error finalizing web-update: %v\n", err)
|
||||
}
|
||||
resp := respMap[update.UpdateId]
|
||||
if resp == nil {
|
||||
resp = &WebShareUpdateResponseType{Success: false, Error: "resp not found"}
|
||||
}
|
||||
if resp.Error != "" {
|
||||
log.Printf("[pcloud] error updateid:%d, type:%s %s/%s err:%v\n", update.UpdateId, update.UpdateType, update.ScreenId, update.LineId, resp.Error)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUpdateWriterRunning(running bool) {
|
||||
updateWriterLock.Lock()
|
||||
defer updateWriterLock.Unlock()
|
||||
updateWriterRunning = running
|
||||
}
|
||||
|
||||
func GetUpdateWriterRunning() bool {
|
||||
updateWriterLock.Lock()
|
||||
defer updateWriterLock.Unlock()
|
||||
return updateWriterRunning
|
||||
}
|
||||
|
||||
func StartUpdateWriter() {
|
||||
updateWriterLock.Lock()
|
||||
defer updateWriterLock.Unlock()
|
||||
if updateWriterRunning {
|
||||
return
|
||||
}
|
||||
updateWriterRunning = true
|
||||
go runWebShareUpdateWriter()
|
||||
}
|
||||
|
||||
func computeUpdateWriterBackoff() time.Duration {
|
||||
updateWriterLock.Lock()
|
||||
numFailures := updateWriterNumFailures
|
||||
updateWriterLock.Unlock()
|
||||
switch numFailures {
|
||||
case 0:
|
||||
return 0
|
||||
case 1:
|
||||
return 1 * time.Second
|
||||
case 2:
|
||||
return 2 * time.Second
|
||||
case 3:
|
||||
return 5 * time.Second
|
||||
case 4:
|
||||
return time.Minute
|
||||
case 5:
|
||||
return 5 * time.Minute
|
||||
case 6:
|
||||
return time.Hour
|
||||
default:
|
||||
return time.Hour
|
||||
}
|
||||
}
|
||||
|
||||
func incrementUpdateWriterNumFailures() {
|
||||
updateWriterLock.Lock()
|
||||
defer updateWriterLock.Unlock()
|
||||
updateWriterNumFailures++
|
||||
}
|
||||
|
||||
func ResetUpdateWriterNumFailures() {
|
||||
updateWriterLock.Lock()
|
||||
defer updateWriterLock.Unlock()
|
||||
updateWriterNumFailures = 0
|
||||
}
|
||||
|
||||
func GetUpdateWriterNumFailures() int {
|
||||
updateWriterLock.Lock()
|
||||
defer updateWriterLock.Unlock()
|
||||
return updateWriterNumFailures
|
||||
}
|
||||
|
||||
type updateKey struct {
|
||||
ScreenId string
|
||||
LineId string
|
||||
UpdateType string
|
||||
}
|
||||
|
||||
func DeDupUpdates(ctx context.Context, updateArr []*sstore.ScreenUpdateType) ([]*sstore.ScreenUpdateType, error) {
|
||||
var rtn []*sstore.ScreenUpdateType
|
||||
var idsToDelete []int64
|
||||
umap := make(map[updateKey]bool)
|
||||
for _, update := range updateArr {
|
||||
key := updateKey{ScreenId: update.ScreenId, LineId: update.LineId, UpdateType: update.UpdateType}
|
||||
if umap[key] {
|
||||
idsToDelete = append(idsToDelete, update.UpdateId)
|
||||
continue
|
||||
}
|
||||
umap[key] = true
|
||||
rtn = append(rtn, update)
|
||||
}
|
||||
if len(idsToDelete) > 0 {
|
||||
err := sstore.RemoveScreenUpdates(ctx, idsToDelete)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error trying to delete screenupdates: %v\n", err)
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func runWebShareUpdateWriter() {
|
||||
defer func() {
|
||||
setUpdateWriterRunning(false)
|
||||
}()
|
||||
log.Printf("[pcloud] starting update writer\n")
|
||||
numErrors := 0
|
||||
for {
|
||||
if numErrors > MaxUpdateWriterErrors {
|
||||
log.Printf("[pcloud] update-writer, too many errors, exiting\n")
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
fullUpdateArr, err := sstore.GetScreenUpdates(context.Background(), MaxUpdatesToDeDup)
|
||||
if err != nil {
|
||||
log.Printf("[pcloud] error retrieving updates: %v", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
numErrors++
|
||||
continue
|
||||
}
|
||||
updateArr, err := DeDupUpdates(context.Background(), fullUpdateArr)
|
||||
if err != nil {
|
||||
log.Printf("[pcloud] error deduping screenupdates: %v", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
numErrors++
|
||||
continue
|
||||
}
|
||||
numErrors = 0
|
||||
|
||||
var webUpdateArr []*WebShareUpdateType
|
||||
totalSize := 0
|
||||
for _, update := range updateArr {
|
||||
webUpdate := convertUpdate(update)
|
||||
if webUpdate == nil {
|
||||
continue
|
||||
}
|
||||
webUpdateArr = append(webUpdateArr, webUpdate)
|
||||
totalSize += webUpdate.GetEstimatedSize()
|
||||
if totalSize > MaxUpdatePayloadSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(webUpdateArr) == 0 {
|
||||
sstore.UpdateWriterCheckMoreData()
|
||||
continue
|
||||
}
|
||||
err = DoWebUpdates(webUpdateArr)
|
||||
if err != nil {
|
||||
incrementUpdateWriterNumFailures()
|
||||
backoffTime := computeUpdateWriterBackoff()
|
||||
log.Printf("[pcloud] error processing %d web-updates (backoff=%v): %v\n", len(webUpdateArr), backoffTime, err)
|
||||
updateBackoffSleep(backoffTime)
|
||||
continue
|
||||
}
|
||||
log.Printf("[pcloud] sent %d web-updates\n", len(webUpdateArr))
|
||||
var debugStrs []string
|
||||
for _, webUpdate := range webUpdateArr {
|
||||
debugStrs = append(debugStrs, webUpdate.String())
|
||||
}
|
||||
log.Printf("[pcloud] updates: %s\n", strings.Join(debugStrs, " "))
|
||||
ResetUpdateWriterNumFailures()
|
||||
}
|
||||
}
|
||||
|
||||
// todo fix this, set deadline, check with condition variable, backoff then just needs to notify
|
||||
func updateBackoffSleep(backoffTime time.Duration) {
|
||||
var totalSleep time.Duration
|
||||
for {
|
||||
sleepTime := time.Second
|
||||
totalSleep += sleepTime
|
||||
time.Sleep(sleepTime)
|
||||
if totalSleep >= backoffTime {
|
||||
break
|
||||
}
|
||||
numFailures := GetUpdateWriterNumFailures()
|
||||
if numFailures == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
196
wavesrv/pkg/pcloud/pclouddata.go
Normal file
196
wavesrv/pkg/pcloud/pclouddata.go
Normal file
@ -0,0 +1,196 @@
|
||||
package pcloud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/commandlinedev/prompt-server/pkg/remote"
|
||||
"github.com/commandlinedev/prompt-server/pkg/rtnstate"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
)
|
||||
|
||||
type NoTelemetryInputType struct {
|
||||
ClientId string `json:"clientid"`
|
||||
Value bool `json:"value"`
|
||||
}
|
||||
|
||||
type TelemetryInputType struct {
|
||||
UserId string `json:"userid"`
|
||||
ClientId string `json:"clientid"`
|
||||
CurDay string `json:"curday"`
|
||||
Activity []*sstore.ActivityType `json:"activity"`
|
||||
}
|
||||
|
||||
type WebShareUpdateType struct {
|
||||
ScreenId string `json:"screenid"`
|
||||
LineId string `json:"lineid"`
|
||||
UpdateId int64 `json:"updateid"`
|
||||
UpdateType string `json:"updatetype"`
|
||||
UpdateTs int64 `json:"updatets"`
|
||||
|
||||
Screen *WebShareScreenType `json:"screen,omitempty"`
|
||||
Line *WebShareLineType `json:"line,omitempty"`
|
||||
Cmd *WebShareCmdType `json:"cmd,omitempty"`
|
||||
PtyData *WebSharePtyData `json:"ptydata,omitempty"`
|
||||
SVal string `json:"sval,omitempty"`
|
||||
IVal int64 `json:"ival,omitempty"`
|
||||
BVal bool `json:"bval,omitempty"`
|
||||
TermOpts *sstore.TermOpts `json:"termopts,omitempty"`
|
||||
}
|
||||
|
||||
const EstimatedSizePadding = 100
|
||||
|
||||
func (update *WebShareUpdateType) GetEstimatedSize() int {
|
||||
barr, _ := json.Marshal(update)
|
||||
return len(barr) + 100
|
||||
}
|
||||
|
||||
func (update *WebShareUpdateType) String() string {
|
||||
var idStr string
|
||||
if update.LineId != "" && update.ScreenId != "" {
|
||||
idStr = fmt.Sprintf("%s:%s", update.ScreenId[0:8], update.LineId[0:8])
|
||||
} else if update.ScreenId != "" {
|
||||
idStr = update.ScreenId[0:8]
|
||||
}
|
||||
if update.UpdateType == sstore.UpdateType_PtyPos && update.PtyData != nil {
|
||||
return fmt.Sprintf("ptydata[%s][%d:%d]", idStr, update.PtyData.PtyPos, len(update.PtyData.Data))
|
||||
}
|
||||
return fmt.Sprintf("%s[%s]", update.UpdateType, idStr)
|
||||
}
|
||||
|
||||
type WebShareUpdateResponseType struct {
|
||||
UpdateId int64 `json:"updateid"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (ur *WebShareUpdateResponseType) GetSimpleKey() int64 {
|
||||
return ur.UpdateId
|
||||
}
|
||||
|
||||
type WebShareRemote struct {
|
||||
RemoteId string `json:"remoteid"`
|
||||
Alias string `json:"alias,omitempty"`
|
||||
CanonicalName string `json:"canonicalname"`
|
||||
Name string `json:"name,omitempty"`
|
||||
HomeDir string `json:"homedir,omitempty"`
|
||||
IsRoot bool `json:"isroot,omitempty"`
|
||||
}
|
||||
|
||||
type WebShareScreenType struct {
|
||||
ScreenId string `json:"screenid"`
|
||||
ShareName string `json:"sharename"`
|
||||
ViewKey string `json:"viewkey"`
|
||||
SelectedLine int `json:"selectedline"`
|
||||
}
|
||||
|
||||
func webRemoteFromRemote(rptr sstore.RemotePtrType, r *sstore.RemoteType) *WebShareRemote {
|
||||
return &WebShareRemote{
|
||||
RemoteId: r.RemoteId,
|
||||
Alias: r.RemoteAlias,
|
||||
CanonicalName: r.RemoteCanonicalName,
|
||||
Name: rptr.Name,
|
||||
HomeDir: r.StateVars["home"],
|
||||
IsRoot: r.StateVars["remoteuser"] == "root",
|
||||
}
|
||||
}
|
||||
|
||||
func webScreenFromScreen(s *sstore.ScreenType) (*WebShareScreenType, error) {
|
||||
if s == nil || s.ScreenId == "" {
|
||||
return nil, fmt.Errorf("invalid nil screen")
|
||||
}
|
||||
if s.WebShareOpts == nil {
|
||||
return nil, fmt.Errorf("invalid screen, no WebShareOpts")
|
||||
}
|
||||
if s.WebShareOpts.ViewKey == "" {
|
||||
return nil, fmt.Errorf("invalid screen, no ViewKey")
|
||||
}
|
||||
var shareName string
|
||||
if s.WebShareOpts.ShareName != "" {
|
||||
shareName = s.WebShareOpts.ShareName
|
||||
} else {
|
||||
shareName = s.Name
|
||||
}
|
||||
return &WebShareScreenType{ScreenId: s.ScreenId, ShareName: shareName, ViewKey: s.WebShareOpts.ViewKey, SelectedLine: int(s.SelectedLine)}, nil
|
||||
}
|
||||
|
||||
type WebShareLineType struct {
|
||||
LineId string `json:"lineid"`
|
||||
Ts int64 `json:"ts"`
|
||||
LineNum int64 `json:"linenum"`
|
||||
LineType string `json:"linetype"`
|
||||
ContentHeight int64 `json:"contentheight"`
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
func webLineFromLine(line *sstore.LineType) (*WebShareLineType, error) {
|
||||
rtn := &WebShareLineType{
|
||||
LineId: line.LineId,
|
||||
Ts: line.Ts,
|
||||
LineNum: line.LineNum,
|
||||
LineType: line.LineType,
|
||||
ContentHeight: line.ContentHeight,
|
||||
Renderer: line.Renderer,
|
||||
Text: line.Text,
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
type WebShareCmdType struct {
|
||||
LineId string `json:"lineid"`
|
||||
CmdStr string `json:"cmdstr"`
|
||||
RawCmdStr string `json:"rawcmdstr"`
|
||||
Remote *WebShareRemote `json:"remote"`
|
||||
FeState sstore.FeStateType `json:"festate"`
|
||||
TermOpts sstore.TermOpts `json:"termopts"`
|
||||
Status string `json:"status"`
|
||||
CmdPid int `json:"cmdpid"`
|
||||
RemotePid int `json:"remotepid"`
|
||||
DoneTs int64 `json:"donets,omitempty"`
|
||||
ExitCode int `json:"exitcode,omitempty"`
|
||||
DurationMs int `json:"durationms,omitempty"`
|
||||
RtnState bool `json:"rtnstate,omitempty"`
|
||||
RtnStateStr string `json:"rtnstatestr,omitempty"`
|
||||
}
|
||||
|
||||
func webCmdFromCmd(lineId string, cmd *sstore.CmdType) (*WebShareCmdType, error) {
|
||||
if cmd.Remote.RemoteId == "" {
|
||||
return nil, fmt.Errorf("invalid cmd, remoteptr has no remoteid")
|
||||
}
|
||||
remote := remote.GetRemoteCopyById(cmd.Remote.RemoteId)
|
||||
if remote == nil {
|
||||
return nil, fmt.Errorf("invalid cmd, cannot retrieve remote:%s", cmd.Remote.RemoteId)
|
||||
}
|
||||
webRemote := webRemoteFromRemote(cmd.Remote, remote)
|
||||
rtn := &WebShareCmdType{
|
||||
LineId: lineId,
|
||||
CmdStr: cmd.CmdStr,
|
||||
RawCmdStr: cmd.RawCmdStr,
|
||||
Remote: webRemote,
|
||||
FeState: cmd.FeState,
|
||||
TermOpts: cmd.TermOpts,
|
||||
Status: cmd.Status,
|
||||
CmdPid: cmd.CmdPid,
|
||||
RemotePid: cmd.RemotePid,
|
||||
DoneTs: cmd.DoneTs,
|
||||
ExitCode: cmd.ExitCode,
|
||||
DurationMs: cmd.DurationMs,
|
||||
RtnState: cmd.RtnState,
|
||||
}
|
||||
if cmd.RtnState {
|
||||
barr, err := rtnstate.GetRtnStateDiff(context.Background(), cmd.ScreenId, cmd.LineId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating rtnstate diff for cmd:%s: %v", cmd.LineId, err)
|
||||
}
|
||||
rtn.RtnStateStr = string(barr)
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
type WebSharePtyData struct {
|
||||
PtyPos int64 `json:"ptypos"`
|
||||
Data []byte `json:"data"`
|
||||
Eof bool `json:"-"` // internal use
|
||||
}
|
199
wavesrv/pkg/promptenc/promptenc.go
Normal file
199
wavesrv/pkg/promptenc/promptenc.go
Normal file
@ -0,0 +1,199 @@
|
||||
package promptenc
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
ccp "golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
const EncTagName = "enc"
|
||||
const EncFieldIndicator = "*"
|
||||
|
||||
type Encryptor struct {
|
||||
Key []byte
|
||||
AEAD cipher.AEAD
|
||||
}
|
||||
|
||||
type HasOData interface {
|
||||
GetOData() string
|
||||
}
|
||||
|
||||
func readRandBytes(n int) ([]byte, error) {
|
||||
rtn := make([]byte, n)
|
||||
_, err := io.ReadFull(rand.Reader, rtn)
|
||||
return rtn, err
|
||||
}
|
||||
|
||||
func MakeRandomEncryptor() (*Encryptor, error) {
|
||||
key, err := readRandBytes(ccp.KeySize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn := &Encryptor{Key: key}
|
||||
rtn.AEAD, err = ccp.NewX(rtn.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func MakeEncryptor(key []byte) (*Encryptor, error) {
|
||||
var err error
|
||||
rtn := &Encryptor{Key: key}
|
||||
rtn.AEAD, err = ccp.NewX(rtn.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func MakeEncryptorB64(key64 string) (*Encryptor, error) {
|
||||
keyBytes, err := base64.RawURLEncoding.DecodeString(key64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return MakeEncryptor(keyBytes)
|
||||
}
|
||||
|
||||
func (enc *Encryptor) EncryptData(plainText []byte, odata string) ([]byte, error) {
|
||||
outputBuf := make([]byte, enc.AEAD.NonceSize()+enc.AEAD.Overhead()+len(plainText))
|
||||
nonce := outputBuf[0:enc.AEAD.NonceSize()]
|
||||
_, err := io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// we're going to append the cipherText to nonce. so the encrypted data is [nonce][ciphertext]
|
||||
// note that outputbuf should be the correct size to hold the rtn value
|
||||
rtn := enc.AEAD.Seal(nonce, nonce, plainText, []byte(odata))
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (enc *Encryptor) DecryptData(encData []byte, odata string) (map[string]interface{}, error) {
|
||||
minLen := enc.AEAD.NonceSize() + enc.AEAD.Overhead()
|
||||
if len(encData) < minLen {
|
||||
return nil, fmt.Errorf("invalid encdata, len:%d is less than minimum len:%d", len(encData), minLen)
|
||||
}
|
||||
m := make(map[string]interface{})
|
||||
nonce := encData[0:enc.AEAD.NonceSize()]
|
||||
cipherText := encData[enc.AEAD.NonceSize():]
|
||||
plainText, err := enc.AEAD.Open(nil, nonce, cipherText, []byte(odata))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = json.Unmarshal(plainText, &m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type EncryptMeta struct {
|
||||
EncField *reflect.StructField
|
||||
PlainFields map[string]reflect.StructField
|
||||
}
|
||||
|
||||
func isByteArrayType(t reflect.Type) bool {
|
||||
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
|
||||
}
|
||||
|
||||
func metaFromType(v interface{}) (*EncryptMeta, error) {
|
||||
if v == nil {
|
||||
return nil, fmt.Errorf("Encryptor cannot encrypt nil")
|
||||
}
|
||||
rt := reflect.TypeOf(v)
|
||||
if rt.Kind() != reflect.Pointer {
|
||||
return nil, fmt.Errorf("Encryptor invalid type %T, not a pointer type", v)
|
||||
}
|
||||
rtElem := rt.Elem()
|
||||
if rtElem.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("Encryptor invalid type %T, not a pointer to struct type", v)
|
||||
}
|
||||
meta := &EncryptMeta{}
|
||||
meta.PlainFields = make(map[string]reflect.StructField)
|
||||
numFields := rtElem.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := rtElem.Field(i)
|
||||
encTag := field.Tag.Get(EncTagName)
|
||||
if encTag == "" {
|
||||
continue
|
||||
}
|
||||
if encTag == EncFieldIndicator {
|
||||
if meta.EncField != nil {
|
||||
return nil, fmt.Errorf("Encryptor, type %T has two enc fields set (*)", v)
|
||||
}
|
||||
if !isByteArrayType(field.Type) {
|
||||
return nil, fmt.Errorf("Encryptor, type %T enc field %q is not []byte", v, field.Name)
|
||||
}
|
||||
meta.EncField = &field
|
||||
continue
|
||||
}
|
||||
if _, found := meta.PlainFields[encTag]; found {
|
||||
return nil, fmt.Errorf("Encryptor, type %T has two enc fields with tag %q", v, encTag)
|
||||
}
|
||||
meta.PlainFields[encTag] = field
|
||||
}
|
||||
if meta.EncField == nil {
|
||||
return nil, fmt.Errorf("Encryptor, type %T has no enc (*) field", v)
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (enc *Encryptor) EncryptODS(v HasOData) error {
|
||||
odata := v.GetOData()
|
||||
return enc.EncryptStructFields(v, odata)
|
||||
}
|
||||
|
||||
func (enc *Encryptor) DecryptODS(v HasOData) error {
|
||||
odata := v.GetOData()
|
||||
return enc.DecryptStructFields(v, odata)
|
||||
}
|
||||
|
||||
func (enc *Encryptor) EncryptStructFields(v interface{}, odata string) error {
|
||||
encMeta, err := metaFromType(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rvPtr := reflect.ValueOf(v)
|
||||
rv := rvPtr.Elem()
|
||||
m := make(map[string]interface{})
|
||||
for jsonKey, field := range encMeta.PlainFields {
|
||||
fieldVal := rv.FieldByIndex(field.Index)
|
||||
m[jsonKey] = fieldVal.Interface()
|
||||
}
|
||||
barr, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cipherText, err := enc.EncryptData(barr, odata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
encFieldValue := rv.FieldByIndex(encMeta.EncField.Index)
|
||||
encFieldValue.SetBytes(cipherText)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (enc *Encryptor) DecryptStructFields(v interface{}, odata string) error {
|
||||
encMeta, err := metaFromType(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rvPtr := reflect.ValueOf(v)
|
||||
rv := rvPtr.Elem()
|
||||
cipherText := rv.FieldByIndex(encMeta.EncField.Index).Bytes()
|
||||
m, err := enc.DecryptData(cipherText, odata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for jsonKey, field := range encMeta.PlainFields {
|
||||
val := m[jsonKey]
|
||||
rv.FieldByIndex(field.Index).Set(reflect.ValueOf(val))
|
||||
}
|
||||
return nil
|
||||
}
|
53
wavesrv/pkg/remote/circlelog.go
Normal file
53
wavesrv/pkg/remote/circlelog.go
Normal file
@ -0,0 +1,53 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type CircleLog struct {
|
||||
Lock *sync.Mutex
|
||||
StartPos int
|
||||
Log []string
|
||||
MaxSize int
|
||||
}
|
||||
|
||||
func MakeCircleLog(maxSize int) *CircleLog {
|
||||
if maxSize <= 0 {
|
||||
panic("invalid maxsize, must be >= 0")
|
||||
}
|
||||
rtn := &CircleLog{
|
||||
Lock: &sync.Mutex{},
|
||||
StartPos: 0,
|
||||
Log: make([]string, 0, maxSize),
|
||||
MaxSize: maxSize,
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (l *CircleLog) Add(s string) {
|
||||
l.Lock.Lock()
|
||||
defer l.Lock.Unlock()
|
||||
if len(l.Log) < l.MaxSize {
|
||||
l.Log = append(l.Log, s)
|
||||
return
|
||||
}
|
||||
l.Log[l.StartPos] = s
|
||||
l.StartPos = (l.StartPos + 1) % l.MaxSize
|
||||
}
|
||||
|
||||
func (l *CircleLog) Addf(sfmt string, args ...interface{}) {
|
||||
// no lock here, since l.Add() is synchronized
|
||||
s := fmt.Sprintf(sfmt, args...)
|
||||
l.Add(s)
|
||||
}
|
||||
|
||||
func (l *CircleLog) GetEntries() []string {
|
||||
l.Lock.Lock()
|
||||
defer l.Lock.Unlock()
|
||||
rtn := make([]string, len(l.Log))
|
||||
for i := 0; i < len(l.Log); i++ {
|
||||
rtn[i] = l.Log[(l.StartPos+i)%l.MaxSize]
|
||||
}
|
||||
return rtn
|
||||
}
|
147
wavesrv/pkg/remote/openai/openai.go
Normal file
147
wavesrv/pkg/remote/openai/openai.go
Normal file
@ -0,0 +1,147 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
openaiapi "github.com/sashabaranov/go-openai"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
)
|
||||
|
||||
// https://github.com/tiktoken-go/tokenizer
|
||||
|
||||
const DefaultMaxTokens = 1000
|
||||
const DefaultModel = "gpt-3.5-turbo"
|
||||
const DefaultStreamChanSize = 10
|
||||
|
||||
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
|
||||
if resp.Usage.TotalTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
return &packet.OpenAIUsageType{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func convertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
||||
var rtn []openaiapi.ChatCompletionMessage
|
||||
for _, p := range prompt {
|
||||
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
||||
rtn = append(rtn, msg)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("no openai opts found")
|
||||
}
|
||||
if opts.Model == "" {
|
||||
return nil, fmt.Errorf("no openai model specified")
|
||||
}
|
||||
if opts.APIToken == "" {
|
||||
return nil, fmt.Errorf("no api token")
|
||||
}
|
||||
client := openaiapi.NewClient(opts.APIToken)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Messages: convertPrompt(prompt),
|
||||
MaxTokens: opts.MaxTokens,
|
||||
}
|
||||
if opts.MaxChoices > 1 {
|
||||
req.N = opts.MaxChoices
|
||||
}
|
||||
apiResp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calling openai API: %v", err)
|
||||
}
|
||||
if len(apiResp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("no response received")
|
||||
}
|
||||
return marshalResponse(apiResp), nil
|
||||
}
|
||||
|
||||
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("no openai opts found")
|
||||
}
|
||||
if opts.Model == "" {
|
||||
return nil, fmt.Errorf("no openai model specified")
|
||||
}
|
||||
if opts.APIToken == "" {
|
||||
return nil, fmt.Errorf("no api token")
|
||||
}
|
||||
client := openaiapi.NewClient(opts.APIToken)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Messages: convertPrompt(prompt),
|
||||
MaxTokens: opts.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
if opts.MaxChoices > 1 {
|
||||
req.N = opts.MaxChoices
|
||||
}
|
||||
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calling openai API: %v", err)
|
||||
}
|
||||
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
|
||||
go func() {
|
||||
sentHeader := false
|
||||
defer close(rtn)
|
||||
for {
|
||||
streamResp, err := apiResp.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
errPk := CreateErrorPacket(fmt.Sprintf("error in recv of streaming data: %v", err))
|
||||
rtn <- errPk
|
||||
break
|
||||
}
|
||||
if streamResp.Model != "" && !sentHeader {
|
||||
pk := packet.MakeOpenAIPacket()
|
||||
pk.Model = streamResp.Model
|
||||
pk.Created = streamResp.Created
|
||||
rtn <- pk
|
||||
sentHeader = true
|
||||
}
|
||||
for _, choice := range streamResp.Choices {
|
||||
pk := packet.MakeOpenAIPacket()
|
||||
pk.Index = choice.Index
|
||||
pk.Text = choice.Delta.Content
|
||||
pk.FinishReason = choice.FinishReason
|
||||
rtn <- pk
|
||||
}
|
||||
}
|
||||
}()
|
||||
return rtn, err
|
||||
}
|
||||
|
||||
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*packet.OpenAIPacketType {
|
||||
var rtn []*packet.OpenAIPacketType
|
||||
headerPk := packet.MakeOpenAIPacket()
|
||||
headerPk.Model = resp.Model
|
||||
headerPk.Created = resp.Created
|
||||
headerPk.Usage = convertUsage(resp)
|
||||
rtn = append(rtn, headerPk)
|
||||
for _, choice := range resp.Choices {
|
||||
choicePk := packet.MakeOpenAIPacket()
|
||||
choicePk.Index = choice.Index
|
||||
choicePk.Text = choice.Message.Content
|
||||
choicePk.FinishReason = choice.FinishReason
|
||||
rtn = append(rtn, choicePk)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
|
||||
errPk := packet.MakeOpenAIPacket()
|
||||
errPk.FinishReason = "error"
|
||||
errPk.Error = errStr
|
||||
return errPk
|
||||
}
|
2104
wavesrv/pkg/remote/remote.go
Normal file
2104
wavesrv/pkg/remote/remote.go
Normal file
File diff suppressed because it is too large
Load Diff
67
wavesrv/pkg/remote/updatequeue.go
Normal file
67
wavesrv/pkg/remote/updatequeue.go
Normal file
@ -0,0 +1,67 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
)
|
||||
|
||||
func startCmdWait(ck base.CommandKey) {
|
||||
GlobalStore.Lock.Lock()
|
||||
defer GlobalStore.Lock.Unlock()
|
||||
GlobalStore.CmdWaitMap[ck] = nil
|
||||
}
|
||||
|
||||
func pushCmdWaitIfRequired(ck base.CommandKey, fn func()) bool {
|
||||
GlobalStore.Lock.Lock()
|
||||
defer GlobalStore.Lock.Unlock()
|
||||
fns, ok := GlobalStore.CmdWaitMap[ck]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
fns = append(fns, fn)
|
||||
GlobalStore.CmdWaitMap[ck] = fns
|
||||
return true
|
||||
}
|
||||
|
||||
func runCmdUpdateFn(ck base.CommandKey, fn func()) {
|
||||
pushed := pushCmdWaitIfRequired(ck, fn)
|
||||
if pushed {
|
||||
return
|
||||
}
|
||||
fn()
|
||||
}
|
||||
|
||||
func runCmdWaitFns(ck base.CommandKey) {
|
||||
for {
|
||||
fn := removeFirstCmdWaitFn(ck)
|
||||
if fn == nil {
|
||||
break
|
||||
}
|
||||
fn()
|
||||
}
|
||||
}
|
||||
|
||||
func removeFirstCmdWaitFn(ck base.CommandKey) func() {
|
||||
GlobalStore.Lock.Lock()
|
||||
defer GlobalStore.Lock.Unlock()
|
||||
|
||||
fns := GlobalStore.CmdWaitMap[ck]
|
||||
if len(fns) == 0 {
|
||||
delete(GlobalStore.CmdWaitMap, ck)
|
||||
return nil
|
||||
}
|
||||
fn := fns[0]
|
||||
GlobalStore.CmdWaitMap[ck] = fns[1:]
|
||||
return fn
|
||||
}
|
||||
|
||||
func removeCmdWait(ck base.CommandKey) {
|
||||
GlobalStore.Lock.Lock()
|
||||
defer GlobalStore.Lock.Unlock()
|
||||
|
||||
fns := GlobalStore.CmdWaitMap[ck]
|
||||
if len(fns) == 0 {
|
||||
delete(GlobalStore.CmdWaitMap, ck)
|
||||
return
|
||||
}
|
||||
go runCmdWaitFns(ck)
|
||||
}
|
196
wavesrv/pkg/rtnstate/rtnstate.go
Normal file
196
wavesrv/pkg/rtnstate/rtnstate.go
Normal file
@ -0,0 +1,196 @@
|
||||
package rtnstate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/apishell/pkg/shexec"
|
||||
"github.com/commandlinedev/apishell/pkg/simpleexpand"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
"github.com/commandlinedev/prompt-server/pkg/utilfn"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
func parseAliasStmt(stmt *syntax.Stmt, sourceStr string) (string, string, error) {
|
||||
cmd := stmt.Cmd
|
||||
callExpr, ok := cmd.(*syntax.CallExpr)
|
||||
if !ok {
|
||||
return "", "", fmt.Errorf("wrong cmd type for alias")
|
||||
}
|
||||
if len(callExpr.Args) != 2 {
|
||||
return "", "", fmt.Errorf("wrong number of words in alias expr wordslen=%d", len(callExpr.Args))
|
||||
}
|
||||
firstWord := callExpr.Args[0]
|
||||
if firstWord.Lit() != "alias" {
|
||||
return "", "", fmt.Errorf("invalid alias cmd word (not 'alias')")
|
||||
}
|
||||
secondWord := callExpr.Args[1]
|
||||
var ectx simpleexpand.SimpleExpandContext // no homedir, do not want ~ expansion
|
||||
val, _ := simpleexpand.SimpleExpandWord(ectx, secondWord, sourceStr)
|
||||
eqIdx := strings.Index(val, "=")
|
||||
if eqIdx == -1 {
|
||||
return "", "", fmt.Errorf("no '=' in alias definition")
|
||||
}
|
||||
return val[0:eqIdx], val[eqIdx+1:], nil
|
||||
}
|
||||
|
||||
func ParseAliases(aliases string) (map[string]string, error) {
|
||||
r := strings.NewReader(aliases)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(r, "aliases")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
for _, stmt := range file.Stmts {
|
||||
aliasName, aliasVal, err := parseAliasStmt(stmt, aliases)
|
||||
if err != nil {
|
||||
// fmt.Printf("stmt-err: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if aliasName != "" {
|
||||
rtn[aliasName] = aliasVal
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func parseFuncStmt(stmt *syntax.Stmt, source string) (string, string, error) {
|
||||
cmd := stmt.Cmd
|
||||
funcDecl, ok := cmd.(*syntax.FuncDecl)
|
||||
if !ok {
|
||||
return "", "", fmt.Errorf("cmd not FuncDecl")
|
||||
}
|
||||
name := funcDecl.Name.Value
|
||||
// fmt.Printf("func: [%s]\n", name)
|
||||
funcBody := funcDecl.Body
|
||||
// fmt.Printf(" %d:%d\n", funcBody.Cmd.Pos().Offset(), funcBody.Cmd.End().Offset())
|
||||
bodyStr := source[funcBody.Cmd.Pos().Offset():funcBody.Cmd.End().Offset()]
|
||||
// fmt.Printf("<<<\n%s\n>>>\n", bodyStr)
|
||||
// fmt.Printf("\n")
|
||||
return name, bodyStr, nil
|
||||
}
|
||||
|
||||
func ParseFuncs(funcs string) (map[string]string, error) {
|
||||
r := strings.NewReader(funcs)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(r, "funcs")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
for _, stmt := range file.Stmts {
|
||||
funcName, funcVal, err := parseFuncStmt(stmt, funcs)
|
||||
if err != nil {
|
||||
// TODO where to put parse errors
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(funcName, "_mshell_") {
|
||||
continue
|
||||
}
|
||||
if funcName != "" {
|
||||
rtn[funcName] = funcVal
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
const MaxDiffKeyLen = 40
|
||||
const MaxDiffValLen = 50
|
||||
|
||||
var IgnoreVars = map[string]bool{"PROMPT": true, "PROMPT_VERSION": true, "MSHELL": true}
|
||||
|
||||
func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newState packet.ShellState) {
|
||||
if newState.Cwd != oldState.Cwd {
|
||||
buf.WriteString(fmt.Sprintf("cwd %s\n", newState.Cwd))
|
||||
}
|
||||
if !bytes.Equal(newState.ShellVars, oldState.ShellVars) {
|
||||
newEnvMap := shexec.DeclMapFromState(&newState)
|
||||
oldEnvMap := shexec.DeclMapFromState(&oldState)
|
||||
for key, newVal := range newEnvMap {
|
||||
if IgnoreVars[key] {
|
||||
continue
|
||||
}
|
||||
oldVal, found := oldEnvMap[key]
|
||||
if !found || !shexec.DeclsEqual(false, oldVal, newVal) {
|
||||
var exportStr string
|
||||
if newVal.IsExport() {
|
||||
exportStr = "export "
|
||||
}
|
||||
buf.WriteString(fmt.Sprintf("%s%s=%s\n", exportStr, utilfn.EllipsisStr(key, MaxDiffKeyLen), utilfn.EllipsisStr(newVal.Value, MaxDiffValLen)))
|
||||
}
|
||||
}
|
||||
for key, _ := range oldEnvMap {
|
||||
if IgnoreVars[key] {
|
||||
continue
|
||||
}
|
||||
_, found := newEnvMap[key]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("unset %s\n", utilfn.EllipsisStr(key, MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
}
|
||||
if newState.Aliases != oldState.Aliases {
|
||||
newAliasMap, _ := ParseAliases(newState.Aliases)
|
||||
oldAliasMap, _ := ParseAliases(oldState.Aliases)
|
||||
for aliasName, newAliasVal := range newAliasMap {
|
||||
oldAliasVal, found := oldAliasMap[aliasName]
|
||||
if !found || newAliasVal != oldAliasVal {
|
||||
buf.WriteString(fmt.Sprintf("alias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
for aliasName, _ := range oldAliasMap {
|
||||
_, found := newAliasMap[aliasName]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("unalias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
}
|
||||
if newState.Funcs != oldState.Funcs {
|
||||
newFuncMap, _ := ParseFuncs(newState.Funcs)
|
||||
oldFuncMap, _ := ParseFuncs(oldState.Funcs)
|
||||
for funcName, newFuncVal := range newFuncMap {
|
||||
oldFuncVal, found := oldFuncMap[funcName]
|
||||
if !found || newFuncVal != oldFuncVal {
|
||||
buf.WriteString(fmt.Sprintf("function %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
for funcName, _ := range oldFuncMap {
|
||||
_, found := newFuncMap[funcName]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("unset -f %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetRtnStateDiff(ctx context.Context, screenId string, lineId string) ([]byte, error) {
|
||||
cmd, err := sstore.GetCmdByScreenId(ctx, screenId, lineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cmd == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if !cmd.RtnState {
|
||||
return nil, nil
|
||||
}
|
||||
if cmd.RtnStatePtr.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
var outputBytes bytes.Buffer
|
||||
initialState, err := sstore.GetFullState(ctx, cmd.StatePtr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting initial full state: %v", err)
|
||||
}
|
||||
rtnState, err := sstore.GetFullState(ctx, cmd.RtnStatePtr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting rtn full state: %v", err)
|
||||
}
|
||||
displayStateUpdateDiff(&outputBytes, *initialState, *rtnState)
|
||||
return outputBytes.Bytes(), nil
|
||||
}
|
407
wavesrv/pkg/scbase/scbase.go
Normal file
407
wavesrv/pkg/scbase/scbase.go
Normal file
@ -0,0 +1,407 @@
|
||||
package scbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const HomeVarName = "HOME"
|
||||
const PromptHomeVarName = "PROMPT_HOME"
|
||||
const PromptDevVarName = "PROMPT_DEV"
|
||||
const SessionsDirBaseName = "sessions"
|
||||
const ScreensDirBaseName = "screens"
|
||||
const PromptLockFile = "prompt.lock"
|
||||
const PromptDirName = "prompt"
|
||||
const PromptDevDirName = "prompt-dev"
|
||||
const PromptAppPathVarName = "PROMPT_APP_PATH"
|
||||
const PromptVersion = "v0.4.0"
|
||||
const PromptAuthKeyFileName = "prompt.authkey"
|
||||
const MShellVersion = "v0.3.0"
|
||||
const DefaultMacOSShell = "/bin/bash"
|
||||
|
||||
var SessionDirCache = make(map[string]string)
|
||||
var ScreenDirCache = make(map[string]string)
|
||||
var BaseLock = &sync.Mutex{}
|
||||
var BuildTime = "-"
|
||||
|
||||
func IsDevMode() bool {
|
||||
pdev := os.Getenv(PromptDevVarName)
|
||||
return pdev != ""
|
||||
}
|
||||
|
||||
// must match js
|
||||
func GetPromptHomeDir() string {
|
||||
scHome := os.Getenv(PromptHomeVarName)
|
||||
if scHome == "" {
|
||||
homeVar := os.Getenv(HomeVarName)
|
||||
if homeVar == "" {
|
||||
homeVar = "/"
|
||||
}
|
||||
pdev := os.Getenv(PromptDevVarName)
|
||||
if pdev != "" {
|
||||
scHome = path.Join(homeVar, PromptDevDirName)
|
||||
} else {
|
||||
scHome = path.Join(homeVar, PromptDirName)
|
||||
}
|
||||
|
||||
}
|
||||
return scHome
|
||||
}
|
||||
|
||||
func MShellBinaryDir() string {
|
||||
appPath := os.Getenv(PromptAppPathVarName)
|
||||
if appPath == "" {
|
||||
appPath = "."
|
||||
}
|
||||
if IsDevMode() {
|
||||
return path.Join(appPath, "dev-bin")
|
||||
}
|
||||
return path.Join(appPath, "bin", "mshell")
|
||||
}
|
||||
|
||||
func MShellBinaryPath(version string, goos string, goarch string) (string, error) {
|
||||
if !base.ValidGoArch(goos, goarch) {
|
||||
return "", fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch)
|
||||
}
|
||||
binaryDir := MShellBinaryDir()
|
||||
versionStr := semver.MajorMinor(version)
|
||||
if versionStr == "" {
|
||||
return "", fmt.Errorf("invalid mshell version: %q", version)
|
||||
}
|
||||
fileName := fmt.Sprintf("mshell-%s-%s.%s", versionStr, goos, goarch)
|
||||
fullFileName := path.Join(binaryDir, fileName)
|
||||
return fullFileName, nil
|
||||
}
|
||||
|
||||
func LocalMShellBinaryPath() (string, error) {
|
||||
return MShellBinaryPath(MShellVersion, runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
func MShellBinaryReader(version string, goos string, goarch string) (io.ReadCloser, error) {
|
||||
mshellPath, err := MShellBinaryPath(version, goos, goarch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fd, err := os.Open(mshellPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot open mshell binary %q: %v", mshellPath, err)
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func createPromptAuthKeyFile(fileName string) (string, error) {
|
||||
fd, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer fd.Close()
|
||||
keyStr := GenPromptUUID()
|
||||
_, err = fd.Write([]byte(keyStr))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return keyStr, nil
|
||||
}
|
||||
|
||||
func ReadPromptAuthKey() (string, error) {
|
||||
homeDir := GetPromptHomeDir()
|
||||
err := ensureDir(homeDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot find/create PROMPT_HOME directory %q", homeDir)
|
||||
}
|
||||
fileName := path.Join(homeDir, PromptAuthKeyFileName)
|
||||
fd, err := os.Open(fileName)
|
||||
if err != nil && errors.Is(err, fs.ErrNotExist) {
|
||||
return createPromptAuthKeyFile(fileName)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error opening prompt authkey:%s: %v", fileName, err)
|
||||
}
|
||||
defer fd.Close()
|
||||
buf, err := io.ReadAll(fd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading prompt authkey:%s: %v", fileName, err)
|
||||
}
|
||||
keyStr := string(buf)
|
||||
_, err = uuid.Parse(keyStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid authkey:%s format: %v", fileName, err)
|
||||
}
|
||||
return keyStr, nil
|
||||
}
|
||||
|
||||
func AcquirePromptLock() (*os.File, error) {
|
||||
homeDir := GetPromptHomeDir()
|
||||
err := ensureDir(homeDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find/create PROMPT_HOME directory %q", homeDir)
|
||||
}
|
||||
lockFileName := path.Join(homeDir, PromptLockFile)
|
||||
fd, err := os.Create(lockFileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = unix.Flock(int(fd.Fd()), unix.LOCK_EX|unix.LOCK_NB)
|
||||
if err != nil {
|
||||
fd.Close()
|
||||
return nil, err
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
// deprecated (v0.1.8)
|
||||
func EnsureSessionDir(sessionId string) (string, error) {
|
||||
if sessionId == "" {
|
||||
return "", fmt.Errorf("cannot get session dir for blank sessionid")
|
||||
}
|
||||
BaseLock.Lock()
|
||||
sdir, ok := SessionDirCache[sessionId]
|
||||
BaseLock.Unlock()
|
||||
if ok {
|
||||
return sdir, nil
|
||||
}
|
||||
scHome := GetPromptHomeDir()
|
||||
sdir = path.Join(scHome, SessionsDirBaseName, sessionId)
|
||||
err := ensureDir(sdir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
BaseLock.Lock()
|
||||
SessionDirCache[sessionId] = sdir
|
||||
BaseLock.Unlock()
|
||||
return sdir, nil
|
||||
}
|
||||
|
||||
// deprecated (v0.1.8)
|
||||
func GetSessionsDir() string {
|
||||
promptHome := GetPromptHomeDir()
|
||||
sdir := path.Join(promptHome, SessionsDirBaseName)
|
||||
return sdir
|
||||
}
|
||||
|
||||
func EnsureScreenDir(screenId string) (string, error) {
|
||||
if screenId == "" {
|
||||
return "", fmt.Errorf("cannot get screen dir for blank sessionid")
|
||||
}
|
||||
BaseLock.Lock()
|
||||
sdir, ok := ScreenDirCache[screenId]
|
||||
BaseLock.Unlock()
|
||||
if ok {
|
||||
return sdir, nil
|
||||
}
|
||||
scHome := GetPromptHomeDir()
|
||||
sdir = path.Join(scHome, ScreensDirBaseName, screenId)
|
||||
err := ensureDir(sdir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
BaseLock.Lock()
|
||||
ScreenDirCache[screenId] = sdir
|
||||
BaseLock.Unlock()
|
||||
return sdir, nil
|
||||
}
|
||||
|
||||
func GetScreensDir() string {
|
||||
promptHome := GetPromptHomeDir()
|
||||
sdir := path.Join(promptHome, ScreensDirBaseName)
|
||||
return sdir
|
||||
}
|
||||
|
||||
func ensureDir(dirName string) error {
|
||||
info, err := os.Stat(dirName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
err = os.MkdirAll(dirName, 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[prompt] created directory %q\n", dirName)
|
||||
info, err = os.Stat(dirName)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("'%s' must be a directory", dirName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deprecated (v0.1.8)
|
||||
func PtyOutFile_Sessions(sessionId string, cmdId string) (string, error) {
|
||||
sdir, err := EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if sessionId == "" {
|
||||
return "", fmt.Errorf("cannot get ptyout file for blank sessionid")
|
||||
}
|
||||
if cmdId == "" {
|
||||
return "", fmt.Errorf("cannot get ptyout file for blank cmdid")
|
||||
}
|
||||
return fmt.Sprintf("%s/%s.ptyout.cf", sdir, cmdId), nil
|
||||
}
|
||||
|
||||
func PtyOutFile(screenId string, lineId string) (string, error) {
|
||||
sdir, err := EnsureScreenDir(screenId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if screenId == "" {
|
||||
return "", fmt.Errorf("cannot get ptyout file for blank screenid")
|
||||
}
|
||||
if lineId == "" {
|
||||
return "", fmt.Errorf("cannot get ptyout file for blank lineid")
|
||||
}
|
||||
return fmt.Sprintf("%s/%s.ptyout.cf", sdir, lineId), nil
|
||||
}
|
||||
|
||||
func GenPromptUUID() string {
|
||||
for {
|
||||
rtn := uuid.New().String()
|
||||
_, err := strconv.Atoi(rtn[0:8])
|
||||
if err == nil { // do not allow UUIDs where the initial 8 bytes parse to an integer
|
||||
continue
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
}
|
||||
|
||||
func NumFormatDec(num int64) string {
|
||||
var signStr string
|
||||
absNum := num
|
||||
if absNum < 0 {
|
||||
absNum = -absNum
|
||||
signStr = "-"
|
||||
}
|
||||
if absNum < 1000 {
|
||||
// raw num
|
||||
return signStr + strconv.FormatInt(absNum, 10)
|
||||
}
|
||||
if absNum < 1000000 {
|
||||
// k num
|
||||
kVal := float64(absNum) / 1000
|
||||
return signStr + strconv.FormatFloat(kVal, 'f', 2, 64) + "k"
|
||||
}
|
||||
if absNum < 1000000000 {
|
||||
// M num
|
||||
mVal := float64(absNum) / 1000000
|
||||
return signStr + strconv.FormatFloat(mVal, 'f', 2, 64) + "m"
|
||||
} else {
|
||||
// G num
|
||||
gVal := float64(absNum) / 1000000000
|
||||
return signStr + strconv.FormatFloat(gVal, 'f', 2, 64) + "g"
|
||||
}
|
||||
}
|
||||
|
||||
func NumFormatB2(num int64) string {
|
||||
var signStr string
|
||||
absNum := num
|
||||
if absNum < 0 {
|
||||
absNum = -absNum
|
||||
signStr = "-"
|
||||
}
|
||||
if absNum < 1024 {
|
||||
// raw num
|
||||
return signStr + strconv.FormatInt(absNum, 10)
|
||||
}
|
||||
if absNum < 1000000 {
|
||||
// k num
|
||||
if absNum%1024 == 0 {
|
||||
return signStr + strconv.FormatInt(absNum/1024, 10) + "K"
|
||||
}
|
||||
kVal := float64(absNum) / 1024
|
||||
return signStr + strconv.FormatFloat(kVal, 'f', 2, 64) + "K"
|
||||
}
|
||||
if absNum < 1000000000 {
|
||||
// M num
|
||||
if absNum%(1024*1024) == 0 {
|
||||
return signStr + strconv.FormatInt(absNum/(1024*1024), 10) + "M"
|
||||
}
|
||||
mVal := float64(absNum) / (1024 * 1024)
|
||||
return signStr + strconv.FormatFloat(mVal, 'f', 2, 64) + "M"
|
||||
} else {
|
||||
// G num
|
||||
if absNum%(1024*1024*1024) == 0 {
|
||||
return signStr + strconv.FormatInt(absNum/(1024*1024*1024), 10) + "G"
|
||||
}
|
||||
gVal := float64(absNum) / (1024 * 1024 * 1024)
|
||||
return signStr + strconv.FormatFloat(gVal, 'f', 2, 64) + "G"
|
||||
}
|
||||
}
|
||||
|
||||
func ClientArch() string {
|
||||
return fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
var releaseRegex = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
|
||||
var osReleaseOnce = &sync.Once{}
|
||||
var osRelease string
|
||||
|
||||
func macOSRelease() string {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
out, err := exec.CommandContext(ctx, "uname", "-r").CombinedOutput()
|
||||
if err != nil {
|
||||
log.Printf("error executing uname -r: %v\n", err)
|
||||
return "-"
|
||||
}
|
||||
releaseStr := strings.TrimSpace(string(out))
|
||||
if !releaseRegex.MatchString(releaseStr) {
|
||||
log.Printf("invalid uname -r output: [%s]\n", releaseStr)
|
||||
return "-"
|
||||
}
|
||||
return releaseStr
|
||||
}
|
||||
|
||||
func MacOSRelease() string {
|
||||
osReleaseOnce.Do(func() {
|
||||
osRelease = macOSRelease()
|
||||
})
|
||||
return osRelease
|
||||
}
|
||||
|
||||
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
|
||||
|
||||
// dscl . -read /User/[username] UserShell
|
||||
// defaults to /bin/bash
|
||||
func MacUserShell() string {
|
||||
osUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Printf("error getting current user: %v\n", err)
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
userStr := "/Users/" + osUser.Name
|
||||
out, err := exec.CommandContext(ctx, "dscl", ".", "-read", userStr, "UserShell").CombinedOutput()
|
||||
if err != nil {
|
||||
log.Printf("error executing macos user shell lookup: %v %q\n", err, string(out))
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
outStr := strings.TrimSpace(string(out))
|
||||
m := userShellRegexp.FindStringSubmatch(outStr)
|
||||
if m == nil {
|
||||
log.Printf("error in format of dscl output: %q\n", outStr)
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
return m[1]
|
||||
}
|
120
wavesrv/pkg/scpacket/scpacket.go
Normal file
120
wavesrv/pkg/scpacket/scpacket.go
Normal file
@ -0,0 +1,120 @@
|
||||
package scpacket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
)
|
||||
|
||||
const FeCommandPacketStr = "fecmd"
|
||||
const WatchScreenPacketStr = "watchscreen"
|
||||
const FeInputPacketStr = "feinput"
|
||||
const RemoteInputPacketStr = "remoteinput"
|
||||
|
||||
type FeCommandPacketType struct {
|
||||
Type string `json:"type"`
|
||||
MetaCmd string `json:"metacmd"`
|
||||
MetaSubCmd string `json:"metasubcmd,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Kwargs map[string]string `json:"kwargs,omitempty"`
|
||||
RawStr string `json:"rawstr,omitempty"`
|
||||
UIContext *UIContextType `json:"uicontext,omitempty"`
|
||||
Interactive bool `json:"interactive"`
|
||||
}
|
||||
|
||||
func (pk *FeCommandPacketType) GetRawStr() string {
|
||||
if pk.RawStr != "" {
|
||||
return pk.RawStr
|
||||
}
|
||||
cmd := "/" + pk.MetaCmd
|
||||
if pk.MetaSubCmd != "" {
|
||||
cmd = cmd + ":" + pk.MetaSubCmd
|
||||
}
|
||||
var args []string
|
||||
for k, v := range pk.Kwargs {
|
||||
argStr := fmt.Sprintf("%s=%s", shellescape.Quote(k), shellescape.Quote(v))
|
||||
args = append(args, argStr)
|
||||
}
|
||||
for _, arg := range pk.Args {
|
||||
args = append(args, shellescape.Quote(arg))
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return cmd
|
||||
}
|
||||
return cmd + " " + strings.Join(args, " ")
|
||||
}
|
||||
|
||||
type UIContextType struct {
|
||||
SessionId string `json:"sessionid"`
|
||||
ScreenId string `json:"screenid"`
|
||||
Remote *sstore.RemotePtrType `json:"remote,omitempty"`
|
||||
WinSize *packet.WinSize `json:"winsize,omitempty"`
|
||||
Build string `json:"build,omitempty"`
|
||||
}
|
||||
|
||||
type FeInputPacketType struct {
|
||||
Type string `json:"type"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
Remote sstore.RemotePtrType `json:"remote"`
|
||||
InputData64 string `json:"inputdata64"`
|
||||
SigName string `json:"signame,omitempty"`
|
||||
WinSize *packet.WinSize `json:"winsize,omitempty"`
|
||||
}
|
||||
|
||||
type RemoteInputPacketType struct {
|
||||
Type string `json:"type"`
|
||||
RemoteId string `json:"remoteid"`
|
||||
InputData64 string `json:"inputdata64"`
|
||||
}
|
||||
|
||||
type WatchScreenPacketType struct {
|
||||
Type string `json:"type"`
|
||||
SessionId string `json:"sessionid"`
|
||||
ScreenId string `json:"screenid"`
|
||||
Connect bool `json:"connect"`
|
||||
AuthKey string `json:"authkey"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
packet.RegisterPacketType(FeCommandPacketStr, reflect.TypeOf(FeCommandPacketType{}))
|
||||
packet.RegisterPacketType(WatchScreenPacketStr, reflect.TypeOf(WatchScreenPacketType{}))
|
||||
packet.RegisterPacketType(FeInputPacketStr, reflect.TypeOf(FeInputPacketType{}))
|
||||
packet.RegisterPacketType(RemoteInputPacketStr, reflect.TypeOf(RemoteInputPacketType{}))
|
||||
}
|
||||
|
||||
func (*FeCommandPacketType) GetType() string {
|
||||
return FeCommandPacketStr
|
||||
}
|
||||
|
||||
func MakeFeCommandPacket() *FeCommandPacketType {
|
||||
return &FeCommandPacketType{Type: FeCommandPacketStr}
|
||||
}
|
||||
|
||||
func (*FeInputPacketType) GetType() string {
|
||||
return FeInputPacketStr
|
||||
}
|
||||
|
||||
func MakeFeInputPacket() *FeInputPacketType {
|
||||
return &FeInputPacketType{Type: FeInputPacketStr}
|
||||
}
|
||||
|
||||
func (*WatchScreenPacketType) GetType() string {
|
||||
return WatchScreenPacketStr
|
||||
}
|
||||
|
||||
func MakeWatchScreenPacket() *WatchScreenPacketType {
|
||||
return &WatchScreenPacketType{Type: WatchScreenPacketStr}
|
||||
}
|
||||
|
||||
func MakeRemoteInputPacket() *RemoteInputPacketType {
|
||||
return &RemoteInputPacketType{Type: RemoteInputPacketStr}
|
||||
}
|
||||
|
||||
func (*RemoteInputPacketType) GetType() string {
|
||||
return RemoteInputPacketStr
|
||||
}
|
305
wavesrv/pkg/scws/scws.go
Normal file
305
wavesrv/pkg/scws/scws.go
Normal file
@ -0,0 +1,305 @@
|
||||
package scws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/prompt-server/pkg/mapqueue"
|
||||
"github.com/commandlinedev/prompt-server/pkg/remote"
|
||||
"github.com/commandlinedev/prompt-server/pkg/scpacket"
|
||||
"github.com/commandlinedev/prompt-server/pkg/sstore"
|
||||
"github.com/commandlinedev/prompt-server/pkg/wsshell"
|
||||
)
|
||||
|
||||
const WSStatePacketChSize = 20
|
||||
const MaxInputDataSize = 1000
|
||||
const RemoteInputQueueSize = 100
|
||||
|
||||
var RemoteInputMapQueue *mapqueue.MapQueue
|
||||
|
||||
func init() {
|
||||
RemoteInputMapQueue = mapqueue.MakeMapQueue(RemoteInputQueueSize)
|
||||
}
|
||||
|
||||
type WSState struct {
|
||||
Lock *sync.Mutex
|
||||
ClientId string
|
||||
ConnectTime time.Time
|
||||
Shell *wsshell.WSShell
|
||||
UpdateCh chan interface{}
|
||||
UpdateQueue []interface{}
|
||||
Authenticated bool
|
||||
AuthKey string
|
||||
|
||||
SessionId string
|
||||
ScreenId string
|
||||
}
|
||||
|
||||
func MakeWSState(clientId string, authKey string) *WSState {
|
||||
rtn := &WSState{}
|
||||
rtn.Lock = &sync.Mutex{}
|
||||
rtn.ClientId = clientId
|
||||
rtn.ConnectTime = time.Now()
|
||||
rtn.AuthKey = authKey
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (ws *WSState) SetAuthenticated(authVal bool) {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
ws.Authenticated = authVal
|
||||
}
|
||||
|
||||
func (ws *WSState) IsAuthenticated() bool {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
return ws.Authenticated
|
||||
}
|
||||
|
||||
func (ws *WSState) GetShell() *wsshell.WSShell {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
return ws.Shell
|
||||
}
|
||||
|
||||
func (ws *WSState) WriteUpdate(update interface{}) error {
|
||||
shell := ws.GetShell()
|
||||
if shell == nil {
|
||||
return fmt.Errorf("cannot write update, empty shell")
|
||||
}
|
||||
err := shell.WriteJson(update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WSState) UpdateConnectTime() {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
ws.ConnectTime = time.Now()
|
||||
}
|
||||
|
||||
func (ws *WSState) GetConnectTime() time.Time {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
return ws.ConnectTime
|
||||
}
|
||||
|
||||
func (ws *WSState) WatchScreen(sessionId string, screenId string) {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
if ws.SessionId == sessionId && ws.ScreenId == screenId {
|
||||
return
|
||||
}
|
||||
ws.SessionId = sessionId
|
||||
ws.ScreenId = screenId
|
||||
ws.UpdateCh = sstore.MainBus.RegisterChannel(ws.ClientId, ws.ScreenId)
|
||||
go ws.RunUpdates(ws.UpdateCh)
|
||||
}
|
||||
|
||||
func (ws *WSState) UnWatchScreen() {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
sstore.MainBus.UnregisterChannel(ws.ClientId)
|
||||
ws.SessionId = ""
|
||||
ws.ScreenId = ""
|
||||
log.Printf("[ws] unwatch screen clientid=%s\n", ws.ClientId)
|
||||
}
|
||||
|
||||
func (ws *WSState) getUpdateCh() chan interface{} {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
return ws.UpdateCh
|
||||
}
|
||||
|
||||
func (ws *WSState) RunUpdates(updateCh chan interface{}) {
|
||||
if updateCh == nil {
|
||||
panic("invalid nil updateCh passed to RunUpdates")
|
||||
}
|
||||
for update := range updateCh {
|
||||
shell := ws.GetShell()
|
||||
if shell != nil {
|
||||
shell.WriteJson(update)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WSState) ReplaceShell(shell *wsshell.WSShell) {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
if ws.Shell == nil {
|
||||
ws.Shell = shell
|
||||
return
|
||||
}
|
||||
ws.Shell.Conn.Close()
|
||||
ws.Shell = shell
|
||||
return
|
||||
}
|
||||
|
||||
func (ws *WSState) handleConnection() error {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
update, err := sstore.GetAllSessions(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting sessions: %w", err)
|
||||
}
|
||||
remotes := remote.GetAllRemoteRuntimeState()
|
||||
ifarr := make([]interface{}, len(remotes))
|
||||
for idx, r := range remotes {
|
||||
ifarr[idx] = r
|
||||
}
|
||||
update.Remotes = ifarr
|
||||
update.Connect = true
|
||||
err = ws.Shell.WriteJson(update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error {
|
||||
if wsPk.SessionId != "" {
|
||||
if _, err := uuid.Parse(wsPk.SessionId); err != nil {
|
||||
return fmt.Errorf("invalid watchscreen sessionid: %w", err)
|
||||
}
|
||||
}
|
||||
if wsPk.ScreenId != "" {
|
||||
if _, err := uuid.Parse(wsPk.ScreenId); err != nil {
|
||||
return fmt.Errorf("invalid watchscreen screenid: %w", err)
|
||||
}
|
||||
}
|
||||
if wsPk.AuthKey == "" {
|
||||
ws.SetAuthenticated(false)
|
||||
return fmt.Errorf("invalid watchscreen, no authkey")
|
||||
}
|
||||
if wsPk.AuthKey != ws.AuthKey {
|
||||
ws.SetAuthenticated(false)
|
||||
return fmt.Errorf("invalid watchscreen, invalid authkey")
|
||||
}
|
||||
ws.SetAuthenticated(true)
|
||||
if wsPk.SessionId == "" || wsPk.ScreenId == "" {
|
||||
ws.UnWatchScreen()
|
||||
} else {
|
||||
ws.WatchScreen(wsPk.SessionId, wsPk.ScreenId)
|
||||
log.Printf("[ws %s] watchscreen %s/%s\n", ws.ClientId, wsPk.SessionId, wsPk.ScreenId)
|
||||
}
|
||||
if wsPk.Connect {
|
||||
// log.Printf("[ws %s] watchscreen connect\n", ws.ClientId)
|
||||
err := ws.handleConnection()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WSState) RunWSRead() {
|
||||
shell := ws.GetShell()
|
||||
if shell == nil {
|
||||
return
|
||||
}
|
||||
shell.WriteJson(map[string]interface{}{"type": "hello"}) // let client know we accepted this connection, ignore error
|
||||
for msgBytes := range shell.ReadChan {
|
||||
pk, err := packet.ParseJsonPacket(msgBytes)
|
||||
if err != nil {
|
||||
log.Printf("error unmarshalling ws message: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == scpacket.WatchScreenPacketStr {
|
||||
wsPk := pk.(*scpacket.WatchScreenPacketType)
|
||||
err := ws.handleWatchScreen(wsPk)
|
||||
if err != nil {
|
||||
// TODO send errors back to client, likely unrecoverable
|
||||
log.Printf("[ws %s] error %v\n", ws.ClientId, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
isAuth := ws.IsAuthenticated()
|
||||
if !isAuth {
|
||||
log.Printf("[error] cannot process ws-packet[%s], not authenticated\n", pk.GetType())
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == scpacket.FeInputPacketStr {
|
||||
feInputPk := pk.(*scpacket.FeInputPacketType)
|
||||
if feInputPk.Remote.OwnerId != "" {
|
||||
log.Printf("[error] cannot send input to remote with ownerid\n")
|
||||
continue
|
||||
}
|
||||
if feInputPk.Remote.RemoteId == "" {
|
||||
log.Printf("[error] invalid input packet, remoteid is not set\n")
|
||||
continue
|
||||
}
|
||||
err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() {
|
||||
err = sendCmdInput(feInputPk)
|
||||
if err != nil {
|
||||
log.Printf("[error] sending command input: %v\n", err)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[error] could not queue sendCmdInput: %v\n", err)
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == scpacket.RemoteInputPacketStr {
|
||||
inputPk := pk.(*scpacket.RemoteInputPacketType)
|
||||
if inputPk.RemoteId == "" {
|
||||
log.Printf("[error] invalid remoteinput packet, remoteid is not set\n")
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
err = remote.SendRemoteInput(inputPk)
|
||||
if err != nil {
|
||||
log.Printf("[error] processing remote input: %v\n", err)
|
||||
}
|
||||
}()
|
||||
continue
|
||||
}
|
||||
log.Printf("got ws bad message: %v\n", pk.GetType())
|
||||
}
|
||||
}
|
||||
|
||||
func sendCmdInput(pk *scpacket.FeInputPacketType) error {
|
||||
err := pk.CK.Validate("input packet")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pk.Remote.RemoteId == "" {
|
||||
return fmt.Errorf("input must set remoteid")
|
||||
}
|
||||
msh := remote.GetRemoteById(pk.Remote.RemoteId)
|
||||
if msh == nil {
|
||||
return fmt.Errorf("remote %s not found", pk.Remote.RemoteId)
|
||||
}
|
||||
if len(pk.InputData64) > 0 {
|
||||
inputLen := packet.B64DecodedLen(pk.InputData64)
|
||||
if inputLen > MaxInputDataSize {
|
||||
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
|
||||
}
|
||||
dataPk := packet.MakeDataPacket()
|
||||
dataPk.CK = pk.CK
|
||||
dataPk.FdNum = 0 // stdin
|
||||
dataPk.Data64 = pk.InputData64
|
||||
err = msh.SendInput(dataPk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if pk.SigName != "" || pk.WinSize != nil {
|
||||
siPk := packet.MakeSpecialInputPacket()
|
||||
siPk.CK = pk.CK
|
||||
siPk.SigName = pk.SigName
|
||||
siPk.WinSize = pk.WinSize
|
||||
err = msh.SendSpecialInput(siPk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
288
wavesrv/pkg/shparse/comp.go
Normal file
288
wavesrv/pkg/shparse/comp.go
Normal file
@ -0,0 +1,288 @@
|
||||
package shparse
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/commandlinedev/prompt-server/pkg/utilfn"
|
||||
)
|
||||
|
||||
const (
|
||||
CompTypeCommandMeta = "command-meta"
|
||||
CompTypeCommand = "command"
|
||||
CompTypeArg = "command-arg"
|
||||
CompTypeInvalid = "invalid"
|
||||
CompTypeVar = "var"
|
||||
CompTypeAssignment = "assignment"
|
||||
CompTypeBasic = "basic"
|
||||
)
|
||||
|
||||
type CompletionPos struct {
|
||||
RawPos int // the raw position of cursor
|
||||
SuperOffset int // adjust all offsets in Cmd and CmdWord by SuperOffset
|
||||
|
||||
CompType string // see CompType* constants
|
||||
Cmd *CmdType // nil if between commands or a special completion (otherwise will be a SimpleCommand)
|
||||
// index into cmd.Words (only set when Cmd is not nil, otherwise we look at CompCommand)
|
||||
// 0 means command-word
|
||||
// negative means assignment-words.
|
||||
// can be past the end of Words (means start new word).
|
||||
CmdWordPos int
|
||||
CompWord *WordType // set to the word we are completing (nil if we are starting a new word)
|
||||
CompWordOffset int // offset into compword (only if CmdWord is not nil)
|
||||
}
|
||||
|
||||
func compTypeFromPos(cmdWordPos int) string {
|
||||
if cmdWordPos == 0 {
|
||||
return CompTypeCommand
|
||||
}
|
||||
if cmdWordPos < 0 {
|
||||
return CompTypeAssignment
|
||||
}
|
||||
return CompTypeArg
|
||||
}
|
||||
|
||||
func (cmd *CmdType) findCompletionPos_simple(pos int, superOffset int) CompletionPos {
|
||||
if cmd.Type != CmdTypeSimple {
|
||||
panic("findCompletetionPos_simple only works for CmdTypeSimple")
|
||||
}
|
||||
rtn := CompletionPos{RawPos: pos, SuperOffset: superOffset, Cmd: cmd}
|
||||
for idx, word := range cmd.AssignmentWords {
|
||||
startOffset := word.Offset
|
||||
endOffset := word.Offset + len(word.Raw)
|
||||
if pos <= startOffset {
|
||||
// starting a new word at this position (before the current assignment word)
|
||||
rtn.CmdWordPos = idx - len(cmd.AssignmentWords)
|
||||
rtn.CompType = CompTypeAssignment
|
||||
return rtn
|
||||
}
|
||||
if pos <= endOffset {
|
||||
// completing an assignment word
|
||||
rtn.CmdWordPos = idx - len(cmd.AssignmentWords)
|
||||
rtn.CompWord = word
|
||||
rtn.CompWordOffset = pos - word.Offset
|
||||
rtn.CompType = CompTypeAssignment
|
||||
return rtn
|
||||
}
|
||||
}
|
||||
var foundWord *WordType
|
||||
var foundWordIdx int
|
||||
for idx, word := range cmd.Words {
|
||||
startOffset := word.Offset
|
||||
endOffset := word.Offset + len(word.Raw)
|
||||
if pos <= startOffset {
|
||||
// starting a new word at this position
|
||||
rtn.CmdWordPos = idx
|
||||
rtn.CompType = compTypeFromPos(idx)
|
||||
return rtn
|
||||
}
|
||||
if pos == endOffset && word.Type == WordTypeOp {
|
||||
// operators are special, they can allow a full-word completion at endpos
|
||||
continue
|
||||
}
|
||||
if pos <= endOffset {
|
||||
foundWord = word
|
||||
foundWordIdx = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundWord != nil {
|
||||
rtn.CmdWordPos = foundWordIdx
|
||||
rtn.CompWord = foundWord
|
||||
rtn.CompWordOffset = pos - foundWord.Offset
|
||||
if foundWord.uncompletable() {
|
||||
// invalid completion point
|
||||
rtn.CompType = CompTypeInvalid
|
||||
return rtn
|
||||
}
|
||||
rtn.CompType = compTypeFromPos(foundWordIdx)
|
||||
return rtn
|
||||
}
|
||||
// past the end, so we're starting a new word in Cmd
|
||||
rtn.CmdWordPos = len(cmd.Words)
|
||||
rtn.CompType = CompTypeArg
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (cmd *CmdType) findCompletionPos_none(pos int, superOffset int) CompletionPos {
|
||||
rtn := CompletionPos{RawPos: pos, SuperOffset: superOffset}
|
||||
if cmd.Type != CmdTypeNone {
|
||||
panic("findCompletionPos_none only works for CmdTypeNone")
|
||||
}
|
||||
var foundWord *WordType
|
||||
for _, word := range cmd.Words {
|
||||
startOffset := word.Offset
|
||||
endOffset := word.Offset + len(word.Raw)
|
||||
if pos <= startOffset {
|
||||
break
|
||||
}
|
||||
if pos <= endOffset {
|
||||
if pos == endOffset && word.Type == WordTypeOp {
|
||||
// operators are special, they can allow a full-word completion at endpos
|
||||
continue
|
||||
}
|
||||
foundWord = word
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundWord == nil {
|
||||
// just revert to a file completion
|
||||
rtn.CompType = CompTypeBasic
|
||||
return rtn
|
||||
}
|
||||
foundWordOffset := pos - foundWord.Offset
|
||||
rtn.CompWord = foundWord
|
||||
rtn.CompWordOffset = foundWordOffset
|
||||
if foundWord.uncompletable() {
|
||||
// ok, we're inside of a word in CmdTypeNone. if we're in an uncompletable word, return CompInvalid
|
||||
rtn.CompType = CompTypeInvalid
|
||||
return rtn
|
||||
}
|
||||
if foundWordOffset > 0 && foundWordOffset < foundWord.contentStartPos() {
|
||||
// cursor is in a weird position, between characters of a multi-char prefix (e.g. "$[*]{hello}" or $[*]'hello'). cannot complete.
|
||||
rtn.CompType = CompTypeInvalid
|
||||
return rtn
|
||||
}
|
||||
// revert to file completion
|
||||
rtn.CompType = CompTypeBasic
|
||||
return rtn
|
||||
}
|
||||
|
||||
func findCompletionWordAtPos(words []*WordType, pos int, allowEndMatch bool) *WordType {
|
||||
// WordTypeSimpleVar is special (always allowEndMatch), if cursor is at the end of SimpleVar it is returned
|
||||
for _, word := range words {
|
||||
if pos > word.Offset && pos < word.Offset+len(word.Raw) {
|
||||
return word
|
||||
}
|
||||
if (allowEndMatch || word.Type == WordTypeSimpleVar) && pos == word.Offset+len(word.Raw) {
|
||||
return word
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// recursively descend down the word, parse commands and find a sub completion point if any.
|
||||
// return nil if there is no sub completion point in this word
|
||||
func findCompletionPosInWord(word *WordType, posInWord int, superOffset int) *CompletionPos {
|
||||
rawPos := word.Offset + posInWord
|
||||
if word.Type == WordTypeGroup || word.Type == WordTypeDQ || word.Type == WordTypeDDQ {
|
||||
// need to descend further
|
||||
if posInWord <= word.contentStartPos() {
|
||||
return nil
|
||||
}
|
||||
if posInWord > word.contentEndPos() {
|
||||
return nil
|
||||
}
|
||||
subWord := findCompletionWordAtPos(word.Subs, posInWord-word.contentStartPos(), false)
|
||||
if subWord == nil {
|
||||
return nil
|
||||
}
|
||||
return findCompletionPosInWord(subWord, posInWord-(subWord.Offset+word.contentStartPos()), superOffset+(word.Offset+word.contentStartPos()))
|
||||
}
|
||||
if word.Type == WordTypeDP || word.Type == WordTypeBQ {
|
||||
if posInWord < word.contentStartPos() {
|
||||
return nil
|
||||
}
|
||||
if posInWord > word.contentEndPos() {
|
||||
return nil
|
||||
}
|
||||
subCmds := ParseCommands(word.Subs)
|
||||
newPos := findCompletionPosInternal(subCmds, posInWord-word.contentStartPos(), superOffset+(word.Offset+word.contentStartPos()))
|
||||
return &newPos
|
||||
}
|
||||
if word.Type == WordTypeSimpleVar || word.Type == WordTypeVarBrace {
|
||||
// special "var" completion
|
||||
rtn := &CompletionPos{RawPos: rawPos, SuperOffset: superOffset}
|
||||
rtn.CompType = CompTypeVar
|
||||
rtn.CompWordOffset = posInWord
|
||||
rtn.CompWord = word
|
||||
return rtn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// returns the context for completion
|
||||
// if we are completing in a simple-command, the returns the Cmd. the Cmd can be used for specialized completion (command name, arg position, etc.)
|
||||
// if we are completing in a word, returns the Word. Word might be a group-word or DQ word, so it may need additional resolution (done in extend)
|
||||
// otherwise we are going to create a new word to insert at offset (so the context does not matter)
|
||||
func findCompletionPosCmds(cmds []*CmdType, pos int, superOffset int) CompletionPos {
|
||||
rtn := CompletionPos{RawPos: pos, SuperOffset: superOffset}
|
||||
if len(cmds) == 0 {
|
||||
// set CompCommand because we're starting a new command
|
||||
rtn.CompType = CompTypeCommand
|
||||
return rtn
|
||||
}
|
||||
for _, cmd := range cmds {
|
||||
endOffset := cmd.endOffset()
|
||||
if pos > endOffset || (cmd.Type == CmdTypeNone && pos == endOffset) {
|
||||
continue
|
||||
}
|
||||
startOffset := cmd.offset()
|
||||
if cmd.Type == CmdTypeSimple {
|
||||
if pos <= startOffset {
|
||||
rtn.CompType = CompTypeCommand
|
||||
return rtn
|
||||
}
|
||||
return cmd.findCompletionPos_simple(pos, superOffset)
|
||||
} else {
|
||||
// not in a simple-command
|
||||
// if we're before the none-command, just start a new command
|
||||
if pos <= startOffset {
|
||||
rtn.CompType = CompTypeCommand
|
||||
return rtn
|
||||
}
|
||||
return cmd.findCompletionPos_none(pos, superOffset)
|
||||
}
|
||||
}
|
||||
// past the end
|
||||
lastCmd := cmds[len(cmds)-1]
|
||||
if lastCmd.Type == CmdTypeSimple {
|
||||
// just extend last command
|
||||
rtn.Cmd = lastCmd
|
||||
rtn.CmdWordPos = len(lastCmd.Words)
|
||||
rtn.CompType = CompTypeArg
|
||||
return rtn
|
||||
}
|
||||
// use lastCmd.NoneComplete to see if last command ended on a "separator". use that to set CompCommand
|
||||
if lastCmd.NoneComplete {
|
||||
rtn.CompType = CompTypeCommand
|
||||
} else {
|
||||
rtn.CompType = CompTypeBasic
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func findCompletionPosInternal(cmds []*CmdType, pos int, superOffset int) CompletionPos {
|
||||
cpos := findCompletionPosCmds(cmds, pos, superOffset)
|
||||
if cpos.CompWord == nil {
|
||||
return cpos
|
||||
}
|
||||
subPos := findCompletionPosInWord(cpos.CompWord, cpos.CompWordOffset, superOffset)
|
||||
if subPos != nil {
|
||||
return *subPos
|
||||
}
|
||||
return cpos
|
||||
}
|
||||
|
||||
func FindCompletionPos(cmds []*CmdType, pos int) CompletionPos {
|
||||
cpos := findCompletionPosInternal(cmds, pos, 0)
|
||||
if cpos.CompType == CompTypeCommand && cpos.SuperOffset == 0 && cpos.CompWord != nil && cpos.CompWord.Offset == 0 && strings.HasPrefix(string(cpos.CompWord.Raw), "/") {
|
||||
cpos.CompType = CompTypeCommandMeta
|
||||
}
|
||||
return cpos
|
||||
}
|
||||
|
||||
func (cpos CompletionPos) Extend(origStr utilfn.StrWithPos, extensionStr string, extensionComplete bool) utilfn.StrWithPos {
|
||||
compWord := cpos.CompWord
|
||||
if compWord == nil {
|
||||
compWord = MakeEmptyWord(WordTypeLit, nil, cpos.RawPos, true)
|
||||
}
|
||||
realOffset := compWord.Offset + cpos.SuperOffset
|
||||
if strings.HasSuffix(extensionStr, "/") {
|
||||
extensionComplete = false
|
||||
}
|
||||
rtnSP := Extend(compWord, cpos.CompWordOffset, extensionStr, extensionComplete)
|
||||
origRunes := []rune(origStr.Str)
|
||||
rtnSP = rtnSP.Prepend(string(origRunes[0:realOffset]))
|
||||
rtnSP = rtnSP.Append(string(origRunes[realOffset+len(compWord.Raw):]))
|
||||
return rtnSP
|
||||
}
|
258
wavesrv/pkg/shparse/expand.go
Normal file
258
wavesrv/pkg/shparse/expand.go
Normal file
@ -0,0 +1,258 @@
|
||||
package shparse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
)
|
||||
|
||||
const MaxExpandLen = 64 * 1024
|
||||
|
||||
type ExpandInfo struct {
|
||||
HasTilde bool // only ~ as the first character when SimpleExpandContext.HomeDir is set
|
||||
HasVar bool // $x, $$, ${...}
|
||||
HasGlob bool // *, ?, [, {
|
||||
HasExtGlob bool // ?(...) ... ?*+@!
|
||||
HasHistory bool // ! (anywhere)
|
||||
HasSpecial bool // subshell, arith
|
||||
}
|
||||
|
||||
type ExpandContext struct {
|
||||
HomeDir string
|
||||
}
|
||||
|
||||
func expandSQ(buf *bytes.Buffer, rawLit []rune) {
|
||||
// no info specials
|
||||
buf.WriteString(string(rawLit))
|
||||
}
|
||||
|
||||
// TODO implement our own ANSI single quote formatter
|
||||
func expandANSISQ(buf *bytes.Buffer, rawLit []rune) {
|
||||
// no info specials
|
||||
str, _, _ := expand.Format(nil, string(rawLit), nil)
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
func expandLiteral(buf *bytes.Buffer, info *ExpandInfo, rawLit []rune) {
|
||||
var lastBackSlash bool
|
||||
var lastExtGlob bool
|
||||
var lastDollar bool
|
||||
for _, ch := range rawLit {
|
||||
if ch == 0 {
|
||||
break
|
||||
}
|
||||
if lastBackSlash {
|
||||
lastBackSlash = false
|
||||
if ch == '\n' {
|
||||
// special case, backslash *and* newline are ignored
|
||||
continue
|
||||
}
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
lastBackSlash = true
|
||||
lastExtGlob = false
|
||||
lastDollar = false
|
||||
continue
|
||||
}
|
||||
if ch == '*' || ch == '?' || ch == '[' || ch == '{' {
|
||||
info.HasGlob = true
|
||||
}
|
||||
if ch == '`' {
|
||||
info.HasSpecial = true
|
||||
}
|
||||
if ch == '!' {
|
||||
info.HasHistory = true
|
||||
}
|
||||
if lastExtGlob && ch == '(' {
|
||||
info.HasExtGlob = true
|
||||
}
|
||||
if lastDollar && (ch != ' ' && ch != '"' && ch != '\'' && ch != '(' || ch != '[') {
|
||||
info.HasVar = true
|
||||
}
|
||||
if lastDollar && (ch == '(' || ch == '[') {
|
||||
info.HasSpecial = true
|
||||
}
|
||||
lastExtGlob = (ch == '?' || ch == '*' || ch == '+' || ch == '@' || ch == '!')
|
||||
lastDollar = (ch == '$')
|
||||
buf.WriteRune(ch)
|
||||
}
|
||||
if lastBackSlash {
|
||||
buf.WriteByte('\\')
|
||||
}
|
||||
}
|
||||
|
||||
// will also work for partial double quoted strings
|
||||
func expandDQLiteral(buf *bytes.Buffer, info *ExpandInfo, rawVal []rune) {
|
||||
var lastBackSlash bool
|
||||
var lastDollar bool
|
||||
for _, ch := range rawVal {
|
||||
if ch == 0 {
|
||||
break
|
||||
}
|
||||
if lastBackSlash {
|
||||
lastBackSlash = false
|
||||
if ch == '"' || ch == '\\' || ch == '$' || ch == '`' {
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
buf.WriteRune('\\')
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
lastBackSlash = true
|
||||
lastDollar = false
|
||||
continue
|
||||
}
|
||||
|
||||
// similar to expandLiteral, but no globbing
|
||||
if ch == '`' {
|
||||
info.HasSpecial = true
|
||||
}
|
||||
if ch == '!' {
|
||||
info.HasHistory = true
|
||||
}
|
||||
if lastDollar && (ch != ' ' && ch != '"' && ch != '\'' && ch != '(' || ch != '[') {
|
||||
info.HasVar = true
|
||||
}
|
||||
if lastDollar && (ch == '(' || ch == '[') {
|
||||
info.HasSpecial = true
|
||||
}
|
||||
lastDollar = (ch == '$')
|
||||
buf.WriteRune(ch)
|
||||
}
|
||||
// in a valid parsed DQ string, you cannot have a trailing backslash (because \" would not end the string)
|
||||
// still putting the case here though in case we ever deal with incomplete strings (e.g. completion)
|
||||
if lastBackSlash {
|
||||
buf.WriteByte('\\')
|
||||
}
|
||||
}
|
||||
|
||||
func simpleExpandSubs(buf *bytes.Buffer, info *ExpandInfo, ectx ExpandContext, word *WordType, pos int) {
|
||||
fmt.Printf("expand subs: %v\n", word)
|
||||
parts := word.Subs
|
||||
startPos := word.contentStartPos()
|
||||
for _, part := range parts {
|
||||
remainingLen := pos - startPos
|
||||
if remainingLen <= 0 {
|
||||
break
|
||||
}
|
||||
simpleExpandWord(buf, info, ectx, part, remainingLen)
|
||||
startPos += len(part.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func canExpand(ectx ExpandContext, wtype string) bool {
|
||||
return wtype == WordTypeLit || wtype == WordTypeSQ || wtype == WordTypeDSQ ||
|
||||
wtype == WordTypeDQ || wtype == WordTypeDDQ || wtype == WordTypeGroup
|
||||
}
|
||||
|
||||
func simpleExpandWord(buf *bytes.Buffer, info *ExpandInfo, ectx ExpandContext, word *WordType, pos int) {
|
||||
if canExpand(ectx, word.Type) {
|
||||
if pos >= word.contentEndPos() {
|
||||
pos = word.contentEndPos()
|
||||
}
|
||||
if pos <= word.contentStartPos() {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if pos >= len(word.Raw) {
|
||||
pos = len(word.Raw)
|
||||
}
|
||||
if pos <= 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch word.Type {
|
||||
case WordTypeLit:
|
||||
if word.QC.cur() == WordTypeDQ {
|
||||
expandDQLiteral(buf, info, word.Raw[:pos])
|
||||
return
|
||||
}
|
||||
expandLiteral(buf, info, word.Raw[:pos])
|
||||
|
||||
case WordTypeSQ:
|
||||
expandSQ(buf, word.Raw[word.contentStartPos():pos])
|
||||
return
|
||||
|
||||
case WordTypeDSQ:
|
||||
expandANSISQ(buf, word.Raw[word.contentStartPos():pos])
|
||||
return
|
||||
|
||||
case WordTypeDQ, WordTypeDDQ:
|
||||
simpleExpandSubs(buf, info, ectx, word, pos)
|
||||
return
|
||||
|
||||
case WordTypeGroup:
|
||||
simpleExpandSubs(buf, info, ectx, word, pos)
|
||||
return
|
||||
|
||||
// not expanded
|
||||
case WordTypeSimpleVar:
|
||||
info.HasVar = true
|
||||
buf.WriteString(string(word.Raw[:pos]))
|
||||
return
|
||||
|
||||
// not expanded
|
||||
case WordTypeVarBrace:
|
||||
info.HasVar = true
|
||||
buf.WriteString(string(word.Raw[:pos]))
|
||||
return
|
||||
|
||||
default:
|
||||
info.HasSpecial = true
|
||||
buf.WriteString(string(word.Raw[:pos]))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func SimpleExpandPrefix(ectx ExpandContext, word *WordType, pos int) (string, ExpandInfo) {
|
||||
var buf bytes.Buffer
|
||||
var info ExpandInfo
|
||||
simpleExpandWord(&buf, &info, ectx, word, pos)
|
||||
return buf.String(), info
|
||||
}
|
||||
|
||||
func SimpleExpand(ectx ExpandContext, word *WordType) (string, ExpandInfo) {
|
||||
return SimpleExpandPrefix(ectx, word, len(word.Raw))
|
||||
}
|
||||
|
||||
// returns varname (no '$') and ok (whether this is a valid varname expansion)
|
||||
func SimpleVarNamePrefix(ectx ExpandContext, word *WordType, pos int) (string, bool) {
|
||||
if word.Type != WordTypeSimpleVar && word.Type != WordTypeVarBrace {
|
||||
return "", false
|
||||
}
|
||||
if word.Type == WordTypeSimpleVar {
|
||||
if pos == 0 {
|
||||
return "", false
|
||||
}
|
||||
if pos == 1 {
|
||||
return "", true
|
||||
}
|
||||
if pos > len(word.Raw) {
|
||||
pos = len(word.Raw)
|
||||
}
|
||||
return string(word.Raw[1:pos]), true
|
||||
}
|
||||
|
||||
// word.Type == WordTypeVarBrace
|
||||
// knock '${' off the front, then see if the rest is a valid var name.
|
||||
if pos == 0 || pos == 1 {
|
||||
return "", false
|
||||
}
|
||||
if pos == 2 {
|
||||
return "", true
|
||||
}
|
||||
if pos > word.contentEndPos() {
|
||||
pos = word.contentEndPos()
|
||||
}
|
||||
rawVarName := word.Raw[2:pos]
|
||||
if isSimpleVarName(rawVarName) {
|
||||
return string(rawVarName), true
|
||||
}
|
||||
return "", false
|
||||
}
|
410
wavesrv/pkg/shparse/extend.go
Normal file
410
wavesrv/pkg/shparse/extend.go
Normal file
@ -0,0 +1,410 @@
|
||||
package shparse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/commandlinedev/prompt-server/pkg/utilfn"
|
||||
)
|
||||
|
||||
var noEscChars []bool
|
||||
var specialEsc []string
|
||||
|
||||
func init() {
|
||||
noEscChars = make([]bool, 256)
|
||||
for ch := 0; ch < 256; ch++ {
|
||||
if (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
|
||||
ch == '-' || ch == '.' || ch == '/' || ch == ':' || ch == '=' || ch == '_' {
|
||||
noEscChars[byte(ch)] = true
|
||||
}
|
||||
}
|
||||
specialEsc = make([]string, 256)
|
||||
specialEsc[0x7] = "\\a"
|
||||
specialEsc[0x8] = "\\b"
|
||||
specialEsc[0x9] = "\\t"
|
||||
specialEsc[0xa] = "\\n"
|
||||
specialEsc[0xb] = "\\v"
|
||||
specialEsc[0xc] = "\\f"
|
||||
specialEsc[0xd] = "\\r"
|
||||
specialEsc[0x1b] = "\\E"
|
||||
}
|
||||
|
||||
func getUtf8Literal(ch rune) string {
|
||||
var buf bytes.Buffer
|
||||
var runeArr [utf8.UTFMax]byte
|
||||
barr := runeArr[:]
|
||||
byteLen := utf8.EncodeRune(barr, ch)
|
||||
for i := 0; i < byteLen; i++ {
|
||||
buf.WriteString("\\x")
|
||||
buf.WriteByte(utilfn.HexDigits[barr[i]/16])
|
||||
buf.WriteByte(utilfn.HexDigits[barr[i]%16])
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (w *WordType) writeString(s string) {
|
||||
for _, ch := range s {
|
||||
w.writeRune(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WordType) writeRune(ch rune) {
|
||||
wmeta := wordMetaMap[w.Type]
|
||||
if w.Complete && wmeta.SuffixLen == 1 {
|
||||
w.Raw = append(w.Raw[0:len(w.Raw)-1], ch, w.Raw[len(w.Raw)-1])
|
||||
return
|
||||
}
|
||||
if w.Complete && wmeta.SuffixLen == 2 {
|
||||
w.Raw = append(w.Raw[0:len(w.Raw)-2], ch, w.Raw[len(w.Raw)-2], w.Raw[len(w.Raw)-1])
|
||||
return
|
||||
}
|
||||
// not complete or SuffixLen == 0 (2+ is not supported)
|
||||
w.Raw = append(w.Raw, ch)
|
||||
return
|
||||
}
|
||||
|
||||
type extendContext struct {
|
||||
Input []*WordType
|
||||
InputPos int
|
||||
QC QuoteContext
|
||||
Rtn []*WordType
|
||||
CurWord *WordType
|
||||
Intention string
|
||||
}
|
||||
|
||||
func makeExtendContext(qc QuoteContext, word *WordType) *extendContext {
|
||||
rtn := &extendContext{QC: qc}
|
||||
if word == nil {
|
||||
rtn.Intention = WordTypeLit
|
||||
return rtn
|
||||
} else {
|
||||
rtn.Intention = word.Type
|
||||
rtn.Rtn = []*WordType{word}
|
||||
rtn.CurWord = word
|
||||
return rtn
|
||||
}
|
||||
}
|
||||
|
||||
func (ec *extendContext) appendWord(w *WordType) {
|
||||
ec.Rtn = append(ec.Rtn, w)
|
||||
ec.CurWord = w
|
||||
}
|
||||
|
||||
func (ec *extendContext) ensureCurWord() {
|
||||
if ec.CurWord == nil || ec.CurWord.Type != ec.Intention {
|
||||
ec.CurWord = MakeEmptyWord(ec.Intention, ec.QC, 0, true)
|
||||
ec.Rtn = append(ec.Rtn, ec.CurWord)
|
||||
}
|
||||
}
|
||||
|
||||
// grp, dq, ddq
|
||||
func extendWithSubs(word *WordType, wordPos int, extStr string, complete bool) utilfn.StrWithPos {
|
||||
wmeta := wordMetaMap[word.Type]
|
||||
if word.Type == WordTypeGroup {
|
||||
atEnd := (wordPos == len(word.Raw))
|
||||
subWord := findCompletionWordAtPos(word.Subs, wordPos, true)
|
||||
if subWord == nil {
|
||||
strPos := Extend(MakeEmptyWord(WordTypeLit, word.QC, 0, true), 0, extStr, atEnd)
|
||||
strPos = strPos.Prepend(string(word.Raw[0:wordPos]))
|
||||
strPos = strPos.Append(string(word.Raw[wordPos:]))
|
||||
return strPos
|
||||
} else {
|
||||
subComplete := complete && atEnd
|
||||
strPos := Extend(subWord, wordPos-subWord.Offset, extStr, subComplete)
|
||||
strPos = strPos.Prepend(string(word.Raw[0:subWord.Offset]))
|
||||
strPos = strPos.Append(string(word.Raw[subWord.Offset+len(subWord.Raw):]))
|
||||
return strPos
|
||||
}
|
||||
} else if word.Type == WordTypeDQ || word.Type == WordTypeDDQ {
|
||||
if wordPos < word.contentStartPos() {
|
||||
wordPos = word.contentStartPos()
|
||||
}
|
||||
atEnd := (wordPos >= len(word.Raw)-wmeta.SuffixLen)
|
||||
subWord := findCompletionWordAtPos(word.Subs, wordPos-wmeta.PrefixLen, true)
|
||||
quoteBalance := !atEnd
|
||||
if subWord == nil {
|
||||
realOffset := wordPos
|
||||
strPos, wordOpen := extendInternal(MakeEmptyWord(WordTypeLit, word.QC.push(WordTypeDQ), 0, true), 0, extStr, false, quoteBalance)
|
||||
strPos = strPos.Prepend(string(word.Raw[0:realOffset]))
|
||||
var requiredSuffix string
|
||||
if wordOpen {
|
||||
requiredSuffix = wmeta.getSuffix()
|
||||
}
|
||||
if atEnd {
|
||||
if complete {
|
||||
return utilfn.StrWithPos{Str: strPos.Str + requiredSuffix + " ", Pos: strPos.Pos + len(requiredSuffix) + 1}
|
||||
} else {
|
||||
if word.Complete && requiredSuffix != "" {
|
||||
return strPos.Append(requiredSuffix)
|
||||
}
|
||||
return strPos
|
||||
}
|
||||
}
|
||||
strPos = strPos.Append(string(word.Raw[wordPos:]))
|
||||
return strPos
|
||||
} else {
|
||||
realOffset := subWord.Offset + wmeta.PrefixLen
|
||||
strPos, wordOpen := extendInternal(subWord, wordPos-realOffset, extStr, false, quoteBalance)
|
||||
strPos = strPos.Prepend(string(word.Raw[0:realOffset]))
|
||||
var requiredSuffix string
|
||||
if wordOpen {
|
||||
requiredSuffix = wmeta.getSuffix()
|
||||
}
|
||||
if atEnd {
|
||||
if complete {
|
||||
return utilfn.StrWithPos{Str: strPos.Str + requiredSuffix + " ", Pos: strPos.Pos + len(requiredSuffix) + 1}
|
||||
} else {
|
||||
if word.Complete && requiredSuffix != "" {
|
||||
return strPos.Append(requiredSuffix)
|
||||
}
|
||||
return strPos
|
||||
}
|
||||
}
|
||||
strPos = strPos.Append(string(word.Raw[realOffset+len(subWord.Raw):]))
|
||||
return strPos
|
||||
}
|
||||
} else {
|
||||
return utilfn.StrWithPos{Str: string(word.Raw), Pos: wordPos}
|
||||
}
|
||||
}
|
||||
|
||||
// lit, svar, varb, sq, dsq
|
||||
func extendLeafCh(buf *bytes.Buffer, wordOpen *bool, wtype string, qc QuoteContext, ch rune) {
|
||||
switch wtype {
|
||||
case WordTypeSimpleVar, WordTypeVarBrace:
|
||||
extendVar(buf, ch)
|
||||
|
||||
case WordTypeLit:
|
||||
if qc.cur() == WordTypeDQ {
|
||||
extendDQLit(buf, wordOpen, ch)
|
||||
} else {
|
||||
extendLit(buf, ch)
|
||||
}
|
||||
|
||||
case WordTypeSQ:
|
||||
extendSQ(buf, wordOpen, ch)
|
||||
|
||||
case WordTypeDSQ:
|
||||
extendDSQ(buf, wordOpen, ch)
|
||||
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func getWordOpenStr(wtype string, qc QuoteContext) string {
|
||||
if wtype == WordTypeLit {
|
||||
if qc.cur() == WordTypeDQ {
|
||||
return "\""
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
wmeta := wordMetaMap[wtype]
|
||||
return wmeta.getPrefix()
|
||||
}
|
||||
|
||||
// lit, svar, varb sq, dsq
|
||||
func extendLeaf(buf *bytes.Buffer, wordOpen *bool, word *WordType, wordPos int, extStr string) {
|
||||
for _, ch := range extStr {
|
||||
extendLeafCh(buf, wordOpen, word.Type, word.QC, ch)
|
||||
}
|
||||
}
|
||||
|
||||
// lit, grp, svar, dq, ddq, varb, sq, dsq
|
||||
// returns (strwithpos, dq-closed)
|
||||
func extendInternal(word *WordType, wordPos int, extStr string, complete bool, requiresQuoteBalance bool) (utilfn.StrWithPos, bool) {
|
||||
if extStr == "" {
|
||||
return utilfn.StrWithPos{Str: string(word.Raw), Pos: wordPos}, true
|
||||
}
|
||||
if word.canHaveSubs() {
|
||||
return extendWithSubs(word, wordPos, extStr, complete), true
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
isEOW := wordPos >= word.contentEndPos()
|
||||
if isEOW {
|
||||
wordPos = word.contentEndPos()
|
||||
}
|
||||
if wordPos < word.contentStartPos() {
|
||||
wordPos = word.contentStartPos()
|
||||
}
|
||||
if wordPos > 0 {
|
||||
buf.WriteString(string(word.Raw[0:word.contentStartPos()])) // write the prefix
|
||||
}
|
||||
if wordPos > word.contentStartPos() {
|
||||
buf.WriteString(string(word.Raw[word.contentStartPos():wordPos]))
|
||||
}
|
||||
wordOpen := true
|
||||
extendLeaf(&buf, &wordOpen, word, wordPos, extStr)
|
||||
if isEOW {
|
||||
// end-of-word, write the suffix (and optional ' '). return the end of the string
|
||||
wmeta := wordMetaMap[word.Type]
|
||||
rtnPos := utf8.RuneCount(buf.Bytes())
|
||||
buf.WriteString(wmeta.getSuffix())
|
||||
if !wordOpen && requiresQuoteBalance {
|
||||
buf.WriteString(getWordOpenStr(word.Type, word.QC))
|
||||
wordOpen = true
|
||||
}
|
||||
if complete {
|
||||
buf.WriteRune(' ')
|
||||
return utilfn.StrWithPos{Str: buf.String(), Pos: utf8.RuneCount(buf.Bytes())}, wordOpen
|
||||
} else {
|
||||
return utilfn.StrWithPos{Str: buf.String(), Pos: rtnPos}, wordOpen
|
||||
}
|
||||
}
|
||||
// completion in the middle of a word (no ' ')
|
||||
rtnPos := utf8.RuneCount(buf.Bytes())
|
||||
if !wordOpen {
|
||||
// always required since there is a suffix
|
||||
buf.WriteString(getWordOpenStr(word.Type, word.QC))
|
||||
wordOpen = true
|
||||
}
|
||||
buf.WriteString(string(word.Raw[wordPos:])) // write the suffix
|
||||
return utilfn.StrWithPos{Str: buf.String(), Pos: rtnPos}, wordOpen
|
||||
}
|
||||
|
||||
// lit, grp, svar, dq, ddq, varb, sq, dsq
|
||||
func Extend(word *WordType, wordPos int, extStr string, complete bool) utilfn.StrWithPos {
|
||||
rtn, _ := extendInternal(word, wordPos, extStr, complete, false)
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (ec *extendContext) extend(ch rune) {
|
||||
if ch == 0 {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func isVarNameChar(ch rune) bool {
|
||||
return ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9')
|
||||
}
|
||||
|
||||
func extendVar(buf *bytes.Buffer, ch rune) {
|
||||
if ch == 0 {
|
||||
return
|
||||
}
|
||||
if !isVarNameChar(ch) {
|
||||
return
|
||||
}
|
||||
buf.WriteRune(ch)
|
||||
}
|
||||
|
||||
func getSpecialEscape(ch rune) string {
|
||||
if ch > unicode.MaxASCII {
|
||||
return ""
|
||||
}
|
||||
return specialEsc[byte(ch)]
|
||||
}
|
||||
|
||||
func writeSpecial(buf *bytes.Buffer, ch rune, wrap bool) {
|
||||
if wrap {
|
||||
buf.WriteRune('$')
|
||||
buf.WriteRune('\'')
|
||||
}
|
||||
sesc := getSpecialEscape(ch)
|
||||
if sesc != "" {
|
||||
buf.WriteString(sesc)
|
||||
} else {
|
||||
utf8Lit := getUtf8Literal(ch)
|
||||
buf.WriteString(utf8Lit)
|
||||
}
|
||||
if wrap {
|
||||
buf.WriteRune('\'')
|
||||
}
|
||||
}
|
||||
|
||||
func extendLit(buf *bytes.Buffer, ch rune) {
|
||||
if ch == 0 {
|
||||
return
|
||||
}
|
||||
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
|
||||
writeSpecial(buf, ch, true)
|
||||
return
|
||||
}
|
||||
var bch = byte(ch)
|
||||
if noEscChars[bch] {
|
||||
buf.WriteRune(ch)
|
||||
return
|
||||
}
|
||||
buf.WriteRune('\\')
|
||||
buf.WriteRune(ch)
|
||||
return
|
||||
}
|
||||
|
||||
func extendDSQ(buf *bytes.Buffer, wordOpen *bool, ch rune) {
|
||||
if ch == 0 {
|
||||
return
|
||||
}
|
||||
if !*wordOpen {
|
||||
buf.WriteRune('$')
|
||||
buf.WriteRune('\'')
|
||||
*wordOpen = true
|
||||
}
|
||||
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
|
||||
writeSpecial(buf, ch, false)
|
||||
return
|
||||
}
|
||||
if ch == '\'' {
|
||||
buf.WriteRune('\\')
|
||||
buf.WriteRune(ch)
|
||||
return
|
||||
}
|
||||
buf.WriteRune(ch)
|
||||
return
|
||||
}
|
||||
|
||||
func extendSQ(buf *bytes.Buffer, wordOpen *bool, ch rune) {
|
||||
if ch == 0 {
|
||||
return
|
||||
}
|
||||
if ch == '\'' {
|
||||
if *wordOpen {
|
||||
buf.WriteRune('\'')
|
||||
*wordOpen = false
|
||||
}
|
||||
buf.WriteRune('\\')
|
||||
buf.WriteRune('\'')
|
||||
return
|
||||
}
|
||||
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
|
||||
if *wordOpen {
|
||||
buf.WriteRune('\'')
|
||||
*wordOpen = false
|
||||
}
|
||||
writeSpecial(buf, ch, true)
|
||||
return
|
||||
}
|
||||
if !*wordOpen {
|
||||
buf.WriteRune('\'')
|
||||
*wordOpen = true
|
||||
}
|
||||
buf.WriteRune(ch)
|
||||
return
|
||||
}
|
||||
|
||||
func extendDQLit(buf *bytes.Buffer, wordOpen *bool, ch rune) {
|
||||
if ch == 0 {
|
||||
return
|
||||
}
|
||||
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
|
||||
if *wordOpen {
|
||||
buf.WriteRune('"')
|
||||
*wordOpen = false
|
||||
}
|
||||
writeSpecial(buf, ch, true)
|
||||
return
|
||||
}
|
||||
if !*wordOpen {
|
||||
buf.WriteRune('"')
|
||||
*wordOpen = true
|
||||
}
|
||||
if ch == '"' || ch == '\\' || ch == '$' || ch == '`' {
|
||||
buf.WriteRune('\\')
|
||||
buf.WriteRune(ch)
|
||||
return
|
||||
}
|
||||
buf.WriteRune(ch)
|
||||
return
|
||||
}
|
693
wavesrv/pkg/shparse/shparse.go
Normal file
693
wavesrv/pkg/shparse/shparse.go
Normal file
@ -0,0 +1,693 @@
|
||||
package shparse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/commandlinedev/prompt-server/pkg/utilfn"
|
||||
)
|
||||
|
||||
//
|
||||
// cmds := cmd (sep cmd)*
|
||||
// sep := ';' | '&' | '&&' | '||' | '|' | '\n'
|
||||
// cmd := simple-cmd | compound-command redirect-list?
|
||||
// compound-command := brace-group | subshell | for-clause | case-clause | if-clause | while-clause | until-clause
|
||||
// brace-group := '{' cmds '}'
|
||||
// subshell := '(' cmds ')'
|
||||
// simple-command := cmd-prefix cmd-word (io-redirect)*
|
||||
// cmd-prefix := (io-redirect | assignment)*
|
||||
// cmd-suffix := (io-redirect | word)*
|
||||
// cmd-name := word
|
||||
// cmd-word := word
|
||||
// io-redirect := (io-number? io-file) | (io-number? io-here)
|
||||
// io-file := ('<' | '<&' | '>' | '>&' | '>>' | '>|' ) filename
|
||||
// io-here := ('<<' | '<<-') here_end
|
||||
// here-end := word
|
||||
// if-clause := 'if' compound-list 'then' compound-list else-part 'fi'
|
||||
// else-part := 'elif' compound-list 'then' compound-list
|
||||
// | 'elif' compount-list 'then' compound-list else-part
|
||||
// | 'else' compound-list
|
||||
// compound-list := linebreak term sep?
|
||||
//
|
||||
//
|
||||
//
|
||||
// A correctly-formed brace expansion must contain unquoted opening and closing braces, and at least one unquoted comma or a valid sequence expression
|
||||
// Any incorrectly formed brace expansion is left unchanged.
|
||||
//
|
||||
// ambiguity between $((...)) and $((ls); ls)
|
||||
// ambiguity between foo=([0]=hell) and foo=([abc)
|
||||
// tokenization https://pubs.opengroup.org/onlinepubs/7908799/xcu/chap2.html#tag_001_003
|
||||
|
||||
// can-extend: WordTypeLit, WordTypeSimpleVar, WordTypeVarBrace, WordTypeDQ, WordTypeDDQ, WordTypeSQ, WordTypeDSQ
|
||||
const (
|
||||
WordTypeRaw = "raw"
|
||||
WordTypeLit = "lit" // (can-extend)
|
||||
WordTypeOp = "op" // single: & ; | ( ) < > \n multi(2): && || ;; << >> <& >& <> >| (( multi(3): <<- ('((' requires special processing)
|
||||
WordTypeKey = "key" // if then else elif fi do done case esac while until for in { } ! (( [[
|
||||
WordTypeGroup = "grp" // contains other words e.g. "hello"foo'bar'$x (has-subs) (can-extend)
|
||||
WordTypeSimpleVar = "svar" // simplevar $ (can-extend)
|
||||
|
||||
WordTypeDQ = "dq" // " (quote-context) (can-extend) (has-subs)
|
||||
WordTypeDDQ = "ddq" // $" (can-extend) (has-subs) (for quotecontext, uses WordTypeDQ)
|
||||
WordTypeVarBrace = "varb" // ${ (quote-context) (can-extend) (internals not parsed)
|
||||
WordTypeDP = "dp" // $( (quote-context) (has-subs)
|
||||
WordTypeBQ = "bq" // ` (quote-context) (has-subs)
|
||||
|
||||
WordTypeSQ = "sq" // ' (can-extend)
|
||||
WordTypeDSQ = "dsq" // $' (can-extend)
|
||||
WordTypeDPP = "dpp" // $(( (internals not parsed)
|
||||
WordTypePP = "pp" // (( (internals not parsed)
|
||||
WordTypeDB = "db" // $[ (internals not parsed)
|
||||
)
|
||||
|
||||
const (
|
||||
CmdTypeNone = "none" // holds control structures: '(' ')' 'for' 'while' etc.
|
||||
CmdTypeSimple = "simple" // holds real commands
|
||||
)
|
||||
|
||||
type WordType struct {
|
||||
Type string
|
||||
Offset int
|
||||
QC QuoteContext
|
||||
Raw []rune
|
||||
Complete bool
|
||||
Prefix []rune
|
||||
Subs []*WordType
|
||||
}
|
||||
|
||||
type CmdType struct {
|
||||
Type string
|
||||
AssignmentWords []*WordType
|
||||
Words []*WordType
|
||||
NoneComplete bool // set to true when last-word is a "separator"
|
||||
}
|
||||
|
||||
type QuoteContext []string
|
||||
|
||||
var wordMetaMap map[string]wordMeta
|
||||
|
||||
// same order as https://www.gnu.org/software/bash/manual/html_node/Reserved-Words.html
|
||||
var bashReservedWords = []string{
|
||||
"if", "then", "elif", "else", "fi", "time",
|
||||
"for", "in", "until", "while", "do", "done",
|
||||
"case", "esac", "coproc", "select", "function",
|
||||
"{", "}", "[[", "]]", "!",
|
||||
}
|
||||
|
||||
// special reserved words: "for", "in", "case", "select", "function", "[[", and "]]"
|
||||
|
||||
var bashNoneRW = []string{
|
||||
"if", "then",
|
||||
"elif", "else", "fi", "time",
|
||||
"until", "while", "do", "done",
|
||||
"esac", "coproc",
|
||||
"{", "}", "!",
|
||||
}
|
||||
|
||||
type wordMeta struct {
|
||||
Type string
|
||||
EmptyWord []rune
|
||||
PrefixLen int
|
||||
SuffixLen int
|
||||
CanExtend bool
|
||||
QuoteContext bool
|
||||
}
|
||||
|
||||
func (m wordMeta) getSuffix() string {
|
||||
if m.SuffixLen == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(m.EmptyWord[len(m.EmptyWord)-m.SuffixLen:])
|
||||
}
|
||||
|
||||
func (m wordMeta) getPrefix() string {
|
||||
if m.PrefixLen == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(m.EmptyWord[:m.PrefixLen])
|
||||
}
|
||||
|
||||
func makeWordMeta(wtype string, emptyWord string, prefixLen int, suffixLen int, canExtend bool, quoteContext bool) {
|
||||
if len(emptyWord) != prefixLen+suffixLen {
|
||||
panic(fmt.Sprintf("invalid empty word %s %d %d", emptyWord, prefixLen, suffixLen))
|
||||
}
|
||||
wordMetaMap[wtype] = wordMeta{wtype, []rune(emptyWord), prefixLen, suffixLen, canExtend, quoteContext}
|
||||
}
|
||||
|
||||
func init() {
|
||||
wordMetaMap = make(map[string]wordMeta)
|
||||
makeWordMeta(WordTypeRaw, "", 0, 0, false, false)
|
||||
makeWordMeta(WordTypeLit, "", 0, 0, true, false)
|
||||
makeWordMeta(WordTypeOp, "", 0, 0, false, false)
|
||||
makeWordMeta(WordTypeKey, "", 0, 0, false, false)
|
||||
makeWordMeta(WordTypeGroup, "", 0, 0, false, false)
|
||||
makeWordMeta(WordTypeSimpleVar, "$", 1, 0, true, false)
|
||||
makeWordMeta(WordTypeVarBrace, "${}", 2, 1, true, true)
|
||||
makeWordMeta(WordTypeDQ, `""`, 1, 1, true, true)
|
||||
makeWordMeta(WordTypeDDQ, `$""`, 2, 1, true, true)
|
||||
makeWordMeta(WordTypeDP, "$()", 2, 1, false, false)
|
||||
makeWordMeta(WordTypeBQ, "``", 1, 1, false, false)
|
||||
makeWordMeta(WordTypeSQ, "''", 1, 1, true, false)
|
||||
makeWordMeta(WordTypeDSQ, "$''", 2, 1, true, false)
|
||||
makeWordMeta(WordTypeDPP, "$(())", 3, 2, false, false)
|
||||
makeWordMeta(WordTypePP, "(())", 2, 2, false, false)
|
||||
makeWordMeta(WordTypeDB, "$[]", 2, 1, false, false)
|
||||
}
|
||||
|
||||
func MakeEmptyWord(wtype string, qc QuoteContext, offset int, complete bool) *WordType {
|
||||
meta := wordMetaMap[wtype]
|
||||
if meta.Type == "" {
|
||||
meta = wordMetaMap[WordTypeRaw]
|
||||
}
|
||||
rtn := &WordType{Type: meta.Type, QC: qc, Offset: offset, Complete: complete}
|
||||
if len(meta.EmptyWord) > 0 {
|
||||
if complete {
|
||||
rtn.Raw = append([]rune(nil), meta.EmptyWord...)
|
||||
} else {
|
||||
rtn.Raw = append([]rune(nil), []rune(meta.getPrefix())...)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (qc QuoteContext) push(q string) QuoteContext {
|
||||
rtn := make([]string, 0, len(qc)+1)
|
||||
rtn = append(rtn, qc...)
|
||||
rtn = append(rtn, q)
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (qc QuoteContext) cur() string {
|
||||
if len(qc) == 0 {
|
||||
return ""
|
||||
}
|
||||
return qc[len(qc)-1]
|
||||
}
|
||||
|
||||
func (qc QuoteContext) clone() QuoteContext {
|
||||
if len(qc) == 0 {
|
||||
return nil
|
||||
}
|
||||
return append([]string(nil), qc...)
|
||||
}
|
||||
|
||||
func makeRepeatStr(ch byte, slen int) string {
|
||||
if slen == 0 {
|
||||
return ""
|
||||
}
|
||||
rtn := make([]byte, slen)
|
||||
for i := 0; i < slen; i++ {
|
||||
rtn[i] = ch
|
||||
}
|
||||
return string(rtn)
|
||||
}
|
||||
|
||||
func (w *WordType) isBlank() bool {
|
||||
return w.Type == WordTypeLit && len(w.Raw) == 0
|
||||
}
|
||||
|
||||
func (w *WordType) contentEndPos() int {
|
||||
if !w.Complete {
|
||||
return len(w.Raw)
|
||||
}
|
||||
wmeta := wordMetaMap[w.Type]
|
||||
return len(w.Raw) - wmeta.SuffixLen
|
||||
}
|
||||
|
||||
func (w *WordType) contentStartPos() int {
|
||||
wmeta := wordMetaMap[w.Type]
|
||||
return wmeta.PrefixLen
|
||||
}
|
||||
|
||||
func (w *WordType) canHaveSubs() bool {
|
||||
switch w.Type {
|
||||
case WordTypeGroup, WordTypeDQ, WordTypeDDQ, WordTypeDP, WordTypeBQ:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WordType) uncompletable() bool {
|
||||
switch w.Type {
|
||||
case WordTypeRaw, WordTypeOp, WordTypeKey, WordTypeDPP, WordTypePP, WordTypeDB, WordTypeBQ, WordTypeDP:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WordType) stringWithPos(pos int) string {
|
||||
notCompleteFlag := " "
|
||||
if !w.Complete {
|
||||
notCompleteFlag = "*"
|
||||
}
|
||||
str := string(w.Raw)
|
||||
if pos != -1 {
|
||||
str = utilfn.StrWithPos{Str: str, Pos: pos}.String()
|
||||
}
|
||||
return fmt.Sprintf("%-4s[%3d]%s %s%q", w.Type, w.Offset, notCompleteFlag, makeRepeatStr('_', len(w.Prefix)), str)
|
||||
}
|
||||
|
||||
func (w *WordType) String() string {
|
||||
notCompleteFlag := " "
|
||||
if !w.Complete {
|
||||
notCompleteFlag = "*"
|
||||
}
|
||||
return fmt.Sprintf("%-4s[%3d]%s %s%q", w.Type, w.Offset, notCompleteFlag, makeRepeatStr('_', len(w.Prefix)), string(w.Raw))
|
||||
}
|
||||
|
||||
// offset = -1 for don't show
|
||||
func dumpWords(words []*WordType, indentStr string, offset int) {
|
||||
wrotePos := false
|
||||
for _, word := range words {
|
||||
posInWord := false
|
||||
if !wrotePos && offset != -1 && offset <= word.Offset {
|
||||
fmt.Printf("%s* [%3d] [*]\n", indentStr, offset)
|
||||
wrotePos = true
|
||||
}
|
||||
if !wrotePos && offset != -1 && offset < word.Offset+len(word.Raw) {
|
||||
fmt.Printf("%s%s\n", indentStr, word.stringWithPos(offset-word.Offset))
|
||||
wrotePos = true
|
||||
posInWord = true
|
||||
} else {
|
||||
fmt.Printf("%s%s\n", indentStr, word.String())
|
||||
}
|
||||
if len(word.Subs) > 0 {
|
||||
if posInWord {
|
||||
wmeta := wordMetaMap[word.Type]
|
||||
dumpWords(word.Subs, indentStr+" ", offset-word.Offset-wmeta.PrefixLen)
|
||||
} else {
|
||||
dumpWords(word.Subs, indentStr+" ", -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dumpCommands(cmds []*CmdType, indentStr string, pos int) {
|
||||
for _, cmd := range cmds {
|
||||
fmt.Printf("%sCMD: %s [%d] pos:%d\n", indentStr, cmd.Type, len(cmd.Words), pos)
|
||||
dumpWords(cmd.AssignmentWords, indentStr+" *", pos)
|
||||
dumpWords(cmd.Words, indentStr+" ", pos)
|
||||
}
|
||||
}
|
||||
|
||||
func wordsToStr(words []*WordType) string {
|
||||
var buf bytes.Buffer
|
||||
for _, word := range words {
|
||||
if len(word.Prefix) > 0 {
|
||||
buf.WriteString(string(word.Prefix))
|
||||
}
|
||||
buf.WriteString(string(word.Raw))
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// recognizes reserved words in first position
|
||||
func convertToAnyReservedWord(w *WordType) bool {
|
||||
if w == nil || w.Type != WordTypeLit {
|
||||
return false
|
||||
}
|
||||
rawVal := string(w.Raw)
|
||||
for _, rw := range bashReservedWords {
|
||||
if rawVal == rw {
|
||||
w.Type = WordTypeKey
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// recognizes the specific reserved-word given only ('in' and 'do' in 'for', 'case', and 'select' commands)
|
||||
func convertToReservedWord(w *WordType, reservedWord string) {
|
||||
if w == nil || w.Type != WordTypeLit {
|
||||
return
|
||||
}
|
||||
if string(w.Raw) == reservedWord {
|
||||
w.Type = WordTypeKey
|
||||
}
|
||||
}
|
||||
|
||||
func isNoneReservedWord(w *WordType) bool {
|
||||
if w.Type != WordTypeKey {
|
||||
return false
|
||||
}
|
||||
rawVal := string(w.Raw)
|
||||
for _, rw := range bashNoneRW {
|
||||
if rawVal == rw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type parseCmdState struct {
|
||||
Input []*WordType
|
||||
InputPos int
|
||||
|
||||
Rtn []*CmdType
|
||||
Cur *CmdType
|
||||
}
|
||||
|
||||
func (state *parseCmdState) isEof() bool {
|
||||
return state.InputPos >= len(state.Input)
|
||||
}
|
||||
|
||||
func (state *parseCmdState) curWord() *WordType {
|
||||
if state.isEof() {
|
||||
return nil
|
||||
}
|
||||
return state.Input[state.InputPos]
|
||||
}
|
||||
|
||||
func (state *parseCmdState) lastCmd() *CmdType {
|
||||
if len(state.Rtn) == 0 {
|
||||
return nil
|
||||
}
|
||||
return state.Rtn[len(state.Rtn)-1]
|
||||
}
|
||||
|
||||
func (state *parseCmdState) makeNoneCmd(sep bool) {
|
||||
if state.Cur == nil || state.Cur.Type != CmdTypeNone {
|
||||
state.Cur = &CmdType{Type: CmdTypeNone}
|
||||
state.Rtn = append(state.Rtn, state.Cur)
|
||||
}
|
||||
state.Cur.Words = append(state.Cur.Words, state.curWord())
|
||||
if sep {
|
||||
state.Cur.NoneComplete = true
|
||||
state.Cur = nil
|
||||
}
|
||||
state.InputPos++
|
||||
}
|
||||
|
||||
func (state *parseCmdState) handleKeyword(word *WordType) bool {
|
||||
if word.Type != WordTypeKey {
|
||||
return false
|
||||
}
|
||||
if isNoneReservedWord(word) {
|
||||
state.makeNoneCmd(true)
|
||||
return true
|
||||
}
|
||||
rw := string(word.Raw)
|
||||
if rw == "[[" {
|
||||
// just ignore everything between [[ and ]]
|
||||
for !state.isEof() {
|
||||
curWord := state.curWord()
|
||||
if curWord.Type == WordTypeLit && string(curWord.Raw) == "]]" {
|
||||
convertToReservedWord(curWord, "]]")
|
||||
state.makeNoneCmd(false)
|
||||
break
|
||||
}
|
||||
state.makeNoneCmd(false)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if rw == "case" {
|
||||
// ignore everything between "case" and "esac"
|
||||
for !state.isEof() {
|
||||
curWord := state.curWord()
|
||||
if curWord.Type == WordTypeKey && string(curWord.Raw) == "esac" {
|
||||
state.makeNoneCmd(false)
|
||||
break
|
||||
}
|
||||
state.makeNoneCmd(false)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if rw == "for" || rw == "select" {
|
||||
// ignore until a "do"
|
||||
for !state.isEof() {
|
||||
curWord := state.curWord()
|
||||
if curWord.Type == WordTypeKey && string(curWord.Raw) == "do" {
|
||||
state.makeNoneCmd(true)
|
||||
break
|
||||
}
|
||||
state.makeNoneCmd(false)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if rw == "in" {
|
||||
// the "for" and "case" clauses should skip "in". so encountering an "in" here is a syntax error.
|
||||
// just treat it as a none and allow a new command after.
|
||||
state.makeNoneCmd(false)
|
||||
return true
|
||||
}
|
||||
if rw == "function" {
|
||||
// ignore until '{'
|
||||
for !state.isEof() {
|
||||
curWord := state.curWord()
|
||||
if curWord.Type == WordTypeKey && string(curWord.Raw) == "{" {
|
||||
state.makeNoneCmd(true)
|
||||
break
|
||||
}
|
||||
state.makeNoneCmd(false)
|
||||
}
|
||||
return true
|
||||
}
|
||||
state.makeNoneCmd(true)
|
||||
return true
|
||||
}
|
||||
|
||||
func isCmdSeparatorOp(word *WordType) bool {
|
||||
if word.Type != WordTypeOp {
|
||||
return false
|
||||
}
|
||||
opVal := string(word.Raw)
|
||||
return opVal == ";" || opVal == "\n" || opVal == "&" || opVal == "|" || opVal == "|&" || opVal == "&&" || opVal == "||" || opVal == "(" || opVal == ")"
|
||||
}
|
||||
|
||||
func (state *parseCmdState) handleOp(word *WordType) bool {
|
||||
opVal := string(word.Raw)
|
||||
// sequential separators
|
||||
if opVal == ";" || opVal == "\n" {
|
||||
state.makeNoneCmd(true)
|
||||
return true
|
||||
}
|
||||
// separator
|
||||
if opVal == "&" {
|
||||
state.makeNoneCmd(true)
|
||||
return true
|
||||
}
|
||||
// pipelines
|
||||
if opVal == "|" || opVal == "|&" {
|
||||
state.makeNoneCmd(true)
|
||||
return true
|
||||
}
|
||||
// lists
|
||||
if opVal == "&&" || opVal == "||" {
|
||||
state.makeNoneCmd(true)
|
||||
return true
|
||||
}
|
||||
// subshell
|
||||
if opVal == "(" || opVal == ")" {
|
||||
state.makeNoneCmd(true)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func wordSliceBoundedIdx(words []*WordType, idx int) *WordType {
|
||||
if idx >= len(words) {
|
||||
return nil
|
||||
}
|
||||
return words[idx]
|
||||
}
|
||||
|
||||
// note that a newline "op" can appear in the third position of "for" or "case". the "in" keyword is still converted because of wordNum == 0
|
||||
func identifyReservedWords(words []*WordType) {
|
||||
wordNum := 0
|
||||
lastReserved := false
|
||||
for idx, word := range words {
|
||||
if wordNum == 0 || lastReserved {
|
||||
convertToAnyReservedWord(word)
|
||||
}
|
||||
if word.Type == WordTypeKey {
|
||||
rwVal := string(word.Raw)
|
||||
switch rwVal {
|
||||
case "for":
|
||||
lastReserved = false
|
||||
third := wordSliceBoundedIdx(words, idx+2)
|
||||
convertToReservedWord(third, "in")
|
||||
convertToReservedWord(third, "do")
|
||||
|
||||
case "case":
|
||||
lastReserved = false
|
||||
third := wordSliceBoundedIdx(words, idx+2)
|
||||
convertToReservedWord(third, "in")
|
||||
|
||||
case "in":
|
||||
lastReserved = false
|
||||
|
||||
default:
|
||||
lastReserved = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastReserved = false
|
||||
if isCmdSeparatorOp(word) {
|
||||
wordNum = 0
|
||||
continue
|
||||
}
|
||||
wordNum++
|
||||
}
|
||||
}
|
||||
|
||||
func ResetWordOffsets(words []*WordType, startIdx int) {
|
||||
pos := startIdx
|
||||
for _, word := range words {
|
||||
pos += len(word.Prefix)
|
||||
word.Offset = pos
|
||||
if len(word.Subs) > 0 {
|
||||
ResetWordOffsets(word.Subs, 0)
|
||||
}
|
||||
pos += len(word.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func CommandsToWords(cmds []*CmdType) []*WordType {
|
||||
var rtn []*WordType
|
||||
for _, cmd := range cmds {
|
||||
rtn = append(rtn, cmd.Words...)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (c *CmdType) stripPrefix() []rune {
|
||||
if len(c.AssignmentWords) > 0 {
|
||||
w := c.AssignmentWords[0]
|
||||
prefix := w.Prefix
|
||||
if len(prefix) == 0 {
|
||||
return nil
|
||||
}
|
||||
newWord := *w
|
||||
newWord.Prefix = nil
|
||||
c.AssignmentWords[0] = &newWord
|
||||
return prefix
|
||||
}
|
||||
if len(c.Words) > 0 {
|
||||
w := c.Words[0]
|
||||
prefix := w.Prefix
|
||||
if len(prefix) == 0 {
|
||||
return nil
|
||||
}
|
||||
newWord := *w
|
||||
newWord.Prefix = nil
|
||||
c.Words[0] = &newWord
|
||||
return prefix
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CmdType) isEmpty() bool {
|
||||
return len(c.AssignmentWords) == 0 && len(c.Words) == 0
|
||||
}
|
||||
|
||||
func (c *CmdType) lastWord() *WordType {
|
||||
if len(c.Words) > 0 {
|
||||
return c.Words[len(c.Words)-1]
|
||||
}
|
||||
if len(c.AssignmentWords) > 0 {
|
||||
return c.AssignmentWords[len(c.AssignmentWords)-1]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CmdType) firstWord() *WordType {
|
||||
if len(c.AssignmentWords) > 0 {
|
||||
return c.AssignmentWords[0]
|
||||
}
|
||||
if len(c.Words) > 0 {
|
||||
return c.Words[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CmdType) offset() int {
|
||||
firstWord := c.firstWord()
|
||||
if firstWord == nil {
|
||||
return 0
|
||||
}
|
||||
return firstWord.Offset
|
||||
}
|
||||
|
||||
func (c *CmdType) endOffset() int {
|
||||
lastWord := c.lastWord()
|
||||
if lastWord == nil {
|
||||
return 0
|
||||
}
|
||||
return lastWord.Offset + len(lastWord.Raw)
|
||||
}
|
||||
|
||||
func indexInRunes(arr []rune, ch rune) int {
|
||||
for idx, r := range arr {
|
||||
if r == ch {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func isAssignmentWord(w *WordType) bool {
|
||||
if w.Type == WordTypeLit || w.Type == WordTypeGroup {
|
||||
eqIdx := indexInRunes(w.Raw, '=')
|
||||
if eqIdx == -1 {
|
||||
return false
|
||||
}
|
||||
prefix := w.Raw[0:eqIdx]
|
||||
return isSimpleVarName(prefix)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// simple commands steal whitespace from subsequent commands
|
||||
func cmdWhitespaceFixup(cmds []*CmdType) {
|
||||
for idx := 0; idx < len(cmds)-1; idx++ {
|
||||
cmd := cmds[idx]
|
||||
if cmd.Type != CmdTypeSimple || cmd.isEmpty() {
|
||||
continue
|
||||
}
|
||||
nextCmd := cmds[idx+1]
|
||||
nextPrefix := nextCmd.stripPrefix()
|
||||
if len(nextPrefix) > 0 {
|
||||
blankWord := &WordType{Type: WordTypeLit, QC: cmd.lastWord().QC, Offset: cmd.endOffset() + len(nextPrefix), Prefix: nextPrefix, Complete: true}
|
||||
cmd.Words = append(cmd.Words, blankWord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ParseCommands(words []*WordType) []*CmdType {
|
||||
identifyReservedWords(words)
|
||||
state := parseCmdState{Input: words}
|
||||
for {
|
||||
if state.isEof() {
|
||||
break
|
||||
}
|
||||
word := state.curWord()
|
||||
if word.Type == WordTypeKey {
|
||||
done := state.handleKeyword(word)
|
||||
if done {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if word.Type == WordTypeOp {
|
||||
done := state.handleOp(word)
|
||||
if done {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if state.Cur == nil || state.Cur.Type != CmdTypeSimple {
|
||||
state.Cur = &CmdType{Type: CmdTypeSimple}
|
||||
state.Rtn = append(state.Rtn, state.Cur)
|
||||
}
|
||||
if len(state.Cur.Words) == 0 && isAssignmentWord(word) {
|
||||
state.Cur.AssignmentWords = append(state.Cur.AssignmentWords, word)
|
||||
} else {
|
||||
state.Cur.Words = append(state.Cur.Words, word)
|
||||
}
|
||||
state.InputPos++
|
||||
}
|
||||
cmdWhitespaceFixup(state.Rtn)
|
||||
return state.Rtn
|
||||
}
|
219
wavesrv/pkg/shparse/shparse_test.go
Normal file
219
wavesrv/pkg/shparse/shparse_test.go
Normal file
@ -0,0 +1,219 @@
|
||||
package shparse
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/commandlinedev/prompt-server/pkg/utilfn"
|
||||
)
|
||||
|
||||
// $(ls f[*]); ./x
|
||||
// ls f => raw["ls f"] -> lit["ls f"] -> lit["ls"] lit["f"]
|
||||
// w; ls foo; => raw["w; ls foo;"]
|
||||
// ls&"ls" => raw["ls&ls"] => lit["ls&"] dq["ls"] => lit["ls"] key["&"] dq["ls"]
|
||||
// ls $x; echo `ls f => raw["ls $x; echo `ls f"]
|
||||
// > echo $foo{x,y}
|
||||
|
||||
func testParse(t *testing.T, s string) {
|
||||
words := Tokenize(s)
|
||||
|
||||
fmt.Printf("parse <<\n%s\n>>\n", s)
|
||||
dumpWords(words, " ", 8)
|
||||
outStr := wordsToStr(words)
|
||||
if outStr != s {
|
||||
t.Errorf("tokenization output does not match input: %q => %q", s, outStr)
|
||||
}
|
||||
fmt.Printf("------\n\n")
|
||||
}
|
||||
|
||||
func Test1(t *testing.T) {
|
||||
testParse(t, "ls")
|
||||
testParse(t, "ls 'foo'")
|
||||
testParse(t, `ls "hello" $'\''`)
|
||||
testParse(t, `ls "foo`)
|
||||
testParse(t, `echo $11 $xyz $ `)
|
||||
testParse(t, `echo $(ls ${x:"hello"} foo`)
|
||||
testParse(t, `ls ${x:"hello"} $[2+2] $((5 * 10)) $(ls; ls&)`)
|
||||
testParse(t, `ls;ls&./foo > out 2> "out2"`)
|
||||
testParse(t, `(( x = 5)); ls& cd ~/work/"hello again"`)
|
||||
testParse(t, `echo "hello"abc$(ls)$x${y:foo}`)
|
||||
testParse(t, `echo $(ls; ./x "foo")`)
|
||||
testParse(t, `echo $(ls; (cd foo; ls); (cd bar; ls))xyz`)
|
||||
testParse(t, `echo "$x ${y:-foo}"`)
|
||||
testParse(t, `command="$(echo "$input" | sed -e "s/^[ \t]*\([^ \t]*\)[ \t]*.*$/\1/g")"`)
|
||||
testParse(t, `echo $(ls $)`)
|
||||
testParse(t, `echo ${x:-hello\}"}"} 2nd`)
|
||||
testParse(t, `echo "$(ls "foo") more $x"`)
|
||||
testParse(t, "echo `ls $x \"hello $x\" \\`ls\\`; ./foo`")
|
||||
testParse(t, `echo $"hello $x $(ls)"`)
|
||||
testParse(t, "echo 'hello'\nls\n")
|
||||
testParse(t, "echo 'hello'abc$'\a'")
|
||||
}
|
||||
|
||||
func lastWord(words []*WordType) *WordType {
|
||||
if len(words) == 0 {
|
||||
return nil
|
||||
}
|
||||
return words[len(words)-1]
|
||||
}
|
||||
|
||||
func testExtend(t *testing.T, startStr string, extendStr string, complete bool, expStr string) {
|
||||
startSP := utilfn.ParseToSP(startStr)
|
||||
words := Tokenize(startSP.Str)
|
||||
word := findCompletionWordAtPos(words, startSP.Pos, true)
|
||||
if word == nil {
|
||||
word = MakeEmptyWord(WordTypeLit, nil, startSP.Pos, true)
|
||||
}
|
||||
outSP := Extend(word, startSP.Pos-word.Offset, extendStr, complete)
|
||||
expSP := utilfn.ParseToSP(expStr)
|
||||
fmt.Printf("extend: [%s] + %q => [%s]\n", startStr, extendStr, outSP)
|
||||
if outSP != expSP {
|
||||
t.Errorf("extension does not match: [%s] + %q => [%s] expected [%s]\n", startStr, extendStr, outSP, expSP)
|
||||
}
|
||||
}
|
||||
|
||||
func Test2(t *testing.T) {
|
||||
testExtend(t, `he[*]`, "llo", false, "hello[*]")
|
||||
testExtend(t, `he[*]`, "llo", true, "hello [*]")
|
||||
testExtend(t, `'mi[*]e`, "k", false, "'mik[*]e")
|
||||
testExtend(t, `'mi[*]e`, "k", true, "'mik[*]e")
|
||||
testExtend(t, `'mi[*]'`, "ke", true, "'mike' [*]")
|
||||
testExtend(t, `'mi'[*]`, "ke", true, "'mike' [*]")
|
||||
testExtend(t, `'mi[*]'`, "ke", false, "'mike[*]'")
|
||||
testExtend(t, `'mi'[*]`, "ke", false, "'mike[*]'")
|
||||
testExtend(t, `$f[*]`, "oo", false, "$foo[*]")
|
||||
testExtend(t, `${f}[*]`, "oo", false, "${foo[*]}")
|
||||
testExtend(t, `${f[*]}`, "oo", true, "${foo} [*]")
|
||||
testExtend(t, `[*]`, "more stuff", false, `more\ stuff[*]`)
|
||||
testExtend(t, `[*]`, "hello\amike", false, `hello$'\a'mike[*]`)
|
||||
testExtend(t, `$'he[*]'`, "\x01\x02\x0a", true, `$'he\x01\x02\n' [*]`)
|
||||
testExtend(t, `${x}\ [*]ll$y`, "e", false, `${x}\ e[*]ll$y`)
|
||||
testExtend(t, `"he[*]"`, "$$o", true, `"he\$\$o" [*]`)
|
||||
testExtend(t, `"h[*]llo"`, "e", false, `"he[*]llo"`)
|
||||
testExtend(t, `"h[*]llo"`, "e", true, `"he[*]llo"`)
|
||||
testExtend(t, `"[*]${h}llo"`, "e\x01", true, `"e"$'\x01'[*]"${h}llo"`)
|
||||
testExtend(t, `"${h}llo[*]"`, "e\x01", true, `"${h}lloe"$'\x01' [*]`)
|
||||
testExtend(t, `"${h}llo[*]"`, "e\x01", false, `"${h}lloe"$'\x01'[*]`)
|
||||
testExtend(t, `"${h}ll[*]o"`, "e\x01", false, `"${h}lle"$'\x01'[*]"o"`)
|
||||
testExtend(t, `"ab[*]c${x}def"`, "\x01", false, `"ab"$'\x01'[*]"c${x}def"`)
|
||||
testExtend(t, `'ab[*]ef'`, "\x01", false, `'ab'$'\x01'[*]'ef'`)
|
||||
|
||||
// testExtend(t, `'he'`, "llo", `'hello'`)
|
||||
// testExtend(t, `'he'`, "'", `'he'\'''`)
|
||||
// testExtend(t, `'he'`, "'\x01", `'he'\'$'\x01'''`)
|
||||
// testExtend(t, `he`, "llo", `hello`)
|
||||
// testExtend(t, `he`, "l*l'\x01\x07o", `hel\*l\'$'\x01'$'\a'o`)
|
||||
// testExtend(t, `$x`, "fo|o", `$xfoo`)
|
||||
// testExtend(t, `${x`, "fo|o", `${xfoo`)
|
||||
// testExtend(t, `$'f`, "oo", `$'foo`)
|
||||
// testExtend(t, `$'f`, "'\x01\x07o", `$'f\'\x01\ao`)
|
||||
// testExtend(t, `"f"`, "oo", `"foo"`)
|
||||
// testExtend(t, `"mi"`, "ke's \"hello\"", `"mike's \"hello\""`)
|
||||
// testExtend(t, `"t"`, "t\x01\x07", `"tt"$'\x01'$'\a'""`)
|
||||
}
|
||||
|
||||
func testParseCommands(t *testing.T, str string) {
|
||||
fmt.Printf("parse: %q\n", str)
|
||||
words := Tokenize(str)
|
||||
cmds := ParseCommands(words)
|
||||
dumpCommands(cmds, " ", -1)
|
||||
fmt.Printf("\n")
|
||||
}
|
||||
|
||||
func TestCmd(t *testing.T) {
|
||||
testParseCommands(t, "ls foo")
|
||||
testParseCommands(t, "function foo () { echo hello; }")
|
||||
testParseCommands(t, "ls foo && ls bar; ./run $x hello | xargs foo; ")
|
||||
testParseCommands(t, "if [[ 2 > 1 ]]; then echo hello\nelse echo world; echo next; done")
|
||||
testParseCommands(t, "case lots of stuff; i don\\'t know how to parse; esac; ls foo")
|
||||
testParseCommands(t, "(ls & ./x \n \n); for x in $vars 3; do { echo $x; ls foo ; } done")
|
||||
testParseCommands(t, `ls f"oo" "${x:"hello$y"}"`)
|
||||
testParseCommands(t, `x="foo $y" z=10 ls`)
|
||||
}
|
||||
|
||||
func testCompPos(t *testing.T, cmdStr string, compType string, hasCommand bool, cmdWordPos int, hasWord bool, superOffset int) {
|
||||
cmdSP := utilfn.ParseToSP(cmdStr)
|
||||
words := Tokenize(cmdSP.Str)
|
||||
cmds := ParseCommands(words)
|
||||
cpos := FindCompletionPos(cmds, cmdSP.Pos)
|
||||
fmt.Printf("testCompPos [%d] %q => [%s] %v\n", cmdSP.Pos, cmdStr, cpos.CompType, cpos)
|
||||
if cpos.CompType != compType {
|
||||
t.Errorf("testCompPos %q => invalid comp-type %q, expected %q", cmdStr, cpos.CompType, compType)
|
||||
}
|
||||
if cpos.CompWord != nil {
|
||||
fmt.Printf(" found-word: %d %s\n", cpos.CompWordOffset, cpos.CompWord.stringWithPos(cpos.CompWordOffset))
|
||||
}
|
||||
if cpos.Cmd != nil {
|
||||
fmt.Printf(" found-cmd: ")
|
||||
dumpCommands([]*CmdType{cpos.Cmd}, " ", cpos.RawPos)
|
||||
}
|
||||
dumpCommands(cmds, " ", cmdSP.Pos)
|
||||
fmt.Printf("\n")
|
||||
if cpos.RawPos+cpos.SuperOffset != cmdSP.Pos {
|
||||
t.Errorf("testCompPos %q => bad rawpos:%d superoffset:%d expected:%d", cmdStr, cpos.RawPos, cpos.SuperOffset, cmdSP.Pos)
|
||||
}
|
||||
if (cpos.Cmd != nil) != hasCommand {
|
||||
t.Errorf("testCompPos %q => bad has-command exp:%v", cmdStr, hasCommand)
|
||||
}
|
||||
if (cpos.CompWord != nil) != hasWord {
|
||||
t.Errorf("testCompPos %q => bad has-word exp:%v", cmdStr, hasWord)
|
||||
}
|
||||
if cpos.CmdWordPos != cmdWordPos {
|
||||
t.Errorf("testCompPos %q => bad cmd-word-pos got:%d exp:%d", cmdStr, cpos.CmdWordPos, cmdWordPos)
|
||||
}
|
||||
if cpos.SuperOffset != superOffset {
|
||||
t.Errorf("testCompPos %q => bad super-offset got:%d exp:%d", cmdStr, cpos.SuperOffset, superOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompPos(t *testing.T) {
|
||||
testCompPos(t, "ls [*]foo", CompTypeArg, true, 1, false, 0)
|
||||
testCompPos(t, "ls foo [*];", CompTypeArg, true, 2, false, 0)
|
||||
testCompPos(t, "ls foo ;[*]", CompTypeCommand, false, 0, false, 0)
|
||||
testCompPos(t, "ls foo >[*]> ./bar", CompTypeInvalid, true, 2, true, 0)
|
||||
testCompPos(t, "l[*]s", CompTypeCommand, true, 0, true, 0)
|
||||
testCompPos(t, "ls[*]", CompTypeCommand, true, 0, true, 0)
|
||||
testCompPos(t, "x=10 { (ls ./f[*] more); ls }", CompTypeArg, true, 1, true, 0)
|
||||
testCompPos(t, "for x in 1[*] 2 3; do ", CompTypeBasic, false, 0, true, 0)
|
||||
testCompPos(t, "for[*] x in 1 2 3;", CompTypeInvalid, false, 0, true, 0)
|
||||
testCompPos(t, `ls "abc $(ls -l t[*])" && foo`, CompTypeArg, true, 2, true, 10)
|
||||
testCompPos(t, "ls ${abc:$(ls -l [*])}", CompTypeVar, false, 0, true, 0) // we don't sub-parse inside of ${} (so this returns "var" right now)
|
||||
testCompPos(t, `ls abc"$(ls $"echo $(ls ./[*]x) foo)" `, CompTypeArg, true, 1, true, 21)
|
||||
testCompPos(t, `ls "abc$d[*]"`, CompTypeVar, false, 0, true, 4)
|
||||
testCompPos(t, `ls "abc$d$'a[*]`, CompTypeArg, true, 1, true, 0)
|
||||
testCompPos(t, `ls $[*]'foo`, CompTypeArg, true, 1, true, 0)
|
||||
testCompPos(t, `echo $TE[*]`, CompTypeVar, false, 0, true, 0)
|
||||
}
|
||||
|
||||
func testExpand(t *testing.T, str string, pos int, expStr string, expInfo *ExpandInfo) {
|
||||
ectx := ExpandContext{HomeDir: "/Users/mike"}
|
||||
words := Tokenize(str)
|
||||
if len(words) == 0 {
|
||||
t.Errorf("could not tokenize any words from %q", str)
|
||||
return
|
||||
}
|
||||
word := words[0]
|
||||
output, info := SimpleExpandPrefix(ectx, word, pos)
|
||||
if output != expStr {
|
||||
t.Errorf("error expanding %q, output:%q exp:%q", str, output, expStr)
|
||||
} else {
|
||||
fmt.Printf("expand: %q (%d) => %q\n", str, pos, output)
|
||||
}
|
||||
if expInfo != nil {
|
||||
if info != *expInfo {
|
||||
t.Errorf("error expanding %q, info:%v exp:%v", str, info, expInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpand(t *testing.T) {
|
||||
testExpand(t, "hello", 3, "hel", nil)
|
||||
testExpand(t, "he\\$xabc", 6, "he$xa", nil)
|
||||
testExpand(t, "he${x}abc", 6, "he${x}", nil)
|
||||
testExpand(t, "'hello\"mike'", 8, "hello\"m", nil)
|
||||
testExpand(t, `$'abc\x01def`, 10, "abc\x01d", nil)
|
||||
testExpand(t, `$((2 + 2))`, 6, "$((2 +", &ExpandInfo{HasSpecial: true})
|
||||
testExpand(t, `abc"def"`, 6, "abcde", nil)
|
||||
testExpand(t, `"abc$x$'"'""`, 12, "abc$x\"", nil)
|
||||
testExpand(t, `'he'\''s'`, 9, "he's", nil)
|
||||
}
|
601
wavesrv/pkg/shparse/tokenize.go
Normal file
601
wavesrv/pkg/shparse/tokenize.go
Normal file
@ -0,0 +1,601 @@
|
||||
package shparse
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// from bash source
|
||||
//
|
||||
// shell_meta_chars "()<>;&|"
|
||||
//
|
||||
|
||||
type tokenizeOutputState struct {
|
||||
Rtn []*WordType
|
||||
CurWord *WordType
|
||||
SavedPrefix []rune
|
||||
}
|
||||
|
||||
func copyRunes(rarr []rune) []rune {
|
||||
if len(rarr) == 0 {
|
||||
return nil
|
||||
}
|
||||
return append([]rune(nil), rarr...)
|
||||
}
|
||||
|
||||
// does not set CurWord
|
||||
func (state *tokenizeOutputState) appendStandaloneWord(word *WordType) {
|
||||
state.delimitCurWord()
|
||||
if len(state.SavedPrefix) > 0 {
|
||||
word.Prefix = state.SavedPrefix
|
||||
state.SavedPrefix = nil
|
||||
}
|
||||
state.Rtn = append(state.Rtn, word)
|
||||
}
|
||||
|
||||
func (state *tokenizeOutputState) appendWord(word *WordType) {
|
||||
if len(state.SavedPrefix) > 0 {
|
||||
word.Prefix = state.SavedPrefix
|
||||
state.SavedPrefix = nil
|
||||
}
|
||||
if state.CurWord == nil {
|
||||
state.CurWord = word
|
||||
return
|
||||
}
|
||||
state.ensureGroupWord()
|
||||
word.Offset = word.Offset - state.CurWord.Offset
|
||||
state.CurWord.Subs = append(state.CurWord.Subs, word)
|
||||
state.CurWord.Raw = append(state.CurWord.Raw, word.Raw...)
|
||||
}
|
||||
|
||||
func (state *tokenizeOutputState) ensureGroupWord() {
|
||||
if state.CurWord == nil {
|
||||
panic("invalid state, cannot make group word when CurWord is nil")
|
||||
}
|
||||
if state.CurWord.Type == WordTypeGroup {
|
||||
return
|
||||
}
|
||||
// moves the prefix from CurWord to the new group word, resets offsets
|
||||
groupWord := &WordType{
|
||||
Type: WordTypeGroup,
|
||||
Offset: state.CurWord.Offset,
|
||||
QC: state.CurWord.QC,
|
||||
Raw: copyRunes(state.CurWord.Raw),
|
||||
Complete: true,
|
||||
Prefix: state.CurWord.Prefix,
|
||||
}
|
||||
state.CurWord.Prefix = nil
|
||||
state.CurWord.Offset = 0
|
||||
groupWord.Subs = []*WordType{state.CurWord}
|
||||
state.CurWord = groupWord
|
||||
}
|
||||
|
||||
func ungroupWord(groupWord *WordType) []*WordType {
|
||||
if groupWord.Type != WordTypeGroup {
|
||||
return []*WordType{groupWord}
|
||||
}
|
||||
rtn := groupWord.Subs
|
||||
if len(groupWord.Prefix) > 0 && len(rtn) > 0 {
|
||||
newPrefix := append([]rune{}, groupWord.Prefix...)
|
||||
newPrefix = append(newPrefix, rtn[0].Prefix...)
|
||||
rtn[0].Prefix = newPrefix
|
||||
}
|
||||
for _, word := range rtn {
|
||||
word.Offset = word.Offset + groupWord.Offset
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (state *tokenizeOutputState) ensureLitCurWord(pc *parseContext) {
|
||||
if state.CurWord == nil {
|
||||
state.CurWord = pc.makeWord(WordTypeLit, 0, true)
|
||||
state.CurWord.Prefix = state.SavedPrefix
|
||||
state.SavedPrefix = nil
|
||||
return
|
||||
}
|
||||
if state.CurWord.Type == WordTypeLit {
|
||||
return
|
||||
}
|
||||
state.ensureGroupWord()
|
||||
lastWord := state.CurWord.Subs[len(state.CurWord.Subs)-1]
|
||||
if lastWord.Type != WordTypeLit {
|
||||
if len(state.SavedPrefix) > 0 {
|
||||
panic("invalid state, there can be no saved prefix")
|
||||
}
|
||||
litWord := pc.makeWord(WordTypeLit, 0, true)
|
||||
litWord.Offset = litWord.Offset - state.CurWord.Offset
|
||||
state.CurWord.Subs = append(state.CurWord.Subs, litWord)
|
||||
}
|
||||
}
|
||||
|
||||
func (state *tokenizeOutputState) delimitCurWord() {
|
||||
if state.CurWord != nil {
|
||||
state.Rtn = append(state.Rtn, state.CurWord)
|
||||
state.CurWord = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (state *tokenizeOutputState) delimitWithSpace(spaceCh rune) {
|
||||
state.delimitCurWord()
|
||||
state.SavedPrefix = append(state.SavedPrefix, spaceCh)
|
||||
}
|
||||
|
||||
func (state *tokenizeOutputState) appendLiteral(pc *parseContext, ch rune) {
|
||||
state.ensureLitCurWord(pc)
|
||||
if state.CurWord.Type == WordTypeLit {
|
||||
state.CurWord.Raw = append(state.CurWord.Raw, ch)
|
||||
} else if state.CurWord.Type == WordTypeGroup {
|
||||
lastWord := state.CurWord.Subs[len(state.CurWord.Subs)-1]
|
||||
if lastWord.Type != WordTypeLit {
|
||||
panic(fmt.Sprintf("invalid curword type (group) %q", state.CurWord.Type))
|
||||
}
|
||||
lastWord.Raw = append(lastWord.Raw, ch)
|
||||
state.CurWord.Raw = append(state.CurWord.Raw, ch)
|
||||
} else {
|
||||
panic(fmt.Sprintf("invalid curword type %q", state.CurWord.Type))
|
||||
}
|
||||
}
|
||||
|
||||
func (state *tokenizeOutputState) finish(pc *parseContext) {
|
||||
state.delimitCurWord()
|
||||
if len(state.SavedPrefix) > 0 {
|
||||
state.ensureLitCurWord(pc)
|
||||
state.delimitCurWord()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *parseContext) tokenizeVarBrace() ([]*WordType, bool) {
|
||||
state := &tokenizeOutputState{}
|
||||
eofExit := false
|
||||
for {
|
||||
ch := c.cur()
|
||||
if ch == 0 {
|
||||
eofExit = true
|
||||
break
|
||||
}
|
||||
if ch == '}' {
|
||||
c.Pos++
|
||||
break
|
||||
}
|
||||
var quoteWord *WordType
|
||||
if ch == '\'' {
|
||||
quoteWord = c.parseStrSQ()
|
||||
}
|
||||
if quoteWord == nil && ch == '"' {
|
||||
quoteWord = c.parseStrDQ()
|
||||
}
|
||||
isNextBrace := c.at(1) == '}'
|
||||
if quoteWord == nil && ch == '$' && !isNextBrace {
|
||||
quoteWord = c.parseStrANSI()
|
||||
if quoteWord == nil {
|
||||
quoteWord = c.parseStrDDQ()
|
||||
}
|
||||
if quoteWord == nil {
|
||||
quoteWord = c.parseExpansion()
|
||||
}
|
||||
}
|
||||
if quoteWord != nil {
|
||||
state.appendWord(quoteWord)
|
||||
continue
|
||||
}
|
||||
if ch == '\\' && c.at(1) != 0 {
|
||||
state.appendLiteral(c, ch)
|
||||
state.appendLiteral(c, c.at(1))
|
||||
c.Pos += 2
|
||||
continue
|
||||
}
|
||||
state.appendLiteral(c, ch)
|
||||
c.Pos++
|
||||
}
|
||||
return state.Rtn, eofExit
|
||||
}
|
||||
|
||||
func (c *parseContext) tokenizeDQ() ([]*WordType, bool) {
|
||||
state := &tokenizeOutputState{}
|
||||
eofExit := false
|
||||
for {
|
||||
ch := c.cur()
|
||||
if ch == 0 {
|
||||
eofExit = true
|
||||
break
|
||||
}
|
||||
if ch == '"' {
|
||||
c.Pos++
|
||||
break
|
||||
}
|
||||
if ch == '$' && c.at(1) != 0 {
|
||||
quoteWord := c.parseStrANSI()
|
||||
if quoteWord == nil {
|
||||
quoteWord = c.parseStrDDQ()
|
||||
}
|
||||
if quoteWord == nil {
|
||||
quoteWord = c.parseExpansion()
|
||||
}
|
||||
if quoteWord != nil {
|
||||
state.appendWord(quoteWord)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if ch == '\\' && c.at(1) != 0 {
|
||||
state.appendLiteral(c, ch)
|
||||
state.appendLiteral(c, c.at(1))
|
||||
c.Pos += 2
|
||||
continue
|
||||
}
|
||||
state.appendLiteral(c, ch)
|
||||
c.Pos++
|
||||
}
|
||||
state.finish(c)
|
||||
if len(state.Rtn) == 0 {
|
||||
return nil, eofExit
|
||||
}
|
||||
if len(state.Rtn) == 1 && state.Rtn[0].Type == WordTypeGroup {
|
||||
return ungroupWord(state.Rtn[0]), eofExit
|
||||
}
|
||||
return state.Rtn, eofExit
|
||||
}
|
||||
|
||||
// returns (words, eofexit)
|
||||
// backticks (WordTypeBQ) handle backslash in a special way, but that seems to mainly effect execution (not completion)
|
||||
// de_backslash => removes initial backslash in \`, \\, and \$ before execution
|
||||
func (c *parseContext) tokenizeRaw() ([]*WordType, bool) {
|
||||
state := &tokenizeOutputState{}
|
||||
isExpSubShell := c.QC.cur() == WordTypeDP
|
||||
isInBQ := c.QC.cur() == WordTypeBQ
|
||||
parenLevel := 0
|
||||
eofExit := false
|
||||
for {
|
||||
ch := c.cur()
|
||||
if ch == 0 {
|
||||
eofExit = true
|
||||
break
|
||||
}
|
||||
if isExpSubShell && ch == ')' && parenLevel == 0 {
|
||||
c.Pos++
|
||||
break
|
||||
}
|
||||
if isInBQ && ch == '`' {
|
||||
c.Pos++
|
||||
break
|
||||
}
|
||||
// fmt.Printf("ch %d %q\n", c.Pos, string([]rune{ch}))
|
||||
foundOp, newOffset := c.parseOp(0)
|
||||
if foundOp {
|
||||
opVal := string(c.Input[c.Pos : c.Pos+newOffset])
|
||||
if opVal == "(" {
|
||||
arithWord := c.parseArith(true)
|
||||
if arithWord != nil {
|
||||
state.appendStandaloneWord(arithWord)
|
||||
continue
|
||||
} else {
|
||||
parenLevel++
|
||||
}
|
||||
}
|
||||
if opVal == ")" {
|
||||
parenLevel--
|
||||
}
|
||||
opWord := c.makeWord(WordTypeOp, newOffset, true)
|
||||
state.appendStandaloneWord(opWord)
|
||||
continue
|
||||
}
|
||||
var quoteWord *WordType
|
||||
if ch == '\'' {
|
||||
quoteWord = c.parseStrSQ()
|
||||
}
|
||||
if quoteWord == nil && ch == '"' {
|
||||
quoteWord = c.parseStrDQ()
|
||||
}
|
||||
if quoteWord == nil && ch == '`' {
|
||||
quoteWord = c.parseStrBQ()
|
||||
}
|
||||
isNextParen := isExpSubShell && c.at(1) == ')'
|
||||
if quoteWord == nil && ch == '$' && !isNextParen {
|
||||
quoteWord = c.parseStrANSI()
|
||||
if quoteWord == nil {
|
||||
quoteWord = c.parseStrDDQ()
|
||||
}
|
||||
if quoteWord == nil {
|
||||
quoteWord = c.parseExpansion()
|
||||
}
|
||||
}
|
||||
if quoteWord != nil {
|
||||
state.appendWord(quoteWord)
|
||||
continue
|
||||
}
|
||||
if ch == '\\' && c.at(1) != 0 {
|
||||
state.appendLiteral(c, ch)
|
||||
state.appendLiteral(c, c.at(1))
|
||||
c.Pos += 2
|
||||
continue
|
||||
}
|
||||
if ch == '\n' {
|
||||
newlineWord := c.makeWord(WordTypeOp, 1, true)
|
||||
state.appendStandaloneWord(newlineWord)
|
||||
continue
|
||||
}
|
||||
if unicode.IsSpace(ch) {
|
||||
state.delimitWithSpace(ch)
|
||||
c.Pos++
|
||||
continue
|
||||
}
|
||||
state.appendLiteral(c, ch)
|
||||
c.Pos++
|
||||
}
|
||||
state.finish(c)
|
||||
return state.Rtn, eofExit
|
||||
}
|
||||
|
||||
type parseContext struct {
|
||||
Input []rune
|
||||
Pos int
|
||||
QC QuoteContext
|
||||
}
|
||||
|
||||
func (c *parseContext) clone(pos int, newQuote string) *parseContext {
|
||||
rtn := parseContext{Input: c.Input[pos:], QC: c.QC}
|
||||
if newQuote != "" {
|
||||
rtn.QC = rtn.QC.push(newQuote)
|
||||
}
|
||||
return &rtn
|
||||
}
|
||||
|
||||
func (c *parseContext) at(offset int) rune {
|
||||
pos := c.Pos + offset
|
||||
if pos < 0 || pos >= len(c.Input) {
|
||||
return 0
|
||||
}
|
||||
return c.Input[pos]
|
||||
}
|
||||
|
||||
func (c *parseContext) eof() bool {
|
||||
return c.Pos >= len(c.Input)
|
||||
}
|
||||
|
||||
func (c *parseContext) cur() rune {
|
||||
return c.at(0)
|
||||
}
|
||||
|
||||
func (c *parseContext) match(ch rune) bool {
|
||||
return c.at(0) == ch
|
||||
}
|
||||
|
||||
func (c *parseContext) match2(ch rune, ch2 rune) bool {
|
||||
return c.at(0) == ch && c.at(1) == ch2
|
||||
}
|
||||
|
||||
func (c *parseContext) match3(ch rune, ch2 rune, ch3 rune) bool {
|
||||
return c.at(0) == ch && c.at(1) == ch2 && c.at(2) == ch3
|
||||
}
|
||||
|
||||
func (c *parseContext) makeWord(t string, length int, complete bool) *WordType {
|
||||
rtn := &WordType{Type: t}
|
||||
rtn.Offset = c.Pos
|
||||
rtn.QC = c.QC
|
||||
rtn.Raw = copyRunes(c.Input[c.Pos : c.Pos+length])
|
||||
rtn.Complete = complete
|
||||
c.Pos += length
|
||||
return rtn
|
||||
}
|
||||
|
||||
// returns (found, newOffset)
|
||||
// shell_meta_chars "()<>;&|"
|
||||
// possible to maybe add ;;& &>> &> |& ;&
|
||||
func (c *parseContext) parseOp(offset int) (bool, int) {
|
||||
ch := c.at(offset)
|
||||
if ch == '(' || ch == ')' || ch == '<' || ch == '>' || ch == ';' || ch == '&' || ch == '|' {
|
||||
ch2 := c.at(offset + 1)
|
||||
if ch2 == 0 {
|
||||
return true, offset + 1
|
||||
}
|
||||
r2 := string([]rune{ch, ch2})
|
||||
if r2 == "<<" {
|
||||
ch3 := c.at(offset + 2)
|
||||
if ch3 == '-' || ch3 == '<' {
|
||||
return true, offset + 3 // "<<-" or "<<<"
|
||||
}
|
||||
return true, offset + 2 // "<<"
|
||||
}
|
||||
if r2 == ">>" || r2 == "&&" || r2 == "||" || r2 == ";;" || r2 == "<<" || r2 == "<&" || r2 == ">&" || r2 == "<>" || r2 == ">|" {
|
||||
// we don't return '((' here (requires special processing)
|
||||
return true, offset + 2
|
||||
}
|
||||
return true, offset + 1
|
||||
}
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// returns (new-offset, complete)
|
||||
func (c *parseContext) skipToChar(offset int, endCh rune, allowEsc bool) (int, bool) {
|
||||
for {
|
||||
ch := c.at(offset)
|
||||
if ch == 0 {
|
||||
return offset, false
|
||||
}
|
||||
if allowEsc && ch == '\\' {
|
||||
if c.at(offset+1) == 0 {
|
||||
return offset + 1, false
|
||||
}
|
||||
offset += 2
|
||||
continue
|
||||
}
|
||||
if ch == endCh {
|
||||
return offset + 1, true
|
||||
}
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
// returns (new-offset, complete)
|
||||
func (c *parseContext) skipToChar2(offset int, endCh rune, endCh2 rune, allowEsc bool) (int, bool) {
|
||||
for {
|
||||
ch := c.at(offset)
|
||||
ch2 := c.at(offset + 1)
|
||||
if ch == 0 {
|
||||
return offset, false
|
||||
}
|
||||
if ch2 == 0 {
|
||||
return offset + 1, false
|
||||
}
|
||||
if allowEsc && ch == '\\' {
|
||||
offset += 2
|
||||
continue
|
||||
}
|
||||
if ch == endCh && ch2 == endCh2 {
|
||||
return offset + 2, true
|
||||
}
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *parseContext) parseStrSQ() *WordType {
|
||||
if !c.match('\'') {
|
||||
return nil
|
||||
}
|
||||
newOffset, complete := c.skipToChar(1, '\'', false)
|
||||
w := c.makeWord(WordTypeSQ, newOffset, complete)
|
||||
return w
|
||||
}
|
||||
|
||||
func (c *parseContext) parseStrDQ() *WordType {
|
||||
if !c.match('"') {
|
||||
return nil
|
||||
}
|
||||
newContext := c.clone(c.Pos+1, WordTypeDQ)
|
||||
subWords, eofExit := newContext.tokenizeDQ()
|
||||
newOffset := newContext.Pos + 1
|
||||
w := c.makeWord(WordTypeDQ, newOffset, !eofExit)
|
||||
w.Subs = subWords
|
||||
return w
|
||||
}
|
||||
|
||||
func (c *parseContext) parseStrDDQ() *WordType {
|
||||
if !c.match2('$', '"') {
|
||||
return nil
|
||||
}
|
||||
newContext := c.clone(c.Pos+2, WordTypeDQ) // use WordTypeDQ (not DDQ)
|
||||
subWords, eofExit := newContext.tokenizeDQ()
|
||||
newOffset := newContext.Pos + 2
|
||||
w := c.makeWord(WordTypeDDQ, newOffset, !eofExit)
|
||||
w.Subs = subWords
|
||||
return w
|
||||
}
|
||||
|
||||
func (c *parseContext) parseStrBQ() *WordType {
|
||||
if !c.match('`') {
|
||||
return nil
|
||||
}
|
||||
newContext := c.clone(c.Pos+1, WordTypeBQ)
|
||||
subWords, eofExit := newContext.tokenizeRaw()
|
||||
newOffset := newContext.Pos + 1
|
||||
w := c.makeWord(WordTypeBQ, newOffset, !eofExit)
|
||||
w.Subs = subWords
|
||||
return w
|
||||
}
|
||||
|
||||
func (c *parseContext) parseStrANSI() *WordType {
|
||||
if !c.match2('$', '\'') {
|
||||
return nil
|
||||
}
|
||||
newOffset, complete := c.skipToChar(2, '\'', true)
|
||||
w := c.makeWord(WordTypeDSQ, newOffset, complete)
|
||||
return w
|
||||
}
|
||||
|
||||
func (c *parseContext) parseArith(mustComplete bool) *WordType {
|
||||
if !c.match2('(', '(') {
|
||||
return nil
|
||||
}
|
||||
newOffset, complete := c.skipToChar2(2, ')', ')', false)
|
||||
if mustComplete && !complete {
|
||||
return nil
|
||||
}
|
||||
w := c.makeWord(WordTypePP, newOffset, complete)
|
||||
return w
|
||||
}
|
||||
|
||||
func (c *parseContext) parseExpansion() *WordType {
|
||||
if !c.match('$') {
|
||||
return nil
|
||||
}
|
||||
if c.match3('$', '(', '(') {
|
||||
newOffset, complete := c.skipToChar2(3, ')', ')', false)
|
||||
w := c.makeWord(WordTypeDPP, newOffset, complete)
|
||||
return w
|
||||
}
|
||||
if c.match2('$', '(') {
|
||||
// subshell
|
||||
newContext := c.clone(c.Pos+2, WordTypeDP)
|
||||
subWords, eofExit := newContext.tokenizeRaw()
|
||||
newOffset := newContext.Pos + 2
|
||||
w := c.makeWord(WordTypeDP, newOffset, !eofExit)
|
||||
w.Subs = subWords
|
||||
return w
|
||||
}
|
||||
if c.match2('$', '[') {
|
||||
// deprecated arith expansion
|
||||
newOffset, complete := c.skipToChar(2, ']', false)
|
||||
w := c.makeWord(WordTypeDB, newOffset, complete)
|
||||
return w
|
||||
}
|
||||
if c.match2('$', '{') {
|
||||
// variable expansion
|
||||
newContext := c.clone(c.Pos+2, WordTypeVarBrace)
|
||||
_, eofExit := newContext.tokenizeVarBrace()
|
||||
newOffset := newContext.Pos + 2
|
||||
w := c.makeWord(WordTypeVarBrace, newOffset, !eofExit)
|
||||
return w
|
||||
}
|
||||
ch2 := c.at(1)
|
||||
if ch2 == 0 || unicode.IsSpace(ch2) {
|
||||
// no expansion
|
||||
return nil
|
||||
}
|
||||
newOffset := c.parseSimpleVarName(1)
|
||||
if newOffset > 1 {
|
||||
// simple variable name
|
||||
w := c.makeWord(WordTypeSimpleVar, newOffset, true)
|
||||
return w
|
||||
}
|
||||
if ch2 == '*' || ch2 == '@' || ch2 == '#' || ch2 == '?' || ch2 == '-' || ch2 == '$' || ch2 == '!' || (ch2 >= '0' && ch2 <= '9') {
|
||||
// single character variable name, e.g. $@, $_, $1, etc.
|
||||
w := c.makeWord(WordTypeSimpleVar, 2, true)
|
||||
return w
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// returns newOffset
|
||||
func (c *parseContext) parseSimpleVarName(offset int) int {
|
||||
first := true
|
||||
for {
|
||||
ch := c.at(offset)
|
||||
if ch == 0 {
|
||||
return offset
|
||||
}
|
||||
if (ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) || (!first && ch >= '0' && ch <= '9') {
|
||||
first = false
|
||||
offset++
|
||||
continue
|
||||
}
|
||||
return offset
|
||||
}
|
||||
}
|
||||
|
||||
func isSimpleVarName(rstr []rune) bool {
|
||||
if len(rstr) == 0 {
|
||||
return false
|
||||
}
|
||||
for idx, ch := range rstr {
|
||||
if (ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) || ((idx != 0) && ch >= '0' && ch <= '9') {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Tokenize(cmd string) []*WordType {
|
||||
c := &parseContext{Input: []rune(cmd)}
|
||||
rtn, _ := c.tokenizeRaw()
|
||||
return rtn
|
||||
}
|
2654
wavesrv/pkg/sstore/dbops.go
Normal file
2654
wavesrv/pkg/sstore/dbops.go
Normal file
File diff suppressed because it is too large
Load Diff
184
wavesrv/pkg/sstore/fileops.go
Normal file
184
wavesrv/pkg/sstore/fileops.go
Normal file
@ -0,0 +1,184 @@
|
||||
package sstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/cirfile"
|
||||
"github.com/commandlinedev/prompt-server/pkg/scbase"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateCmdPtyFile(ctx context.Context, screenId string, lineId string, maxSize int64) error {
|
||||
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := cirfile.CreateCirFile(ptyOutFileName, maxSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
func StatCmdPtyFile(ctx context.Context, screenId string, lineId string) (*cirfile.Stat, error) {
|
||||
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cirfile.StatCirFile(ctx, ptyOutFileName)
|
||||
}
|
||||
|
||||
func AppendToCmdPtyBlob(ctx context.Context, screenId string, lineId string, data []byte, pos int64) (*PtyDataUpdate, error) {
|
||||
if screenId == "" {
|
||||
return nil, fmt.Errorf("cannot append to PtyBlob, screenid is not set")
|
||||
}
|
||||
if pos < 0 {
|
||||
return nil, fmt.Errorf("invalid seek pos '%d' in AppendToCmdPtyBlob", pos)
|
||||
}
|
||||
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := cirfile.OpenCirFile(ptyOutFileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
err = f.WriteAt(ctx, data, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data64 := base64.StdEncoding.EncodeToString(data)
|
||||
update := &PtyDataUpdate{
|
||||
ScreenId: screenId,
|
||||
LineId: lineId,
|
||||
PtyPos: pos,
|
||||
PtyData64: data64,
|
||||
PtyDataLen: int64(len(data)),
|
||||
}
|
||||
err = MaybeInsertPtyPosUpdate(ctx, screenId, lineId)
|
||||
if err != nil {
|
||||
// just log
|
||||
log.Printf("error inserting ptypos update %s/%s: %v\n", screenId, lineId, err)
|
||||
}
|
||||
return update, nil
|
||||
}
|
||||
|
||||
// returns (real-offset, data, err)
|
||||
func ReadFullPtyOutFile(ctx context.Context, screenId string, lineId string) (int64, []byte, error) {
|
||||
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
f, err := cirfile.OpenCirFile(ptyOutFileName)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
return f.ReadAll(ctx)
|
||||
}
|
||||
|
||||
// returns (real-offset, data, err)
|
||||
func ReadPtyOutFile(ctx context.Context, screenId string, lineId string, offset int64, maxSize int64) (int64, []byte, error) {
|
||||
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
f, err := cirfile.OpenCirFile(ptyOutFileName)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
return f.ReadAtWithMax(ctx, offset, maxSize)
|
||||
}
|
||||
|
||||
type SessionDiskSizeType struct {
|
||||
NumFiles int
|
||||
TotalSize int64
|
||||
ErrorCount int
|
||||
Location string
|
||||
}
|
||||
|
||||
func directorySize(dirName string) (SessionDiskSizeType, error) {
|
||||
var rtn SessionDiskSizeType
|
||||
rtn.Location = dirName
|
||||
entries, err := os.ReadDir(dirName)
|
||||
if err != nil {
|
||||
return rtn, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
rtn.ErrorCount++
|
||||
continue
|
||||
}
|
||||
finfo, err := entry.Info()
|
||||
if err != nil {
|
||||
rtn.ErrorCount++
|
||||
continue
|
||||
}
|
||||
rtn.NumFiles++
|
||||
rtn.TotalSize += finfo.Size()
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func SessionDiskSize(sessionId string) (SessionDiskSizeType, error) {
|
||||
sessionDir, err := scbase.EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return SessionDiskSizeType{}, err
|
||||
}
|
||||
return directorySize(sessionDir)
|
||||
}
|
||||
|
||||
func FullSessionDiskSize() (map[string]SessionDiskSizeType, error) {
|
||||
sdir := scbase.GetSessionsDir()
|
||||
entries, err := os.ReadDir(sdir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn := make(map[string]SessionDiskSizeType)
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
_, err = uuid.Parse(name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
diskSize, err := directorySize(path.Join(sdir, name))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
rtn[name] = diskSize
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func DeletePtyOutFile(ctx context.Context, screenId string, lineId string) error {
|
||||
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = os.Remove(ptyOutFileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteScreenDir(ctx context.Context, screenId string) error {
|
||||
screenDir, err := scbase.EnsureScreenDir(screenId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting screendir: %w", err)
|
||||
}
|
||||
log.Printf("remove-all %s\n", screenDir)
|
||||
return os.RemoveAll(screenDir)
|
||||
}
|
33
wavesrv/pkg/sstore/map.go
Normal file
33
wavesrv/pkg/sstore/map.go
Normal file
@ -0,0 +1,33 @@
|
||||
package sstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func WithTxRtn[RT any](ctx context.Context, fn func(tx *TxWrap) (RT, error)) (RT, error) {
|
||||
var rtn RT
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
temp, err := fn(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rtn = temp
|
||||
return nil
|
||||
})
|
||||
return rtn, txErr
|
||||
}
|
||||
|
||||
func WithTxRtn3[RT1 any, RT2 any](ctx context.Context, fn func(tx *TxWrap) (RT1, RT2, error)) (RT1, RT2, error) {
|
||||
var rtn1 RT1
|
||||
var rtn2 RT2
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
temp1, temp2, err := fn(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rtn1 = temp1
|
||||
rtn2 = temp2
|
||||
return nil
|
||||
})
|
||||
return rtn1, rtn2, txErr
|
||||
}
|
213
wavesrv/pkg/sstore/migrate.go
Normal file
213
wavesrv/pkg/sstore/migrate.go
Normal file
@ -0,0 +1,213 @@
|
||||
package sstore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
sh2db "github.com/commandlinedev/prompt-server/db"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
)
|
||||
|
||||
const MaxMigration = 22
|
||||
const MigratePrimaryScreenVersion = 9
|
||||
const CmdScreenSpecialMigration = 13
|
||||
const CmdLineSpecialMigration = 20
|
||||
|
||||
func MakeMigrate() (*migrate.Migrate, error) {
|
||||
fsVar, err := iofs.New(sh2db.MigrationFS, "migrations")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening iofs: %w", err)
|
||||
}
|
||||
// migrationPathUrl := fmt.Sprintf("file://%s", path.Join(wd, "db", "migrations"))
|
||||
dbUrl := fmt.Sprintf("sqlite3://%s", GetDBName())
|
||||
m, err := migrate.NewWithSourceInstance("iofs", fsVar, dbUrl)
|
||||
// m, err := migrate.New(migrationPathUrl, dbUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("making migration db[%s]: %w", GetDBName(), err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func copyFile(srcFile string, dstFile string) error {
|
||||
if srcFile == dstFile {
|
||||
return fmt.Errorf("cannot copy %s to itself", srcFile)
|
||||
}
|
||||
srcFd, err := os.Open(srcFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open %s: %v", err)
|
||||
}
|
||||
defer srcFd.Close()
|
||||
dstFd, err := os.OpenFile(dstFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open destination file %s: %v", err)
|
||||
}
|
||||
_, err = io.Copy(dstFd, srcFd)
|
||||
if err != nil {
|
||||
dstFd.Close()
|
||||
return fmt.Errorf("error copying file: %v", err)
|
||||
}
|
||||
return dstFd.Close()
|
||||
}
|
||||
|
||||
func MigrateUpStep(m *migrate.Migrate, newVersion uint) error {
|
||||
startTime := time.Now()
|
||||
err := m.Migrate(newVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newVersion == CmdScreenSpecialMigration {
|
||||
mErr := RunMigration13()
|
||||
if mErr != nil {
|
||||
return fmt.Errorf("migrating to v%d: %w", newVersion, mErr)
|
||||
}
|
||||
}
|
||||
if newVersion == CmdLineSpecialMigration {
|
||||
mErr := RunMigration20()
|
||||
if mErr != nil {
|
||||
return fmt.Errorf("migrating to v%d: %w", newVersion, mErr)
|
||||
}
|
||||
}
|
||||
log.Printf("[db] migration v%d, elapsed %v\n", newVersion, time.Since(startTime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateUp(targetVersion uint) error {
|
||||
m, err := MakeMigrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
curVersion, dirty, err := MigrateVersion(m)
|
||||
if dirty {
|
||||
return fmt.Errorf("cannot migrate up, database is dirty")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get current migration version: %v", err)
|
||||
}
|
||||
if curVersion >= targetVersion {
|
||||
return nil
|
||||
}
|
||||
log.Printf("[db] migrating from %d to %d\n", curVersion, targetVersion)
|
||||
log.Printf("[db] backing up database %s to %s\n", DBFileName, DBFileNameBackup)
|
||||
err = copyFile(GetDBName(), GetDBBackupName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating database backup: %v", err)
|
||||
}
|
||||
for newVersion := curVersion + 1; newVersion <= targetVersion; newVersion++ {
|
||||
err = MigrateUpStep(m, newVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("during migration v%d: %w", newVersion, err)
|
||||
}
|
||||
}
|
||||
log.Printf("[db] migration done, new version = %d\n", targetVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
// returns curVersion, dirty, error
|
||||
func MigrateVersion(m *migrate.Migrate) (uint, bool, error) {
|
||||
if m == nil {
|
||||
var err error
|
||||
m, err = MakeMigrate()
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
}
|
||||
curVersion, dirty, err := m.Version()
|
||||
if err == migrate.ErrNilVersion {
|
||||
return 0, false, nil
|
||||
}
|
||||
return curVersion, dirty, err
|
||||
}
|
||||
|
||||
func MigrateDown() error {
|
||||
m, err := MakeMigrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = m.Down()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateGoto(n uint) error {
|
||||
curVersion, _, _ := MigrateVersion(nil)
|
||||
if curVersion == n {
|
||||
return nil
|
||||
}
|
||||
if curVersion < n {
|
||||
return MigrateUp(n)
|
||||
}
|
||||
m, err := MakeMigrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = m.Migrate(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TryMigrateUp() error {
|
||||
curVersion, _, _ := MigrateVersion(nil)
|
||||
log.Printf("[db] db version = %d\n", curVersion)
|
||||
if curVersion >= MaxMigration {
|
||||
return nil
|
||||
}
|
||||
err := MigrateUp(MaxMigration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MigratePrintVersion()
|
||||
}
|
||||
|
||||
func MigratePrintVersion() error {
|
||||
version, dirty, err := MigrateVersion(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting db version: %v", err)
|
||||
}
|
||||
if dirty {
|
||||
return fmt.Errorf("error db is dirty, version=%d", version)
|
||||
}
|
||||
log.Printf("[db] version=%d\n", version)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateCommandOpts(opts []string) error {
|
||||
var err error
|
||||
if opts[0] == "--migrate-up" {
|
||||
fmt.Printf("migrate-up %v\n", GetDBName())
|
||||
time.Sleep(3 * time.Second)
|
||||
err = MigrateUp(MaxMigration)
|
||||
} else if opts[0] == "--migrate-down" {
|
||||
fmt.Printf("migrate-down %v\n", GetDBName())
|
||||
time.Sleep(3 * time.Second)
|
||||
err = MigrateDown()
|
||||
} else if opts[0] == "--migrate-goto" {
|
||||
n, err := strconv.Atoi(opts[1])
|
||||
if err == nil {
|
||||
fmt.Printf("migrate-goto %v => %d\n", GetDBName(), n)
|
||||
time.Sleep(3 * time.Second)
|
||||
err = MigrateGoto(uint(n))
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("invalid migration command")
|
||||
}
|
||||
if err != nil && err.Error() == migrate.ErrNoChange.Error() {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MigratePrintVersion()
|
||||
}
|
19
wavesrv/pkg/sstore/quick.go
Normal file
19
wavesrv/pkg/sstore/quick.go
Normal file
@ -0,0 +1,19 @@
|
||||
package sstore
|
||||
|
||||
import (
|
||||
"github.com/commandlinedev/prompt-server/pkg/dbutil"
|
||||
)
|
||||
|
||||
var quickSetStr = dbutil.QuickSetStr
|
||||
var quickSetInt64 = dbutil.QuickSetInt64
|
||||
var quickSetInt = dbutil.QuickSetInt
|
||||
var quickSetBool = dbutil.QuickSetBool
|
||||
var quickSetBytes = dbutil.QuickSetBytes
|
||||
var quickSetJson = dbutil.QuickSetJson
|
||||
var quickSetNullableJson = dbutil.QuickSetNullableJson
|
||||
var quickSetJsonArr = dbutil.QuickSetJsonArr
|
||||
var quickNullableJson = dbutil.QuickNullableJson
|
||||
var quickJson = dbutil.QuickJson
|
||||
var quickJsonArr = dbutil.QuickJsonArr
|
||||
var quickScanJson = dbutil.QuickScanJson
|
||||
var quickValueJson = dbutil.QuickValueJson
|
1295
wavesrv/pkg/sstore/sstore.go
Normal file
1295
wavesrv/pkg/sstore/sstore.go
Normal file
File diff suppressed because it is too large
Load Diff
154
wavesrv/pkg/sstore/sstore_migrate.go
Normal file
154
wavesrv/pkg/sstore/sstore_migrate.go
Normal file
@ -0,0 +1,154 @@
|
||||
package sstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/commandlinedev/prompt-server/pkg/scbase"
|
||||
)
|
||||
|
||||
const MigrationChunkSize = 10
|
||||
|
||||
type cmdMigration13Type struct {
|
||||
SessionId string
|
||||
ScreenId string
|
||||
CmdId string
|
||||
}
|
||||
|
||||
type cmdMigration20Type struct {
|
||||
ScreenId string
|
||||
LineId string
|
||||
CmdId string
|
||||
}
|
||||
|
||||
func getSliceChunk[T any](slice []T, chunkSize int) ([]T, []T) {
|
||||
if chunkSize >= len(slice) {
|
||||
return slice, nil
|
||||
}
|
||||
return slice[0:chunkSize], slice[chunkSize:]
|
||||
}
|
||||
|
||||
func RunMigration20() error {
|
||||
ctx := context.Background()
|
||||
startTime := time.Now()
|
||||
var migrations []cmdMigration20Type
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
tx.Select(&migrations, `SELECT * FROM cmd_migrate20`)
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return fmt.Errorf("trying to get cmd20 migrations: %w", txErr)
|
||||
}
|
||||
log.Printf("[db] got %d cmd-line migrations\n", len(migrations))
|
||||
for len(migrations) > 0 {
|
||||
var mchunk []cmdMigration20Type
|
||||
mchunk, migrations = getSliceChunk(migrations, MigrationChunkSize)
|
||||
err := processMigration20Chunk(ctx, mchunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cmd migration failed on chunk: %w", err)
|
||||
}
|
||||
}
|
||||
log.Printf("[db] cmd line migration done: %v\n", time.Since(startTime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func processMigration20Chunk(ctx context.Context, mchunk []cmdMigration20Type) error {
|
||||
for _, mig := range mchunk {
|
||||
newFile, err := scbase.PtyOutFile(mig.ScreenId, mig.LineId)
|
||||
if err != nil {
|
||||
log.Printf("ptyoutfile(lineid) error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
oldFile, err := scbase.PtyOutFile(mig.ScreenId, mig.CmdId)
|
||||
if err != nil {
|
||||
log.Printf("ptyoutfile(cmdid) error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
err = os.Rename(oldFile, newFile)
|
||||
if err != nil {
|
||||
log.Printf("error renaming %s => %s: %v\n", oldFile, newFile, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
for _, mig := range mchunk {
|
||||
query := `DELETE FROM cmd_migrate20 WHERE cmdid = ?`
|
||||
tx.Exec(query, mig.CmdId)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunMigration13() error {
|
||||
ctx := context.Background()
|
||||
startTime := time.Now()
|
||||
var migrations []cmdMigration13Type
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
tx.Select(&migrations, `SELECT * FROM cmd_migrate`)
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return fmt.Errorf("trying to get cmd13 migrations: %w", txErr)
|
||||
}
|
||||
log.Printf("[db] got %d cmd-screen migrations\n", len(migrations))
|
||||
for len(migrations) > 0 {
|
||||
var mchunk []cmdMigration13Type
|
||||
mchunk, migrations = getSliceChunk(migrations, MigrationChunkSize)
|
||||
err := processMigration13Chunk(ctx, mchunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cmd migration failed on chunk: %w", err)
|
||||
}
|
||||
}
|
||||
err := os.RemoveAll(scbase.GetSessionsDir())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot remove old sessions dir %s: %w\n", scbase.GetSessionsDir(), err)
|
||||
}
|
||||
txErr = WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `UPDATE client SET cmdstoretype = 'screen'`
|
||||
tx.Exec(query)
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return fmt.Errorf("cannot change client cmdstoretype: %w", err)
|
||||
}
|
||||
log.Printf("[db] cmd screen migration done: %v\n", time.Since(startTime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func processMigration13Chunk(ctx context.Context, mchunk []cmdMigration13Type) error {
|
||||
for _, mig := range mchunk {
|
||||
newFile, err := scbase.PtyOutFile(mig.ScreenId, mig.CmdId)
|
||||
if err != nil {
|
||||
log.Printf("ptyoutfile error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
oldFile, err := scbase.PtyOutFile_Sessions(mig.SessionId, mig.CmdId)
|
||||
if err != nil {
|
||||
log.Printf("ptyoutfile_sessions error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
err = os.Rename(oldFile, newFile)
|
||||
if err != nil {
|
||||
log.Printf("error renaming %s => %s: %v\n", oldFile, newFile, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
for _, mig := range mchunk {
|
||||
query := `DELETE FROM cmd_migrate WHERE cmdid = ?`
|
||||
tx.Exec(query, mig.CmdId)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
return nil
|
||||
}
|
228
wavesrv/pkg/sstore/updatebus.go
Normal file
228
wavesrv/pkg/sstore/updatebus.go
Normal file
@ -0,0 +1,228 @@
|
||||
package sstore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var MainBus *UpdateBus = MakeUpdateBus()
|
||||
|
||||
const PtyDataUpdateStr = "pty"
|
||||
const ModelUpdateStr = "model"
|
||||
const UpdateChSize = 100
|
||||
|
||||
type UpdatePacket interface {
|
||||
UpdateType() string
|
||||
Clean()
|
||||
}
|
||||
|
||||
type PtyDataUpdate struct {
|
||||
ScreenId string `json:"screenid,omitempty"`
|
||||
LineId string `json:"lineid,omitempty"`
|
||||
RemoteId string `json:"remoteid,omitempty"`
|
||||
PtyPos int64 `json:"ptypos"`
|
||||
PtyData64 string `json:"ptydata64"`
|
||||
PtyDataLen int64 `json:"ptydatalen"`
|
||||
}
|
||||
|
||||
func (*PtyDataUpdate) UpdateType() string {
|
||||
return PtyDataUpdateStr
|
||||
}
|
||||
|
||||
func (pdu *PtyDataUpdate) Clean() {}
|
||||
|
||||
type ModelUpdate struct {
|
||||
Sessions []*SessionType `json:"sessions,omitempty"`
|
||||
ActiveSessionId string `json:"activesessionid,omitempty"`
|
||||
Screens []*ScreenType `json:"screens,omitempty"`
|
||||
ScreenLines *ScreenLinesType `json:"screenlines,omitempty"`
|
||||
Line *LineType `json:"line,omitempty"`
|
||||
Lines []*LineType `json:"lines,omitempty"`
|
||||
Cmd *CmdType `json:"cmd,omitempty"`
|
||||
CmdLine *CmdLineType `json:"cmdline,omitempty"`
|
||||
Info *InfoMsgType `json:"info,omitempty"`
|
||||
ClearInfo bool `json:"clearinfo,omitempty"`
|
||||
Remotes []interface{} `json:"remotes,omitempty"` // []*remote.RemoteState
|
||||
History *HistoryInfoType `json:"history,omitempty"`
|
||||
Interactive bool `json:"interactive"`
|
||||
Connect bool `json:"connect,omitempty"`
|
||||
MainView string `json:"mainview,omitempty"`
|
||||
Bookmarks []*BookmarkType `json:"bookmarks,omitempty"`
|
||||
SelectedBookmark string `json:"selectedbookmark,omitempty"`
|
||||
HistoryViewData *HistoryViewData `json:"historyviewdata,omitempty"`
|
||||
ClientData *ClientData `json:"clientdata,omitempty"`
|
||||
RemoteView *RemoteViewType `json:"remoteview,omitempty"`
|
||||
}
|
||||
|
||||
func (*ModelUpdate) UpdateType() string {
|
||||
return ModelUpdateStr
|
||||
}
|
||||
|
||||
func (update *ModelUpdate) Clean() {
|
||||
if update == nil {
|
||||
return
|
||||
}
|
||||
update.ClientData = update.ClientData.Clean()
|
||||
}
|
||||
|
||||
type RemoteViewType struct {
|
||||
RemoteShowAll bool `json:"remoteshowall,omitempty"`
|
||||
PtyRemoteId string `json:"ptyremoteid,omitempty"`
|
||||
RemoteEdit *RemoteEditType `json:"remoteedit,omitempty"`
|
||||
}
|
||||
|
||||
func InfoMsgUpdate(infoMsgFmt string, args ...interface{}) *ModelUpdate {
|
||||
msg := fmt.Sprintf(infoMsgFmt, args...)
|
||||
return &ModelUpdate{
|
||||
Info: &InfoMsgType{InfoMsg: msg},
|
||||
}
|
||||
}
|
||||
|
||||
type HistoryViewData struct {
|
||||
Items []*HistoryItemType `json:"items"`
|
||||
Offset int `json:"offset"`
|
||||
RawOffset int `json:"rawoffset"`
|
||||
NextRawOffset int `json:"nextrawoffset"`
|
||||
HasMore bool `json:"hasmore"`
|
||||
Lines []*LineType `json:"lines"`
|
||||
Cmds []*CmdType `json:"cmds"`
|
||||
}
|
||||
|
||||
type RemoteEditType struct {
|
||||
RemoteEdit bool `json:"remoteedit"`
|
||||
RemoteId string `json:"remoteid,omitempty"`
|
||||
ErrorStr string `json:"errorstr,omitempty"`
|
||||
InfoStr string `json:"infostr,omitempty"`
|
||||
KeyStr string `json:"keystr,omitempty"`
|
||||
HasPassword bool `json:"haspassword,omitempty"`
|
||||
}
|
||||
|
||||
type InfoMsgType struct {
|
||||
InfoTitle string `json:"infotitle"`
|
||||
InfoError string `json:"infoerror,omitempty"`
|
||||
InfoMsg string `json:"infomsg,omitempty"`
|
||||
InfoMsgHtml bool `json:"infomsghtml,omitempty"`
|
||||
WebShareLink bool `json:"websharelink,omitempty"`
|
||||
InfoComps []string `json:"infocomps,omitempty"`
|
||||
InfoCompsMore bool `json:"infocompssmore,omitempty"`
|
||||
InfoLines []string `json:"infolines,omitempty"`
|
||||
TimeoutMs int64 `json:"timeoutms,omitempty"`
|
||||
}
|
||||
|
||||
type HistoryInfoType struct {
|
||||
HistoryType string `json:"historytype"`
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
ScreenId string `json:"screenid,omitempty"`
|
||||
Items []*HistoryItemType `json:"items"`
|
||||
Show bool `json:"show"`
|
||||
}
|
||||
|
||||
type CmdLineType struct {
|
||||
CmdLine string `json:"cmdline"`
|
||||
CursorPos int `json:"cursorpos"`
|
||||
}
|
||||
|
||||
type UpdateChannel struct {
|
||||
ScreenId string
|
||||
ClientId string
|
||||
Ch chan interface{}
|
||||
}
|
||||
|
||||
func (uch UpdateChannel) Match(screenId string) bool {
|
||||
if screenId == "" {
|
||||
return true
|
||||
}
|
||||
return screenId == uch.ScreenId
|
||||
}
|
||||
|
||||
type UpdateBus struct {
|
||||
Lock *sync.Mutex
|
||||
Channels map[string]UpdateChannel
|
||||
}
|
||||
|
||||
func MakeUpdateBus() *UpdateBus {
|
||||
return &UpdateBus{
|
||||
Lock: &sync.Mutex{},
|
||||
Channels: make(map[string]UpdateChannel),
|
||||
}
|
||||
}
|
||||
|
||||
// always returns a new channel
|
||||
func (bus *UpdateBus) RegisterChannel(clientId string, screenId string) chan interface{} {
|
||||
bus.Lock.Lock()
|
||||
defer bus.Lock.Unlock()
|
||||
uch, found := bus.Channels[clientId]
|
||||
if found {
|
||||
close(uch.Ch)
|
||||
uch.ScreenId = screenId
|
||||
uch.Ch = make(chan interface{}, UpdateChSize)
|
||||
} else {
|
||||
uch = UpdateChannel{
|
||||
ClientId: clientId,
|
||||
ScreenId: screenId,
|
||||
Ch: make(chan interface{}, UpdateChSize),
|
||||
}
|
||||
}
|
||||
bus.Channels[clientId] = uch
|
||||
return uch.Ch
|
||||
}
|
||||
|
||||
func (bus *UpdateBus) UnregisterChannel(clientId string) {
|
||||
bus.Lock.Lock()
|
||||
defer bus.Lock.Unlock()
|
||||
uch, found := bus.Channels[clientId]
|
||||
if found {
|
||||
close(uch.Ch)
|
||||
delete(bus.Channels, clientId)
|
||||
}
|
||||
}
|
||||
|
||||
func (bus *UpdateBus) SendUpdate(update UpdatePacket) {
|
||||
if update == nil {
|
||||
return
|
||||
}
|
||||
update.Clean()
|
||||
bus.Lock.Lock()
|
||||
defer bus.Lock.Unlock()
|
||||
for _, uch := range bus.Channels {
|
||||
select {
|
||||
case uch.Ch <- update:
|
||||
|
||||
default:
|
||||
log.Printf("[error] dropped update on updatebus uch clientid=%s\n", uch.ClientId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (bus *UpdateBus) SendScreenUpdate(screenId string, update UpdatePacket) {
|
||||
if update == nil {
|
||||
return
|
||||
}
|
||||
update.Clean()
|
||||
bus.Lock.Lock()
|
||||
defer bus.Lock.Unlock()
|
||||
for _, uch := range bus.Channels {
|
||||
if uch.Match(screenId) {
|
||||
select {
|
||||
case uch.Ch <- update:
|
||||
|
||||
default:
|
||||
log.Printf("[error] dropped update on updatebus uch clientid=%s\n", uch.ClientId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MakeSessionsUpdateForRemote(sessionId string, ri *RemoteInstance) []*SessionType {
|
||||
return []*SessionType{
|
||||
&SessionType{
|
||||
SessionId: sessionId,
|
||||
Remotes: []*RemoteInstance{ri},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type BookmarksViewType struct {
|
||||
Bookmarks []*BookmarkType `json:"bookmarks"`
|
||||
}
|
132
wavesrv/pkg/utilfn/linediff.go
Normal file
132
wavesrv/pkg/utilfn/linediff.go
Normal file
@ -0,0 +1,132 @@
|
||||
package utilfn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const LineDiffVersion = 0
|
||||
|
||||
type LineDiffType struct {
|
||||
Lines []int
|
||||
NewData []string
|
||||
}
|
||||
|
||||
// simple encoding
|
||||
// a 0 means read a line from NewData
|
||||
// a non-zero number means read the 1-indexed line from OldData
|
||||
func applyDiff(oldData []string, diff LineDiffType) ([]string, error) {
|
||||
rtn := make([]string, 0, len(diff.Lines))
|
||||
newDataPos := 0
|
||||
for i := 0; i < len(diff.Lines); i++ {
|
||||
if diff.Lines[i] == 0 {
|
||||
if newDataPos >= len(diff.NewData) {
|
||||
return nil, fmt.Errorf("not enough newdata for diff")
|
||||
}
|
||||
rtn = append(rtn, diff.NewData[newDataPos])
|
||||
newDataPos++
|
||||
} else {
|
||||
idx := diff.Lines[i] - 1 // 1-indexed
|
||||
if idx < 0 || idx >= len(oldData) {
|
||||
return nil, fmt.Errorf("diff index out of bounds %d old-data-len:%d", idx, len(oldData))
|
||||
}
|
||||
rtn = append(rtn, oldData[idx])
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func putUVarint(buf *bytes.Buffer, viBuf []byte, ival int) {
|
||||
l := binary.PutUvarint(viBuf, uint64(ival))
|
||||
buf.Write(viBuf[0:l])
|
||||
}
|
||||
|
||||
// simple encoding
|
||||
// write varints. first version, then len, then len-number-of-varints, then fill the rest with newdata
|
||||
// [version] [len-varint] [varint]xlen... newdata (bytes)
|
||||
func encodeDiff(diff LineDiffType) []byte {
|
||||
var buf bytes.Buffer
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
putUVarint(&buf, viBuf, 0)
|
||||
putUVarint(&buf, viBuf, len(diff.Lines))
|
||||
for _, val := range diff.Lines {
|
||||
putUVarint(&buf, viBuf, val)
|
||||
}
|
||||
for _, str := range diff.NewData {
|
||||
buf.WriteString(str)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func decodeDiff(diffBytes []byte) (LineDiffType, error) {
|
||||
var rtn LineDiffType
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return rtn, fmt.Errorf("invalid diff, cannot read version: %v", err)
|
||||
}
|
||||
if version != LineDiffVersion {
|
||||
return rtn, fmt.Errorf("invalid diff, bad version: %d", version)
|
||||
}
|
||||
linesLen64, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return rtn, fmt.Errorf("invalid diff, cannot read lines length: %v", err)
|
||||
}
|
||||
linesLen := int(linesLen64)
|
||||
rtn.Lines = make([]int, linesLen)
|
||||
for idx := 0; idx < linesLen; idx++ {
|
||||
vi, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return rtn, fmt.Errorf("invalid diff, cannot read line %d: %v", idx, err)
|
||||
}
|
||||
rtn.Lines[idx] = int(vi)
|
||||
}
|
||||
restOfInput := string(r.Bytes())
|
||||
rtn.NewData = strings.Split(restOfInput, "\n")
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func makeDiff(oldData []string, newData []string) LineDiffType {
|
||||
var rtn LineDiffType
|
||||
oldDataMap := make(map[string]int) // 1-indexed
|
||||
for idx, str := range oldData {
|
||||
if _, found := oldDataMap[str]; found {
|
||||
continue
|
||||
}
|
||||
oldDataMap[str] = idx + 1
|
||||
}
|
||||
rtn.Lines = make([]int, len(newData))
|
||||
for idx, str := range newData {
|
||||
oldIdx, found := oldDataMap[str]
|
||||
if found {
|
||||
rtn.Lines[idx] = oldIdx
|
||||
} else {
|
||||
rtn.Lines[idx] = 0
|
||||
rtn.NewData = append(rtn.NewData, str)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func MakeDiff(str1 string, str2 string) []byte {
|
||||
str1Arr := strings.Split(str1, "\n")
|
||||
str2Arr := strings.Split(str2, "\n")
|
||||
diff := makeDiff(str1Arr, str2Arr)
|
||||
return encodeDiff(diff)
|
||||
}
|
||||
|
||||
func ApplyDiff(str1 string, diffBytes []byte) (string, error) {
|
||||
diff, err := decodeDiff(diffBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
str1Arr := strings.Split(str1, "\n")
|
||||
str2Arr, err := applyDiff(str1Arr, diff)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.Join(str2Arr, "\n"), nil
|
||||
}
|
208
wavesrv/pkg/utilfn/utilfn.go
Normal file
208
wavesrv/pkg/utilfn/utilfn.go
Normal file
@ -0,0 +1,208 @@
|
||||
package utilfn
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var HexDigits = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}
|
||||
|
||||
func GetStrArr(v interface{}, field string) []string {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
m, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
fieldVal := m[field]
|
||||
if fieldVal == nil {
|
||||
return nil
|
||||
}
|
||||
iarr, ok := fieldVal.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var sarr []string
|
||||
for _, iv := range iarr {
|
||||
if sv, ok := iv.(string); ok {
|
||||
sarr = append(sarr, sv)
|
||||
}
|
||||
}
|
||||
return sarr
|
||||
}
|
||||
|
||||
func GetBool(v interface{}, field string) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
m, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
fieldVal := m[field]
|
||||
if fieldVal == nil {
|
||||
return false
|
||||
}
|
||||
bval, ok := fieldVal.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return bval
|
||||
}
|
||||
|
||||
var needsQuoteRe = regexp.MustCompile(`[^\w@%:,./=+-]`)
|
||||
|
||||
// minimum maxlen=6
|
||||
func ShellQuote(val string, forceQuote bool, maxLen int) string {
|
||||
if maxLen < 6 {
|
||||
maxLen = 6
|
||||
}
|
||||
rtn := val
|
||||
if needsQuoteRe.MatchString(val) {
|
||||
rtn = "'" + strings.ReplaceAll(val, "'", `'"'"'`) + "'"
|
||||
}
|
||||
if strings.HasPrefix(rtn, "\"") || strings.HasPrefix(rtn, "'") {
|
||||
if len(rtn) > maxLen {
|
||||
return rtn[0:maxLen-4] + "..." + rtn[0:1]
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
if forceQuote {
|
||||
if len(rtn) > maxLen-2 {
|
||||
return "\"" + rtn[0:maxLen-5] + "...\""
|
||||
}
|
||||
return "\"" + rtn + "\""
|
||||
} else {
|
||||
if len(rtn) > maxLen {
|
||||
return rtn[0:maxLen-3] + "..."
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
}
|
||||
|
||||
func EllipsisStr(s string, maxLen int) string {
|
||||
if maxLen < 4 {
|
||||
maxLen = 4
|
||||
}
|
||||
if len(s) > maxLen {
|
||||
return s[0:maxLen-3] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func LongestPrefix(root string, strs []string) string {
|
||||
if len(strs) == 0 {
|
||||
return root
|
||||
}
|
||||
if len(strs) == 1 {
|
||||
comp := strs[0]
|
||||
if len(comp) >= len(root) && strings.HasPrefix(comp, root) {
|
||||
if strings.HasSuffix(comp, "/") {
|
||||
return strs[0]
|
||||
}
|
||||
return strs[0]
|
||||
}
|
||||
}
|
||||
lcp := strs[0]
|
||||
for i := 1; i < len(strs); i++ {
|
||||
s := strs[i]
|
||||
for j := 0; j < len(lcp); j++ {
|
||||
if j >= len(s) || lcp[j] != s[j] {
|
||||
lcp = lcp[0:j]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(lcp) < len(root) || !strings.HasPrefix(lcp, root) {
|
||||
return root
|
||||
}
|
||||
return lcp
|
||||
}
|
||||
|
||||
func ContainsStr(strs []string, test string) bool {
|
||||
for _, s := range strs {
|
||||
if s == test {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsPrefix(strs []string, test string) bool {
|
||||
for _, s := range strs {
|
||||
if len(s) > len(test) && strings.HasPrefix(s, test) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type StrWithPos struct {
|
||||
Str string
|
||||
Pos int // this is a 'rune' position (not a byte position)
|
||||
}
|
||||
|
||||
func (sp StrWithPos) String() string {
|
||||
return strWithCursor(sp.Str, sp.Pos)
|
||||
}
|
||||
|
||||
func ParseToSP(s string) StrWithPos {
|
||||
idx := strings.Index(s, "[*]")
|
||||
if idx == -1 {
|
||||
return StrWithPos{Str: s}
|
||||
}
|
||||
return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: utf8.RuneCountInString(s[0:idx])}
|
||||
}
|
||||
|
||||
func strWithCursor(str string, pos int) string {
|
||||
if pos < 0 {
|
||||
return "[*]_" + str
|
||||
}
|
||||
if pos >= len(str) {
|
||||
if pos > len(str) {
|
||||
return str + "_[*]"
|
||||
}
|
||||
return str + "[*]"
|
||||
}
|
||||
|
||||
var rtn []rune
|
||||
for _, ch := range str {
|
||||
if len(rtn) == pos {
|
||||
rtn = append(rtn, '[', '*', ']')
|
||||
}
|
||||
rtn = append(rtn, ch)
|
||||
}
|
||||
return string(rtn)
|
||||
}
|
||||
|
||||
func (sp StrWithPos) Prepend(str string) StrWithPos {
|
||||
return StrWithPos{Str: str + sp.Str, Pos: utf8.RuneCountInString(str) + sp.Pos}
|
||||
}
|
||||
|
||||
func (sp StrWithPos) Append(str string) StrWithPos {
|
||||
return StrWithPos{Str: sp.Str + str, Pos: sp.Pos}
|
||||
}
|
||||
|
||||
// returns base64 hash of data
|
||||
func Sha1Hash(data []byte) string {
|
||||
hvalRaw := sha1.Sum(data)
|
||||
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
|
||||
return hval
|
||||
}
|
||||
|
||||
func ChunkSlice[T any](s []T, chunkSize int) [][]T {
|
||||
var rtn [][]T
|
||||
for len(rtn) > 0 {
|
||||
if len(s) <= chunkSize {
|
||||
rtn = append(rtn, s)
|
||||
break
|
||||
}
|
||||
rtn = append(rtn, s[:chunkSize])
|
||||
s = s[chunkSize:]
|
||||
}
|
||||
return rtn
|
||||
}
|
48
wavesrv/pkg/utilfn/utilfn_test.go
Normal file
48
wavesrv/pkg/utilfn/utilfn_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package utilfn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const Str1 = `
|
||||
hello
|
||||
line #2
|
||||
more
|
||||
stuff
|
||||
apple
|
||||
`
|
||||
|
||||
const Str2 = `
|
||||
line #2
|
||||
apple
|
||||
grapes
|
||||
banana
|
||||
`
|
||||
|
||||
const Str3 = `
|
||||
more
|
||||
stuff
|
||||
banana
|
||||
coconut
|
||||
`
|
||||
|
||||
func testDiff(t *testing.T, str1 string, str2 string) {
|
||||
diffBytes := MakeDiff(str1, str2)
|
||||
fmt.Printf("diff-len: %d\n", len(diffBytes))
|
||||
out, err := ApplyDiff(str1, diffBytes)
|
||||
if err != nil {
|
||||
t.Errorf("error in diff: %v", err)
|
||||
return
|
||||
}
|
||||
if out != str2 {
|
||||
t.Errorf("bad diff output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiff(t *testing.T) {
|
||||
testDiff(t, Str1, Str2)
|
||||
testDiff(t, Str2, Str3)
|
||||
testDiff(t, Str1, Str3)
|
||||
testDiff(t, Str3, Str1)
|
||||
}
|
189
wavesrv/pkg/wsshell/wsshell.go
Normal file
189
wavesrv/pkg/wsshell/wsshell.go
Normal file
@ -0,0 +1,189 @@
|
||||
package wsshell
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const readWaitTimeout = 15 * time.Second
|
||||
const writeWaitTimeout = 10 * time.Second
|
||||
const pingPeriodTickTime = 10 * time.Second
|
||||
const initialPingTime = 1 * time.Second
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4 * 1024,
|
||||
WriteBufferSize: 32 * 1024,
|
||||
HandshakeTimeout: 1 * time.Second,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
type WSShell struct {
|
||||
Conn *websocket.Conn
|
||||
RemoteAddr string
|
||||
ConnId string
|
||||
Query url.Values
|
||||
OpenTime time.Time
|
||||
NumPings int
|
||||
LastPing time.Time
|
||||
LastRecv time.Time
|
||||
Header http.Header
|
||||
|
||||
CloseChan chan bool
|
||||
WriteChan chan []byte
|
||||
ReadChan chan []byte
|
||||
}
|
||||
|
||||
func (ws *WSShell) NonBlockingWrite(data []byte) bool {
|
||||
select {
|
||||
case ws.WriteChan <- data:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WSShell) WritePing() error {
|
||||
now := time.Now()
|
||||
pingMessage := map[string]interface{}{"type": "ping", "stime": now.Unix()}
|
||||
jsonVal, _ := json.Marshal(pingMessage)
|
||||
_ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error
|
||||
err := ws.Conn.WriteMessage(websocket.TextMessage, jsonVal)
|
||||
ws.NumPings++
|
||||
ws.LastPing = now
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WSShell) WriteJson(val interface{}) error {
|
||||
if ws.IsClosed() {
|
||||
return fmt.Errorf("cannot write packet, empty or closed wsshell")
|
||||
}
|
||||
barr, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ws.WriteChan <- barr
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WSShell) WritePump() {
|
||||
ticker := time.NewTicker(initialPingTime)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
ws.Conn.Close()
|
||||
}()
|
||||
initialPing := true
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := ws.WritePing()
|
||||
if err != nil {
|
||||
log.Printf("WritePump %s err: %v\n", ws.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
if initialPing {
|
||||
initialPing = false
|
||||
ticker.Reset(pingPeriodTickTime)
|
||||
}
|
||||
|
||||
case msgBytes, ok := <-ws.WriteChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
_ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error
|
||||
err := ws.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||
if err != nil {
|
||||
log.Printf("WritePump %s err: %v\n", ws.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WSShell) ReadPump() {
|
||||
readWait := readWaitTimeout
|
||||
defer func() {
|
||||
ws.Conn.Close()
|
||||
}()
|
||||
ws.Conn.SetReadLimit(4096)
|
||||
ws.Conn.SetReadDeadline(time.Now().Add(readWait))
|
||||
for {
|
||||
_, message, err := ws.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("ReadPump %s Err: %v\n", ws.RemoteAddr, err)
|
||||
break
|
||||
}
|
||||
jmsg := map[string]interface{}{}
|
||||
err = json.Unmarshal(message, &jmsg)
|
||||
if err != nil {
|
||||
log.Printf("Error unmarshalling json: %v\n", err)
|
||||
break
|
||||
}
|
||||
ws.Conn.SetReadDeadline(time.Now().Add(readWait))
|
||||
ws.LastRecv = time.Now()
|
||||
if str, ok := jmsg["type"].(string); ok && str == "pong" {
|
||||
// nothing
|
||||
continue
|
||||
}
|
||||
if str, ok := jmsg["type"].(string); ok && str == "ping" {
|
||||
now := time.Now()
|
||||
pongMessage := map[string]interface{}{"type": "pong", "stime": now.Unix()}
|
||||
jsonVal, _ := json.Marshal(pongMessage)
|
||||
ws.WriteChan <- jsonVal
|
||||
continue
|
||||
}
|
||||
ws.ReadChan <- message
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WSShell) IsClosed() bool {
|
||||
select {
|
||||
case <-ws.CloseChan:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func StartWS(w http.ResponseWriter, r *http.Request) (*WSShell, error) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ws := WSShell{Conn: conn, ConnId: uuid.New().String(), OpenTime: time.Now()}
|
||||
ws.CloseChan = make(chan bool)
|
||||
ws.WriteChan = make(chan []byte, 10)
|
||||
ws.ReadChan = make(chan []byte, 10)
|
||||
ws.RemoteAddr = r.RemoteAddr
|
||||
ws.Query = r.URL.Query()
|
||||
ws.Header = r.Header
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ws.WritePump()
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ws.ReadPump()
|
||||
}()
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(ws.CloseChan)
|
||||
close(ws.ReadChan)
|
||||
}()
|
||||
return &ws, nil
|
||||
}
|
16
wavesrv/scripthaus.md
Normal file
16
wavesrv/scripthaus.md
Normal file
@ -0,0 +1,16 @@
|
||||
# SH2 Server Commands
|
||||
|
||||
```bash
|
||||
# @scripthaus command dump-schema-dev
|
||||
sqlite3 /Users/mike/prompt-dev/prompt.db .schema > db/schema.sql
|
||||
```
|
||||
|
||||
```bash
|
||||
# @scripthaus command opendb-dev
|
||||
sqlite3 /Users/mike/prompt-dev/prompt.db
|
||||
```
|
||||
|
||||
```bash
|
||||
# @scripthaus command build
|
||||
go build -ldflags "-X main.BuildTime=$(date +'%Y%m%d%H%M')" -o bin/local-server ./cmd
|
||||
```
|
Loading…
Reference in New Issue
Block a user