Merging wave server code into mono-repo

This commit is contained in:
sawka 2023-10-16 13:22:23 -07:00
commit a4c0128c89
92 changed files with 22462 additions and 0 deletions

887
wavesrv/cmd/main-server.go Normal file
View File

@ -0,0 +1,887 @@
package main
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"mime/multipart"
"net/http"
"os"
"os/signal"
"path/filepath"
"regexp"
"runtime/debug"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/apishell/pkg/server"
"github.com/commandlinedev/prompt-server/pkg/cmdrunner"
"github.com/commandlinedev/prompt-server/pkg/pcloud"
"github.com/commandlinedev/prompt-server/pkg/remote"
"github.com/commandlinedev/prompt-server/pkg/rtnstate"
"github.com/commandlinedev/prompt-server/pkg/scbase"
"github.com/commandlinedev/prompt-server/pkg/scpacket"
"github.com/commandlinedev/prompt-server/pkg/scws"
"github.com/commandlinedev/prompt-server/pkg/sstore"
"github.com/commandlinedev/prompt-server/pkg/wsshell"
)
type WebFnType = func(http.ResponseWriter, *http.Request)
const HttpReadTimeout = 5 * time.Second
const HttpWriteTimeout = 21 * time.Second
const HttpMaxHeaderBytes = 60000
const HttpTimeoutDuration = 21 * time.Second
const MainServerAddr = "127.0.0.1:1619" // PromptServer, P=16, S=19, PS=1619
const WebSocketServerAddr = "127.0.0.1:1623" // PromptWebsock, P=16, W=23, PW=1623
const MainServerDevAddr = "127.0.0.1:8090"
const WebSocketServerDevAddr = "127.0.0.1:8091"
const WSStateReconnectTime = 30 * time.Second
const WSStatePacketChSize = 20
const InitialTelemetryWait = 30 * time.Second
const TelemetryTick = 30 * time.Minute
const TelemetryInterval = 8 * time.Hour
const MaxWriteFileMemSize = 20 * (1024 * 1024) // 20M
var GlobalLock = &sync.Mutex{}
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
var GlobalAuthKey string
var BuildTime = "0"
var shutdownOnce sync.Once
var ContentTypeHeaderValidRe = regexp.MustCompile(`^\w+/[\w.+-]+$`)
type ClientActiveState struct {
Fg bool `json:"fg"`
Active bool `json:"active"`
Open bool `json:"open"`
}
func setWSState(state *scws.WSState) {
GlobalLock.Lock()
defer GlobalLock.Unlock()
WSStateMap[state.ClientId] = state
}
func getWSState(clientId string) *scws.WSState {
GlobalLock.Lock()
defer GlobalLock.Unlock()
return WSStateMap[clientId]
}
func removeWSStateAfterTimeout(clientId string, connectTime time.Time, waitDuration time.Duration) {
go func() {
time.Sleep(waitDuration)
GlobalLock.Lock()
defer GlobalLock.Unlock()
state := WSStateMap[clientId]
if state == nil || state.ConnectTime != connectTime {
return
}
delete(WSStateMap, clientId)
state.UnWatchScreen()
}()
}
func HandleWs(w http.ResponseWriter, r *http.Request) {
shell, err := wsshell.StartWS(w, r)
if err != nil {
log.Printf("WebSocket Upgrade Failed %T: %v\n", w, err)
return
}
defer shell.Conn.Close()
clientId := r.URL.Query().Get("clientid")
if clientId == "" {
close(shell.WriteChan)
return
}
state := getWSState(clientId)
if state == nil {
state = scws.MakeWSState(clientId, GlobalAuthKey)
state.ReplaceShell(shell)
setWSState(state)
} else {
state.UpdateConnectTime()
state.ReplaceShell(shell)
}
stateConnectTime := state.GetConnectTime()
defer func() {
removeWSStateAfterTimeout(clientId, stateConnectTime, WSStateReconnectTime)
}()
log.Printf("WebSocket opened %s %s\n", state.ClientId, shell.RemoteAddr)
state.RunWSRead()
}
// todo: sync multiple writes to the same fifoName into a single go-routine and do liveness checking on fifo
// if this returns an error, likely the fifo is dead and the cmd should be marked as 'done'
func writeToFifo(fifoName string, data []byte) error {
rwfd, err := os.OpenFile(fifoName, os.O_RDWR, 0600)
if err != nil {
return err
}
defer rwfd.Close()
fifoWriter, err := os.OpenFile(fifoName, os.O_WRONLY, 0600) // blocking open (open won't block because of rwfd)
if err != nil {
return err
}
defer fifoWriter.Close()
// this *could* block if the fifo buffer is full
// unlikely because if the reader is dead, and len(data) < pipe size, then the buffer will be empty and will clear after rwfd is closed
_, err = fifoWriter.Write(data)
if err != nil {
return err
}
return nil
}
func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
cdata, err := sstore.EnsureClientData(r.Context())
if err != nil {
WriteJsonError(w, err)
return
}
cdata = cdata.Clean()
WriteJsonSuccess(w, cdata)
return
}
func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
var winSize sstore.ClientWinSizeType
err := decoder.Decode(&winSize)
if err != nil {
WriteJsonError(w, fmt.Errorf("error decoding json: %w", err))
return
}
err = sstore.SetWinSize(r.Context(), winSize)
if err != nil {
WriteJsonError(w, fmt.Errorf("error setting winsize: %w", err))
return
}
WriteJsonSuccess(w, true)
return
}
// params: fg, active, open
func HandleLogActiveState(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
var activeState ClientActiveState
err := decoder.Decode(&activeState)
if err != nil {
WriteJsonError(w, fmt.Errorf("error decoding json: %w", err))
return
}
activity := sstore.ActivityUpdate{}
if activeState.Fg {
activity.FgMinutes = 1
}
if activeState.Active {
activity.ActiveMinutes = 1
}
if activeState.Open {
activity.OpenMinutes = 1
}
activity.NumConns = remote.NumRemotes()
err = sstore.UpdateCurrentActivity(r.Context(), activity)
if err != nil {
WriteJsonError(w, fmt.Errorf("error updating activity: %w", err))
return
}
WriteJsonSuccess(w, true)
return
}
// params: screenid
func HandleGetScreenLines(w http.ResponseWriter, r *http.Request) {
qvals := r.URL.Query()
screenId := qvals.Get("screenid")
if _, err := uuid.Parse(screenId); err != nil {
WriteJsonError(w, fmt.Errorf("invalid screenid: %w", err))
return
}
screenLines, err := sstore.GetScreenLinesById(r.Context(), screenId)
if err != nil {
WriteJsonError(w, err)
return
}
WriteJsonSuccess(w, screenLines)
return
}
func HandleRtnState(w http.ResponseWriter, r *http.Request) {
defer func() {
r := recover()
if r == nil {
return
}
log.Printf("[error] in handlertnstate: %v\n", r)
debug.PrintStack()
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("panic: %v", r)))
return
}()
qvals := r.URL.Query()
screenId := qvals.Get("screenid")
lineId := qvals.Get("lineid")
if screenId == "" || lineId == "" {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("must specify screenid and lineid")))
return
}
if _, err := uuid.Parse(screenId); err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid screenid: %v", err)))
return
}
if _, err := uuid.Parse(lineId); err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
return
}
data, err := rtnstate.GetRtnStateDiff(r.Context(), screenId, lineId)
if err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("cannot get rtnstate diff: %v", err)))
return
}
w.WriteHeader(http.StatusOK)
w.Write(data)
return
}
func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
qvals := r.URL.Query()
remoteId := qvals.Get("remoteid")
if remoteId == "" {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("must specify remoteid")))
return
}
if _, err := uuid.Parse(remoteId); err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid remoteid: %v", err)))
return
}
realOffset, data, err := remote.ReadRemotePty(r.Context(), remoteId)
if err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("error reading ptyout file: %v", err)))
return
}
w.Header().Set("X-PtyDataOffset", strconv.FormatInt(realOffset, 10))
w.WriteHeader(http.StatusOK)
w.Write(data)
return
}
func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
qvals := r.URL.Query()
screenId := qvals.Get("screenid")
lineId := qvals.Get("lineid")
if screenId == "" || lineId == "" {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("must specify screenid and lineid")))
return
}
if _, err := uuid.Parse(screenId); err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid screenid: %v", err)))
return
}
if _, err := uuid.Parse(lineId); err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
return
}
realOffset, data, err := sstore.ReadFullPtyOutFile(r.Context(), screenId, lineId)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("error reading ptyout file: %v", err)))
return
}
w.Header().Set("X-PtyDataOffset", strconv.FormatInt(realOffset, 10))
w.WriteHeader(http.StatusOK)
w.Write(data)
}
type writeFileParamsType struct {
ScreenId string `json:"screenid"`
LineId string `json:"lineid"`
Path string `json:"path"`
UseTemp bool `json:"usetemp,omitempty"`
}
func parseWriteFileParams(r *http.Request) (*writeFileParamsType, multipart.File, error) {
err := r.ParseMultipartForm(MaxWriteFileMemSize)
if err != nil {
return nil, nil, fmt.Errorf("cannot parse multipart form data: %v", err)
}
form := r.MultipartForm
if len(form.Value["params"]) == 0 {
return nil, nil, fmt.Errorf("no params found")
}
paramsStr := form.Value["params"][0]
var params writeFileParamsType
err = json.Unmarshal([]byte(paramsStr), &params)
if err != nil {
return nil, nil, fmt.Errorf("bad params json: %v", err)
}
if len(form.File["data"]) == 0 {
return nil, nil, fmt.Errorf("no data found")
}
fileHeader := form.File["data"][0]
file, err := fileHeader.Open()
if err != nil {
return nil, nil, fmt.Errorf("error opening multipart data file: %v", err)
}
return &params, file, nil
}
func HandleWriteFile(w http.ResponseWriter, r *http.Request) {
defer func() {
r := recover()
if r == nil {
return
}
log.Printf("[error] in write-file: %v\n", r)
debug.PrintStack()
WriteJsonError(w, fmt.Errorf("panic: %v", r))
return
}()
w.Header().Set("Cache-Control", "no-cache")
params, mpFile, err := parseWriteFileParams(r)
if err != nil {
WriteJsonError(w, fmt.Errorf("error parsing multipart form params: %w", err))
return
}
if params.ScreenId == "" || params.LineId == "" || params.Path == "" {
WriteJsonError(w, fmt.Errorf("invalid params, must set screenid, lineid, and path"))
return
}
if _, err := uuid.Parse(params.ScreenId); err != nil {
WriteJsonError(w, fmt.Errorf("invalid screenid: %v", err))
return
}
if _, err := uuid.Parse(params.LineId); err != nil {
WriteJsonError(w, fmt.Errorf("invalid lineid: %v", err))
return
}
_, cmd, err := sstore.GetLineCmdByLineId(r.Context(), params.ScreenId, params.LineId)
if err != nil {
WriteJsonError(w, fmt.Errorf("cannot retrieve line/cmd: %v", err))
return
}
if cmd == nil {
WriteJsonError(w, fmt.Errorf("line not found"))
return
}
if cmd.Remote.RemoteId == "" {
WriteJsonError(w, fmt.Errorf("invalid line, no remote"))
return
}
msh := remote.GetRemoteById(cmd.Remote.RemoteId)
if msh == nil {
WriteJsonError(w, fmt.Errorf("invalid line, cannot resolve remote"))
return
}
rrState := msh.GetRemoteRuntimeState()
fullPath, err := rrState.ExpandHomeDir(params.Path)
if err != nil {
WriteJsonError(w, fmt.Errorf("error expanding homedir: %v", err))
return
}
cwd := cmd.FeState["cwd"]
writePk := packet.MakeWriteFilePacket()
writePk.ReqId = uuid.New().String()
writePk.UseTemp = params.UseTemp
if filepath.IsAbs(fullPath) {
writePk.Path = fullPath
} else {
writePk.Path = filepath.Join(cwd, fullPath)
}
iter, err := msh.PacketRpcIter(r.Context(), writePk)
if err != nil {
WriteJsonError(w, fmt.Errorf("error: %v", err))
return
}
// first packet should be WriteFileReady
readyIf, err := iter.Next(r.Context())
if err != nil {
WriteJsonError(w, fmt.Errorf("error while getting ready response: %w", err))
return
}
readyPk, ok := readyIf.(*packet.WriteFileReadyPacketType)
if !ok {
WriteJsonError(w, fmt.Errorf("bad ready packet received: %T", readyIf))
return
}
if readyPk.Error != "" {
WriteJsonError(w, fmt.Errorf("ready error: %s", readyPk.Error))
return
}
var buffer [server.MaxFileDataPacketSize]byte
bufSlice := buffer[:]
for {
dataPk := packet.MakeFileDataPacket(writePk.ReqId)
nr, err := io.ReadFull(mpFile, bufSlice)
if err == io.ErrUnexpectedEOF || err == io.EOF {
dataPk.Eof = true
} else if err != nil {
dataErr := fmt.Errorf("error reading file data: %v", err)
dataPk.Error = dataErr.Error()
msh.SendFileData(dataPk)
WriteJsonError(w, dataErr)
return
}
if nr > 0 {
dataPk.Data = make([]byte, nr)
copy(dataPk.Data, bufSlice[0:nr])
}
msh.SendFileData(dataPk)
if dataPk.Eof {
break
}
// slight throttle for sending packets
time.Sleep(10 * time.Millisecond)
}
doneIf, err := iter.Next(r.Context())
if err != nil {
WriteJsonError(w, fmt.Errorf("error while getting done response: %w", err))
return
}
donePk, ok := doneIf.(*packet.WriteFileDonePacketType)
if !ok {
WriteJsonError(w, fmt.Errorf("bad done packet received: %T", doneIf))
return
}
if donePk.Error != "" {
WriteJsonError(w, fmt.Errorf("dne error: %s", donePk.Error))
return
}
WriteJsonSuccess(w, nil)
return
}
func HandleReadFile(w http.ResponseWriter, r *http.Request) {
qvals := r.URL.Query()
screenId := qvals.Get("screenid")
lineId := qvals.Get("lineid")
path := qvals.Get("path") // validate path?
contentType := qvals.Get("mimetype")
if contentType == "" {
contentType = "application/octet-stream"
}
if screenId == "" || lineId == "" {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("must specify sessionid, screenid, and lineid")))
return
}
if path == "" {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("must specify path")))
return
}
if _, err := uuid.Parse(screenId); err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid screenid: %v", err)))
return
}
if _, err := uuid.Parse(lineId); err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
return
}
if !ContentTypeHeaderValidRe.MatchString(contentType) {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid mimetype specified")))
return
}
_, cmd, err := sstore.GetLineCmdByLineId(r.Context(), screenId, lineId)
if err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
return
}
if cmd == nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid line, no cmd")))
return
}
if cmd.Remote.RemoteId == "" {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid line, no remote")))
return
}
msh := remote.GetRemoteById(cmd.Remote.RemoteId)
if msh == nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("invalid line, cannot resolve remote")))
return
}
rrState := msh.GetRemoteRuntimeState()
fullPath, err := rrState.ExpandHomeDir(path)
if err != nil {
WriteJsonError(w, fmt.Errorf("error expanding homedir: %v", err))
return
}
streamPk := packet.MakeStreamFilePacket()
streamPk.ReqId = uuid.New().String()
cwd := cmd.FeState["cwd"]
if filepath.IsAbs(fullPath) {
streamPk.Path = fullPath
} else {
streamPk.Path = filepath.Join(cwd, fullPath)
}
iter, err := msh.StreamFile(r.Context(), streamPk)
if err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("error trying to stream file: %v", err)))
return
}
defer iter.Close()
respIf, err := iter.Next(r.Context())
if err != nil {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("error getting streamfile response: %v", err)))
return
}
resp, ok := respIf.(*packet.StreamFileResponseType)
if !ok {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("bad response packet type: %T", respIf)))
return
}
if resp.Error != "" {
w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("error response: %s", resp.Error)))
return
}
infoJson, _ := json.Marshal(resp.Info)
w.Header().Set("X-FileInfo", base64.StdEncoding.EncodeToString(infoJson))
w.Header().Set("Content-Type", contentType)
w.WriteHeader(http.StatusOK)
for {
dataPkIf, err := iter.Next(r.Context())
if err != nil {
log.Printf("error in read-file while getting data: %v\n", err)
break
}
if dataPkIf == nil {
break
}
dataPk, ok := dataPkIf.(*packet.FileDataPacketType)
if !ok {
log.Printf("error in read-file, invalid data packet type: %T", dataPkIf)
break
}
if dataPk.Error != "" {
log.Printf("in read-file, data packet error: %s", dataPk.Error)
break
}
w.Write(dataPk.Data)
}
return
}
func WriteJsonError(w http.ResponseWriter, errVal error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
errMap := make(map[string]interface{})
errMap["error"] = errVal.Error()
barr, _ := json.Marshal(errMap)
w.Write(barr)
return
}
func WriteJsonSuccess(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
rtnMap := make(map[string]interface{})
rtnMap["success"] = true
if data != nil {
rtnMap["data"] = data
}
barr, err := json.Marshal(rtnMap)
if err != nil {
WriteJsonError(w, err)
return
}
w.WriteHeader(200)
w.Write(barr)
return
}
func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
defer func() {
r := recover()
if r == nil {
return
}
log.Printf("[error] in run-command: %v\n", r)
debug.PrintStack()
WriteJsonError(w, fmt.Errorf("panic: %v", r))
return
}()
w.Header().Set("Cache-Control", "no-cache")
decoder := json.NewDecoder(r.Body)
var commandPk scpacket.FeCommandPacketType
err := decoder.Decode(&commandPk)
if err != nil {
WriteJsonError(w, fmt.Errorf("error decoding json: %w", err))
return
}
update, err := cmdrunner.HandleCommand(r.Context(), &commandPk)
if err != nil {
WriteJsonError(w, err)
return
}
if update != nil {
update.Clean()
}
WriteJsonSuccess(w, update)
return
}
func AuthKeyWrap(fn WebFnType) WebFnType {
return func(w http.ResponseWriter, r *http.Request) {
reqAuthKey := r.Header.Get("X-AuthKey")
if reqAuthKey == "" {
w.WriteHeader(500)
w.Write([]byte("no x-authkey header"))
return
}
if reqAuthKey != GlobalAuthKey {
w.WriteHeader(500)
w.Write([]byte("x-authkey header is invalid"))
return
}
w.Header().Set("Cache-Control", "no-cache")
fn(w, r)
}
}
func runWebSocketServer() {
gr := mux.NewRouter()
gr.HandleFunc("/ws", HandleWs)
serverAddr := WebSocketServerAddr
if scbase.IsDevMode() {
serverAddr = WebSocketServerDevAddr
}
server := &http.Server{
Addr: serverAddr,
ReadTimeout: HttpReadTimeout,
WriteTimeout: HttpWriteTimeout,
MaxHeaderBytes: HttpMaxHeaderBytes,
Handler: gr,
}
server.SetKeepAlivesEnabled(false)
log.Printf("Running websocket server on %s\n", serverAddr)
err := server.ListenAndServe()
if err != nil {
log.Printf("[error] trying to run websocket server: %v\n", err)
}
}
func test() error {
return nil
}
func sendTelemetryWrapper() {
defer func() {
r := recover()
if r == nil {
return
}
log.Printf("[error] in sendTelemetryWrapper: %v\n", r)
debug.PrintStack()
return
}()
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
err := pcloud.SendTelemetry(ctx, false)
if err != nil {
log.Printf("[error] sending telemetry: %v\n", err)
}
}
func telemetryLoop() {
var lastSent time.Time
time.Sleep(InitialTelemetryWait)
for {
dur := time.Now().Sub(lastSent)
if lastSent.IsZero() || dur >= TelemetryInterval {
lastSent = time.Now()
sendTelemetryWrapper()
}
time.Sleep(TelemetryTick)
}
}
// watch stdin, kill server if stdin is closed
func stdinReadWatch() {
buf := make([]byte, 1024)
for {
_, err := os.Stdin.Read(buf)
if err != nil {
doShutdown(fmt.Sprintf("stdin closed/error (%v)", err))
break
}
}
}
// ignore SIGHUP
func installSignalHandlers() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGHUP)
go func() {
for sig := range sigCh {
doShutdown(fmt.Sprintf("got signal %v", sig))
break
}
}()
}
func doShutdown(reason string) {
shutdownOnce.Do(func() {
log.Printf("[prompt] local server %v, start shutdown\n", reason)
sendTelemetryWrapper()
log.Printf("[prompt] closing db connection\n")
sstore.CloseDB()
log.Printf("[prompt] *** shutting down local server\n")
time.Sleep(1 * time.Second)
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
time.Sleep(5 * time.Second)
syscall.Kill(syscall.Getpid(), syscall.SIGKILL)
})
}
func main() {
scbase.BuildTime = BuildTime
if len(os.Args) >= 2 && os.Args[1] == "--test" {
log.Printf("running test fn\n")
err := test()
if err != nil {
log.Printf("[error] %v\n", err)
}
return
}
scHomeDir := scbase.GetPromptHomeDir()
log.Printf("[prompt] *** starting local server\n")
log.Printf("[prompt] local server version %s+%s\n", scbase.PromptVersion, scbase.BuildTime)
log.Printf("[prompt] homedir = %q\n", scHomeDir)
scLock, err := scbase.AcquirePromptLock()
if err != nil || scLock == nil {
log.Printf("[error] cannot acquire prompt lock: %v\n", err)
return
}
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
err := sstore.MigrateCommandOpts(os.Args[1:])
if err != nil {
log.Printf("[error] migrate cmd: %v\n", err)
}
return
}
authKey, err := scbase.ReadPromptAuthKey()
if err != nil {
log.Printf("[error] %v\n", err)
return
}
GlobalAuthKey = authKey
err = sstore.TryMigrateUp()
if err != nil {
log.Printf("[error] migrate up: %v\n", err)
return
}
clientData, err := sstore.EnsureClientData(context.Background())
if err != nil {
log.Printf("[error] ensuring client data: %v\n", err)
return
}
log.Printf("userid = %s\n", clientData.UserId)
err = sstore.EnsureLocalRemote(context.Background())
if err != nil {
log.Printf("[error] ensuring local remote: %v\n", err)
return
}
_, err = sstore.EnsureDefaultSession(context.Background())
if err != nil {
log.Printf("[error] ensuring default session: %v\n", err)
return
}
err = remote.LoadRemotes(context.Background())
if err != nil {
log.Printf("[error] loading remotes: %v\n", err)
return
}
err = sstore.HangupAllRunningCmds(context.Background())
if err != nil {
log.Printf("[error] calling HUP on all running commands: %v\n", err)
}
err = sstore.ReInitFocus(context.Background())
if err != nil {
log.Printf("[error] resetting screen focus: %v\n", err)
}
log.Printf("PCLOUD_ENDPOINT=%s\n", pcloud.GetEndpoint())
err = sstore.UpdateCurrentActivity(context.Background(), sstore.ActivityUpdate{NumConns: remote.NumRemotes()}) // set at least one record into activity
if err != nil {
log.Printf("[error] updating activity: %v\n", err)
}
installSignalHandlers()
go telemetryLoop()
go stdinReadWatch()
go runWebSocketServer()
go func() {
time.Sleep(10 * time.Second)
pcloud.StartUpdateWriter()
}()
gr := mux.NewRouter()
gr.HandleFunc("/api/ptyout", AuthKeyWrap(HandleGetPtyOut))
gr.HandleFunc("/api/remote-pty", AuthKeyWrap(HandleRemotePty))
gr.HandleFunc("/api/rtnstate", AuthKeyWrap(HandleRtnState))
gr.HandleFunc("/api/get-screen-lines", AuthKeyWrap(HandleGetScreenLines))
gr.HandleFunc("/api/run-command", AuthKeyWrap(HandleRunCommand)).Methods("POST")
gr.HandleFunc("/api/get-client-data", AuthKeyWrap(HandleGetClientData))
gr.HandleFunc("/api/set-winsize", AuthKeyWrap(HandleSetWinSize))
gr.HandleFunc("/api/log-active-state", AuthKeyWrap(HandleLogActiveState))
gr.HandleFunc("/api/read-file", AuthKeyWrap(HandleReadFile))
gr.HandleFunc("/api/write-file", AuthKeyWrap(HandleWriteFile)).Methods("POST")
serverAddr := MainServerAddr
if scbase.IsDevMode() {
serverAddr = MainServerDevAddr
}
server := &http.Server{
Addr: serverAddr,
ReadTimeout: HttpReadTimeout,
WriteTimeout: HttpWriteTimeout,
MaxHeaderBytes: HttpMaxHeaderBytes,
Handler: http.TimeoutHandler(gr, HttpTimeoutDuration, "Timeout"),
}
server.SetKeepAlivesEnabled(false)
log.Printf("Running main server on %s\n", serverAddr)
err = server.ListenAndServe()
if err != nil {
log.Printf("ERROR: %v\n", err)
}
}

9
wavesrv/db/db.go Normal file
View File

@ -0,0 +1,9 @@
// provides the io/fs for DB migrations
package db
import "embed"
// since embeds must be relative to the package directory, this source file is required
//go:embed migrations/*.sql
var MigrationFS embed.FS

View File

@ -0,0 +1,13 @@
DROP TABLE client;
DROP TABLE session;
DROP TABLE window;
DROP TABLE screen;
DROP TABLE screen_window;
DROP TABLE remote_instance;
DROP TABLE line;
DROP TABLE remote;
DROP TABLE cmd;
DROP TABLE history;
DROP TABLE state_base;
DROP TABLE state_diff;

View File

@ -0,0 +1,167 @@
CREATE TABLE client (
clientid varchar(36) NOT NULL,
userid varchar(36) NOT NULL,
activesessionid varchar(36) NOT NULL,
userpublickeybytes blob NOT NULL,
userprivatekeybytes blob NOT NULL,
winsize json NOT NULL
);
CREATE TABLE session (
sessionid varchar(36) PRIMARY KEY,
name varchar(50) NOT NULL,
sessionidx int NOT NULL,
activescreenid varchar(36) NOT NULL,
notifynum int NOT NULL,
archived boolean NOT NULL,
archivedts bigint NOT NULL,
ownerid varchar(36) NOT NULL,
sharemode varchar(12) NOT NULL,
accesskey varchar(36) NOT NULL
);
CREATE TABLE window (
sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
curremoteownerid varchar(36) NOT NULL,
curremoteid varchar(36) NOT NULL,
curremotename varchar(50) NOT NULL,
nextlinenum int NOT NULL,
winopts json NOT NULL,
ownerid varchar(36) NOT NULL,
sharemode varchar(12) NOT NULL,
shareopts json NOT NULL,
PRIMARY KEY (sessionid, windowid)
);
CREATE TABLE screen (
sessionid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
name varchar(50) NOT NULL,
activewindowid varchar(36) NOT NULL,
screenidx int NOT NULL,
screenopts json NOT NULL,
ownerid varchar(36) NOT NULL,
sharemode varchar(12) NOT NULL,
incognito boolean NOT NULL,
archived boolean NOT NULL,
archivedts bigint NOT NULL,
PRIMARY KEY (sessionid, screenid)
);
CREATE TABLE screen_window (
sessionid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
name varchar(50) NOT NULL,
layout json NOT NULL,
selectedline int NOT NULL,
anchor json NOT NULL,
focustype varchar(12) NOT NULL,
PRIMARY KEY (sessionid, screenid, windowid)
);
CREATE TABLE remote_instance (
riid varchar(36) PRIMARY KEY,
name varchar(50) NOT NULL,
sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
remoteownerid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
festate json NOT NULL,
statebasehash varchar(36) NOT NULL,
statediffhasharr json NOT NULL
);
CREATE TABLE state_base (
basehash varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
version varchar(200) NOT NULL,
data blob NOT NULL
);
CREATE TABLE state_diff (
diffhash varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
basehash varchar(36) NOT NULL,
diffhasharr json NOT NULL,
data blob NOT NULL
);
CREATE TABLE line (
sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
userid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
ts bigint NOT NULL,
linenum int NOT NULL,
linenumtemp boolean NOT NULL,
linetype varchar(10) NOT NULL,
linelocal boolean NOT NULL,
text text NOT NULL,
cmdid varchar(36) NOT NULL,
ephemeral boolean NOT NULL,
contentheight int NOT NULL,
star int NOT NULL,
archived boolean NOT NULL,
PRIMARY KEY (sessionid, windowid, lineid)
);
CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY,
physicalid varchar(36) NOT NULL,
remotetype varchar(10) NOT NULL,
remotealias varchar(50) NOT NULL,
remotecanonicalname varchar(200) NOT NULL,
remotesudo boolean NOT NULL,
remoteuser varchar(50) NOT NULL,
remotehost varchar(200) NOT NULL,
connectmode varchar(20) NOT NULL,
autoinstall boolean NOT NULL,
sshopts json NOT NULL,
remoteopts json NOT NULL,
lastconnectts bigint NOT NULL,
local boolean NOT NULL,
archived boolean NOT NULL,
remoteidx int NOT NULL
);
CREATE TABLE cmd (
sessionid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
remoteownerid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
remotename varchar(50) NOT NULL,
cmdstr text NOT NULL,
festate json NOT NULL,
statebasehash varchar(36) NOT NULL,
statediffhasharr json NOT NULL,
termopts json NOT NULL,
origtermopts json NOT NULL,
status varchar(10) NOT NULL,
startpk json NOT NULL,
doneinfo json NOT NULL,
runout json NOT NULL,
rtnstate boolean NOT NULL,
rtnbasehash varchar(36) NOT NULL,
rtndiffhasharr json NOT NULL,
PRIMARY KEY (sessionid, cmdid)
);
CREATE TABLE history (
historyid varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
userid varchar(36) NOT NULL,
sessionid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
lineid int NOT NULL,
remoteownerid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
remotename varchar(50) NOT NULL,
haderror boolean NOT NULL,
cmdid varchar(36) NOT NULL,
cmdstr text NOT NULL,
ismetacmd boolean,
incognito boolean
);

View File

@ -0,0 +1,3 @@
DROP TABLE activity;
ALTER TABLE client DROP COLUMN clientopts;

View File

@ -0,0 +1,11 @@
CREATE TABLE activity (
day varchar(20) PRIMARY KEY,
uploaded boolean NOT NULL,
tdata json NOT NULL,
tzname varchar(50) NOT NULL,
tzoffset int NOT NULL,
clientversion varchar(20) NOT NULL,
clientarch varchar(20) NOT NULL
);
ALTER TABLE client ADD COLUMN clientopts json NOT NULL DEFAULT '';

View File

@ -0,0 +1,2 @@
ALTER TABLE line DROP COLUMN renderer;

View File

@ -0,0 +1,2 @@
ALTER TABLE line ADD COLUMN renderer varchar(50) NOT NULL DEFAULT '';

View File

@ -0,0 +1,6 @@
DROP TABLE bookmark;
DROP TABLE bookmark_order;
DROP TABLE bookmark_cmd;
ALTER TABLE line DROP COLUMN bookmarked;
ALTER TABLE line DROP COLUMN pinned;

View File

@ -0,0 +1,26 @@
CREATE TABLE bookmark (
bookmarkid varchar(36) PRIMARY KEY,
createdts bigint NOT NULL,
cmdstr text NOT NULL,
alias varchar(50) NOT NULL,
tags json NOT NULL,
description text NOT NULL
);
CREATE TABLE bookmark_order (
tag varchar(50) NOT NULL,
bookmarkid varchar(36) NOT NULL,
orderidx int NOT NULL,
PRIMARY KEY (tag, bookmarkid)
);
CREATE TABLE bookmark_cmd (
bookmarkid varchar(36) NOT NULL,
sessionid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
PRIMARY KEY (bookmarkid, sessionid, cmdid)
);
ALTER TABLE line ADD COLUMN bookmarked boolean NOT NULL DEFAULT 0;
ALTER TABLE line ADD COLUMN pinned boolean NOT NULL DEFAULT 0;

View File

@ -0,0 +1,2 @@
ALTER TABLE activity DROP COLUMN buildtime;
ALTER TABLE activity DROP COLUMN osrelease;

View File

@ -0,0 +1,2 @@
ALTER TABLE activity ADD COLUMN buildtime varchar(20) NOT NULL DEFAULT '-';
ALTER TABLE activity ADD COLUMN osrelease varchar(20) NOT NULL DEFAULT '-';

View File

@ -0,0 +1 @@
ALTER TABLE client DROP COLUMN feopts;

View File

@ -0,0 +1,3 @@
ALTER TABLE client ADD COLUMN feopts json NOT NULL DEFAULT '{}';

View File

@ -0,0 +1,3 @@
DROP TABLE playbook;
DROP TABLE playbook_entry;

View File

@ -0,0 +1,16 @@
CREATE TABLE playbook (
playbookid varchar(36) PRIMARY KEY,
playbookname varchar(100) NOT NULL,
description text NOT NULL,
entryids json NOT NULL
);
CREATE TABLE playbook_entry (
entryid varchar(36) PRIMARY KEY,
playbookid varchar(36) NOT NULL,
description text NOT NULL,
alias varchar(50) NOT NULL,
cmdstr text NOT NULL,
createdts bigint NOT NULL,
updatedts bigint NOT NULL
);

View File

@ -0,0 +1,5 @@
ALTER TABLE session ADD COLUMN accesskey DEFAULT '';
ALTER TABLE session ADD COLUMN ownerid DEFAULT '';
DROP TABLE cloud_session;
DROP TABLE cloud_update;

View File

@ -0,0 +1,20 @@
ALTER TABLE session DROP COLUMN accesskey;
ALTER TABLE session DROP COLUMN ownerid;
CREATE TABLE cloud_session (
sessionid varchar(36) PRIMARY KEY,
viewkey varchar(50) NOT NULL,
writekey varchar(50) NOT NULL,
enckey varchar(100) NOT NULL,
enctype varchar(50) NOT NULL,
vts bigint NOT NULL,
acl json NOT NULL
);
CREATE TABLE cloud_update (
updateid varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
updatetype varchar(50) NOT NULL,
updatekeys json NOT NULL
);

View File

@ -0,0 +1,3 @@
-- invalid, will throw an error, cannot migrate down
SELECT x;

View File

@ -0,0 +1,56 @@
CREATE TABLE new_screen (
sessionid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
name varchar(50) NOT NULL,
screenidx int NOT NULL,
screenopts json NOT NULL,
ownerid varchar(36) NOT NULL,
sharemode varchar(12) NOT NULL,
curremoteownerid varchar(36) NOT NULL,
curremoteid varchar(36) NOT NULL,
curremotename varchar(50) NOT NULL,
nextlinenum int NOT NULL,
selectedline int NOT NULL,
anchor json NOT NULL,
focustype varchar(12) NOT NULL,
archived boolean NOT NULL,
archivedts bigint NOT NULL,
PRIMARY KEY (sessionid, screenid)
);
INSERT INTO new_screen
SELECT
s.sessionid,
s.screenid,
w.windowid,
s.name,
s.screenidx,
json_patch(s.screenopts, w.winopts),
s.ownerid,
s.sharemode,
w.curremoteownerid,
w.curremoteid,
w.curremotename,
w.nextlinenum,
sw.selectedline,
sw.anchor,
sw.focustype,
s.archived,
s.archivedts
FROM
screen s,
screen_window sw,
window w
WHERE
s.screenid = sw.screenid
AND sw.windowid = w.windowid
;
DROP TABLE screen;
DROP TABLE screen_window;
DROP TABLE window;
ALTER TABLE new_screen RENAME TO screen;

View File

@ -0,0 +1,2 @@
-- invalid, will throw an error, cannot migrate down
SELECT x;

View File

@ -0,0 +1,17 @@
ALTER TABLE remote_instance RENAME COLUMN windowid TO screenid;
ALTER TABLE line RENAME COLUMN windowid TO screenid;
UPDATE remote_instance
SET screenid = COALESCE((SELECT screen.screenid FROM screen WHERE screen.windowid = remote_instance.screenid), '')
WHERE screenid <> ''
;
UPDATE line
SET screenid = COALESCE((SELECT screen.screenid FROM screen WHERE screen.windowid = line.screenid), '')
WHERE screenid <> ''
;
ALTER TABLE history DROP COLUMN windowid;
ALTER TABLE screen DROP COLUMN windowid;

View File

@ -0,0 +1,2 @@
ALTER TABLE cmd DROP COLUMN screenid;

View File

@ -0,0 +1,5 @@
ALTER TABLE cmd ADD COLUMN screenid varchar(36) NOT NULL DEFAULT '';
UPDATE cmd
SET screenid = (SELECT line.screenid FROM line WHERE line.cmdid = cmd.cmdid)
;

View File

@ -0,0 +1 @@
ALTER TABLE history DROP COLUMN linenum;

View File

@ -0,0 +1,6 @@
ALTER TABLE history ADD COLUMN linenum int NOT NULL DEFAULT 0;
UPDATE history
SET linenum = COALESCE((SELECT line.linenum FROM line WHERE line.lineid = history.lineid), 0)
;

View File

@ -0,0 +1,2 @@
-- invalid, will throw an error, cannot migrate down
SELECT x;

View File

@ -0,0 +1,123 @@
DELETE FROM cmd
WHERE screenid = '';
DELETE FROM line
WHERE screenid = '';
DELETE FROM cmd
WHERE cmdid NOT IN (SELECT cmdid FROM line);
DELETE FROM line
WHERE cmdid <> '' AND cmdid NOT IN (SELECT cmdid FROM cmd);
CREATE TABLE new_bookmark_cmd (
bookmarkid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
PRIMARY KEY (bookmarkid, screenid, cmdid)
);
INSERT INTO new_bookmark_cmd
SELECT
b.bookmarkid,
c.screenid,
c.cmdid
FROM bookmark_cmd b, cmd c
WHERE b.cmdid = c.cmdid;
DROP TABLE bookmark_cmd;
ALTER TABLE new_bookmark_cmd RENAME TO bookmark_cmd;
ALTER TABLE client ADD COLUMN cmdstoretype varchar(20) DEFAULT 'session';
CREATE TABLE cmd_migrate (
sessionid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL
);
INSERT INTO cmd_migrate
SELECT sessionid, screenid, cmdid
FROM cmd;
-- update primary key for screen
CREATE TABLE new_screen (
screenid varchar(36) NOT NULL,
sessionid varchar(36) NOT NULL,
name varchar(50) NOT NULL,
screenidx int NOT NULL,
screenopts json NOT NULL,
ownerid varchar(36) NOT NULL,
sharemode varchar(12) NOT NULL,
curremoteownerid varchar(36) NOT NULL,
curremoteid varchar(36) NOT NULL,
curremotename varchar(50) NOT NULL,
nextlinenum int NOT NULL,
selectedline int NOT NULL,
anchor json NOT NULL,
focustype varchar(12) NOT NULL,
archived boolean NOT NULL,
archivedts bigint NOT NULL,
PRIMARY KEY (screenid)
);
INSERT INTO new_screen
SELECT screenid, sessionid, name, screenidx, screenopts, ownerid, sharemode,
curremoteownerid, curremoteid, curremotename, nextlinenum, selectedline,
anchor, focustype, archived, archivedts
FROM screen;
DROP TABLE screen;
ALTER TABLE new_screen RENAME TO screen;
-- drop sessionid from line
CREATE TABLE new_line (
screenid varchar(36) NOT NULL,
userid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
ts bigint NOT NULL,
linenum int NOT NULL,
linenumtemp boolean NOT NULL,
linetype varchar(10) NOT NULL,
linelocal boolean NOT NULL,
text text NOT NULL,
cmdid varchar(36) NOT NULL,
ephemeral boolean NOT NULL,
contentheight int NOT NULL,
star int NOT NULL,
archived boolean NOT NULL,
renderer varchar(50) NOT NULL,
bookmarked boolean NOT NULL,
PRIMARY KEY (screenid, lineid)
);
INSERT INTO new_line
SELECT screenid, userid, lineid, ts, linenum, linenumtemp, linetype, linelocal,
text, cmdid, ephemeral, contentheight, star, archived, renderer, bookmarked
FROM line;
DROP TABLE line;
ALTER TABLE new_line RENAME TO line;
-- drop sessionid from cmd
CREATE TABLE new_cmd (
screenid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
remoteownerid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
remotename varchar(50) NOT NULL,
cmdstr text NOT NULL,
rawcmdstr text NOT NULL,
festate json NOT NULL,
statebasehash varchar(36) NOT NULL,
statediffhasharr json NOT NULL,
termopts json NOT NULL,
origtermopts json NOT NULL,
status varchar(10) NOT NULL,
startpk json NOT NULL,
doneinfo json NOT NULL,
runout json NOT NULL,
rtnstate boolean NOT NULL,
rtnbasehash varchar(36) NOT NULL,
rtndiffhasharr json NOT NULL,
PRIMARY KEY (screenid, cmdid)
);
INSERT INTO new_cmd
SELECT screenid, cmdid, remoteownerid, remoteid, remotename, cmdstr, cmdstr,
festate, statebasehash, statediffhasharr, termopts, origtermopts, status, startpk, doneinfo, runout, rtnstate, rtnbasehash, rtndiffhasharr
FROM cmd;
DROP TABLE cmd;
ALTER TABLE new_cmd RENAME TO cmd;

View File

@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS "bookmark_cmd" (
bookmarkid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
PRIMARY KEY (bookmarkid, screenid, cmdid)
);
ALTER TABLE line ADD COLUMN bookmarked boolean NOT NULL DEFAULT 0;

View File

@ -0,0 +1,3 @@
DROP TABLE bookmark_cmd;
ALTER TABLE line DROP COLUMN bookmarked;

View File

@ -0,0 +1,4 @@
DROP TABLE screenupdate;
ALTER TABLE screen DROP COLUMN webshareopts;

View File

@ -0,0 +1,10 @@
CREATE TABLE screenupdate (
updateid integer PRIMARY KEY,
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
updatetype varchar(50) NOT NULL,
updatets bigint NOT NULL
);
ALTER TABLE screen ADD COLUMN webshareopts json NOT NULL DEFAULT 'null';

View File

@ -0,0 +1,3 @@
DROP TABLE webptypos;
DROP INDEX idx_screenupdate_ids;

View File

@ -0,0 +1,8 @@
CREATE TABLE webptypos (
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
ptypos bigint NOT NULL,
PRIMARY KEY (screenid, lineid)
);
CREATE INDEX idx_screenupdate_ids ON screenupdate (screenid, lineid);

View File

@ -0,0 +1,2 @@
ALTER TABLE remote DROP COLUMN statevars;

View File

@ -0,0 +1,2 @@
ALTER TABLE remote ADD COLUMN statevars json NOT NULL DEFAULT '{}';

View File

@ -0,0 +1,14 @@
ALTER TABLE remote ADD COLUMN remotesudo;
UPDATE remote
SET remotesudo = 1
WHERE json_extract(sshopts, '$.issudo')
;
UPDATE remote
SET sshopts = json_remove(sshopts, '$.issudo')
;
ALTER TABLE remote ADD COLUMN physicalid varchar(36) NOT NULL DEFAULT '';
ALTER TABLE remote DROP COLUMN openaiopts;

View File

@ -0,0 +1,11 @@
UPDATE remote
SET sshopts = json_set(sshopts, '$.issudo', json('true'))
WHERE remotesudo
;
ALTER TABLE remote DROP COLUMN remotesudo;
ALTER TABLE remote DROP COLUMN physicalid;
ALTER TABLE remote ADD COLUMN openaiopts json NOT NULL DEFAULT '{}';

View File

@ -0,0 +1 @@
ALTER TABLE client DROP COLUMN openaiopts;

View File

@ -0,0 +1,2 @@
ALTER TABLE client ADD COLUMN openaiopts json NOT NULL DEFAULT '{}';

View File

@ -0,0 +1,2 @@
-- invalid, will throw an error, cannot migrate down
SELECT x;

View File

@ -0,0 +1,74 @@
-- remove cmdid from line, history, and cmd (use lineid everywhere)
CREATE TABLE cmd_new (
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
remoteownerid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
remotename varchar(50) NOT NULL,
cmdstr text NOT NULL,
rawcmdstr text NOT NULL,
festate json NOT NULL,
statebasehash varchar(36) NOT NULL,
statediffhasharr json NOT NULL,
termopts json NOT NULL,
origtermopts json NOT NULL,
status varchar(10) NOT NULL,
cmdpid int NOT NULL,
remotepid int NOT NULL,
donets bigint NOT NULL,
exitcode int NOT NULL,
durationms int NOT NULL,
rtnstate boolean NOT NULL,
rtnbasehash varchar(36) NOT NULL,
rtndiffhasharr json NOT NULL,
runout json NOT NULL,
PRIMARY KEY (screenid, lineid)
);
CREATE TABLE cmd_migrate20 (
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
PRIMARY KEY (screenid, lineid)
);
INSERT INTO cmd_migrate20
SELECT screenid, lineid, cmdid
FROM line
WHERE cmdid <> ''
;
INSERT INTO cmd_new
SELECT
c.screenid,
l.lineid,
c.remoteownerid,
c.remoteid,
c.remotename,
c.cmdstr,
c.rawcmdstr,
c.festate,
c.statebasehash,
c.statediffhasharr,
c.termopts,
c.origtermopts,
c.status,
coalesce(json_extract(startpk, '$.pid'), 0),
coalesce(json_extract(startpk, '$.mshellpid'), 0),
coalesce(json_extract(doneinfo, '$.ts'), 0),
coalesce(json_extract(doneinfo, '$.exitcode'), 0),
coalesce(json_extract(doneinfo, '$.durationms'), 0),
c.rtnstate,
c.rtnbasehash,
c.rtndiffhasharr,
c.runout
FROM cmd c
JOIN line l ON (l.cmdid = c.cmdid);
DROP TABLE cmd;
ALTER TABLE cmd_new RENAME TO cmd;
ALTER TABLE history DROP COLUMN cmdid;
ALTER TABLE line DROP COLUMN cmdid;

View File

@ -0,0 +1 @@
ALTER TABLE line DROP COLUMN linestate;

View File

@ -0,0 +1 @@
ALTER TABLE line ADD COLUMN linestate json NOT NULL DEFAULT '{}';

View File

@ -0,0 +1 @@
-- no down migration

View File

@ -0,0 +1 @@
UPDATE screen SET sharemode = 'local' AND webshareopts = 'null';

219
wavesrv/db/schema.sql Normal file
View File

@ -0,0 +1,219 @@
CREATE TABLE schema_migrations (version uint64,dirty bool);
CREATE UNIQUE INDEX version_unique ON schema_migrations (version);
CREATE TABLE client (
clientid varchar(36) NOT NULL,
userid varchar(36) NOT NULL,
activesessionid varchar(36) NOT NULL,
userpublickeybytes blob NOT NULL,
userprivatekeybytes blob NOT NULL,
winsize json NOT NULL
, clientopts json NOT NULL DEFAULT '', feopts json NOT NULL DEFAULT '{}', cmdstoretype varchar(20) DEFAULT 'session', openaiopts json NOT NULL DEFAULT '{}');
CREATE TABLE session (
sessionid varchar(36) PRIMARY KEY,
name varchar(50) NOT NULL,
sessionidx int NOT NULL,
activescreenid varchar(36) NOT NULL,
notifynum int NOT NULL,
archived boolean NOT NULL,
archivedts bigint NOT NULL,
sharemode varchar(12) NOT NULL);
CREATE TABLE remote_instance (
riid varchar(36) PRIMARY KEY,
name varchar(50) NOT NULL,
sessionid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
remoteownerid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
festate json NOT NULL,
statebasehash varchar(36) NOT NULL,
statediffhasharr json NOT NULL
);
CREATE TABLE state_base (
basehash varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
version varchar(200) NOT NULL,
data blob NOT NULL
);
CREATE TABLE state_diff (
diffhash varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
basehash varchar(36) NOT NULL,
diffhasharr json NOT NULL,
data blob NOT NULL
);
CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY,
remotetype varchar(10) NOT NULL,
remotealias varchar(50) NOT NULL,
remotecanonicalname varchar(200) NOT NULL,
remoteuser varchar(50) NOT NULL,
remotehost varchar(200) NOT NULL,
connectmode varchar(20) NOT NULL,
autoinstall boolean NOT NULL,
sshopts json NOT NULL,
remoteopts json NOT NULL,
lastconnectts bigint NOT NULL,
local boolean NOT NULL,
archived boolean NOT NULL,
remoteidx int NOT NULL
, statevars json NOT NULL DEFAULT '{}', openaiopts json NOT NULL DEFAULT '{}');
CREATE TABLE history (
historyid varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
userid varchar(36) NOT NULL,
sessionid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
lineid int NOT NULL,
remoteownerid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
remotename varchar(50) NOT NULL,
haderror boolean NOT NULL,
cmdstr text NOT NULL,
ismetacmd boolean,
incognito boolean
, linenum int NOT NULL DEFAULT 0);
CREATE TABLE activity (
day varchar(20) PRIMARY KEY,
uploaded boolean NOT NULL,
tdata json NOT NULL,
tzname varchar(50) NOT NULL,
tzoffset int NOT NULL,
clientversion varchar(20) NOT NULL,
clientarch varchar(20) NOT NULL
, buildtime varchar(20) NOT NULL DEFAULT '-', osrelease varchar(20) NOT NULL DEFAULT '-');
CREATE TABLE bookmark (
bookmarkid varchar(36) PRIMARY KEY,
createdts bigint NOT NULL,
cmdstr text NOT NULL,
alias varchar(50) NOT NULL,
tags json NOT NULL,
description text NOT NULL
);
CREATE TABLE bookmark_order (
tag varchar(50) NOT NULL,
bookmarkid varchar(36) NOT NULL,
orderidx int NOT NULL,
PRIMARY KEY (tag, bookmarkid)
);
CREATE TABLE playbook (
playbookid varchar(36) PRIMARY KEY,
playbookname varchar(100) NOT NULL,
description text NOT NULL,
entryids json NOT NULL
);
CREATE TABLE playbook_entry (
entryid varchar(36) PRIMARY KEY,
playbookid varchar(36) NOT NULL,
description text NOT NULL,
alias varchar(50) NOT NULL,
cmdstr text NOT NULL,
createdts bigint NOT NULL,
updatedts bigint NOT NULL
);
CREATE TABLE cloud_session (
sessionid varchar(36) PRIMARY KEY,
viewkey varchar(50) NOT NULL,
writekey varchar(50) NOT NULL,
enckey varchar(100) NOT NULL,
enctype varchar(50) NOT NULL,
vts bigint NOT NULL,
acl json NOT NULL
);
CREATE TABLE cloud_update (
updateid varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
updatetype varchar(50) NOT NULL,
updatekeys json NOT NULL
);
CREATE TABLE cmd_migrate (
sessionid varchar(36) NOT NULL,
screenid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL
);
CREATE TABLE IF NOT EXISTS "screen" (
screenid varchar(36) NOT NULL,
sessionid varchar(36) NOT NULL,
name varchar(50) NOT NULL,
screenidx int NOT NULL,
screenopts json NOT NULL,
ownerid varchar(36) NOT NULL,
sharemode varchar(12) NOT NULL,
curremoteownerid varchar(36) NOT NULL,
curremoteid varchar(36) NOT NULL,
curremotename varchar(50) NOT NULL,
nextlinenum int NOT NULL,
selectedline int NOT NULL,
anchor json NOT NULL,
focustype varchar(12) NOT NULL,
archived boolean NOT NULL,
archivedts bigint NOT NULL, webshareopts json NOT NULL DEFAULT 'null',
PRIMARY KEY (screenid)
);
CREATE TABLE IF NOT EXISTS "line" (
screenid varchar(36) NOT NULL,
userid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
ts bigint NOT NULL,
linenum int NOT NULL,
linenumtemp boolean NOT NULL,
linetype varchar(10) NOT NULL,
linelocal boolean NOT NULL,
text text NOT NULL,
ephemeral boolean NOT NULL,
contentheight int NOT NULL,
star int NOT NULL,
archived boolean NOT NULL,
renderer varchar(50) NOT NULL, linestate json NOT NULL DEFAULT '{}',
PRIMARY KEY (screenid, lineid)
);
CREATE TABLE screenupdate (
updateid integer PRIMARY KEY,
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
updatetype varchar(50) NOT NULL,
updatets bigint NOT NULL
);
CREATE TABLE webptypos (
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
ptypos bigint NOT NULL,
PRIMARY KEY (screenid, lineid)
);
CREATE INDEX idx_screenupdate_ids ON screenupdate (screenid, lineid);
CREATE TABLE cmd_migration (
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
PRIMARY KEY (screenid, lineid)
);
CREATE TABLE IF NOT EXISTS "cmd" (
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
remoteownerid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
remotename varchar(50) NOT NULL,
cmdstr text NOT NULL,
rawcmdstr text NOT NULL,
festate json NOT NULL,
statebasehash varchar(36) NOT NULL,
statediffhasharr json NOT NULL,
termopts json NOT NULL,
origtermopts json NOT NULL,
status varchar(10) NOT NULL,
cmdpid int NOT NULL,
remotepid int NOT NULL,
donets bigint NOT NULL,
exitcode int NOT NULL,
durationms int NOT NULL,
rtnstate boolean NOT NULL,
rtnbasehash varchar(36) NOT NULL,
rtndiffhasharr json NOT NULL,
runout json NOT NULL,
PRIMARY KEY (screenid, lineid)
);
CREATE TABLE cmd_migrate20 (
screenid varchar(36) NOT NULL,
lineid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
PRIMARY KEY (screenid, lineid)
);

28
wavesrv/go.mod Normal file
View File

@ -0,0 +1,28 @@
module github.com/commandlinedev/prompt-server
go 1.18
require (
github.com/alessio/shellescape v1.4.1
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2
github.com/commandlinedev/apishell v0.0.0
github.com/creack/pty v1.1.18
github.com/golang-migrate/migrate/v4 v4.16.2
github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/jmoiron/sqlx v1.3.5
github.com/mattn/go-sqlite3 v1.14.16
github.com/sashabaranov/go-openai v1.9.0
github.com/sawka/txwrap v0.1.2
golang.org/x/crypto v0.7.0
golang.org/x/mod v0.10.0
golang.org/x/sys v0.10.0
mvdan.cc/sh/v3 v3.7.0
)
require (
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
go.uber.org/atomic v1.7.0 // indirect
)

64
wavesrv/go.sum Normal file
View File

@ -0,0 +1,64 @@
github.com/alessio/shellescape v1.4.1 h1:V7yhSDDn8LP4lc4jS8pFkt0zCnzVJlG5JXy9BVKJUX0=
github.com/alessio/shellescape v1.4.1/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30=
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloDxZfhMm0xrLXZS8+COSu2bXmEQs=
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang-migrate/migrate/v4 v4.16.2 h1:8coYbMKUyInrFk1lfGfRovTLAW7PhWp8qQDT2iKfuoA=
github.com/golang-migrate/migrate/v4 v4.16.2/go.mod h1:pfcJX4nPHaVdc5nmdCikFBWtm+UBpiZjRNNsyBbp0/o=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.10.1-0.20230524175051-ec119421bb97 h1:3RPlVWzZ/PDqmVuf/FKHARG5EMid/tl7cv54Sw/QRVY=
github.com/rogpeppe/go-internal v1.10.1-0.20230524175051-ec119421bb97/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/sashabaranov/go-openai v1.9.0 h1:NoiO++IISxxJ1pRc0n7uZvMGMake0G+FJ1XPwXtprsA=
github.com/sashabaranov/go-openai v1.9.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sawka/txwrap v0.1.2 h1:v8xS0Z1LE7/6vMZA81PYihI+0TSR6Zm1MalzzBIuXKc=
github.com/sawka/txwrap v0.1.2/go.mod h1:T3nlw2gVpuolo6/XEetvBbk1oMXnY978YmBFy1UyHvw=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
mvdan.cc/sh/v3 v3.7.0 h1:lSTjdP/1xsddtaKfGg7Myu7DnlHItd3/M2tomOcNNBg=
mvdan.cc/sh/v3 v3.7.0/go.mod h1:K2gwkaesF/D7av7Kxl0HbF5kGOd2ArupNTX3X44+8l8=

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,516 @@
package cmdrunner
import (
"context"
"fmt"
"log"
"regexp"
"strconv"
"strings"
"github.com/commandlinedev/prompt-server/pkg/remote"
"github.com/commandlinedev/prompt-server/pkg/scpacket"
"github.com/commandlinedev/prompt-server/pkg/sstore"
"github.com/google/uuid"
)
const (
R_Session = 1
R_Screen = 2
R_Remote = 8
R_RemoteConnected = 16
)
type resolvedIds struct {
SessionId string
ScreenId string
Remote *ResolvedRemote
}
type ResolvedRemote struct {
DisplayName string
RemotePtr sstore.RemotePtrType
MShell *remote.MShellProc
RState remote.RemoteRuntimeState
RemoteCopy *sstore.RemoteType
StatePtr *sstore.ShellStatePtr
FeState map[string]string
}
type ResolveItem = sstore.ResolveItem
func itemNames(items []ResolveItem) []string {
if len(items) == 0 {
return nil
}
rtn := make([]string, len(items))
for idx, item := range items {
rtn[idx] = item.Name
}
return rtn
}
func sessionsToResolveItems(sessions []*sstore.SessionType) []ResolveItem {
if len(sessions) == 0 {
return nil
}
rtn := make([]ResolveItem, len(sessions))
for idx, session := range sessions {
rtn[idx] = ResolveItem{Name: session.Name, Id: session.SessionId, Hidden: session.Archived}
}
return rtn
}
func screensToResolveItems(screens []*sstore.ScreenType) []ResolveItem {
if len(screens) == 0 {
return nil
}
rtn := make([]ResolveItem, len(screens))
for idx, screen := range screens {
rtn[idx] = ResolveItem{Name: screen.Name, Id: screen.ScreenId, Hidden: screen.Archived}
}
return rtn
}
// 1-indexed
func boundInt(ival int, maxVal int, wrap bool) int {
if maxVal == 0 {
return 0
}
if ival < 1 {
if wrap {
return maxVal
} else {
return 1
}
}
if ival > maxVal {
if wrap {
return 1
} else {
return maxVal
}
}
return ival
}
type posArgType struct {
Pos int
IsWrap bool
IsRelative bool
StartAnchor bool
EndAnchor bool
}
func parsePosArg(posStr string) *posArgType {
if !positionRe.MatchString(posStr) {
return nil
}
if posStr == "+" {
return &posArgType{Pos: 1, IsWrap: true, IsRelative: true}
} else if posStr == "-" {
return &posArgType{Pos: -1, IsWrap: true, IsRelative: true}
} else if posStr == "S" {
return &posArgType{Pos: 0, IsRelative: true, StartAnchor: true}
} else if posStr == "E" {
return &posArgType{Pos: 0, IsRelative: true, EndAnchor: true}
}
if strings.HasPrefix(posStr, "S+") {
pos, _ := strconv.Atoi(posStr[2:])
return &posArgType{Pos: pos, IsRelative: true, StartAnchor: true}
}
if strings.HasPrefix(posStr, "E-") {
pos, _ := strconv.Atoi(posStr[1:])
return &posArgType{Pos: pos, IsRelative: true, EndAnchor: true}
}
if strings.HasPrefix(posStr, "+") || strings.HasPrefix(posStr, "-") {
pos, _ := strconv.Atoi(posStr)
return &posArgType{Pos: pos, IsRelative: true}
}
pos, _ := strconv.Atoi(posStr)
return &posArgType{Pos: pos}
}
func resolveByPosition(isNumeric bool, allItems []ResolveItem, curId string, posStr string) *ResolveItem {
items := make([]ResolveItem, 0, len(allItems))
for _, item := range allItems {
if !item.Hidden {
items = append(items, item)
}
}
if len(items) == 0 {
return nil
}
posArg := parsePosArg(posStr)
if posArg == nil {
return nil
}
var finalPos int
if posArg.IsRelative {
var curIdx int
if posArg.StartAnchor {
curIdx = 1
} else if posArg.EndAnchor {
curIdx = len(items)
} else {
curIdx = 1 // if no match, curIdx will be first item
for idx, item := range items {
if item.Id == curId {
curIdx = idx + 1
break
}
}
}
finalPos = curIdx + posArg.Pos
finalPos = boundInt(finalPos, len(items), posArg.IsWrap)
return &items[finalPos-1]
} else if isNumeric {
// these resolve items have a "Num" set that should be used to look up non-relative positions
// use allItems for numeric resolve
for _, item := range allItems {
if item.Num == posArg.Pos {
return &item
}
}
return nil
} else {
// non-numeric means position is just the index
finalPos = posArg.Pos
if finalPos <= 0 || finalPos > len(items) {
return nil
}
return &items[finalPos-1]
}
}
func resolveRemoteArg(remoteArg string) (*sstore.RemotePtrType, error) {
rrUser, rrRemote, rrName, err := parseFullRemoteRef(remoteArg)
if err != nil {
return nil, err
}
if rrUser != "" {
return nil, fmt.Errorf("remoteusers not supported")
}
msh := remote.GetRemoteByArg(rrRemote)
if msh == nil {
return nil, nil
}
rcopy := msh.GetRemoteCopy()
return &sstore.RemotePtrType{RemoteId: rcopy.RemoteId, Name: rrName}, nil
}
func resolveUiIds(ctx context.Context, pk *scpacket.FeCommandPacketType, rtype int) (resolvedIds, error) {
rtn := resolvedIds{}
uictx := pk.UIContext
if uictx != nil {
rtn.SessionId = uictx.SessionId
rtn.ScreenId = uictx.ScreenId
}
if pk.Kwargs["session"] != "" {
sessionId, err := resolveSessionArg(pk.Kwargs["session"])
if err != nil {
return rtn, err
}
if sessionId != "" {
rtn.SessionId = sessionId
}
}
if pk.Kwargs["screen"] != "" {
screenId, err := resolveScreenArg(rtn.SessionId, pk.Kwargs["screen"])
if err != nil {
return rtn, err
}
if screenId != "" {
rtn.ScreenId = screenId
}
}
var rptr *sstore.RemotePtrType
var err error
if pk.Kwargs["remote"] != "" {
rptr, err = resolveRemoteArg(pk.Kwargs["remote"])
if err != nil {
return rtn, err
}
if rptr == nil {
return rtn, fmt.Errorf("invalid remote argument %q passed, remote not found", pk.Kwargs["remote"])
}
} else if uictx.Remote != nil {
rptr = uictx.Remote
}
if rptr != nil {
err = rptr.Validate()
if err != nil {
return rtn, fmt.Errorf("invalid resolved remote: %v", err)
}
rr, err := ResolveRemoteFromPtr(ctx, rptr, rtn.SessionId, rtn.ScreenId)
if err != nil {
return rtn, err
}
rtn.Remote = rr
}
if rtype&R_Session > 0 && rtn.SessionId == "" {
return rtn, fmt.Errorf("no session")
}
if rtype&R_Screen > 0 && rtn.ScreenId == "" {
return rtn, fmt.Errorf("no screen")
}
if (rtype&R_Remote > 0 || rtype&R_RemoteConnected > 0) && rtn.Remote == nil {
return rtn, fmt.Errorf("no remote")
}
if rtype&R_RemoteConnected > 0 {
if !rtn.Remote.RState.IsConnected() {
err = rtn.Remote.MShell.TryAutoConnect()
if err != nil {
return rtn, fmt.Errorf("error trying to auto-connect remote [%s]: %w", rtn.Remote.DisplayName, err)
}
rrNew, err := ResolveRemoteFromPtr(ctx, rptr, rtn.SessionId, rtn.ScreenId)
if err != nil {
return rtn, err
}
rtn.Remote = rrNew
}
if !rtn.Remote.RState.IsConnected() {
return rtn, fmt.Errorf("remote [%s] is not connected", rtn.Remote.DisplayName)
}
if rtn.Remote.StatePtr == nil || rtn.Remote.FeState == nil {
return rtn, fmt.Errorf("remote [%s] state is not available", rtn.Remote.DisplayName)
}
}
return rtn, nil
}
func resolveSessionScreen(ctx context.Context, sessionId string, screenArg string, curScreenArg string) (*ResolveItem, error) {
screens, err := sstore.GetSessionScreens(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("could not retreive screens for session=%s: %v", sessionId, err)
}
ritems := screensToResolveItems(screens)
return genericResolve(screenArg, curScreenArg, ritems, false, "screen")
}
func resolveSession(ctx context.Context, sessionArg string, curSessionArg string) (*ResolveItem, error) {
bareSessions, err := sstore.GetBareSessions(ctx)
if err != nil {
return nil, err
}
ritems := sessionsToResolveItems(bareSessions)
ritem, err := genericResolve(sessionArg, curSessionArg, ritems, false, "session")
if err != nil {
return nil, err
}
return ritem, nil
}
func resolveLine(ctx context.Context, sessionId string, screenId string, lineArg string, curLineArg string) (*ResolveItem, error) {
lines, err := sstore.GetLineResolveItems(ctx, screenId)
if err != nil {
return nil, fmt.Errorf("could not get lines: %v", err)
}
return genericResolve(lineArg, curLineArg, lines, true, "line")
}
func getSessionIds(sarr []*sstore.SessionType) []string {
rtn := make([]string, len(sarr))
for idx, s := range sarr {
rtn[idx] = s.SessionId
}
return rtn
}
var partialUUIDRe = regexp.MustCompile("^[0-9a-f]{8}$")
func isPartialUUID(s string) bool {
return partialUUIDRe.MatchString(s)
}
func isUUID(s string) bool {
_, err := uuid.Parse(s)
return err == nil
}
func getResolveItemById(id string, items []ResolveItem) *ResolveItem {
if id == "" {
return nil
}
for _, item := range items {
if item.Id == id {
return &item
}
}
return nil
}
func genericResolve(arg string, curArg string, items []ResolveItem, isNumeric bool, typeStr string) (*ResolveItem, error) {
if len(items) == 0 || arg == "" {
return nil, nil
}
var curId string
if curArg != "" {
curItem, _ := genericResolve(curArg, "", items, isNumeric, typeStr)
if curItem != nil {
curId = curItem.Id
}
}
rtnItem := resolveByPosition(isNumeric, items, curId, arg)
if rtnItem != nil {
return rtnItem, nil
}
isUuid := isUUID(arg)
tryPuid := isPartialUUID(arg)
var prefixMatches []ResolveItem
for _, item := range items {
if (isUuid && item.Id == arg) || (tryPuid && strings.HasPrefix(item.Id, arg)) {
return &item, nil
}
if item.Name != "" {
if item.Name == arg {
return &item, nil
}
if !item.Hidden && strings.HasPrefix(item.Name, arg) {
prefixMatches = append(prefixMatches, item)
}
}
}
if len(prefixMatches) == 1 {
return &prefixMatches[0], nil
}
if len(prefixMatches) > 1 {
return nil, fmt.Errorf("could not resolve %s '%s', ambiguious prefix matched multiple %ss: %s", typeStr, arg, typeStr, formatStrs(itemNames(prefixMatches), "and", true))
}
return nil, fmt.Errorf("could not resolve %s '%s' (name/id/pos not found)", typeStr, arg)
}
func resolveSessionId(pk *scpacket.FeCommandPacketType) (string, error) {
sessionId := pk.Kwargs["session"]
if sessionId == "" {
return "", nil
}
if _, err := uuid.Parse(sessionId); err != nil {
return "", fmt.Errorf("invalid sessionid '%s'", sessionId)
}
return sessionId, nil
}
func resolveSessionArg(sessionArg string) (string, error) {
if sessionArg == "" {
return "", nil
}
if _, err := uuid.Parse(sessionArg); err != nil {
return "", fmt.Errorf("invalid session arg specified (must be sessionid) '%s'", sessionArg)
}
return sessionArg, nil
}
func resolveScreenArg(sessionId string, screenArg string) (string, error) {
if screenArg == "" {
return "", nil
}
if _, err := uuid.Parse(screenArg); err != nil {
return "", fmt.Errorf("invalid screen arg specified (must be screenid) '%s'", screenArg)
}
return screenArg, nil
}
func resolveScreenId(ctx context.Context, pk *scpacket.FeCommandPacketType, sessionId string) (string, error) {
screenArg := pk.Kwargs["screen"]
if screenArg == "" {
return "", nil
}
if _, err := uuid.Parse(screenArg); err == nil {
return screenArg, nil
}
if sessionId == "" {
return "", fmt.Errorf("cannot resolve screen without session")
}
ritem, err := resolveSessionScreen(ctx, sessionId, screenArg, "")
if err != nil {
return "", err
}
return ritem.Id, nil
}
// returns (remoteuserref, remoteref, name, error)
func parseFullRemoteRef(fullRemoteRef string) (string, string, string, error) {
if strings.HasPrefix(fullRemoteRef, "[") && strings.HasSuffix(fullRemoteRef, "]") {
fullRemoteRef = fullRemoteRef[1 : len(fullRemoteRef)-1]
}
fields := strings.Split(fullRemoteRef, ":")
if len(fields) > 3 {
return "", "", "", fmt.Errorf("invalid remote format '%s'", fullRemoteRef)
}
if len(fields) == 1 {
return "", fields[0], "", nil
}
if len(fields) == 2 {
if strings.HasPrefix(fields[0], "@") {
return fields[0], fields[1], "", nil
}
return "", fields[0], fields[1], nil
}
return fields[0], fields[1], fields[2], nil
}
func ResolveRemoteFromPtr(ctx context.Context, rptr *sstore.RemotePtrType, sessionId string, screenId string) (*ResolvedRemote, error) {
if rptr == nil || rptr.RemoteId == "" {
return nil, nil
}
msh := remote.GetRemoteById(rptr.RemoteId)
if msh == nil {
return nil, fmt.Errorf("invalid remote '%s', not found", rptr.RemoteId)
}
rstate := msh.GetRemoteRuntimeState()
rcopy := msh.GetRemoteCopy()
displayName := rstate.GetDisplayName(rptr)
rtn := &ResolvedRemote{
DisplayName: displayName,
RemotePtr: *rptr,
RState: rstate,
MShell: msh,
RemoteCopy: &rcopy,
StatePtr: nil,
FeState: nil,
}
if sessionId != "" && screenId != "" {
ri, err := sstore.GetRemoteInstance(ctx, sessionId, screenId, *rptr)
if err != nil {
log.Printf("ERROR resolving remote state '%s': %v\n", displayName, err)
// continue with state set to nil
} else {
if ri == nil {
rtn.StatePtr = msh.GetDefaultStatePtr()
rtn.FeState = msh.GetDefaultFeState()
} else {
rtn.StatePtr = &sstore.ShellStatePtr{BaseHash: ri.StateBaseHash, DiffHashArr: ri.StateDiffHashArr}
rtn.FeState = ri.FeState
}
}
}
return rtn, nil
}
// returns (remoteDisplayName, remoteptr, state, rstate, err)
func resolveRemote(ctx context.Context, fullRemoteRef string, sessionId string, screenId string) (string, *sstore.RemotePtrType, *remote.RemoteRuntimeState, error) {
if fullRemoteRef == "" {
return "", nil, nil, nil
}
userRef, remoteRef, remoteName, err := parseFullRemoteRef(fullRemoteRef)
if err != nil {
return "", nil, nil, err
}
if userRef != "" {
return "", nil, nil, fmt.Errorf("invalid remote '%s', cannot resolve remote userid '%s'", fullRemoteRef, userRef)
}
rstate := remote.ResolveRemoteRef(remoteRef)
if rstate == nil {
return "", nil, nil, fmt.Errorf("cannot resolve remote '%s': not found", fullRemoteRef)
}
rptr := sstore.RemotePtrType{RemoteId: rstate.RemoteId, Name: remoteName}
rname := rstate.RemoteCanonicalName
if rstate.RemoteAlias != "" {
rname = rstate.RemoteAlias
}
if rptr.Name != "" {
rname = fmt.Sprintf("%s:%s", rname, rptr.Name)
}
return rname, &rptr, rstate, nil
}

View File

@ -0,0 +1,333 @@
package cmdrunner
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/commandlinedev/apishell/pkg/shexec"
"github.com/commandlinedev/apishell/pkg/simpleexpand"
"github.com/commandlinedev/prompt-server/pkg/scpacket"
"github.com/commandlinedev/prompt-server/pkg/utilfn"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax"
)
var ValidMetaCmdRe = regexp.MustCompile("^/([a-z_][a-z0-9_-]*)(?::([a-z][a-z0-9_-]*))?$")
type BareMetaCmdDecl struct {
CmdStr string
MetaCmd string
}
var BareMetaCmds = []BareMetaCmdDecl{
BareMetaCmdDecl{"cr", "cr"},
BareMetaCmdDecl{"connect", "cr"},
BareMetaCmdDecl{"clear", "clear"},
BareMetaCmdDecl{"reset", "reset"},
BareMetaCmdDecl{"codeedit", "codeedit"},
BareMetaCmdDecl{"codeview", "codeview"},
BareMetaCmdDecl{"imageview", "imageview"},
BareMetaCmdDecl{"markdownview", "markdownview"},
BareMetaCmdDecl{"mdview", "markdownview"},
BareMetaCmdDecl{"csvview", "csvview"},
}
const (
CmdParseTypePositional = "pos"
CmdParseTypeRaw = "raw"
)
var CmdParseOverrides map[string]string = map[string]string{
"setenv": CmdParseTypePositional,
"unset": CmdParseTypePositional,
"set": CmdParseTypePositional,
"run": CmdParseTypeRaw,
"comment": CmdParseTypeRaw,
"chat": CmdParseTypeRaw,
}
func DumpPacket(pk *scpacket.FeCommandPacketType) {
if pk == nil || pk.MetaCmd == "" {
fmt.Printf("[no metacmd]\n")
return
}
if pk.MetaSubCmd == "" {
fmt.Printf("/%s\n", pk.MetaCmd)
} else {
fmt.Printf("/%s:%s\n", pk.MetaCmd, pk.MetaSubCmd)
}
for _, arg := range pk.Args {
fmt.Printf(" %q\n", arg)
}
for key, val := range pk.Kwargs {
fmt.Printf(" [%s]=%q\n", key, val)
}
}
func isQuoted(source string, w *syntax.Word) bool {
if w == nil {
return false
}
offset := w.Pos().Offset()
if int(offset) >= len(source) {
return false
}
return source[offset] == '"' || source[offset] == '\''
}
func getSourceStr(source string, w *syntax.Word) string {
if w == nil {
return ""
}
offset := w.Pos().Offset()
end := w.End().Offset()
return source[offset:end]
}
func SubMetaCmd(cmd string) string {
switch cmd {
case "s":
return "screen"
case "r":
return "run"
case "c":
return "comment"
case "e":
return "eval"
case "export":
return "setenv"
case "connection":
return "remote"
default:
return cmd
}
}
// returns (metaCmd, metaSubCmd, rest)
// if metaCmd is "" then this isn't a valid metacmd string
func parseMetaCmd(origCommandStr string) (string, string, string) {
commandStr := strings.TrimSpace(origCommandStr)
if len(commandStr) < 2 {
return "run", "", origCommandStr
}
fields := strings.SplitN(commandStr, " ", 2)
firstArg := fields[0]
rest := ""
if len(fields) > 1 {
rest = strings.TrimSpace(fields[1])
}
for _, decl := range BareMetaCmds {
if firstArg == decl.CmdStr {
return decl.MetaCmd, "", rest
}
}
m := ValidMetaCmdRe.FindStringSubmatch(firstArg)
if m == nil {
return "run", "", origCommandStr
}
return SubMetaCmd(m[1]), m[2], rest
}
func onlyPositionalArgs(metaCmd string, metaSubCmd string) bool {
return (CmdParseOverrides[metaCmd] == CmdParseTypePositional) && metaSubCmd == ""
}
func onlyRawArgs(metaCmd string, metaSubCmd string) bool {
return CmdParseOverrides[metaCmd] == CmdParseTypeRaw
}
func setBracketArgs(argMap map[string]string, bracketStr string) error {
bracketStr = strings.TrimSpace(bracketStr)
if bracketStr == "" {
return nil
}
strReader := strings.NewReader(bracketStr)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
var wordErr error
var ectx simpleexpand.SimpleExpandContext // do not set HomeDir (we don't expand ~ in bracket args)
err := parser.Words(strReader, func(w *syntax.Word) bool {
litStr, _ := simpleexpand.SimpleExpandWord(ectx, w, bracketStr)
eqIdx := strings.Index(litStr, "=")
var varName, varVal string
if eqIdx == -1 {
varName = litStr
} else {
varName = litStr[0:eqIdx]
varVal = litStr[eqIdx+1:]
}
if !shexec.IsValidBashIdentifier(varName) {
wordErr = fmt.Errorf("invalid identifier %s in bracket args", utilfn.ShellQuote(varName, true, 20))
return false
}
if varVal == "" {
varVal = "1"
}
argMap[varName] = varVal
return true
})
if err != nil {
return err
}
if wordErr != nil {
return wordErr
}
return nil
}
var literalRtnStateCommands = []string{".", "source", "unset", "cd", "alias", "unalias", "deactivate"}
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {
if len(callExpr.Args) <= argNum {
return ""
}
arg := callExpr.Args[argNum]
if len(arg.Parts) == 0 {
return ""
}
lit, ok := arg.Parts[0].(*syntax.Lit)
if !ok {
return ""
}
return lit.Value
}
// detects: export, declare, ., source, X=1, unset
func IsReturnStateCommand(cmdStr string) bool {
cmdReader := strings.NewReader(cmdStr)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(cmdReader, "cmd")
if err != nil {
return false
}
for _, stmt := range file.Stmts {
if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok {
if len(callExpr.Assigns) > 0 && len(callExpr.Args) == 0 {
return true
}
arg0 := getCallExprLitArg(callExpr, 0)
if arg0 != "" && utilfn.ContainsStr(literalRtnStateCommands, arg0) {
return true
}
if arg0 == "git" {
arg1 := getCallExprLitArg(callExpr, 1)
if arg1 == "checkout" || arg1 == "switch" {
return true
}
}
} else if _, ok := stmt.Cmd.(*syntax.DeclClause); ok {
return true
}
}
return false
}
func EvalBracketArgs(origCmdStr string) (map[string]string, string, error) {
rtn := make(map[string]string)
if strings.HasPrefix(origCmdStr, " ") {
rtn["nohist"] = "1"
}
cmdStr := strings.TrimSpace(origCmdStr)
if !strings.HasPrefix(cmdStr, "[") {
return rtn, origCmdStr, nil
}
rbIdx := strings.Index(cmdStr, "]")
if rbIdx == -1 {
return nil, "", fmt.Errorf("unmatched '[' found in command")
}
bracketStr := cmdStr[1:rbIdx]
restStr := strings.TrimSpace(cmdStr[rbIdx+1:])
err := setBracketArgs(rtn, bracketStr)
if err != nil {
return nil, "", err
}
return rtn, restStr, nil
}
func unescapeBackSlashes(s string) string {
if strings.Index(s, "\\") == -1 {
return s
}
var newStr []rune
var lastSlash bool
for _, r := range s {
if lastSlash {
lastSlash = false
newStr = append(newStr, r)
continue
}
if r == '\\' {
lastSlash = true
continue
}
newStr = append(newStr, r)
}
return string(newStr)
}
func EvalMetaCommand(ctx context.Context, origPk *scpacket.FeCommandPacketType) (*scpacket.FeCommandPacketType, error) {
if len(origPk.Args) == 0 {
return nil, fmt.Errorf("empty command (no fields)")
}
if strings.TrimSpace(origPk.Args[0]) == "" {
return nil, fmt.Errorf("empty command")
}
bracketArgs, cmdStr, err := EvalBracketArgs(origPk.Args[0])
if err != nil {
return nil, err
}
metaCmd, metaSubCmd, commandArgs := parseMetaCmd(cmdStr)
rtnPk := scpacket.MakeFeCommandPacket()
rtnPk.MetaCmd = metaCmd
rtnPk.MetaSubCmd = metaSubCmd
rtnPk.Kwargs = make(map[string]string)
rtnPk.UIContext = origPk.UIContext
rtnPk.RawStr = origPk.RawStr
for key, val := range origPk.Kwargs {
rtnPk.Kwargs[key] = val
}
for key, val := range bracketArgs {
rtnPk.Kwargs[key] = val
}
if onlyRawArgs(metaCmd, metaSubCmd) {
// don't evaluate arguments for /run or /comment
rtnPk.Args = []string{commandArgs}
return rtnPk, nil
}
commandReader := strings.NewReader(commandArgs)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
var words []*syntax.Word
err = parser.Words(commandReader, func(w *syntax.Word) bool {
words = append(words, w)
return true
})
if err != nil {
return nil, fmt.Errorf("parsing metacmd, position %v", err)
}
envMap := make(map[string]string) // later we can add vars like session, screen, remote, and user
cfg := shexec.GetParserConfig(envMap)
// process arguments
for idx, w := range words {
literalVal, err := expand.Literal(cfg, w)
if err != nil {
return nil, fmt.Errorf("error evaluating metacmd argument %d [%s]: %v", idx+1, getSourceStr(commandArgs, w), err)
}
if isQuoted(commandArgs, w) || onlyPositionalArgs(metaCmd, metaSubCmd) {
rtnPk.Args = append(rtnPk.Args, literalVal)
continue
}
eqIdx := strings.Index(literalVal, "=")
if eqIdx != -1 && eqIdx != 0 {
varName := literalVal[:eqIdx]
varVal := literalVal[eqIdx+1:]
rtnPk.Kwargs[varName] = varVal
continue
}
rtnPk.Args = append(rtnPk.Args, unescapeBackSlashes(literalVal))
}
if resolveBool(rtnPk.Kwargs["dump"], false) {
DumpPacket(rtnPk)
}
return rtnPk, nil
}

View File

@ -0,0 +1,57 @@
package cmdrunner
import (
"fmt"
"os"
"testing"
)
func xTestParseAliases(t *testing.T) {
m, err := ParseAliases(`
alias cdg='cd work/gopath/src/github.com/sawka'
alias s='scripthaus'
alias x='ls;ls"'
alias foo="bar \"hello\""
alias x=y
`)
if err != nil {
fmt.Printf("err: %v\n", err)
return
}
fmt.Printf("m: %#v\n", m)
}
func xTestParseFuncs(t *testing.T) {
file, err := os.ReadFile("./linux-decls.txt")
if err != nil {
t.Fatalf("error reading linux-decls: %v", err)
}
m, err := ParseFuncs(string(file))
if err != nil {
t.Fatalf("error parsing funcs: %v", err)
}
for key, val := range m {
fmt.Printf("func: %s %d\n", key, len(val))
}
}
func testRSC(t *testing.T, cmd string, expected bool) {
rtn := IsReturnStateCommand(cmd)
if rtn != expected {
t.Errorf("cmd [%s], rtn=%v, expected=%v", cmd, rtn, expected)
}
}
func TestIsReturnStateCommand(t *testing.T) {
testRSC(t, "FOO=1", true)
testRSC(t, "FOO=1 X=2", true)
testRSC(t, "ls", false)
testRSC(t, "export X", true)
testRSC(t, "export X=1", true)
testRSC(t, "declare -x FOO=1", true)
testRSC(t, "source ./test", true)
testRSC(t, "unset FOO BAR", true)
testRSC(t, "FOO=1; ls", true)
testRSC(t, ". ./test", true)
testRSC(t, "{ FOO=6; }", false)
}

View File

@ -0,0 +1,120 @@
package cmdrunner
import (
"fmt"
"strconv"
"strings"
"github.com/commandlinedev/apishell/pkg/base"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/apishell/pkg/shexec"
"github.com/commandlinedev/prompt-server/pkg/remote"
"github.com/commandlinedev/prompt-server/pkg/sstore"
)
// PTERM=MxM,Mx25
// PTERM="Mx25!"
// PTERM=80x25,80x35
type PTermOptsType struct {
Rows string
RowsFlex bool
Cols string
ColsFlex bool
}
const PTermMax = "M"
func isDigits(s string) bool {
for _, ch := range s {
if ch < '0' || ch > '9' {
return false
}
}
return true
}
func atoiDefault(s string, def int) int {
ival, err := strconv.Atoi(s)
if err != nil {
return def
}
return ival
}
func parseTermPart(part string, partType string) (string, bool, error) {
flex := true
if strings.HasSuffix(part, "!") {
part = part[:len(part)-1]
flex = false
}
if part == "" {
return PTermMax, flex, nil
}
if part == PTermMax {
return PTermMax, flex, nil
}
if !isDigits(part) {
return "", false, fmt.Errorf("invalid PTERM %s: must be '%s' or [number]", partType, PTermMax)
}
return part, flex, nil
}
func parseSingleTermStr(s string) (*PTermOptsType, error) {
s = strings.TrimSpace(s)
xIdx := strings.Index(s, "x")
if xIdx == -1 {
return nil, fmt.Errorf("invalid PTERM, must include 'x' to separate width and height (e.g. WxH)")
}
rowsPart := s[0:xIdx]
colsPart := s[xIdx+1:]
rows, rowsFlex, err := parseTermPart(rowsPart, "rows")
if err != nil {
return nil, err
}
cols, colsFlex, err := parseTermPart(colsPart, "cols")
if err != nil {
return nil, err
}
return &PTermOptsType{Rows: rows, RowsFlex: rowsFlex, Cols: cols, ColsFlex: colsFlex}, nil
}
func GetUITermOpts(winSize *packet.WinSize, ptermStr string) (*packet.TermOpts, error) {
opts, err := parseSingleTermStr(ptermStr)
if err != nil {
return nil, err
}
termOpts := &packet.TermOpts{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols, Term: remote.DefaultTerm, MaxPtySize: shexec.DefaultMaxPtySize}
if winSize == nil {
winSize = &packet.WinSize{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols}
}
if winSize.Rows == 0 {
winSize.Rows = shexec.DefaultTermRows
}
if winSize.Cols == 0 {
winSize.Cols = shexec.DefaultTermCols
}
if opts.Rows == PTermMax {
termOpts.Rows = winSize.Rows
} else {
termOpts.Rows = atoiDefault(opts.Rows, termOpts.Rows)
}
if opts.Cols == PTermMax {
termOpts.Cols = winSize.Cols
} else {
termOpts.Cols = atoiDefault(opts.Cols, termOpts.Cols)
}
termOpts.MaxPtySize = base.BoundInt64(termOpts.MaxPtySize, shexec.MinMaxPtySize, shexec.MaxMaxPtySize)
termOpts.Cols = base.BoundInt(termOpts.Cols, shexec.MinTermCols, shexec.MaxTermCols)
termOpts.Rows = base.BoundInt(termOpts.Rows, shexec.MinTermRows, shexec.MaxTermRows)
return termOpts, nil
}
func convertTermOpts(pkto *packet.TermOpts) *sstore.TermOpts {
return &sstore.TermOpts{
Rows: int64(pkto.Rows),
Cols: int64(pkto.Cols),
FlexRows: true,
MaxPtySize: pkto.MaxPtySize,
}
}

640
wavesrv/pkg/comp/comp.go Normal file
View File

@ -0,0 +1,640 @@
// scripthaus completion
package comp
import (
"bytes"
"context"
"fmt"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
"github.com/commandlinedev/apishell/pkg/simpleexpand"
"github.com/commandlinedev/prompt-server/pkg/shparse"
"github.com/commandlinedev/prompt-server/pkg/sstore"
"github.com/commandlinedev/prompt-server/pkg/utilfn"
"mvdan.cc/sh/v3/syntax"
)
const MaxCompQuoteLen = 5000
const (
// local to simplecomp
CGTypeCommand = "command"
CGTypeFile = "file"
CGTypeDir = "directory"
CGTypeVariable = "variable"
// implemented in cmdrunner
CGTypeMeta = "metacmd"
CGTypeCommandMeta = "command+meta"
CGTypeRemote = "remote"
CGTypeRemoteInstance = "remoteinstance"
CGTypeGlobalCmd = "globalcmd"
)
const (
QuoteTypeLiteral = ""
QuoteTypeDQ = "\""
QuoteTypeANSI = "$'"
QuoteTypeSQ = "'"
)
type CompContext struct {
RemotePtr *sstore.RemotePtrType
Cwd string
ForDisplay bool
}
type ParsedWord struct {
Offset int
Word *syntax.Word
PartialWord string
Prefix string
}
type CompPoint struct {
StmtStr string
Words []ParsedWord
CompWord int
CompWordPos int
Prefix string
Suffix string
}
// directories will have a trailing "/"
type CompEntry struct {
Word string
IsMetaCmd bool
}
type CompReturn struct {
CompType string
Entries []CompEntry
HasMore bool
}
var noEscChars []bool
var specialEsc []string
func init() {
noEscChars = make([]bool, 256)
for ch := 0; ch < 256; ch++ {
if (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
ch == '-' || ch == '.' || ch == '/' || ch == ':' || ch == '=' {
noEscChars[byte(ch)] = true
}
}
specialEsc = make([]string, 256)
specialEsc[0x7] = "\\a"
specialEsc[0x8] = "\\b"
specialEsc[0x9] = "\\t"
specialEsc[0xa] = "\\n"
specialEsc[0xb] = "\\v"
specialEsc[0xc] = "\\f"
specialEsc[0xd] = "\\r"
specialEsc[0x1b] = "\\E"
}
func compQuoteDQString(s string, close bool) string {
var buf bytes.Buffer
buf.WriteByte('"')
for _, ch := range s {
if ch == '"' || ch == '\\' || ch == '$' || ch == '`' {
buf.WriteByte('\\')
buf.WriteRune(ch)
continue
}
buf.WriteRune(ch)
}
if close {
buf.WriteByte('"')
}
return buf.String()
}
func hasGlob(s string) bool {
var lastExtGlob bool
for _, ch := range s {
if ch == '*' || ch == '?' || ch == '[' || ch == '{' {
return true
}
if ch == '+' || ch == '@' || ch == '!' {
lastExtGlob = true
continue
}
if lastExtGlob && ch == '(' {
return true
}
lastExtGlob = false
}
return false
}
func writeUtf8Literal(buf *bytes.Buffer, ch rune) {
var runeArr [utf8.UTFMax]byte
buf.WriteString("$'")
barr := runeArr[:]
byteLen := utf8.EncodeRune(barr, ch)
for i := 0; i < byteLen; i++ {
buf.WriteString("\\x")
buf.WriteByte(utilfn.HexDigits[barr[i]/16])
buf.WriteByte(utilfn.HexDigits[barr[i]%16])
}
buf.WriteByte('\'')
}
func compQuoteLiteralString(s string) string {
var buf bytes.Buffer
for idx, ch := range s {
if ch == 0 {
break
}
if idx == 0 && ch == '~' {
buf.WriteRune(ch)
continue
}
if ch > unicode.MaxASCII {
writeUtf8Literal(&buf, ch)
continue
}
var bch = byte(ch)
if noEscChars[bch] {
buf.WriteRune(ch)
continue
}
if specialEsc[bch] != "" {
buf.WriteString(specialEsc[bch])
continue
}
if !unicode.IsPrint(ch) {
writeUtf8Literal(&buf, ch)
continue
}
buf.WriteByte('\\')
buf.WriteByte(bch)
}
return buf.String()
}
func compQuoteSQString(s string) string {
var buf bytes.Buffer
for _, ch := range s {
if ch == 0 {
break
}
if ch == '\'' {
buf.WriteString("'\\''")
continue
}
var bch byte
if ch <= unicode.MaxASCII {
bch = byte(ch)
}
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
buf.WriteByte('\'')
if bch != 0 && specialEsc[bch] != "" {
buf.WriteString(specialEsc[bch])
} else {
writeUtf8Literal(&buf, ch)
}
buf.WriteByte('\'')
continue
}
buf.WriteByte(bch)
}
return buf.String()
}
func compQuoteString(s string, quoteType string, close bool) string {
if quoteType != QuoteTypeANSI && quoteType != QuoteTypeLiteral {
for _, ch := range s {
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) || ch == '!' {
quoteType = QuoteTypeANSI
break
}
if ch == '\'' {
if quoteType == QuoteTypeSQ {
quoteType = QuoteTypeANSI
break
}
}
}
}
if quoteType == QuoteTypeANSI {
rtn := strconv.QuoteToASCII(s)
rtn = "$'" + strings.ReplaceAll(rtn[1:len(rtn)-1], "'", "\\'")
if close {
rtn = rtn + "'"
}
return rtn
}
if quoteType == QuoteTypeLiteral {
return compQuoteLiteralString(s)
}
if quoteType == QuoteTypeSQ {
rtn := utilfn.ShellQuote(s, false, MaxCompQuoteLen)
if len(rtn) > 0 && rtn[0] != '\'' {
rtn = "'" + rtn + "'"
}
if !close {
rtn = rtn[0 : len(rtn)-1]
}
return rtn
}
// QuoteTypeDQ
return compQuoteDQString(s, close)
}
func (p *CompPoint) wordAsStr(w ParsedWord) string {
if w.Word != nil {
return p.StmtStr[w.Word.Pos().Offset():w.Word.End().Offset()]
}
return w.PartialWord
}
func (p *CompPoint) simpleExpandWord(w ParsedWord) (string, simpleexpand.SimpleExpandInfo) {
ectx := simpleexpand.SimpleExpandContext{}
if w.Word != nil {
return simpleexpand.SimpleExpandWord(ectx, w.Word, p.StmtStr)
}
return simpleexpand.SimpleExpandPartialWord(ectx, w.PartialWord, false)
}
func getQuoteTypePref(str string) string {
if strings.HasPrefix(str, QuoteTypeANSI) {
return QuoteTypeANSI
}
if strings.HasPrefix(str, QuoteTypeDQ) {
return QuoteTypeDQ
}
if strings.HasPrefix(str, QuoteTypeSQ) {
return QuoteTypeSQ
}
return QuoteTypeLiteral
}
func (p *CompPoint) getCompPrefix() (string, simpleexpand.SimpleExpandInfo) {
if p.CompWordPos == 0 {
return "", simpleexpand.SimpleExpandInfo{}
}
pword := p.Words[p.CompWord]
wordStr := p.wordAsStr(pword)
if p.CompWordPos == len(wordStr) {
return p.simpleExpandWord(pword)
}
// TODO we can do better, if p.Word is not nil, we can look for which WordPart
// our pos is in. we can then do a normal word expand on the previous parts
// and a partial on just the current part. this is an uncommon case though
// and has very little upside (even bash does not expand multipart words correctly)
partialWordStr := wordStr[:p.CompWordPos]
return simpleexpand.SimpleExpandPartialWord(simpleexpand.SimpleExpandContext{}, partialWordStr, false)
}
func (p *CompPoint) extendWord(newWord string, newWordComplete bool) utilfn.StrWithPos {
pword := p.Words[p.CompWord]
wordStr := p.wordAsStr(pword)
quotePref := getQuoteTypePref(wordStr)
needsClose := newWordComplete && (len(wordStr) == p.CompWordPos)
wordSuffix := wordStr[p.CompWordPos:]
newQuotedStr := compQuoteString(newWord, quotePref, needsClose)
if needsClose && wordSuffix == "" && !strings.HasSuffix(newWord, "/") {
newQuotedStr = newQuotedStr + " "
}
newPos := len(newQuotedStr)
return utilfn.StrWithPos{Str: newQuotedStr + wordSuffix, Pos: newPos}
}
// returns (extension, complete)
func computeCompExtension(compPrefix string, crtn *CompReturn) (string, bool) {
if crtn == nil || crtn.HasMore {
return "", false
}
compStrs := crtn.GetCompStrs()
lcp := utilfn.LongestPrefix(compPrefix, compStrs)
if lcp == compPrefix || len(lcp) < len(compPrefix) || !strings.HasPrefix(lcp, compPrefix) {
return "", false
}
return lcp[len(compPrefix):], (utilfn.ContainsStr(compStrs, lcp) && !utilfn.IsPrefix(compStrs, lcp))
}
func (p *CompPoint) FullyExtend(crtn *CompReturn) utilfn.StrWithPos {
if crtn == nil || crtn.HasMore {
return utilfn.StrWithPos{Str: p.getOrigStr(), Pos: p.getOrigPos()}
}
compStrs := crtn.GetCompStrs()
compPrefix, _ := p.getCompPrefix()
lcp := utilfn.LongestPrefix(compPrefix, compStrs)
if lcp == compPrefix || len(lcp) < len(compPrefix) || !strings.HasPrefix(lcp, compPrefix) {
return utilfn.StrWithPos{Str: p.getOrigStr(), Pos: p.getOrigPos()}
}
newStr := p.extendWord(lcp, utilfn.ContainsStr(compStrs, lcp))
var buf bytes.Buffer
buf.WriteString(p.Prefix)
for idx, w := range p.Words {
if idx == p.CompWord {
buf.WriteString(w.Prefix)
buf.WriteString(newStr.Str)
} else {
buf.WriteString(w.Prefix)
buf.WriteString(p.wordAsStr(w))
}
}
buf.WriteString(p.Suffix)
compWord := p.Words[p.CompWord]
newPos := len(p.Prefix) + compWord.Offset + len(compWord.Prefix) + newStr.Pos
return utilfn.StrWithPos{Str: buf.String(), Pos: newPos}
}
func (p *CompPoint) dump() {
if p.Prefix != "" {
fmt.Printf("prefix: %s\n", p.Prefix)
}
fmt.Printf("cpos: %d %d\n", p.CompWord, p.CompWordPos)
for idx, w := range p.Words {
fmt.Printf("w[%d]: ", idx)
if w.Prefix != "" {
fmt.Printf("{%s}", w.Prefix)
}
if idx == p.CompWord {
fmt.Printf("%s\n", utilfn.StrWithPos{Str: p.wordAsStr(w), Pos: p.CompWordPos})
} else {
fmt.Printf("%s\n", p.wordAsStr(w))
}
}
if p.Suffix != "" {
fmt.Printf("suffix: %s\n", p.Suffix)
}
fmt.Printf("\n")
}
var SimpleCompGenFns map[string]SimpleCompGenFnType
func isWhitespace(str string) bool {
return strings.TrimSpace(str) == ""
}
func splitInitialWhitespace(str string) (string, string) {
for pos, ch := range str { // rune iteration :/
if !unicode.IsSpace(ch) {
return str[:pos], str[pos:]
}
}
return str, ""
}
func ParseCompPoint(cmdStr utilfn.StrWithPos) *CompPoint {
fullCmdStr := cmdStr.Str
pos := cmdStr.Pos
// fmt.Printf("---\n")
// fmt.Printf("cmd: %s\n", strWithCursor(fullCmdStr, pos))
// first, find the stmt that the pos appears in
cmdReader := strings.NewReader(fullCmdStr)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
var foundStmt *syntax.Stmt
var lastStmt *syntax.Stmt
var restStartPos int
parser.Stmts(cmdReader, func(stmt *syntax.Stmt) bool { // ignore parse errors (since stmtStr will be the unparsed part)
restStartPos = int(stmt.End().Offset())
lastStmt = stmt
if uint(pos) >= stmt.Pos().Offset() && uint(pos) < stmt.End().Offset() {
foundStmt = stmt
return false
}
// fmt.Printf("stmt: [[%s]] %d:%d (%d)\n", fullCmdStr[stmt.Pos().Offset():stmt.End().Offset()], stmt.Pos().Offset(), stmt.End().Offset(), stmt.Semicolon.Offset())
return true
})
restStr := fullCmdStr[restStartPos:]
if foundStmt == nil && lastStmt != nil && isWhitespace(restStr) && lastStmt.Semicolon.Offset() == 0 {
foundStmt = lastStmt
}
var rtnPoint CompPoint
var stmtStr string
var stmtPos int
if foundStmt != nil {
stmtPos = pos - int(foundStmt.Pos().Offset())
rtnPoint.Prefix = fullCmdStr[:foundStmt.Pos().Offset()]
if isWhitespace(fullCmdStr[foundStmt.End().Offset():]) {
stmtStr = fullCmdStr[foundStmt.Pos().Offset():]
rtnPoint.Suffix = ""
} else {
stmtStr = fullCmdStr[foundStmt.Pos().Offset():foundStmt.End().Offset()]
rtnPoint.Suffix = fullCmdStr[foundStmt.End().Offset():]
}
} else {
stmtStr = restStr
stmtPos = pos - restStartPos
rtnPoint.Prefix = fullCmdStr[:restStartPos]
rtnPoint.Suffix = fullCmdStr[restStartPos+len(stmtStr):]
}
if stmtPos > len(stmtStr) {
// this should not happen and will cause a jump in completed strings
stmtPos = len(stmtStr)
}
// fmt.Printf("found: ((%s))%s((%s))\n", rtnPoint.Prefix, strWithCursor(stmtStr, stmtPos), rtnPoint.Suffix)
// now, find the word that the pos appears in within the stmt above
rtnPoint.StmtStr = stmtStr
stmtReader := strings.NewReader(stmtStr)
lastWordPos := 0
parser.Words(stmtReader, func(w *syntax.Word) bool {
var pword ParsedWord
pword.Offset = lastWordPos
if int(w.Pos().Offset()) > lastWordPos {
pword.Prefix = stmtStr[lastWordPos:w.Pos().Offset()]
}
pword.Word = w
rtnPoint.Words = append(rtnPoint.Words, pword)
lastWordPos = int(w.End().Offset())
return true
})
if lastWordPos < len(stmtStr) {
pword := ParsedWord{Offset: lastWordPos}
pword.Prefix, pword.PartialWord = splitInitialWhitespace(stmtStr[lastWordPos:])
rtnPoint.Words = append(rtnPoint.Words, pword)
}
if len(rtnPoint.Words) == 0 {
rtnPoint.Words = append(rtnPoint.Words, ParsedWord{})
}
for idx, w := range rtnPoint.Words {
wordLen := len(rtnPoint.wordAsStr(w))
if stmtPos > w.Offset && stmtPos <= w.Offset+len(w.Prefix)+wordLen {
rtnPoint.CompWord = idx
rtnPoint.CompWordPos = stmtPos - w.Offset - len(w.Prefix)
if rtnPoint.CompWordPos < 0 {
splitCompWord(&rtnPoint)
}
}
}
return &rtnPoint
}
func splitCompWord(p *CompPoint) {
w := p.Words[p.CompWord]
prefixPos := p.CompWordPos + len(w.Prefix)
w1 := ParsedWord{Offset: w.Offset, Prefix: w.Prefix[:prefixPos]}
w2 := ParsedWord{Offset: w.Offset + prefixPos, Prefix: w.Prefix[prefixPos:], Word: w.Word, PartialWord: w.PartialWord}
p.CompWord = p.CompWord // the same (w1)
p.CompWordPos = 0 // will be at 0 since w1 has a word length of 0
var newWords []ParsedWord
if p.CompWord > 0 {
newWords = append(newWords, p.Words[0:p.CompWord]...)
}
newWords = append(newWords, w1, w2)
newWords = append(newWords, p.Words[p.CompWord+1:]...)
p.Words = newWords
}
func getCompType(compPos shparse.CompletionPos) string {
switch compPos.CompType {
case shparse.CompTypeCommandMeta:
return CGTypeCommandMeta
case shparse.CompTypeCommand:
return CGTypeCommand
case shparse.CompTypeVar:
return CGTypeVariable
case shparse.CompTypeArg, shparse.CompTypeBasic, shparse.CompTypeAssignment:
return CGTypeFile
default:
return CGTypeFile
}
}
func fixupVarPrefix(varPrefix string) string {
if strings.HasPrefix(varPrefix, "${") {
varPrefix = varPrefix[2:]
if strings.HasSuffix(varPrefix, "}") {
varPrefix = varPrefix[:len(varPrefix)-1]
}
} else if strings.HasPrefix(varPrefix, "$") {
varPrefix = varPrefix[1:]
}
return varPrefix
}
func DoCompGen(ctx context.Context, cmdStr utilfn.StrWithPos, compCtx CompContext) (*CompReturn, *utilfn.StrWithPos, error) {
words := shparse.Tokenize(cmdStr.Str)
cmds := shparse.ParseCommands(words)
compPos := shparse.FindCompletionPos(cmds, cmdStr.Pos)
if compPos.CompType == shparse.CompTypeInvalid {
return nil, nil, nil
}
var compPrefix string
if compPos.CompWord != nil {
var info shparse.ExpandInfo
compPrefix, info = shparse.SimpleExpandPrefix(shparse.ExpandContext{}, compPos.CompWord, compPos.CompWordOffset)
if info.HasGlob || info.HasExtGlob || info.HasHistory || info.HasSpecial {
return nil, nil, nil
}
if compPos.CompType != shparse.CompTypeVar && info.HasVar {
return nil, nil, nil
}
if compPos.CompType == shparse.CompTypeVar {
compPrefix = fixupVarPrefix(compPrefix)
}
}
scType := getCompType(compPos)
crtn, err := DoSimpleComp(ctx, scType, compPrefix, compCtx, nil)
if err != nil {
return nil, nil, err
}
if compCtx.ForDisplay {
return crtn, nil, nil
}
extensionStr, extensionComplete := computeCompExtension(compPrefix, crtn)
if extensionStr == "" {
return crtn, nil, nil
}
rtnSP := compPos.Extend(cmdStr, extensionStr, extensionComplete)
return crtn, &rtnSP, nil
}
func DoCompGenOld(ctx context.Context, sp utilfn.StrWithPos, compCtx CompContext) (*CompReturn, *utilfn.StrWithPos, error) {
compPoint := ParseCompPoint(sp)
compType := CGTypeFile
if compPoint.CompWord == 0 {
compType = CGTypeCommandMeta
}
// TODO lookup special types
compPrefix, info := compPoint.getCompPrefix()
if info.HasVar || info.HasGlob || info.HasExtGlob || info.HasHistory || info.HasSpecial {
return nil, nil, nil
}
crtn, err := DoSimpleComp(ctx, compType, compPrefix, compCtx, nil)
if err != nil {
return nil, nil, err
}
if compCtx.ForDisplay {
return crtn, nil, nil
}
rtnSP := compPoint.FullyExtend(crtn)
return crtn, &rtnSP, nil
}
func SortCompReturnEntries(c *CompReturn) {
sort.Slice(c.Entries, func(i int, j int) bool {
e1 := c.Entries[i]
e2 := c.Entries[j]
if e1.Word < e2.Word {
return true
}
if e1.Word == e2.Word && e1.IsMetaCmd && !e2.IsMetaCmd {
return true
}
return false
})
}
func CombineCompReturn(compType string, c1 *CompReturn, c2 *CompReturn) *CompReturn {
if c1 == nil {
return c2
}
if c2 == nil {
return c1
}
var rtn CompReturn
rtn.CompType = compType
rtn.HasMore = c1.HasMore || c2.HasMore
rtn.Entries = append([]CompEntry{}, c1.Entries...)
rtn.Entries = append(rtn.Entries, c2.Entries...)
SortCompReturnEntries(&rtn)
return &rtn
}
func (c *CompReturn) GetCompStrs() []string {
rtn := make([]string, len(c.Entries))
for idx, entry := range c.Entries {
rtn[idx] = entry.Word
}
return rtn
}
func (c *CompReturn) GetCompDisplayStrs() []string {
rtn := make([]string, len(c.Entries))
for idx, entry := range c.Entries {
if entry.IsMetaCmd {
rtn[idx] = "^" + entry.Word
} else {
rtn[idx] = entry.Word
}
}
return rtn
}
func (p CompPoint) getOrigPos() int {
pword := p.Words[p.CompWord]
return len(p.Prefix) + pword.Offset + len(pword.Prefix) + p.CompWordPos
}
func (p CompPoint) getOrigStr() string {
return p.Prefix + p.StmtStr + p.Suffix
}

View File

@ -0,0 +1,106 @@
package comp
import (
"fmt"
"strings"
"testing"
)
func parseToSP(s string) StrWithPos {
idx := strings.Index(s, "[*]")
if idx == -1 {
return StrWithPos{Str: s}
}
return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: idx}
}
func testParse(cmdStr string, pos int) {
fmt.Printf("cmd: %s\n", strWithCursor(cmdStr, pos))
p := ParseCompPoint(StrWithPos{Str: cmdStr, Pos: pos})
p.dump()
}
func _Test1(t *testing.T) {
testParse("ls ", 3)
testParse("ls ", 4)
testParse("ls -l foo", 4)
testParse("ls foo; cd h", 12)
testParse("ls foo; cd h;", 13)
testParse("ls & foo; cd h", 12)
testParse("ls \"he", 6)
testParse("ls;", 3)
testParse("ls;", 2)
testParse("ls; cd x; ls", 8)
testParse("cd \"foo ", 8)
testParse("ls; { ls f", 10)
testParse("ls; { ls -l; ls f", 17)
testParse("ls $(ls f", 9)
}
func testMiniExtend(t *testing.T, p *CompPoint, newWord string, complete bool, expectedStr string) {
newSP := p.extendWord(newWord, complete)
expectedSP := parseToSP(expectedStr)
if newSP != expectedSP {
t.Fatalf("not equal: [%s] != [%s]", newSP, expectedSP)
} else {
fmt.Printf("extend: %s\n", newSP)
}
}
func Test2(t *testing.T) {
p := ParseCompPoint(parseToSP("ls f[*]"))
testMiniExtend(t, p, "foo", false, "foo[*]")
testMiniExtend(t, p, "foo", true, "foo [*]")
testMiniExtend(t, p, "foo bar", true, "'foo bar' [*]")
testMiniExtend(t, p, "foo'bar", true, `$'foo\'bar' [*]`)
p = ParseCompPoint(parseToSP("ls f[*]more"))
testMiniExtend(t, p, "foo", false, "foo[*]more")
testMiniExtend(t, p, "foo bar", false, `'foo bar[*]more`)
testMiniExtend(t, p, "foo bar", true, `'foo bar[*]more`)
testMiniExtend(t, p, "foo's", true, `$'foo\'s[*]more`)
}
func testParseRT(t *testing.T, origSP StrWithPos) {
p := ParseCompPoint(origSP)
newSP := StrWithPos{Str: p.getOrigStr(), Pos: p.getOrigPos()}
if origSP != newSP {
t.Fatalf("not equal: [%s] != [%s]", origSP, newSP)
}
}
func Test3(t *testing.T) {
testParseRT(t, parseToSP("ls f[*]"))
testParseRT(t, parseToSP("ls f[*]; more $FOO"))
testParseRT(t, parseToSP("hello; ls [*]f"))
testParseRT(t, parseToSP("ls -l; ./foo he[*]ll more; touch foo &"))
}
func testExtend(t *testing.T, origStr string, compStrs []string, expectedStr string) {
origSP := parseToSP(origStr)
expectedSP := parseToSP(expectedStr)
p := ParseCompPoint(origSP)
crtn := compsToCompReturn(compStrs, false)
newSP := p.FullyExtend(crtn)
if newSP != expectedSP {
t.Fatalf("comp-fail: %s + %v => [%s] expected[%s]", origSP, compStrs, newSP, expectedSP)
} else {
fmt.Printf("comp: %s + %v => [%s]\n", origSP, compStrs, newSP)
}
}
func Test4(t *testing.T) {
testExtend(t, "ls f[*]", []string{"foo"}, "ls foo [*]")
testExtend(t, "ls f[*]", []string{"foox", "fooy"}, "ls foo[*]")
testExtend(t, "w; ls f[*]; touch x", []string{"foo"}, "w; ls foo [*]; touch x")
testExtend(t, "w; ls f[*] more; touch x", []string{"foo"}, "w; ls foo [*] more; touch x")
testExtend(t, "w; ls f[*]oo; touch x", []string{"foo"}, "w; ls foo[*]oo; touch x")
testExtend(t, `ls "f[*]`, []string{"foo"}, `ls "foo" [*]`)
testExtend(t, `ls 'f[*]`, []string{"foo"}, `ls 'foo' [*]`)
testExtend(t, `ls $'f[*]`, []string{"foo"}, `ls $'foo' [*]`)
testExtend(t, `ls f[*]`, []string{"foo/"}, `ls foo/[*]`)
testExtend(t, `ls f[*]`, []string{"foo bar"}, `ls 'foo bar' [*]`)
testExtend(t, `ls f[*]`, []string{"f\x01\x02"}, `ls $'f\x01\x02' [*]`)
testExtend(t, `ls "foo [*]`, []string{"foo bar"}, `ls "foo bar" [*]`)
testExtend(t, `ls f[*]`, []string{"foo's"}, `ls $'foo\'s' [*]`)
}

View File

@ -0,0 +1,100 @@
package comp
import (
"context"
"fmt"
"sync"
"github.com/google/uuid"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/prompt-server/pkg/remote"
"github.com/commandlinedev/prompt-server/pkg/utilfn"
)
var globalLock = &sync.Mutex{}
var simpleCompMap = map[string]SimpleCompGenFnType{
CGTypeCommand: simpleCompCommand,
CGTypeFile: simpleCompFile,
CGTypeDir: simpleCompDir,
CGTypeVariable: simpleCompVar,
}
type SimpleCompGenFnType = func(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error)
func RegisterSimpleCompFn(compType string, fn SimpleCompGenFnType) {
globalLock.Lock()
defer globalLock.Unlock()
if _, ok := simpleCompMap[compType]; ok {
panic(fmt.Sprintf("simpleCompFn %q already registered", compType))
}
simpleCompMap[compType] = fn
}
func getSimpleCompFn(compType string) SimpleCompGenFnType {
globalLock.Lock()
defer globalLock.Unlock()
return simpleCompMap[compType]
}
func DoSimpleComp(ctx context.Context, compType string, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
compFn := getSimpleCompFn(compType)
if compFn == nil {
return nil, fmt.Errorf("no simple comp fn for %q", compType)
}
crtn, err := compFn(ctx, prefix, compCtx, args)
if err != nil {
return nil, err
}
crtn.CompType = compType
return crtn, nil
}
func compsToCompReturn(comps []string, hasMore bool) *CompReturn {
var rtn CompReturn
rtn.HasMore = hasMore
for _, comp := range comps {
rtn.Entries = append(rtn.Entries, CompEntry{Word: comp})
}
return &rtn
}
func doCompGen(ctx context.Context, prefix string, compType string, compCtx CompContext) (*CompReturn, error) {
if !packet.IsValidCompGenType(compType) {
return nil, fmt.Errorf("/_compgen invalid type '%s'", compType)
}
msh := remote.GetRemoteById(compCtx.RemotePtr.RemoteId)
if msh == nil {
return nil, fmt.Errorf("invalid remote '%s', not found", compCtx.RemotePtr)
}
cgPacket := packet.MakeCompGenPacket()
cgPacket.ReqId = uuid.New().String()
cgPacket.CompType = compType
cgPacket.Prefix = prefix
cgPacket.Cwd = compCtx.Cwd
resp, err := msh.PacketRpc(ctx, cgPacket)
if err != nil {
return nil, err
}
if err = resp.Err(); err != nil {
return nil, err
}
comps := utilfn.GetStrArr(resp.Data, "comps")
hasMore := utilfn.GetBool(resp.Data, "hasmore")
return compsToCompReturn(comps, hasMore), nil
}
func simpleCompFile(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
return doCompGen(ctx, prefix, CGTypeFile, compCtx)
}
func simpleCompDir(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
return doCompGen(ctx, prefix, CGTypeDir, compCtx)
}
func simpleCompVar(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
return doCompGen(ctx, prefix, CGTypeVariable, compCtx)
}
func simpleCompCommand(ctx context.Context, prefix string, compCtx CompContext, args []interface{}) (*CompReturn, error) {
return doCompGen(ctx, prefix, CGTypeCommand, compCtx)
}

View File

@ -0,0 +1,212 @@
package dbutil
import (
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"strconv"
)
func QuickSetStr(strVal *string, m map[string]interface{}, name string) {
v, ok := m[name]
if !ok {
return
}
ival, ok := v.(int64)
if ok {
*strVal = strconv.FormatInt(ival, 10)
return
}
str, ok := v.(string)
if !ok {
return
}
*strVal = str
}
func QuickSetInt(ival *int, m map[string]interface{}, name string) {
v, ok := m[name]
if !ok {
return
}
sqlInt, ok := v.(int)
if ok {
*ival = sqlInt
return
}
sqlInt64, ok := v.(int64)
if ok {
*ival = int(sqlInt64)
return
}
}
func QuickSetInt64(ival *int64, m map[string]interface{}, name string) {
v, ok := m[name]
if !ok {
return
}
sqlInt64, ok := v.(int64)
if ok {
*ival = sqlInt64
return
}
sqlInt, ok := v.(int)
if ok {
*ival = int64(sqlInt)
return
}
}
func QuickSetBool(bval *bool, m map[string]interface{}, name string) {
v, ok := m[name]
if !ok {
return
}
sqlInt, ok := v.(int64)
if ok {
if sqlInt > 0 {
*bval = true
}
return
}
sqlBool, ok := v.(bool)
if ok {
*bval = sqlBool
}
}
func QuickSetBytes(bval *[]byte, m map[string]interface{}, name string) {
v, ok := m[name]
if !ok {
return
}
sqlBytes, ok := v.([]byte)
if ok {
*bval = sqlBytes
}
}
func getByteArr(m map[string]any, name string, def string) ([]byte, bool) {
v, ok := m[name]
if !ok {
return nil, false
}
barr, ok := v.([]byte)
if !ok {
str, ok := v.(string)
if !ok {
return nil, false
}
barr = []byte(str)
}
if len(barr) == 0 {
barr = []byte(def)
}
return barr, true
}
func QuickSetJson(ptr interface{}, m map[string]interface{}, name string) {
barr, ok := getByteArr(m, name, "{}")
if !ok {
return
}
json.Unmarshal(barr, ptr)
}
func QuickSetNullableJson(ptr interface{}, m map[string]interface{}, name string) {
barr, ok := getByteArr(m, name, "null")
if !ok {
return
}
json.Unmarshal(barr, ptr)
}
func QuickSetJsonArr(ptr interface{}, m map[string]interface{}, name string) {
barr, ok := getByteArr(m, name, "[]")
if !ok {
return
}
json.Unmarshal(barr, ptr)
}
func CheckNil(v interface{}) bool {
rv := reflect.ValueOf(v)
if !rv.IsValid() {
return true
}
switch rv.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
func QuickNullableJson(v interface{}) string {
if CheckNil(v) {
return "null"
}
barr, _ := json.Marshal(v)
return string(barr)
}
func QuickJson(v interface{}) string {
if CheckNil(v) {
return "{}"
}
barr, _ := json.Marshal(v)
return string(barr)
}
func QuickJsonBytes(v interface{}) []byte {
if CheckNil(v) {
return []byte("{}")
}
barr, _ := json.Marshal(v)
return barr
}
func QuickJsonArr(v interface{}) string {
if CheckNil(v) {
return "[]"
}
barr, _ := json.Marshal(v)
return string(barr)
}
func QuickJsonArrBytes(v interface{}) []byte {
if CheckNil(v) {
return []byte("[]")
}
barr, _ := json.Marshal(v)
return barr
}
func QuickScanJson(ptr interface{}, val interface{}) error {
barrVal, ok := val.([]byte)
if !ok {
strVal, ok := val.(string)
if !ok {
return fmt.Errorf("cannot scan '%T' into '%T'", val, ptr)
}
barrVal = []byte(strVal)
}
if len(barrVal) == 0 {
barrVal = []byte("{}")
}
return json.Unmarshal(barrVal, ptr)
}
func QuickValueJson(v interface{}) (driver.Value, error) {
if CheckNil(v) {
return "{}", nil
}
barr, err := json.Marshal(v)
if err != nil {
return nil, err
}
return string(barr), nil
}

234
wavesrv/pkg/dbutil/map.go Normal file
View File

@ -0,0 +1,234 @@
package dbutil
import (
"fmt"
"reflect"
"strings"
"github.com/sawka/txwrap"
)
type DBMappable interface {
UseDBMap()
}
type MapEntry[T any] struct {
Key string
Val T
}
type MapConverter interface {
ToMap() map[string]interface{}
FromMap(map[string]interface{}) bool
}
type HasSimpleKey interface {
GetSimpleKey() string
}
type HasSimpleInt64Key interface {
GetSimpleKey() int64
}
type MapConverterPtr[T any] interface {
MapConverter
*T
}
type DBMappablePtr[T any] interface {
DBMappable
*T
}
func FromMap[PT MapConverterPtr[T], T any](m map[string]any) PT {
if len(m) == 0 {
return nil
}
rtn := PT(new(T))
ok := rtn.FromMap(m)
if !ok {
return nil
}
return rtn
}
func GetMapGen[PT MapConverterPtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) PT {
m := tx.GetMap(query, args...)
return FromMap[PT](m)
}
func GetMappable[PT DBMappablePtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) PT {
m := tx.GetMap(query, args...)
if len(m) == 0 {
return nil
}
rtn := PT(new(T))
FromDBMap(rtn, m)
return rtn
}
func SelectMappable[PT DBMappablePtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) []PT {
var rtn []PT
marr := tx.SelectMaps(query, args...)
for _, m := range marr {
if len(m) == 0 {
continue
}
val := PT(new(T))
FromDBMap(val, m)
rtn = append(rtn, val)
}
return rtn
}
func SelectMapsGen[PT MapConverterPtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) []PT {
var rtn []PT
marr := tx.SelectMaps(query, args...)
for _, m := range marr {
val := FromMap[PT](m)
if val != nil {
rtn = append(rtn, val)
}
}
return rtn
}
func SelectSimpleMap[T any](tx *txwrap.TxWrap, query string, args ...interface{}) map[string]T {
var rtn []MapEntry[T]
tx.Select(&rtn, query, args...)
if len(rtn) == 0 {
return nil
}
rtnMap := make(map[string]T)
for _, entry := range rtn {
rtnMap[entry.Key] = entry.Val
}
return rtnMap
}
func MakeGenMap[T HasSimpleKey](arr []T) map[string]T {
rtn := make(map[string]T)
for _, val := range arr {
rtn[val.GetSimpleKey()] = val
}
return rtn
}
func MakeGenMapInt64[T HasSimpleInt64Key](arr []T) map[int64]T {
rtn := make(map[int64]T)
for _, val := range arr {
rtn[val.GetSimpleKey()] = val
}
return rtn
}
func isStructType(rt reflect.Type) bool {
if rt.Kind() == reflect.Struct {
return true
}
if rt.Kind() == reflect.Pointer && rt.Elem().Kind() == reflect.Struct {
return true
}
return false
}
func isByteArrayType(t reflect.Type) bool {
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
}
func isStringMapType(t reflect.Type) bool {
return t.Kind() == reflect.Map && t.Key().Kind() == reflect.String
}
func ToDBMap(v DBMappable, useBytes bool) map[string]interface{} {
if CheckNil(v) {
return nil
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructToDBMap", v))
}
rt := rv.Type()
m := make(map[string]interface{})
numFields := rt.NumField()
for i := 0; i < numFields; i++ {
field := rt.Field(i)
fieldVal := rv.FieldByIndex(field.Index)
dbName := field.Tag.Get("dbmap")
if dbName == "" {
dbName = strings.ToLower(field.Name)
}
if dbName == "-" {
continue
}
if isByteArrayType(field.Type) {
m[dbName] = fieldVal.Interface()
} else if field.Type.Kind() == reflect.Slice {
if useBytes {
m[dbName] = QuickJsonArrBytes(fieldVal.Interface())
} else {
m[dbName] = QuickJsonArr(fieldVal.Interface())
}
} else if isStructType(field.Type) || isStringMapType(field.Type) {
if useBytes {
m[dbName] = QuickJsonBytes(fieldVal.Interface())
} else {
m[dbName] = QuickJson(fieldVal.Interface())
}
} else {
m[dbName] = fieldVal.Interface()
}
}
return m
}
func FromDBMap(v DBMappable, m map[string]interface{}) {
if CheckNil(v) {
panic("StructFromDBMap, v cannot be nil")
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructFromDBMap", v))
}
rt := rv.Type()
numFields := rt.NumField()
for i := 0; i < numFields; i++ {
field := rt.Field(i)
fieldVal := rv.FieldByIndex(field.Index)
dbName := field.Tag.Get("dbmap")
if dbName == "" {
dbName = strings.ToLower(field.Name)
}
if dbName == "-" {
continue
}
if isByteArrayType(field.Type) {
barrVal := fieldVal.Addr().Interface()
QuickSetBytes(barrVal.(*[]byte), m, dbName)
} else if field.Type.Kind() == reflect.Slice {
QuickSetJsonArr(fieldVal.Addr().Interface(), m, dbName)
} else if isStructType(field.Type) || isStringMapType(field.Type) {
QuickSetJson(fieldVal.Addr().Interface(), m, dbName)
} else if field.Type.Kind() == reflect.String {
strVal := fieldVal.Addr().Interface()
QuickSetStr(strVal.(*string), m, dbName)
} else if field.Type.Kind() == reflect.Int64 {
intVal := fieldVal.Addr().Interface()
QuickSetInt64(intVal.(*int64), m, dbName)
} else if field.Type.Kind() == reflect.Int {
intVal := fieldVal.Addr().Interface()
QuickSetInt(intVal.(*int), m, dbName)
} else if field.Type.Kind() == reflect.Bool {
boolVal := fieldVal.Addr().Interface()
QuickSetBool(boolVal.(*bool), m, dbName)
} else {
panic(fmt.Sprintf("StructFromDBMap invalid field type %v in %T", fieldVal.Type(), v))
}
}
}

View File

@ -0,0 +1,112 @@
// Utility functions for generating and reading public/private keypairs.
package keygen
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"math/big"
"os"
"time"
)
const p384Params = "BgUrgQQAIg=="
// Creates a keypair with CN=[id], private key at keyFileName, and
// public key certificate at certFileName.
func CreateKeyPair(keyFileName string, certFileName string, id string) error {
privateKey, err := CreatePrivateKey(keyFileName)
if err != nil {
return err
}
err = CreateCertificate(certFileName, privateKey, id)
if err != nil {
return err
}
return nil
}
// Creates a private key at keyFileName (ECDSA, secp384r1 (P-384)), PEM format
func CreatePrivateKey() (*ecdsa.PrivateKey, error) {
curve := elliptic.P384() // secp384r1
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return nil, fmt.Errorf("Error generating P-384 key err:%w", err)
}
keyFile, err := os.Create(keyFileName)
if err != nil {
return nil, fmt.Errorf("error opening file:%s err:%w", keyFileName, err)
}
defer keyFile.Close()
pkBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, fmt.Errorf("Error MarshalPKCS8PrivateKey err:%w", err)
}
paramsBytes, err := base64.StdEncoding.DecodeString(p384Params)
if err != nil {
return nil, fmt.Errorf("Error decoding bytes for P-384 EC PARAMETERS err:%w", err)
}
var pemParamsBlock = &pem.Block{
Type: "EC PARAMETERS",
Bytes: paramsBytes,
}
err = pem.Encode(keyFile, pemParamsBlock)
if err != nil {
return nil, fmt.Errorf("Error writing EC PARAMETERS pem block err:%w", err)
}
var pemPrivateBlock = &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: pkBytes,
}
err = pem.Encode(keyFile, pemPrivateBlock)
if err != nil {
return nil, fmt.Errorf("Error writing EC PRIVATE KEY pem block err:%w", err)
}
return privateKey, nil
}
// Creates a public key certificate at certFileName using privateKey with CN=[id].
func CreateCertificate(certFileName string, privateKey *ecdsa.PrivateKey, id string) error {
serialNumber, err := rand.Int(rand.Reader, big.NewInt(1000000000000))
if err != nil {
return fmt.Errorf("Cannot generate serial number err:%w", err)
}
notBefore, err := time.Parse("Jan 2 15:04:05 2006", "Jan 1 00:00:00 2020")
if err != nil {
return fmt.Errorf("Cannot Parse Date err:%w", err)
}
notAfter, err := time.Parse("Jan 2 15:04:05 2006", "Jan 1 00:00:00 2030")
if err != nil {
return fmt.Errorf("Cannot Parse Date err:%w", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: id,
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return fmt.Errorf("Error running x509.CreateCertificate err:%v\n", err)
}
certFile, err := os.Create(certFileName)
if err != nil {
return fmt.Errorf("Error opening file:%s err:%w", certFileName, err)
}
defer certFile.Close()
err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
if err != nil {
return fmt.Errorf("Error writing CERTIFICATE pem block err:%w", err)
}
return nil
}

View File

@ -0,0 +1,99 @@
package mapqueue
import (
"fmt"
"log"
"runtime/debug"
"sync"
)
type MQEntry struct {
Lock *sync.Mutex
Running bool
Queue chan func()
}
type MapQueue struct {
Lock *sync.Mutex
M map[string]*MQEntry
QueueSize int
}
func MakeMapQueue(queueSize int) *MapQueue {
rtn := &MapQueue{
Lock: &sync.Mutex{},
M: make(map[string]*MQEntry),
QueueSize: queueSize,
}
return rtn
}
func (mq *MapQueue) getEntry(id string) *MQEntry {
mq.Lock.Lock()
defer mq.Lock.Unlock()
entry := mq.M[id]
if entry == nil {
entry = &MQEntry{
Lock: &sync.Mutex{},
Running: false,
Queue: make(chan func(), mq.QueueSize),
}
mq.M[id] = entry
}
return entry
}
func (entry *MQEntry) add(fn func()) error {
select {
case entry.Queue <- fn:
break
default:
return fmt.Errorf("input queue full")
}
entry.tryRun()
return nil
}
func runFn(fn func()) {
defer func() {
r := recover()
if r == nil {
return
}
log.Printf("[error] panic in MQEntry runFn: %v\n", r)
debug.PrintStack()
return
}()
fn()
}
func (entry *MQEntry) tryRun() {
entry.Lock.Lock()
defer entry.Lock.Unlock()
if entry.Running {
return
}
if len(entry.Queue) > 0 {
entry.Running = true
go entry.run()
}
}
func (entry *MQEntry) run() {
for fn := range entry.Queue {
runFn(fn)
}
entry.Lock.Lock()
entry.Running = false
entry.Lock.Unlock()
entry.tryRun()
}
func (mq *MapQueue) Enqueue(id string, fn func()) error {
entry := mq.getEntry(id)
err := entry.add(fn)
if err != nil {
return fmt.Errorf("cannot enqueue: %v", err)
}
return nil
}

View File

@ -0,0 +1,628 @@
package pcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/commandlinedev/prompt-server/pkg/dbutil"
"github.com/commandlinedev/prompt-server/pkg/rtnstate"
"github.com/commandlinedev/prompt-server/pkg/scbase"
"github.com/commandlinedev/prompt-server/pkg/sstore"
)
const PCloudEndpoint = "https://api.getprompt.dev/central"
const PCloudEndpointVarName = "PCLOUD_ENDPOINT"
const APIVersion = 1
const MaxPtyUpdateSize = (128 * 1024)
const MaxUpdatesPerReq = 10
const MaxUpdatesToDeDup = 1000
const MaxUpdateWriterErrors = 3
const PCloudDefaultTimeout = 5 * time.Second
const PCloudWebShareUpdateTimeout = 15 * time.Second
// setting to 1M to be safe (max is 6M for API-GW + Lambda, but there is base64 encoding and upload time)
// we allow one extra update past this estimated size
const MaxUpdatePayloadSize = 1 * (1024 * 1024)
const TelemetryUrl = "/telemetry"
const NoTelemetryUrl = "/no-telemetry"
const WebShareUpdateUrl = "/auth/web-share-update"
var updateWriterLock = &sync.Mutex{}
var updateWriterRunning = false
var updateWriterNumFailures = 0
type AuthInfo struct {
UserId string `json:"userid"`
ClientId string `json:"clientid"`
AuthKey string `json:"authkey"`
}
func GetEndpoint() string {
if !scbase.IsDevMode() {
return PCloudEndpoint
}
endpoint := os.Getenv(PCloudEndpointVarName)
if endpoint == "" || !strings.HasPrefix(endpoint, "https://") {
panic("Invalid PCloud dev endpoint, PCLOUD_ENDPOINT not set or invalid")
}
return endpoint
}
func makeAuthPostReq(ctx context.Context, apiUrl string, authInfo AuthInfo, data interface{}) (*http.Request, error) {
var dataReader io.Reader
if data != nil {
byteArr, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf("error marshaling json for %s request: %v", apiUrl, err)
}
dataReader = bytes.NewReader(byteArr)
}
fullUrl := GetEndpoint() + apiUrl
req, err := http.NewRequestWithContext(ctx, "POST", fullUrl, dataReader)
if err != nil {
return nil, fmt.Errorf("error creating %s request: %v", apiUrl, err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-PromptAPIVersion", strconv.Itoa(APIVersion))
req.Header.Set("X-PromptAPIUrl", apiUrl)
req.Header.Set("X-PromptUserId", authInfo.UserId)
req.Header.Set("X-PromptClientId", authInfo.ClientId)
req.Header.Set("X-PromptAuthKey", authInfo.AuthKey)
req.Close = true
return req, nil
}
func makeAnonPostReq(ctx context.Context, apiUrl string, data interface{}) (*http.Request, error) {
var dataReader io.Reader
if data != nil {
byteArr, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf("error marshaling json for %s request: %v", apiUrl, err)
}
dataReader = bytes.NewReader(byteArr)
}
fullUrl := GetEndpoint() + apiUrl
req, err := http.NewRequestWithContext(ctx, "POST", fullUrl, dataReader)
if err != nil {
return nil, fmt.Errorf("error creating %s request: %v", apiUrl, err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-PromptAPIVersion", strconv.Itoa(APIVersion))
req.Header.Set("X-PromptAPIUrl", apiUrl)
req.Close = true
return req, nil
}
func doRequest(req *http.Request, outputObj interface{}) (*http.Response, error) {
apiUrl := req.Header.Get("X-PromptAPIUrl")
log.Printf("[pcloud] sending request %s %v\n", req.Method, req.URL)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("error contacting pcloud %q service: %v", apiUrl, err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return resp, fmt.Errorf("error reading %q response body: %v", apiUrl, err)
}
if resp.StatusCode != http.StatusOK {
return resp, fmt.Errorf("error contacting pcloud %q service: %s", apiUrl, resp.Status)
}
if outputObj != nil && resp.Header.Get("Content-Type") == "application/json" {
err = json.Unmarshal(bodyBytes, outputObj)
if err != nil {
return resp, fmt.Errorf("error decoding json: %v", err)
}
}
return resp, nil
}
func SendTelemetry(ctx context.Context, force bool) error {
clientData, err := sstore.EnsureClientData(ctx)
if err != nil {
return fmt.Errorf("cannot retrieve client data: %v", err)
}
if !force && clientData.ClientOpts.NoTelemetry {
return nil
}
activity, err := sstore.GetNonUploadedActivity(ctx)
if err != nil {
return fmt.Errorf("cannot get activity: %v", err)
}
if len(activity) == 0 {
return nil
}
log.Printf("[pcloud] sending telemetry data\n")
dayStr := sstore.GetCurDayStr()
input := TelemetryInputType{UserId: clientData.UserId, ClientId: clientData.ClientId, CurDay: dayStr, Activity: activity}
req, err := makeAnonPostReq(ctx, TelemetryUrl, input)
if err != nil {
return err
}
_, err = doRequest(req, nil)
if err != nil {
return err
}
err = sstore.MarkActivityAsUploaded(ctx, activity)
if err != nil {
return fmt.Errorf("error marking activity as uploaded: %v", err)
}
return nil
}
func SendNoTelemetryUpdate(ctx context.Context, noTelemetryVal bool) error {
clientData, err := sstore.EnsureClientData(ctx)
if err != nil {
return fmt.Errorf("cannot retrieve client data: %v", err)
}
req, err := makeAnonPostReq(ctx, NoTelemetryUrl, NoTelemetryInputType{ClientId: clientData.ClientId, Value: noTelemetryVal})
if err != nil {
return err
}
_, err = doRequest(req, nil)
if err != nil {
return err
}
return nil
}
func getAuthInfo(ctx context.Context) (AuthInfo, error) {
clientData, err := sstore.EnsureClientData(ctx)
if err != nil {
return AuthInfo{}, fmt.Errorf("cannot retrieve client data: %v", err)
}
return AuthInfo{UserId: clientData.UserId, ClientId: clientData.ClientId}, nil
}
func defaultError(err error, estr string) error {
if err != nil {
return err
}
return errors.New(estr)
}
func MakeScreenNewUpdate(screen *sstore.ScreenType, webShareOpts sstore.ScreenWebShareOpts) *WebShareUpdateType {
rtn := &WebShareUpdateType{
ScreenId: screen.ScreenId,
UpdateId: -1,
UpdateType: sstore.UpdateType_ScreenNew,
UpdateTs: time.Now().UnixMilli(),
}
rtn.Screen = &WebShareScreenType{
ScreenId: screen.ScreenId,
SelectedLine: int(screen.SelectedLine),
ShareName: webShareOpts.ShareName,
ViewKey: webShareOpts.ViewKey,
}
return rtn
}
func MakeScreenDelUpdate(screen *sstore.ScreenType, screenId string) *WebShareUpdateType {
rtn := &WebShareUpdateType{
ScreenId: screenId,
UpdateId: -1,
UpdateType: sstore.UpdateType_ScreenDel,
UpdateTs: time.Now().UnixMilli(),
}
return rtn
}
func makeWebShareUpdate(ctx context.Context, update *sstore.ScreenUpdateType) (*WebShareUpdateType, error) {
rtn := &WebShareUpdateType{
ScreenId: update.ScreenId,
LineId: update.LineId,
UpdateId: update.UpdateId,
UpdateType: update.UpdateType,
UpdateTs: update.UpdateTs,
}
switch update.UpdateType {
case sstore.UpdateType_ScreenNew:
screen, err := sstore.GetScreenById(ctx, update.ScreenId)
if err != nil || screen == nil {
return nil, fmt.Errorf("error getting screen: %v", defaultError(err, "not found"))
}
rtn.Screen, err = webScreenFromScreen(screen)
if err != nil {
return nil, fmt.Errorf("error converting screen to web-screen: %v", err)
}
case sstore.UpdateType_ScreenDel:
break
case sstore.UpdateType_ScreenName, sstore.UpdateType_ScreenSelectedLine:
screen, err := sstore.GetScreenById(ctx, update.ScreenId)
if err != nil {
return nil, fmt.Errorf("error getting screen: %v", err)
}
if screen == nil || screen.WebShareOpts == nil {
return nil, fmt.Errorf("invalid screen, not webshared (makeWebScreenUpdate)")
}
if update.UpdateType == sstore.UpdateType_ScreenName {
rtn.SVal = screen.WebShareOpts.ShareName
} else if update.UpdateType == sstore.UpdateType_ScreenSelectedLine {
rtn.IVal = int64(screen.SelectedLine)
}
case sstore.UpdateType_LineNew:
line, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
if err != nil || line == nil {
return nil, fmt.Errorf("error getting line/cmd: %v", defaultError(err, "not found"))
}
rtn.Line, err = webLineFromLine(line)
if err != nil {
return nil, fmt.Errorf("error converting line to web-line: %v", err)
}
if cmd != nil {
rtn.Cmd, err = webCmdFromCmd(update.LineId, cmd)
if err != nil {
return nil, fmt.Errorf("error converting cmd to web-cmd: %v", err)
}
}
case sstore.UpdateType_LineDel:
break
case sstore.UpdateType_LineRenderer, sstore.UpdateType_LineContentHeight:
line, err := sstore.GetLineById(ctx, update.ScreenId, update.LineId)
if err != nil || line == nil {
return nil, fmt.Errorf("error getting line: %v", defaultError(err, "not found"))
}
if update.UpdateType == sstore.UpdateType_LineRenderer {
rtn.SVal = line.Renderer
} else if update.UpdateType == sstore.UpdateType_LineContentHeight {
rtn.IVal = line.ContentHeight
}
case sstore.UpdateType_CmdStatus:
_, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
if err != nil || cmd == nil {
return nil, fmt.Errorf("error getting cmd: %v", defaultError(err, "not found"))
}
rtn.SVal = cmd.Status
case sstore.UpdateType_CmdTermOpts:
_, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
if err != nil || cmd == nil {
return nil, fmt.Errorf("error getting cmd: %v", defaultError(err, "not found"))
}
rtn.TermOpts = &cmd.TermOpts
case sstore.UpdateType_CmdExitCode, sstore.UpdateType_CmdDurationMs:
_, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
if err != nil || cmd == nil {
return nil, fmt.Errorf("error getting cmd: %v", defaultError(err, "not found"))
}
if update.UpdateType == sstore.UpdateType_CmdExitCode {
rtn.IVal = int64(cmd.ExitCode)
} else if update.UpdateType == sstore.UpdateType_CmdDurationMs {
rtn.IVal = int64(cmd.DurationMs)
}
case sstore.UpdateType_CmdRtnState:
_, cmd, err := sstore.GetLineCmdByLineId(ctx, update.ScreenId, update.LineId)
if err != nil || cmd == nil {
return nil, fmt.Errorf("error getting cmd: %v", defaultError(err, "not found"))
}
data, err := rtnstate.GetRtnStateDiff(ctx, update.ScreenId, cmd.LineId)
if err != nil {
return nil, fmt.Errorf("cannot compute rtnstate: %v", err)
}
rtn.SVal = string(data)
case sstore.UpdateType_PtyPos:
ptyPos, err := sstore.GetWebPtyPos(ctx, update.ScreenId, update.LineId)
if err != nil {
return nil, fmt.Errorf("error getting ptypos: %v", err)
}
realOffset, data, err := sstore.ReadPtyOutFile(ctx, update.ScreenId, update.LineId, ptyPos, MaxPtyUpdateSize+1)
if err != nil {
return nil, fmt.Errorf("error getting ptydata: %v", err)
}
if len(data) == 0 {
return nil, nil
}
if len(data) > MaxPtyUpdateSize {
rtn.PtyData = &WebSharePtyData{PtyPos: realOffset, Data: data[0:MaxPtyUpdateSize], Eof: false}
} else {
rtn.PtyData = &WebSharePtyData{PtyPos: realOffset, Data: data, Eof: true}
}
case sstore.UpdateType_LineState:
// TODO implement!
default:
return nil, fmt.Errorf("unsupported update type (pcloud/makeWebScreenUpdate): %s\n", update.UpdateType)
}
return rtn, nil
}
func finalizeWebScreenUpdate(ctx context.Context, webUpdate *WebShareUpdateType) error {
switch webUpdate.UpdateType {
case sstore.UpdateType_PtyPos:
newPos := webUpdate.PtyData.PtyPos + int64(len(webUpdate.PtyData.Data))
err := sstore.SetWebPtyPos(ctx, webUpdate.ScreenId, webUpdate.LineId, newPos)
if err != nil {
return err
}
case sstore.UpdateType_LineDel:
err := sstore.DeleteWebPtyPos(ctx, webUpdate.ScreenId, webUpdate.LineId)
if err != nil {
return err
}
}
err := sstore.RemoveScreenUpdate(ctx, webUpdate.UpdateId)
if err != nil {
// this is not great, this *should* never fail and is not easy to recover from
return err
}
return nil
}
type webShareResponseType struct {
Success bool `json:"success"`
Data []*WebShareUpdateResponseType `json:"data"`
}
func convertUpdate(update *sstore.ScreenUpdateType) *WebShareUpdateType {
webUpdate, err := makeWebShareUpdate(context.Background(), update)
if err != nil || webUpdate == nil {
if err != nil {
log.Printf("[pcloud] error create web-share update updateid:%d: %v", update.UpdateId, err)
}
// if err, or no web update created, remove the screenupdate
removeErr := sstore.RemoveScreenUpdate(context.Background(), update.UpdateId)
if removeErr != nil {
// ignore this error too (although this is really problematic, there is nothing to do)
log.Printf("[pcloud] error removing screen update updateid:%d: %v", update.UpdateId, removeErr)
}
}
return webUpdate
}
func DoSyncWebUpdate(webUpdate *WebShareUpdateType) error {
authInfo, err := getAuthInfo(context.Background())
if err != nil {
return fmt.Errorf("could not get authinfo for request: %v", err)
}
ctx, cancelFn := context.WithTimeout(context.Background(), PCloudDefaultTimeout)
defer cancelFn()
req, err := makeAuthPostReq(ctx, WebShareUpdateUrl, authInfo, []*WebShareUpdateType{webUpdate})
if err != nil {
return fmt.Errorf("cannot create auth-post-req for %s: %v", WebShareUpdateUrl, err)
}
var resp webShareResponseType
_, err = doRequest(req, &resp)
if err != nil {
return err
}
if len(resp.Data) == 0 {
return fmt.Errorf("invalid response received from server")
}
urt := resp.Data[0]
if urt.Error != "" {
return errors.New(urt.Error)
}
return nil
}
func DoWebUpdates(webUpdates []*WebShareUpdateType) error {
if len(webUpdates) == 0 {
return nil
}
authInfo, err := getAuthInfo(context.Background())
if err != nil {
return fmt.Errorf("could not get authinfo for request: %v", err)
}
ctx, cancelFn := context.WithTimeout(context.Background(), PCloudWebShareUpdateTimeout)
defer cancelFn()
req, err := makeAuthPostReq(ctx, WebShareUpdateUrl, authInfo, webUpdates)
if err != nil {
return fmt.Errorf("cannot create auth-post-req for %s: %v", WebShareUpdateUrl, err)
}
var resp webShareResponseType
_, err = doRequest(req, &resp)
if err != nil {
return err
}
respMap := dbutil.MakeGenMapInt64(resp.Data)
for _, update := range webUpdates {
err = finalizeWebScreenUpdate(context.Background(), update)
if err != nil {
// ignore this error (nothing to do)
log.Printf("[pcloud] error finalizing web-update: %v\n", err)
}
resp := respMap[update.UpdateId]
if resp == nil {
resp = &WebShareUpdateResponseType{Success: false, Error: "resp not found"}
}
if resp.Error != "" {
log.Printf("[pcloud] error updateid:%d, type:%s %s/%s err:%v\n", update.UpdateId, update.UpdateType, update.ScreenId, update.LineId, resp.Error)
}
}
return nil
}
func setUpdateWriterRunning(running bool) {
updateWriterLock.Lock()
defer updateWriterLock.Unlock()
updateWriterRunning = running
}
func GetUpdateWriterRunning() bool {
updateWriterLock.Lock()
defer updateWriterLock.Unlock()
return updateWriterRunning
}
func StartUpdateWriter() {
updateWriterLock.Lock()
defer updateWriterLock.Unlock()
if updateWriterRunning {
return
}
updateWriterRunning = true
go runWebShareUpdateWriter()
}
func computeUpdateWriterBackoff() time.Duration {
updateWriterLock.Lock()
numFailures := updateWriterNumFailures
updateWriterLock.Unlock()
switch numFailures {
case 0:
return 0
case 1:
return 1 * time.Second
case 2:
return 2 * time.Second
case 3:
return 5 * time.Second
case 4:
return time.Minute
case 5:
return 5 * time.Minute
case 6:
return time.Hour
default:
return time.Hour
}
}
func incrementUpdateWriterNumFailures() {
updateWriterLock.Lock()
defer updateWriterLock.Unlock()
updateWriterNumFailures++
}
func ResetUpdateWriterNumFailures() {
updateWriterLock.Lock()
defer updateWriterLock.Unlock()
updateWriterNumFailures = 0
}
func GetUpdateWriterNumFailures() int {
updateWriterLock.Lock()
defer updateWriterLock.Unlock()
return updateWriterNumFailures
}
type updateKey struct {
ScreenId string
LineId string
UpdateType string
}
func DeDupUpdates(ctx context.Context, updateArr []*sstore.ScreenUpdateType) ([]*sstore.ScreenUpdateType, error) {
var rtn []*sstore.ScreenUpdateType
var idsToDelete []int64
umap := make(map[updateKey]bool)
for _, update := range updateArr {
key := updateKey{ScreenId: update.ScreenId, LineId: update.LineId, UpdateType: update.UpdateType}
if umap[key] {
idsToDelete = append(idsToDelete, update.UpdateId)
continue
}
umap[key] = true
rtn = append(rtn, update)
}
if len(idsToDelete) > 0 {
err := sstore.RemoveScreenUpdates(ctx, idsToDelete)
if err != nil {
return nil, fmt.Errorf("error trying to delete screenupdates: %v\n", err)
}
}
return rtn, nil
}
func runWebShareUpdateWriter() {
defer func() {
setUpdateWriterRunning(false)
}()
log.Printf("[pcloud] starting update writer\n")
numErrors := 0
for {
if numErrors > MaxUpdateWriterErrors {
log.Printf("[pcloud] update-writer, too many errors, exiting\n")
break
}
time.Sleep(100 * time.Millisecond)
fullUpdateArr, err := sstore.GetScreenUpdates(context.Background(), MaxUpdatesToDeDup)
if err != nil {
log.Printf("[pcloud] error retrieving updates: %v", err)
time.Sleep(1 * time.Second)
numErrors++
continue
}
updateArr, err := DeDupUpdates(context.Background(), fullUpdateArr)
if err != nil {
log.Printf("[pcloud] error deduping screenupdates: %v", err)
time.Sleep(1 * time.Second)
numErrors++
continue
}
numErrors = 0
var webUpdateArr []*WebShareUpdateType
totalSize := 0
for _, update := range updateArr {
webUpdate := convertUpdate(update)
if webUpdate == nil {
continue
}
webUpdateArr = append(webUpdateArr, webUpdate)
totalSize += webUpdate.GetEstimatedSize()
if totalSize > MaxUpdatePayloadSize {
break
}
}
if len(webUpdateArr) == 0 {
sstore.UpdateWriterCheckMoreData()
continue
}
err = DoWebUpdates(webUpdateArr)
if err != nil {
incrementUpdateWriterNumFailures()
backoffTime := computeUpdateWriterBackoff()
log.Printf("[pcloud] error processing %d web-updates (backoff=%v): %v\n", len(webUpdateArr), backoffTime, err)
updateBackoffSleep(backoffTime)
continue
}
log.Printf("[pcloud] sent %d web-updates\n", len(webUpdateArr))
var debugStrs []string
for _, webUpdate := range webUpdateArr {
debugStrs = append(debugStrs, webUpdate.String())
}
log.Printf("[pcloud] updates: %s\n", strings.Join(debugStrs, " "))
ResetUpdateWriterNumFailures()
}
}
// todo fix this, set deadline, check with condition variable, backoff then just needs to notify
func updateBackoffSleep(backoffTime time.Duration) {
var totalSleep time.Duration
for {
sleepTime := time.Second
totalSleep += sleepTime
time.Sleep(sleepTime)
if totalSleep >= backoffTime {
break
}
numFailures := GetUpdateWriterNumFailures()
if numFailures == 0 {
break
}
}
}

View File

@ -0,0 +1,196 @@
package pcloud
import (
"context"
"encoding/json"
"fmt"
"github.com/commandlinedev/prompt-server/pkg/remote"
"github.com/commandlinedev/prompt-server/pkg/rtnstate"
"github.com/commandlinedev/prompt-server/pkg/sstore"
)
type NoTelemetryInputType struct {
ClientId string `json:"clientid"`
Value bool `json:"value"`
}
type TelemetryInputType struct {
UserId string `json:"userid"`
ClientId string `json:"clientid"`
CurDay string `json:"curday"`
Activity []*sstore.ActivityType `json:"activity"`
}
type WebShareUpdateType struct {
ScreenId string `json:"screenid"`
LineId string `json:"lineid"`
UpdateId int64 `json:"updateid"`
UpdateType string `json:"updatetype"`
UpdateTs int64 `json:"updatets"`
Screen *WebShareScreenType `json:"screen,omitempty"`
Line *WebShareLineType `json:"line,omitempty"`
Cmd *WebShareCmdType `json:"cmd,omitempty"`
PtyData *WebSharePtyData `json:"ptydata,omitempty"`
SVal string `json:"sval,omitempty"`
IVal int64 `json:"ival,omitempty"`
BVal bool `json:"bval,omitempty"`
TermOpts *sstore.TermOpts `json:"termopts,omitempty"`
}
const EstimatedSizePadding = 100
func (update *WebShareUpdateType) GetEstimatedSize() int {
barr, _ := json.Marshal(update)
return len(barr) + 100
}
func (update *WebShareUpdateType) String() string {
var idStr string
if update.LineId != "" && update.ScreenId != "" {
idStr = fmt.Sprintf("%s:%s", update.ScreenId[0:8], update.LineId[0:8])
} else if update.ScreenId != "" {
idStr = update.ScreenId[0:8]
}
if update.UpdateType == sstore.UpdateType_PtyPos && update.PtyData != nil {
return fmt.Sprintf("ptydata[%s][%d:%d]", idStr, update.PtyData.PtyPos, len(update.PtyData.Data))
}
return fmt.Sprintf("%s[%s]", update.UpdateType, idStr)
}
type WebShareUpdateResponseType struct {
UpdateId int64 `json:"updateid"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
func (ur *WebShareUpdateResponseType) GetSimpleKey() int64 {
return ur.UpdateId
}
type WebShareRemote struct {
RemoteId string `json:"remoteid"`
Alias string `json:"alias,omitempty"`
CanonicalName string `json:"canonicalname"`
Name string `json:"name,omitempty"`
HomeDir string `json:"homedir,omitempty"`
IsRoot bool `json:"isroot,omitempty"`
}
type WebShareScreenType struct {
ScreenId string `json:"screenid"`
ShareName string `json:"sharename"`
ViewKey string `json:"viewkey"`
SelectedLine int `json:"selectedline"`
}
func webRemoteFromRemote(rptr sstore.RemotePtrType, r *sstore.RemoteType) *WebShareRemote {
return &WebShareRemote{
RemoteId: r.RemoteId,
Alias: r.RemoteAlias,
CanonicalName: r.RemoteCanonicalName,
Name: rptr.Name,
HomeDir: r.StateVars["home"],
IsRoot: r.StateVars["remoteuser"] == "root",
}
}
func webScreenFromScreen(s *sstore.ScreenType) (*WebShareScreenType, error) {
if s == nil || s.ScreenId == "" {
return nil, fmt.Errorf("invalid nil screen")
}
if s.WebShareOpts == nil {
return nil, fmt.Errorf("invalid screen, no WebShareOpts")
}
if s.WebShareOpts.ViewKey == "" {
return nil, fmt.Errorf("invalid screen, no ViewKey")
}
var shareName string
if s.WebShareOpts.ShareName != "" {
shareName = s.WebShareOpts.ShareName
} else {
shareName = s.Name
}
return &WebShareScreenType{ScreenId: s.ScreenId, ShareName: shareName, ViewKey: s.WebShareOpts.ViewKey, SelectedLine: int(s.SelectedLine)}, nil
}
type WebShareLineType struct {
LineId string `json:"lineid"`
Ts int64 `json:"ts"`
LineNum int64 `json:"linenum"`
LineType string `json:"linetype"`
ContentHeight int64 `json:"contentheight"`
Renderer string `json:"renderer,omitempty"`
Text string `json:"text,omitempty"`
}
func webLineFromLine(line *sstore.LineType) (*WebShareLineType, error) {
rtn := &WebShareLineType{
LineId: line.LineId,
Ts: line.Ts,
LineNum: line.LineNum,
LineType: line.LineType,
ContentHeight: line.ContentHeight,
Renderer: line.Renderer,
Text: line.Text,
}
return rtn, nil
}
type WebShareCmdType struct {
LineId string `json:"lineid"`
CmdStr string `json:"cmdstr"`
RawCmdStr string `json:"rawcmdstr"`
Remote *WebShareRemote `json:"remote"`
FeState sstore.FeStateType `json:"festate"`
TermOpts sstore.TermOpts `json:"termopts"`
Status string `json:"status"`
CmdPid int `json:"cmdpid"`
RemotePid int `json:"remotepid"`
DoneTs int64 `json:"donets,omitempty"`
ExitCode int `json:"exitcode,omitempty"`
DurationMs int `json:"durationms,omitempty"`
RtnState bool `json:"rtnstate,omitempty"`
RtnStateStr string `json:"rtnstatestr,omitempty"`
}
func webCmdFromCmd(lineId string, cmd *sstore.CmdType) (*WebShareCmdType, error) {
if cmd.Remote.RemoteId == "" {
return nil, fmt.Errorf("invalid cmd, remoteptr has no remoteid")
}
remote := remote.GetRemoteCopyById(cmd.Remote.RemoteId)
if remote == nil {
return nil, fmt.Errorf("invalid cmd, cannot retrieve remote:%s", cmd.Remote.RemoteId)
}
webRemote := webRemoteFromRemote(cmd.Remote, remote)
rtn := &WebShareCmdType{
LineId: lineId,
CmdStr: cmd.CmdStr,
RawCmdStr: cmd.RawCmdStr,
Remote: webRemote,
FeState: cmd.FeState,
TermOpts: cmd.TermOpts,
Status: cmd.Status,
CmdPid: cmd.CmdPid,
RemotePid: cmd.RemotePid,
DoneTs: cmd.DoneTs,
ExitCode: cmd.ExitCode,
DurationMs: cmd.DurationMs,
RtnState: cmd.RtnState,
}
if cmd.RtnState {
barr, err := rtnstate.GetRtnStateDiff(context.Background(), cmd.ScreenId, cmd.LineId)
if err != nil {
return nil, fmt.Errorf("error creating rtnstate diff for cmd:%s: %v", cmd.LineId, err)
}
rtn.RtnStateStr = string(barr)
}
return rtn, nil
}
type WebSharePtyData struct {
PtyPos int64 `json:"ptypos"`
Data []byte `json:"data"`
Eof bool `json:"-"` // internal use
}

View File

@ -0,0 +1,199 @@
package promptenc
import (
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"reflect"
ccp "golang.org/x/crypto/chacha20poly1305"
)
const EncTagName = "enc"
const EncFieldIndicator = "*"
type Encryptor struct {
Key []byte
AEAD cipher.AEAD
}
type HasOData interface {
GetOData() string
}
func readRandBytes(n int) ([]byte, error) {
rtn := make([]byte, n)
_, err := io.ReadFull(rand.Reader, rtn)
return rtn, err
}
func MakeRandomEncryptor() (*Encryptor, error) {
key, err := readRandBytes(ccp.KeySize)
if err != nil {
return nil, err
}
rtn := &Encryptor{Key: key}
rtn.AEAD, err = ccp.NewX(rtn.Key)
if err != nil {
return nil, err
}
return rtn, nil
}
func MakeEncryptor(key []byte) (*Encryptor, error) {
var err error
rtn := &Encryptor{Key: key}
rtn.AEAD, err = ccp.NewX(rtn.Key)
if err != nil {
return nil, err
}
return rtn, nil
}
func MakeEncryptorB64(key64 string) (*Encryptor, error) {
keyBytes, err := base64.RawURLEncoding.DecodeString(key64)
if err != nil {
return nil, err
}
return MakeEncryptor(keyBytes)
}
func (enc *Encryptor) EncryptData(plainText []byte, odata string) ([]byte, error) {
outputBuf := make([]byte, enc.AEAD.NonceSize()+enc.AEAD.Overhead()+len(plainText))
nonce := outputBuf[0:enc.AEAD.NonceSize()]
_, err := io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, err
}
// we're going to append the cipherText to nonce. so the encrypted data is [nonce][ciphertext]
// note that outputbuf should be the correct size to hold the rtn value
rtn := enc.AEAD.Seal(nonce, nonce, plainText, []byte(odata))
return rtn, nil
}
func (enc *Encryptor) DecryptData(encData []byte, odata string) (map[string]interface{}, error) {
minLen := enc.AEAD.NonceSize() + enc.AEAD.Overhead()
if len(encData) < minLen {
return nil, fmt.Errorf("invalid encdata, len:%d is less than minimum len:%d", len(encData), minLen)
}
m := make(map[string]interface{})
nonce := encData[0:enc.AEAD.NonceSize()]
cipherText := encData[enc.AEAD.NonceSize():]
plainText, err := enc.AEAD.Open(nil, nonce, cipherText, []byte(odata))
if err != nil {
return nil, err
}
err = json.Unmarshal(plainText, &m)
if err != nil {
return nil, err
}
return m, nil
}
type EncryptMeta struct {
EncField *reflect.StructField
PlainFields map[string]reflect.StructField
}
func isByteArrayType(t reflect.Type) bool {
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
}
func metaFromType(v interface{}) (*EncryptMeta, error) {
if v == nil {
return nil, fmt.Errorf("Encryptor cannot encrypt nil")
}
rt := reflect.TypeOf(v)
if rt.Kind() != reflect.Pointer {
return nil, fmt.Errorf("Encryptor invalid type %T, not a pointer type", v)
}
rtElem := rt.Elem()
if rtElem.Kind() != reflect.Struct {
return nil, fmt.Errorf("Encryptor invalid type %T, not a pointer to struct type", v)
}
meta := &EncryptMeta{}
meta.PlainFields = make(map[string]reflect.StructField)
numFields := rtElem.NumField()
for i := 0; i < numFields; i++ {
field := rtElem.Field(i)
encTag := field.Tag.Get(EncTagName)
if encTag == "" {
continue
}
if encTag == EncFieldIndicator {
if meta.EncField != nil {
return nil, fmt.Errorf("Encryptor, type %T has two enc fields set (*)", v)
}
if !isByteArrayType(field.Type) {
return nil, fmt.Errorf("Encryptor, type %T enc field %q is not []byte", v, field.Name)
}
meta.EncField = &field
continue
}
if _, found := meta.PlainFields[encTag]; found {
return nil, fmt.Errorf("Encryptor, type %T has two enc fields with tag %q", v, encTag)
}
meta.PlainFields[encTag] = field
}
if meta.EncField == nil {
return nil, fmt.Errorf("Encryptor, type %T has no enc (*) field", v)
}
return meta, nil
}
func (enc *Encryptor) EncryptODS(v HasOData) error {
odata := v.GetOData()
return enc.EncryptStructFields(v, odata)
}
func (enc *Encryptor) DecryptODS(v HasOData) error {
odata := v.GetOData()
return enc.DecryptStructFields(v, odata)
}
func (enc *Encryptor) EncryptStructFields(v interface{}, odata string) error {
encMeta, err := metaFromType(v)
if err != nil {
return err
}
rvPtr := reflect.ValueOf(v)
rv := rvPtr.Elem()
m := make(map[string]interface{})
for jsonKey, field := range encMeta.PlainFields {
fieldVal := rv.FieldByIndex(field.Index)
m[jsonKey] = fieldVal.Interface()
}
barr, err := json.Marshal(m)
if err != nil {
return err
}
cipherText, err := enc.EncryptData(barr, odata)
if err != nil {
return err
}
encFieldValue := rv.FieldByIndex(encMeta.EncField.Index)
encFieldValue.SetBytes(cipherText)
return nil
}
func (enc *Encryptor) DecryptStructFields(v interface{}, odata string) error {
encMeta, err := metaFromType(v)
if err != nil {
return err
}
rvPtr := reflect.ValueOf(v)
rv := rvPtr.Elem()
cipherText := rv.FieldByIndex(encMeta.EncField.Index).Bytes()
m, err := enc.DecryptData(cipherText, odata)
if err != nil {
return err
}
for jsonKey, field := range encMeta.PlainFields {
val := m[jsonKey]
rv.FieldByIndex(field.Index).Set(reflect.ValueOf(val))
}
return nil
}

View File

@ -0,0 +1,53 @@
package remote
import (
"fmt"
"sync"
)
type CircleLog struct {
Lock *sync.Mutex
StartPos int
Log []string
MaxSize int
}
func MakeCircleLog(maxSize int) *CircleLog {
if maxSize <= 0 {
panic("invalid maxsize, must be >= 0")
}
rtn := &CircleLog{
Lock: &sync.Mutex{},
StartPos: 0,
Log: make([]string, 0, maxSize),
MaxSize: maxSize,
}
return rtn
}
func (l *CircleLog) Add(s string) {
l.Lock.Lock()
defer l.Lock.Unlock()
if len(l.Log) < l.MaxSize {
l.Log = append(l.Log, s)
return
}
l.Log[l.StartPos] = s
l.StartPos = (l.StartPos + 1) % l.MaxSize
}
func (l *CircleLog) Addf(sfmt string, args ...interface{}) {
// no lock here, since l.Add() is synchronized
s := fmt.Sprintf(sfmt, args...)
l.Add(s)
}
func (l *CircleLog) GetEntries() []string {
l.Lock.Lock()
defer l.Lock.Unlock()
rtn := make([]string, len(l.Log))
for i := 0; i < len(l.Log); i++ {
rtn[i] = l.Log[(l.StartPos+i)%l.MaxSize]
}
return rtn
}

View File

@ -0,0 +1,147 @@
package openai
import (
"context"
"fmt"
"io"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/prompt-server/pkg/sstore"
)
// https://github.com/tiktoken-go/tokenizer
const DefaultMaxTokens = 1000
const DefaultModel = "gpt-3.5-turbo"
const DefaultStreamChanSize = 10
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
}
return &packet.OpenAIUsageType{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
func convertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
rtn = append(rtn, msg)
}
return rtn
}
func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) {
if opts == nil {
return nil, fmt.Errorf("no openai opts found")
}
if opts.Model == "" {
return nil, fmt.Errorf("no openai model specified")
}
if opts.APIToken == "" {
return nil, fmt.Errorf("no api token")
}
client := openaiapi.NewClient(opts.APIToken)
req := openaiapi.ChatCompletionRequest{
Model: opts.Model,
Messages: convertPrompt(prompt),
MaxTokens: opts.MaxTokens,
}
if opts.MaxChoices > 1 {
req.N = opts.MaxChoices
}
apiResp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
return nil, fmt.Errorf("error calling openai API: %v", err)
}
if len(apiResp.Choices) == 0 {
return nil, fmt.Errorf("no response received")
}
return marshalResponse(apiResp), nil
}
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
if opts == nil {
return nil, fmt.Errorf("no openai opts found")
}
if opts.Model == "" {
return nil, fmt.Errorf("no openai model specified")
}
if opts.APIToken == "" {
return nil, fmt.Errorf("no api token")
}
client := openaiapi.NewClient(opts.APIToken)
req := openaiapi.ChatCompletionRequest{
Model: opts.Model,
Messages: convertPrompt(prompt),
MaxTokens: opts.MaxTokens,
Stream: true,
}
if opts.MaxChoices > 1 {
req.N = opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
return nil, fmt.Errorf("error calling openai API: %v", err)
}
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
go func() {
sentHeader := false
defer close(rtn)
for {
streamResp, err := apiResp.Recv()
if err == io.EOF {
break
}
if err != nil {
errPk := CreateErrorPacket(fmt.Sprintf("error in recv of streaming data: %v", err))
rtn <- errPk
break
}
if streamResp.Model != "" && !sentHeader {
pk := packet.MakeOpenAIPacket()
pk.Model = streamResp.Model
pk.Created = streamResp.Created
rtn <- pk
sentHeader = true
}
for _, choice := range streamResp.Choices {
pk := packet.MakeOpenAIPacket()
pk.Index = choice.Index
pk.Text = choice.Delta.Content
pk.FinishReason = choice.FinishReason
rtn <- pk
}
}
}()
return rtn, err
}
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*packet.OpenAIPacketType {
var rtn []*packet.OpenAIPacketType
headerPk := packet.MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
headerPk.Usage = convertUsage(resp)
rtn = append(rtn, headerPk)
for _, choice := range resp.Choices {
choicePk := packet.MakeOpenAIPacket()
choicePk.Index = choice.Index
choicePk.Text = choice.Message.Content
choicePk.FinishReason = choice.FinishReason
rtn = append(rtn, choicePk)
}
return rtn
}
func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
errPk := packet.MakeOpenAIPacket()
errPk.FinishReason = "error"
errPk.Error = errStr
return errPk
}

2104
wavesrv/pkg/remote/remote.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,67 @@
package remote
import (
"github.com/commandlinedev/apishell/pkg/base"
)
func startCmdWait(ck base.CommandKey) {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
GlobalStore.CmdWaitMap[ck] = nil
}
func pushCmdWaitIfRequired(ck base.CommandKey, fn func()) bool {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
fns, ok := GlobalStore.CmdWaitMap[ck]
if !ok {
return false
}
fns = append(fns, fn)
GlobalStore.CmdWaitMap[ck] = fns
return true
}
func runCmdUpdateFn(ck base.CommandKey, fn func()) {
pushed := pushCmdWaitIfRequired(ck, fn)
if pushed {
return
}
fn()
}
func runCmdWaitFns(ck base.CommandKey) {
for {
fn := removeFirstCmdWaitFn(ck)
if fn == nil {
break
}
fn()
}
}
func removeFirstCmdWaitFn(ck base.CommandKey) func() {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
fns := GlobalStore.CmdWaitMap[ck]
if len(fns) == 0 {
delete(GlobalStore.CmdWaitMap, ck)
return nil
}
fn := fns[0]
GlobalStore.CmdWaitMap[ck] = fns[1:]
return fn
}
func removeCmdWait(ck base.CommandKey) {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
fns := GlobalStore.CmdWaitMap[ck]
if len(fns) == 0 {
delete(GlobalStore.CmdWaitMap, ck)
return
}
go runCmdWaitFns(ck)
}

View File

@ -0,0 +1,196 @@
package rtnstate
import (
"bytes"
"context"
"fmt"
"strings"
"github.com/alessio/shellescape"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/apishell/pkg/shexec"
"github.com/commandlinedev/apishell/pkg/simpleexpand"
"github.com/commandlinedev/prompt-server/pkg/sstore"
"github.com/commandlinedev/prompt-server/pkg/utilfn"
"mvdan.cc/sh/v3/syntax"
)
func parseAliasStmt(stmt *syntax.Stmt, sourceStr string) (string, string, error) {
cmd := stmt.Cmd
callExpr, ok := cmd.(*syntax.CallExpr)
if !ok {
return "", "", fmt.Errorf("wrong cmd type for alias")
}
if len(callExpr.Args) != 2 {
return "", "", fmt.Errorf("wrong number of words in alias expr wordslen=%d", len(callExpr.Args))
}
firstWord := callExpr.Args[0]
if firstWord.Lit() != "alias" {
return "", "", fmt.Errorf("invalid alias cmd word (not 'alias')")
}
secondWord := callExpr.Args[1]
var ectx simpleexpand.SimpleExpandContext // no homedir, do not want ~ expansion
val, _ := simpleexpand.SimpleExpandWord(ectx, secondWord, sourceStr)
eqIdx := strings.Index(val, "=")
if eqIdx == -1 {
return "", "", fmt.Errorf("no '=' in alias definition")
}
return val[0:eqIdx], val[eqIdx+1:], nil
}
func ParseAliases(aliases string) (map[string]string, error) {
r := strings.NewReader(aliases)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(r, "aliases")
if err != nil {
return nil, err
}
rtn := make(map[string]string)
for _, stmt := range file.Stmts {
aliasName, aliasVal, err := parseAliasStmt(stmt, aliases)
if err != nil {
// fmt.Printf("stmt-err: %v\n", err)
continue
}
if aliasName != "" {
rtn[aliasName] = aliasVal
}
}
return rtn, nil
}
func parseFuncStmt(stmt *syntax.Stmt, source string) (string, string, error) {
cmd := stmt.Cmd
funcDecl, ok := cmd.(*syntax.FuncDecl)
if !ok {
return "", "", fmt.Errorf("cmd not FuncDecl")
}
name := funcDecl.Name.Value
// fmt.Printf("func: [%s]\n", name)
funcBody := funcDecl.Body
// fmt.Printf(" %d:%d\n", funcBody.Cmd.Pos().Offset(), funcBody.Cmd.End().Offset())
bodyStr := source[funcBody.Cmd.Pos().Offset():funcBody.Cmd.End().Offset()]
// fmt.Printf("<<<\n%s\n>>>\n", bodyStr)
// fmt.Printf("\n")
return name, bodyStr, nil
}
func ParseFuncs(funcs string) (map[string]string, error) {
r := strings.NewReader(funcs)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(r, "funcs")
if err != nil {
return nil, err
}
rtn := make(map[string]string)
for _, stmt := range file.Stmts {
funcName, funcVal, err := parseFuncStmt(stmt, funcs)
if err != nil {
// TODO where to put parse errors
continue
}
if strings.HasPrefix(funcName, "_mshell_") {
continue
}
if funcName != "" {
rtn[funcName] = funcVal
}
}
return rtn, nil
}
const MaxDiffKeyLen = 40
const MaxDiffValLen = 50
var IgnoreVars = map[string]bool{"PROMPT": true, "PROMPT_VERSION": true, "MSHELL": true}
func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newState packet.ShellState) {
if newState.Cwd != oldState.Cwd {
buf.WriteString(fmt.Sprintf("cwd %s\n", newState.Cwd))
}
if !bytes.Equal(newState.ShellVars, oldState.ShellVars) {
newEnvMap := shexec.DeclMapFromState(&newState)
oldEnvMap := shexec.DeclMapFromState(&oldState)
for key, newVal := range newEnvMap {
if IgnoreVars[key] {
continue
}
oldVal, found := oldEnvMap[key]
if !found || !shexec.DeclsEqual(false, oldVal, newVal) {
var exportStr string
if newVal.IsExport() {
exportStr = "export "
}
buf.WriteString(fmt.Sprintf("%s%s=%s\n", exportStr, utilfn.EllipsisStr(key, MaxDiffKeyLen), utilfn.EllipsisStr(newVal.Value, MaxDiffValLen)))
}
}
for key, _ := range oldEnvMap {
if IgnoreVars[key] {
continue
}
_, found := newEnvMap[key]
if !found {
buf.WriteString(fmt.Sprintf("unset %s\n", utilfn.EllipsisStr(key, MaxDiffKeyLen)))
}
}
}
if newState.Aliases != oldState.Aliases {
newAliasMap, _ := ParseAliases(newState.Aliases)
oldAliasMap, _ := ParseAliases(oldState.Aliases)
for aliasName, newAliasVal := range newAliasMap {
oldAliasVal, found := oldAliasMap[aliasName]
if !found || newAliasVal != oldAliasVal {
buf.WriteString(fmt.Sprintf("alias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
}
}
for aliasName, _ := range oldAliasMap {
_, found := newAliasMap[aliasName]
if !found {
buf.WriteString(fmt.Sprintf("unalias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
}
}
}
if newState.Funcs != oldState.Funcs {
newFuncMap, _ := ParseFuncs(newState.Funcs)
oldFuncMap, _ := ParseFuncs(oldState.Funcs)
for funcName, newFuncVal := range newFuncMap {
oldFuncVal, found := oldFuncMap[funcName]
if !found || newFuncVal != oldFuncVal {
buf.WriteString(fmt.Sprintf("function %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
}
}
for funcName, _ := range oldFuncMap {
_, found := newFuncMap[funcName]
if !found {
buf.WriteString(fmt.Sprintf("unset -f %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
}
}
}
}
func GetRtnStateDiff(ctx context.Context, screenId string, lineId string) ([]byte, error) {
cmd, err := sstore.GetCmdByScreenId(ctx, screenId, lineId)
if err != nil {
return nil, err
}
if cmd == nil {
return nil, nil
}
if !cmd.RtnState {
return nil, nil
}
if cmd.RtnStatePtr.IsEmpty() {
return nil, nil
}
var outputBytes bytes.Buffer
initialState, err := sstore.GetFullState(ctx, cmd.StatePtr)
if err != nil {
return nil, fmt.Errorf("getting initial full state: %v", err)
}
rtnState, err := sstore.GetFullState(ctx, cmd.RtnStatePtr)
if err != nil {
return nil, fmt.Errorf("getting rtn full state: %v", err)
}
displayStateUpdateDiff(&outputBytes, *initialState, *rtnState)
return outputBytes.Bytes(), nil
}

View File

@ -0,0 +1,407 @@
package scbase
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"log"
"os"
"os/exec"
"os/user"
"path"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/commandlinedev/apishell/pkg/base"
"github.com/google/uuid"
"golang.org/x/mod/semver"
"golang.org/x/sys/unix"
)
const HomeVarName = "HOME"
const PromptHomeVarName = "PROMPT_HOME"
const PromptDevVarName = "PROMPT_DEV"
const SessionsDirBaseName = "sessions"
const ScreensDirBaseName = "screens"
const PromptLockFile = "prompt.lock"
const PromptDirName = "prompt"
const PromptDevDirName = "prompt-dev"
const PromptAppPathVarName = "PROMPT_APP_PATH"
const PromptVersion = "v0.4.0"
const PromptAuthKeyFileName = "prompt.authkey"
const MShellVersion = "v0.3.0"
const DefaultMacOSShell = "/bin/bash"
var SessionDirCache = make(map[string]string)
var ScreenDirCache = make(map[string]string)
var BaseLock = &sync.Mutex{}
var BuildTime = "-"
func IsDevMode() bool {
pdev := os.Getenv(PromptDevVarName)
return pdev != ""
}
// must match js
func GetPromptHomeDir() string {
scHome := os.Getenv(PromptHomeVarName)
if scHome == "" {
homeVar := os.Getenv(HomeVarName)
if homeVar == "" {
homeVar = "/"
}
pdev := os.Getenv(PromptDevVarName)
if pdev != "" {
scHome = path.Join(homeVar, PromptDevDirName)
} else {
scHome = path.Join(homeVar, PromptDirName)
}
}
return scHome
}
func MShellBinaryDir() string {
appPath := os.Getenv(PromptAppPathVarName)
if appPath == "" {
appPath = "."
}
if IsDevMode() {
return path.Join(appPath, "dev-bin")
}
return path.Join(appPath, "bin", "mshell")
}
func MShellBinaryPath(version string, goos string, goarch string) (string, error) {
if !base.ValidGoArch(goos, goarch) {
return "", fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch)
}
binaryDir := MShellBinaryDir()
versionStr := semver.MajorMinor(version)
if versionStr == "" {
return "", fmt.Errorf("invalid mshell version: %q", version)
}
fileName := fmt.Sprintf("mshell-%s-%s.%s", versionStr, goos, goarch)
fullFileName := path.Join(binaryDir, fileName)
return fullFileName, nil
}
func LocalMShellBinaryPath() (string, error) {
return MShellBinaryPath(MShellVersion, runtime.GOOS, runtime.GOARCH)
}
func MShellBinaryReader(version string, goos string, goarch string) (io.ReadCloser, error) {
mshellPath, err := MShellBinaryPath(version, goos, goarch)
if err != nil {
return nil, err
}
fd, err := os.Open(mshellPath)
if err != nil {
return nil, fmt.Errorf("cannot open mshell binary %q: %v", mshellPath, err)
}
return fd, nil
}
func createPromptAuthKeyFile(fileName string) (string, error) {
fd, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return "", err
}
defer fd.Close()
keyStr := GenPromptUUID()
_, err = fd.Write([]byte(keyStr))
if err != nil {
return "", err
}
return keyStr, nil
}
func ReadPromptAuthKey() (string, error) {
homeDir := GetPromptHomeDir()
err := ensureDir(homeDir)
if err != nil {
return "", fmt.Errorf("cannot find/create PROMPT_HOME directory %q", homeDir)
}
fileName := path.Join(homeDir, PromptAuthKeyFileName)
fd, err := os.Open(fileName)
if err != nil && errors.Is(err, fs.ErrNotExist) {
return createPromptAuthKeyFile(fileName)
}
if err != nil {
return "", fmt.Errorf("error opening prompt authkey:%s: %v", fileName, err)
}
defer fd.Close()
buf, err := io.ReadAll(fd)
if err != nil {
return "", fmt.Errorf("error reading prompt authkey:%s: %v", fileName, err)
}
keyStr := string(buf)
_, err = uuid.Parse(keyStr)
if err != nil {
return "", fmt.Errorf("invalid authkey:%s format: %v", fileName, err)
}
return keyStr, nil
}
func AcquirePromptLock() (*os.File, error) {
homeDir := GetPromptHomeDir()
err := ensureDir(homeDir)
if err != nil {
return nil, fmt.Errorf("cannot find/create PROMPT_HOME directory %q", homeDir)
}
lockFileName := path.Join(homeDir, PromptLockFile)
fd, err := os.Create(lockFileName)
if err != nil {
return nil, err
}
err = unix.Flock(int(fd.Fd()), unix.LOCK_EX|unix.LOCK_NB)
if err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
// deprecated (v0.1.8)
func EnsureSessionDir(sessionId string) (string, error) {
if sessionId == "" {
return "", fmt.Errorf("cannot get session dir for blank sessionid")
}
BaseLock.Lock()
sdir, ok := SessionDirCache[sessionId]
BaseLock.Unlock()
if ok {
return sdir, nil
}
scHome := GetPromptHomeDir()
sdir = path.Join(scHome, SessionsDirBaseName, sessionId)
err := ensureDir(sdir)
if err != nil {
return "", err
}
BaseLock.Lock()
SessionDirCache[sessionId] = sdir
BaseLock.Unlock()
return sdir, nil
}
// deprecated (v0.1.8)
func GetSessionsDir() string {
promptHome := GetPromptHomeDir()
sdir := path.Join(promptHome, SessionsDirBaseName)
return sdir
}
func EnsureScreenDir(screenId string) (string, error) {
if screenId == "" {
return "", fmt.Errorf("cannot get screen dir for blank sessionid")
}
BaseLock.Lock()
sdir, ok := ScreenDirCache[screenId]
BaseLock.Unlock()
if ok {
return sdir, nil
}
scHome := GetPromptHomeDir()
sdir = path.Join(scHome, ScreensDirBaseName, screenId)
err := ensureDir(sdir)
if err != nil {
return "", err
}
BaseLock.Lock()
ScreenDirCache[screenId] = sdir
BaseLock.Unlock()
return sdir, nil
}
func GetScreensDir() string {
promptHome := GetPromptHomeDir()
sdir := path.Join(promptHome, ScreensDirBaseName)
return sdir
}
func ensureDir(dirName string) error {
info, err := os.Stat(dirName)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(dirName, 0700)
if err != nil {
return err
}
log.Printf("[prompt] created directory %q\n", dirName)
info, err = os.Stat(dirName)
}
if err != nil {
return err
}
if !info.IsDir() {
return fmt.Errorf("'%s' must be a directory", dirName)
}
return nil
}
// deprecated (v0.1.8)
func PtyOutFile_Sessions(sessionId string, cmdId string) (string, error) {
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return "", err
}
if sessionId == "" {
return "", fmt.Errorf("cannot get ptyout file for blank sessionid")
}
if cmdId == "" {
return "", fmt.Errorf("cannot get ptyout file for blank cmdid")
}
return fmt.Sprintf("%s/%s.ptyout.cf", sdir, cmdId), nil
}
func PtyOutFile(screenId string, lineId string) (string, error) {
sdir, err := EnsureScreenDir(screenId)
if err != nil {
return "", err
}
if screenId == "" {
return "", fmt.Errorf("cannot get ptyout file for blank screenid")
}
if lineId == "" {
return "", fmt.Errorf("cannot get ptyout file for blank lineid")
}
return fmt.Sprintf("%s/%s.ptyout.cf", sdir, lineId), nil
}
func GenPromptUUID() string {
for {
rtn := uuid.New().String()
_, err := strconv.Atoi(rtn[0:8])
if err == nil { // do not allow UUIDs where the initial 8 bytes parse to an integer
continue
}
return rtn
}
}
func NumFormatDec(num int64) string {
var signStr string
absNum := num
if absNum < 0 {
absNum = -absNum
signStr = "-"
}
if absNum < 1000 {
// raw num
return signStr + strconv.FormatInt(absNum, 10)
}
if absNum < 1000000 {
// k num
kVal := float64(absNum) / 1000
return signStr + strconv.FormatFloat(kVal, 'f', 2, 64) + "k"
}
if absNum < 1000000000 {
// M num
mVal := float64(absNum) / 1000000
return signStr + strconv.FormatFloat(mVal, 'f', 2, 64) + "m"
} else {
// G num
gVal := float64(absNum) / 1000000000
return signStr + strconv.FormatFloat(gVal, 'f', 2, 64) + "g"
}
}
func NumFormatB2(num int64) string {
var signStr string
absNum := num
if absNum < 0 {
absNum = -absNum
signStr = "-"
}
if absNum < 1024 {
// raw num
return signStr + strconv.FormatInt(absNum, 10)
}
if absNum < 1000000 {
// k num
if absNum%1024 == 0 {
return signStr + strconv.FormatInt(absNum/1024, 10) + "K"
}
kVal := float64(absNum) / 1024
return signStr + strconv.FormatFloat(kVal, 'f', 2, 64) + "K"
}
if absNum < 1000000000 {
// M num
if absNum%(1024*1024) == 0 {
return signStr + strconv.FormatInt(absNum/(1024*1024), 10) + "M"
}
mVal := float64(absNum) / (1024 * 1024)
return signStr + strconv.FormatFloat(mVal, 'f', 2, 64) + "M"
} else {
// G num
if absNum%(1024*1024*1024) == 0 {
return signStr + strconv.FormatInt(absNum/(1024*1024*1024), 10) + "G"
}
gVal := float64(absNum) / (1024 * 1024 * 1024)
return signStr + strconv.FormatFloat(gVal, 'f', 2, 64) + "G"
}
}
func ClientArch() string {
return fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
}
var releaseRegex = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
var osReleaseOnce = &sync.Once{}
var osRelease string
func macOSRelease() string {
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
out, err := exec.CommandContext(ctx, "uname", "-r").CombinedOutput()
if err != nil {
log.Printf("error executing uname -r: %v\n", err)
return "-"
}
releaseStr := strings.TrimSpace(string(out))
if !releaseRegex.MatchString(releaseStr) {
log.Printf("invalid uname -r output: [%s]\n", releaseStr)
return "-"
}
return releaseStr
}
func MacOSRelease() string {
osReleaseOnce.Do(func() {
osRelease = macOSRelease()
})
return osRelease
}
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
// dscl . -read /User/[username] UserShell
// defaults to /bin/bash
func MacUserShell() string {
osUser, err := user.Current()
if err != nil {
log.Printf("error getting current user: %v\n", err)
return DefaultMacOSShell
}
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
userStr := "/Users/" + osUser.Name
out, err := exec.CommandContext(ctx, "dscl", ".", "-read", userStr, "UserShell").CombinedOutput()
if err != nil {
log.Printf("error executing macos user shell lookup: %v %q\n", err, string(out))
return DefaultMacOSShell
}
outStr := strings.TrimSpace(string(out))
m := userShellRegexp.FindStringSubmatch(outStr)
if m == nil {
log.Printf("error in format of dscl output: %q\n", outStr)
return DefaultMacOSShell
}
return m[1]
}

View File

@ -0,0 +1,120 @@
package scpacket
import (
"fmt"
"reflect"
"strings"
"github.com/alessio/shellescape"
"github.com/commandlinedev/apishell/pkg/base"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/prompt-server/pkg/sstore"
)
const FeCommandPacketStr = "fecmd"
const WatchScreenPacketStr = "watchscreen"
const FeInputPacketStr = "feinput"
const RemoteInputPacketStr = "remoteinput"
type FeCommandPacketType struct {
Type string `json:"type"`
MetaCmd string `json:"metacmd"`
MetaSubCmd string `json:"metasubcmd,omitempty"`
Args []string `json:"args,omitempty"`
Kwargs map[string]string `json:"kwargs,omitempty"`
RawStr string `json:"rawstr,omitempty"`
UIContext *UIContextType `json:"uicontext,omitempty"`
Interactive bool `json:"interactive"`
}
func (pk *FeCommandPacketType) GetRawStr() string {
if pk.RawStr != "" {
return pk.RawStr
}
cmd := "/" + pk.MetaCmd
if pk.MetaSubCmd != "" {
cmd = cmd + ":" + pk.MetaSubCmd
}
var args []string
for k, v := range pk.Kwargs {
argStr := fmt.Sprintf("%s=%s", shellescape.Quote(k), shellescape.Quote(v))
args = append(args, argStr)
}
for _, arg := range pk.Args {
args = append(args, shellescape.Quote(arg))
}
if len(args) == 0 {
return cmd
}
return cmd + " " + strings.Join(args, " ")
}
type UIContextType struct {
SessionId string `json:"sessionid"`
ScreenId string `json:"screenid"`
Remote *sstore.RemotePtrType `json:"remote,omitempty"`
WinSize *packet.WinSize `json:"winsize,omitempty"`
Build string `json:"build,omitempty"`
}
type FeInputPacketType struct {
Type string `json:"type"`
CK base.CommandKey `json:"ck"`
Remote sstore.RemotePtrType `json:"remote"`
InputData64 string `json:"inputdata64"`
SigName string `json:"signame,omitempty"`
WinSize *packet.WinSize `json:"winsize,omitempty"`
}
type RemoteInputPacketType struct {
Type string `json:"type"`
RemoteId string `json:"remoteid"`
InputData64 string `json:"inputdata64"`
}
type WatchScreenPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid"`
ScreenId string `json:"screenid"`
Connect bool `json:"connect"`
AuthKey string `json:"authkey"`
}
func init() {
packet.RegisterPacketType(FeCommandPacketStr, reflect.TypeOf(FeCommandPacketType{}))
packet.RegisterPacketType(WatchScreenPacketStr, reflect.TypeOf(WatchScreenPacketType{}))
packet.RegisterPacketType(FeInputPacketStr, reflect.TypeOf(FeInputPacketType{}))
packet.RegisterPacketType(RemoteInputPacketStr, reflect.TypeOf(RemoteInputPacketType{}))
}
func (*FeCommandPacketType) GetType() string {
return FeCommandPacketStr
}
func MakeFeCommandPacket() *FeCommandPacketType {
return &FeCommandPacketType{Type: FeCommandPacketStr}
}
func (*FeInputPacketType) GetType() string {
return FeInputPacketStr
}
func MakeFeInputPacket() *FeInputPacketType {
return &FeInputPacketType{Type: FeInputPacketStr}
}
func (*WatchScreenPacketType) GetType() string {
return WatchScreenPacketStr
}
func MakeWatchScreenPacket() *WatchScreenPacketType {
return &WatchScreenPacketType{Type: WatchScreenPacketStr}
}
func MakeRemoteInputPacket() *RemoteInputPacketType {
return &RemoteInputPacketType{Type: RemoteInputPacketStr}
}
func (*RemoteInputPacketType) GetType() string {
return RemoteInputPacketStr
}

305
wavesrv/pkg/scws/scws.go Normal file
View File

@ -0,0 +1,305 @@
package scws
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/google/uuid"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/prompt-server/pkg/mapqueue"
"github.com/commandlinedev/prompt-server/pkg/remote"
"github.com/commandlinedev/prompt-server/pkg/scpacket"
"github.com/commandlinedev/prompt-server/pkg/sstore"
"github.com/commandlinedev/prompt-server/pkg/wsshell"
)
const WSStatePacketChSize = 20
const MaxInputDataSize = 1000
const RemoteInputQueueSize = 100
var RemoteInputMapQueue *mapqueue.MapQueue
func init() {
RemoteInputMapQueue = mapqueue.MakeMapQueue(RemoteInputQueueSize)
}
type WSState struct {
Lock *sync.Mutex
ClientId string
ConnectTime time.Time
Shell *wsshell.WSShell
UpdateCh chan interface{}
UpdateQueue []interface{}
Authenticated bool
AuthKey string
SessionId string
ScreenId string
}
func MakeWSState(clientId string, authKey string) *WSState {
rtn := &WSState{}
rtn.Lock = &sync.Mutex{}
rtn.ClientId = clientId
rtn.ConnectTime = time.Now()
rtn.AuthKey = authKey
return rtn
}
func (ws *WSState) SetAuthenticated(authVal bool) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.Authenticated = authVal
}
func (ws *WSState) IsAuthenticated() bool {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Authenticated
}
func (ws *WSState) GetShell() *wsshell.WSShell {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Shell
}
func (ws *WSState) WriteUpdate(update interface{}) error {
shell := ws.GetShell()
if shell == nil {
return fmt.Errorf("cannot write update, empty shell")
}
err := shell.WriteJson(update)
if err != nil {
return err
}
return nil
}
func (ws *WSState) UpdateConnectTime() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.ConnectTime = time.Now()
}
func (ws *WSState) GetConnectTime() time.Time {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.ConnectTime
}
func (ws *WSState) WatchScreen(sessionId string, screenId string) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
if ws.SessionId == sessionId && ws.ScreenId == screenId {
return
}
ws.SessionId = sessionId
ws.ScreenId = screenId
ws.UpdateCh = sstore.MainBus.RegisterChannel(ws.ClientId, ws.ScreenId)
go ws.RunUpdates(ws.UpdateCh)
}
func (ws *WSState) UnWatchScreen() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
sstore.MainBus.UnregisterChannel(ws.ClientId)
ws.SessionId = ""
ws.ScreenId = ""
log.Printf("[ws] unwatch screen clientid=%s\n", ws.ClientId)
}
func (ws *WSState) getUpdateCh() chan interface{} {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.UpdateCh
}
func (ws *WSState) RunUpdates(updateCh chan interface{}) {
if updateCh == nil {
panic("invalid nil updateCh passed to RunUpdates")
}
for update := range updateCh {
shell := ws.GetShell()
if shell != nil {
shell.WriteJson(update)
}
}
}
func (ws *WSState) ReplaceShell(shell *wsshell.WSShell) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
if ws.Shell == nil {
ws.Shell = shell
return
}
ws.Shell.Conn.Close()
ws.Shell = shell
return
}
func (ws *WSState) handleConnection() error {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
update, err := sstore.GetAllSessions(ctx)
if err != nil {
return fmt.Errorf("getting sessions: %w", err)
}
remotes := remote.GetAllRemoteRuntimeState()
ifarr := make([]interface{}, len(remotes))
for idx, r := range remotes {
ifarr[idx] = r
}
update.Remotes = ifarr
update.Connect = true
err = ws.Shell.WriteJson(update)
if err != nil {
return err
}
return nil
}
func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error {
if wsPk.SessionId != "" {
if _, err := uuid.Parse(wsPk.SessionId); err != nil {
return fmt.Errorf("invalid watchscreen sessionid: %w", err)
}
}
if wsPk.ScreenId != "" {
if _, err := uuid.Parse(wsPk.ScreenId); err != nil {
return fmt.Errorf("invalid watchscreen screenid: %w", err)
}
}
if wsPk.AuthKey == "" {
ws.SetAuthenticated(false)
return fmt.Errorf("invalid watchscreen, no authkey")
}
if wsPk.AuthKey != ws.AuthKey {
ws.SetAuthenticated(false)
return fmt.Errorf("invalid watchscreen, invalid authkey")
}
ws.SetAuthenticated(true)
if wsPk.SessionId == "" || wsPk.ScreenId == "" {
ws.UnWatchScreen()
} else {
ws.WatchScreen(wsPk.SessionId, wsPk.ScreenId)
log.Printf("[ws %s] watchscreen %s/%s\n", ws.ClientId, wsPk.SessionId, wsPk.ScreenId)
}
if wsPk.Connect {
// log.Printf("[ws %s] watchscreen connect\n", ws.ClientId)
err := ws.handleConnection()
if err != nil {
return fmt.Errorf("connect: %w", err)
}
}
return nil
}
func (ws *WSState) RunWSRead() {
shell := ws.GetShell()
if shell == nil {
return
}
shell.WriteJson(map[string]interface{}{"type": "hello"}) // let client know we accepted this connection, ignore error
for msgBytes := range shell.ReadChan {
pk, err := packet.ParseJsonPacket(msgBytes)
if err != nil {
log.Printf("error unmarshalling ws message: %v\n", err)
continue
}
if pk.GetType() == scpacket.WatchScreenPacketStr {
wsPk := pk.(*scpacket.WatchScreenPacketType)
err := ws.handleWatchScreen(wsPk)
if err != nil {
// TODO send errors back to client, likely unrecoverable
log.Printf("[ws %s] error %v\n", ws.ClientId, err)
}
continue
}
isAuth := ws.IsAuthenticated()
if !isAuth {
log.Printf("[error] cannot process ws-packet[%s], not authenticated\n", pk.GetType())
continue
}
if pk.GetType() == scpacket.FeInputPacketStr {
feInputPk := pk.(*scpacket.FeInputPacketType)
if feInputPk.Remote.OwnerId != "" {
log.Printf("[error] cannot send input to remote with ownerid\n")
continue
}
if feInputPk.Remote.RemoteId == "" {
log.Printf("[error] invalid input packet, remoteid is not set\n")
continue
}
err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() {
err = sendCmdInput(feInputPk)
if err != nil {
log.Printf("[error] sending command input: %v\n", err)
}
})
if err != nil {
log.Printf("[error] could not queue sendCmdInput: %v\n", err)
continue
}
continue
}
if pk.GetType() == scpacket.RemoteInputPacketStr {
inputPk := pk.(*scpacket.RemoteInputPacketType)
if inputPk.RemoteId == "" {
log.Printf("[error] invalid remoteinput packet, remoteid is not set\n")
continue
}
go func() {
err = remote.SendRemoteInput(inputPk)
if err != nil {
log.Printf("[error] processing remote input: %v\n", err)
}
}()
continue
}
log.Printf("got ws bad message: %v\n", pk.GetType())
}
}
func sendCmdInput(pk *scpacket.FeInputPacketType) error {
err := pk.CK.Validate("input packet")
if err != nil {
return err
}
if pk.Remote.RemoteId == "" {
return fmt.Errorf("input must set remoteid")
}
msh := remote.GetRemoteById(pk.Remote.RemoteId)
if msh == nil {
return fmt.Errorf("remote %s not found", pk.Remote.RemoteId)
}
if len(pk.InputData64) > 0 {
inputLen := packet.B64DecodedLen(pk.InputData64)
if inputLen > MaxInputDataSize {
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
}
dataPk := packet.MakeDataPacket()
dataPk.CK = pk.CK
dataPk.FdNum = 0 // stdin
dataPk.Data64 = pk.InputData64
err = msh.SendInput(dataPk)
if err != nil {
return err
}
}
if pk.SigName != "" || pk.WinSize != nil {
siPk := packet.MakeSpecialInputPacket()
siPk.CK = pk.CK
siPk.SigName = pk.SigName
siPk.WinSize = pk.WinSize
err = msh.SendSpecialInput(siPk)
if err != nil {
return err
}
}
return nil
}

288
wavesrv/pkg/shparse/comp.go Normal file
View File

@ -0,0 +1,288 @@
package shparse
import (
"strings"
"github.com/commandlinedev/prompt-server/pkg/utilfn"
)
const (
CompTypeCommandMeta = "command-meta"
CompTypeCommand = "command"
CompTypeArg = "command-arg"
CompTypeInvalid = "invalid"
CompTypeVar = "var"
CompTypeAssignment = "assignment"
CompTypeBasic = "basic"
)
type CompletionPos struct {
RawPos int // the raw position of cursor
SuperOffset int // adjust all offsets in Cmd and CmdWord by SuperOffset
CompType string // see CompType* constants
Cmd *CmdType // nil if between commands or a special completion (otherwise will be a SimpleCommand)
// index into cmd.Words (only set when Cmd is not nil, otherwise we look at CompCommand)
// 0 means command-word
// negative means assignment-words.
// can be past the end of Words (means start new word).
CmdWordPos int
CompWord *WordType // set to the word we are completing (nil if we are starting a new word)
CompWordOffset int // offset into compword (only if CmdWord is not nil)
}
func compTypeFromPos(cmdWordPos int) string {
if cmdWordPos == 0 {
return CompTypeCommand
}
if cmdWordPos < 0 {
return CompTypeAssignment
}
return CompTypeArg
}
func (cmd *CmdType) findCompletionPos_simple(pos int, superOffset int) CompletionPos {
if cmd.Type != CmdTypeSimple {
panic("findCompletetionPos_simple only works for CmdTypeSimple")
}
rtn := CompletionPos{RawPos: pos, SuperOffset: superOffset, Cmd: cmd}
for idx, word := range cmd.AssignmentWords {
startOffset := word.Offset
endOffset := word.Offset + len(word.Raw)
if pos <= startOffset {
// starting a new word at this position (before the current assignment word)
rtn.CmdWordPos = idx - len(cmd.AssignmentWords)
rtn.CompType = CompTypeAssignment
return rtn
}
if pos <= endOffset {
// completing an assignment word
rtn.CmdWordPos = idx - len(cmd.AssignmentWords)
rtn.CompWord = word
rtn.CompWordOffset = pos - word.Offset
rtn.CompType = CompTypeAssignment
return rtn
}
}
var foundWord *WordType
var foundWordIdx int
for idx, word := range cmd.Words {
startOffset := word.Offset
endOffset := word.Offset + len(word.Raw)
if pos <= startOffset {
// starting a new word at this position
rtn.CmdWordPos = idx
rtn.CompType = compTypeFromPos(idx)
return rtn
}
if pos == endOffset && word.Type == WordTypeOp {
// operators are special, they can allow a full-word completion at endpos
continue
}
if pos <= endOffset {
foundWord = word
foundWordIdx = idx
break
}
}
if foundWord != nil {
rtn.CmdWordPos = foundWordIdx
rtn.CompWord = foundWord
rtn.CompWordOffset = pos - foundWord.Offset
if foundWord.uncompletable() {
// invalid completion point
rtn.CompType = CompTypeInvalid
return rtn
}
rtn.CompType = compTypeFromPos(foundWordIdx)
return rtn
}
// past the end, so we're starting a new word in Cmd
rtn.CmdWordPos = len(cmd.Words)
rtn.CompType = CompTypeArg
return rtn
}
func (cmd *CmdType) findCompletionPos_none(pos int, superOffset int) CompletionPos {
rtn := CompletionPos{RawPos: pos, SuperOffset: superOffset}
if cmd.Type != CmdTypeNone {
panic("findCompletionPos_none only works for CmdTypeNone")
}
var foundWord *WordType
for _, word := range cmd.Words {
startOffset := word.Offset
endOffset := word.Offset + len(word.Raw)
if pos <= startOffset {
break
}
if pos <= endOffset {
if pos == endOffset && word.Type == WordTypeOp {
// operators are special, they can allow a full-word completion at endpos
continue
}
foundWord = word
break
}
}
if foundWord == nil {
// just revert to a file completion
rtn.CompType = CompTypeBasic
return rtn
}
foundWordOffset := pos - foundWord.Offset
rtn.CompWord = foundWord
rtn.CompWordOffset = foundWordOffset
if foundWord.uncompletable() {
// ok, we're inside of a word in CmdTypeNone. if we're in an uncompletable word, return CompInvalid
rtn.CompType = CompTypeInvalid
return rtn
}
if foundWordOffset > 0 && foundWordOffset < foundWord.contentStartPos() {
// cursor is in a weird position, between characters of a multi-char prefix (e.g. "$[*]{hello}" or $[*]'hello'). cannot complete.
rtn.CompType = CompTypeInvalid
return rtn
}
// revert to file completion
rtn.CompType = CompTypeBasic
return rtn
}
func findCompletionWordAtPos(words []*WordType, pos int, allowEndMatch bool) *WordType {
// WordTypeSimpleVar is special (always allowEndMatch), if cursor is at the end of SimpleVar it is returned
for _, word := range words {
if pos > word.Offset && pos < word.Offset+len(word.Raw) {
return word
}
if (allowEndMatch || word.Type == WordTypeSimpleVar) && pos == word.Offset+len(word.Raw) {
return word
}
}
return nil
}
// recursively descend down the word, parse commands and find a sub completion point if any.
// return nil if there is no sub completion point in this word
func findCompletionPosInWord(word *WordType, posInWord int, superOffset int) *CompletionPos {
rawPos := word.Offset + posInWord
if word.Type == WordTypeGroup || word.Type == WordTypeDQ || word.Type == WordTypeDDQ {
// need to descend further
if posInWord <= word.contentStartPos() {
return nil
}
if posInWord > word.contentEndPos() {
return nil
}
subWord := findCompletionWordAtPos(word.Subs, posInWord-word.contentStartPos(), false)
if subWord == nil {
return nil
}
return findCompletionPosInWord(subWord, posInWord-(subWord.Offset+word.contentStartPos()), superOffset+(word.Offset+word.contentStartPos()))
}
if word.Type == WordTypeDP || word.Type == WordTypeBQ {
if posInWord < word.contentStartPos() {
return nil
}
if posInWord > word.contentEndPos() {
return nil
}
subCmds := ParseCommands(word.Subs)
newPos := findCompletionPosInternal(subCmds, posInWord-word.contentStartPos(), superOffset+(word.Offset+word.contentStartPos()))
return &newPos
}
if word.Type == WordTypeSimpleVar || word.Type == WordTypeVarBrace {
// special "var" completion
rtn := &CompletionPos{RawPos: rawPos, SuperOffset: superOffset}
rtn.CompType = CompTypeVar
rtn.CompWordOffset = posInWord
rtn.CompWord = word
return rtn
}
return nil
}
// returns the context for completion
// if we are completing in a simple-command, the returns the Cmd. the Cmd can be used for specialized completion (command name, arg position, etc.)
// if we are completing in a word, returns the Word. Word might be a group-word or DQ word, so it may need additional resolution (done in extend)
// otherwise we are going to create a new word to insert at offset (so the context does not matter)
func findCompletionPosCmds(cmds []*CmdType, pos int, superOffset int) CompletionPos {
rtn := CompletionPos{RawPos: pos, SuperOffset: superOffset}
if len(cmds) == 0 {
// set CompCommand because we're starting a new command
rtn.CompType = CompTypeCommand
return rtn
}
for _, cmd := range cmds {
endOffset := cmd.endOffset()
if pos > endOffset || (cmd.Type == CmdTypeNone && pos == endOffset) {
continue
}
startOffset := cmd.offset()
if cmd.Type == CmdTypeSimple {
if pos <= startOffset {
rtn.CompType = CompTypeCommand
return rtn
}
return cmd.findCompletionPos_simple(pos, superOffset)
} else {
// not in a simple-command
// if we're before the none-command, just start a new command
if pos <= startOffset {
rtn.CompType = CompTypeCommand
return rtn
}
return cmd.findCompletionPos_none(pos, superOffset)
}
}
// past the end
lastCmd := cmds[len(cmds)-1]
if lastCmd.Type == CmdTypeSimple {
// just extend last command
rtn.Cmd = lastCmd
rtn.CmdWordPos = len(lastCmd.Words)
rtn.CompType = CompTypeArg
return rtn
}
// use lastCmd.NoneComplete to see if last command ended on a "separator". use that to set CompCommand
if lastCmd.NoneComplete {
rtn.CompType = CompTypeCommand
} else {
rtn.CompType = CompTypeBasic
}
return rtn
}
func findCompletionPosInternal(cmds []*CmdType, pos int, superOffset int) CompletionPos {
cpos := findCompletionPosCmds(cmds, pos, superOffset)
if cpos.CompWord == nil {
return cpos
}
subPos := findCompletionPosInWord(cpos.CompWord, cpos.CompWordOffset, superOffset)
if subPos != nil {
return *subPos
}
return cpos
}
func FindCompletionPos(cmds []*CmdType, pos int) CompletionPos {
cpos := findCompletionPosInternal(cmds, pos, 0)
if cpos.CompType == CompTypeCommand && cpos.SuperOffset == 0 && cpos.CompWord != nil && cpos.CompWord.Offset == 0 && strings.HasPrefix(string(cpos.CompWord.Raw), "/") {
cpos.CompType = CompTypeCommandMeta
}
return cpos
}
func (cpos CompletionPos) Extend(origStr utilfn.StrWithPos, extensionStr string, extensionComplete bool) utilfn.StrWithPos {
compWord := cpos.CompWord
if compWord == nil {
compWord = MakeEmptyWord(WordTypeLit, nil, cpos.RawPos, true)
}
realOffset := compWord.Offset + cpos.SuperOffset
if strings.HasSuffix(extensionStr, "/") {
extensionComplete = false
}
rtnSP := Extend(compWord, cpos.CompWordOffset, extensionStr, extensionComplete)
origRunes := []rune(origStr.Str)
rtnSP = rtnSP.Prepend(string(origRunes[0:realOffset]))
rtnSP = rtnSP.Append(string(origRunes[realOffset+len(compWord.Raw):]))
return rtnSP
}

View File

@ -0,0 +1,258 @@
package shparse
import (
"bytes"
"fmt"
"mvdan.cc/sh/v3/expand"
)
const MaxExpandLen = 64 * 1024
type ExpandInfo struct {
HasTilde bool // only ~ as the first character when SimpleExpandContext.HomeDir is set
HasVar bool // $x, $$, ${...}
HasGlob bool // *, ?, [, {
HasExtGlob bool // ?(...) ... ?*+@!
HasHistory bool // ! (anywhere)
HasSpecial bool // subshell, arith
}
type ExpandContext struct {
HomeDir string
}
func expandSQ(buf *bytes.Buffer, rawLit []rune) {
// no info specials
buf.WriteString(string(rawLit))
}
// TODO implement our own ANSI single quote formatter
func expandANSISQ(buf *bytes.Buffer, rawLit []rune) {
// no info specials
str, _, _ := expand.Format(nil, string(rawLit), nil)
buf.WriteString(str)
}
func expandLiteral(buf *bytes.Buffer, info *ExpandInfo, rawLit []rune) {
var lastBackSlash bool
var lastExtGlob bool
var lastDollar bool
for _, ch := range rawLit {
if ch == 0 {
break
}
if lastBackSlash {
lastBackSlash = false
if ch == '\n' {
// special case, backslash *and* newline are ignored
continue
}
buf.WriteRune(ch)
continue
}
if ch == '\\' {
lastBackSlash = true
lastExtGlob = false
lastDollar = false
continue
}
if ch == '*' || ch == '?' || ch == '[' || ch == '{' {
info.HasGlob = true
}
if ch == '`' {
info.HasSpecial = true
}
if ch == '!' {
info.HasHistory = true
}
if lastExtGlob && ch == '(' {
info.HasExtGlob = true
}
if lastDollar && (ch != ' ' && ch != '"' && ch != '\'' && ch != '(' || ch != '[') {
info.HasVar = true
}
if lastDollar && (ch == '(' || ch == '[') {
info.HasSpecial = true
}
lastExtGlob = (ch == '?' || ch == '*' || ch == '+' || ch == '@' || ch == '!')
lastDollar = (ch == '$')
buf.WriteRune(ch)
}
if lastBackSlash {
buf.WriteByte('\\')
}
}
// will also work for partial double quoted strings
func expandDQLiteral(buf *bytes.Buffer, info *ExpandInfo, rawVal []rune) {
var lastBackSlash bool
var lastDollar bool
for _, ch := range rawVal {
if ch == 0 {
break
}
if lastBackSlash {
lastBackSlash = false
if ch == '"' || ch == '\\' || ch == '$' || ch == '`' {
buf.WriteRune(ch)
continue
}
buf.WriteRune('\\')
buf.WriteRune(ch)
continue
}
if ch == '\\' {
lastBackSlash = true
lastDollar = false
continue
}
// similar to expandLiteral, but no globbing
if ch == '`' {
info.HasSpecial = true
}
if ch == '!' {
info.HasHistory = true
}
if lastDollar && (ch != ' ' && ch != '"' && ch != '\'' && ch != '(' || ch != '[') {
info.HasVar = true
}
if lastDollar && (ch == '(' || ch == '[') {
info.HasSpecial = true
}
lastDollar = (ch == '$')
buf.WriteRune(ch)
}
// in a valid parsed DQ string, you cannot have a trailing backslash (because \" would not end the string)
// still putting the case here though in case we ever deal with incomplete strings (e.g. completion)
if lastBackSlash {
buf.WriteByte('\\')
}
}
func simpleExpandSubs(buf *bytes.Buffer, info *ExpandInfo, ectx ExpandContext, word *WordType, pos int) {
fmt.Printf("expand subs: %v\n", word)
parts := word.Subs
startPos := word.contentStartPos()
for _, part := range parts {
remainingLen := pos - startPos
if remainingLen <= 0 {
break
}
simpleExpandWord(buf, info, ectx, part, remainingLen)
startPos += len(part.Raw)
}
}
func canExpand(ectx ExpandContext, wtype string) bool {
return wtype == WordTypeLit || wtype == WordTypeSQ || wtype == WordTypeDSQ ||
wtype == WordTypeDQ || wtype == WordTypeDDQ || wtype == WordTypeGroup
}
func simpleExpandWord(buf *bytes.Buffer, info *ExpandInfo, ectx ExpandContext, word *WordType, pos int) {
if canExpand(ectx, word.Type) {
if pos >= word.contentEndPos() {
pos = word.contentEndPos()
}
if pos <= word.contentStartPos() {
return
}
} else {
if pos >= len(word.Raw) {
pos = len(word.Raw)
}
if pos <= 0 {
return
}
}
switch word.Type {
case WordTypeLit:
if word.QC.cur() == WordTypeDQ {
expandDQLiteral(buf, info, word.Raw[:pos])
return
}
expandLiteral(buf, info, word.Raw[:pos])
case WordTypeSQ:
expandSQ(buf, word.Raw[word.contentStartPos():pos])
return
case WordTypeDSQ:
expandANSISQ(buf, word.Raw[word.contentStartPos():pos])
return
case WordTypeDQ, WordTypeDDQ:
simpleExpandSubs(buf, info, ectx, word, pos)
return
case WordTypeGroup:
simpleExpandSubs(buf, info, ectx, word, pos)
return
// not expanded
case WordTypeSimpleVar:
info.HasVar = true
buf.WriteString(string(word.Raw[:pos]))
return
// not expanded
case WordTypeVarBrace:
info.HasVar = true
buf.WriteString(string(word.Raw[:pos]))
return
default:
info.HasSpecial = true
buf.WriteString(string(word.Raw[:pos]))
return
}
}
func SimpleExpandPrefix(ectx ExpandContext, word *WordType, pos int) (string, ExpandInfo) {
var buf bytes.Buffer
var info ExpandInfo
simpleExpandWord(&buf, &info, ectx, word, pos)
return buf.String(), info
}
func SimpleExpand(ectx ExpandContext, word *WordType) (string, ExpandInfo) {
return SimpleExpandPrefix(ectx, word, len(word.Raw))
}
// returns varname (no '$') and ok (whether this is a valid varname expansion)
func SimpleVarNamePrefix(ectx ExpandContext, word *WordType, pos int) (string, bool) {
if word.Type != WordTypeSimpleVar && word.Type != WordTypeVarBrace {
return "", false
}
if word.Type == WordTypeSimpleVar {
if pos == 0 {
return "", false
}
if pos == 1 {
return "", true
}
if pos > len(word.Raw) {
pos = len(word.Raw)
}
return string(word.Raw[1:pos]), true
}
// word.Type == WordTypeVarBrace
// knock '${' off the front, then see if the rest is a valid var name.
if pos == 0 || pos == 1 {
return "", false
}
if pos == 2 {
return "", true
}
if pos > word.contentEndPos() {
pos = word.contentEndPos()
}
rawVarName := word.Raw[2:pos]
if isSimpleVarName(rawVarName) {
return string(rawVarName), true
}
return "", false
}

View File

@ -0,0 +1,410 @@
package shparse
import (
"bytes"
"unicode"
"unicode/utf8"
"github.com/commandlinedev/prompt-server/pkg/utilfn"
)
var noEscChars []bool
var specialEsc []string
func init() {
noEscChars = make([]bool, 256)
for ch := 0; ch < 256; ch++ {
if (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
ch == '-' || ch == '.' || ch == '/' || ch == ':' || ch == '=' || ch == '_' {
noEscChars[byte(ch)] = true
}
}
specialEsc = make([]string, 256)
specialEsc[0x7] = "\\a"
specialEsc[0x8] = "\\b"
specialEsc[0x9] = "\\t"
specialEsc[0xa] = "\\n"
specialEsc[0xb] = "\\v"
specialEsc[0xc] = "\\f"
specialEsc[0xd] = "\\r"
specialEsc[0x1b] = "\\E"
}
func getUtf8Literal(ch rune) string {
var buf bytes.Buffer
var runeArr [utf8.UTFMax]byte
barr := runeArr[:]
byteLen := utf8.EncodeRune(barr, ch)
for i := 0; i < byteLen; i++ {
buf.WriteString("\\x")
buf.WriteByte(utilfn.HexDigits[barr[i]/16])
buf.WriteByte(utilfn.HexDigits[barr[i]%16])
}
return buf.String()
}
func (w *WordType) writeString(s string) {
for _, ch := range s {
w.writeRune(ch)
}
}
func (w *WordType) writeRune(ch rune) {
wmeta := wordMetaMap[w.Type]
if w.Complete && wmeta.SuffixLen == 1 {
w.Raw = append(w.Raw[0:len(w.Raw)-1], ch, w.Raw[len(w.Raw)-1])
return
}
if w.Complete && wmeta.SuffixLen == 2 {
w.Raw = append(w.Raw[0:len(w.Raw)-2], ch, w.Raw[len(w.Raw)-2], w.Raw[len(w.Raw)-1])
return
}
// not complete or SuffixLen == 0 (2+ is not supported)
w.Raw = append(w.Raw, ch)
return
}
type extendContext struct {
Input []*WordType
InputPos int
QC QuoteContext
Rtn []*WordType
CurWord *WordType
Intention string
}
func makeExtendContext(qc QuoteContext, word *WordType) *extendContext {
rtn := &extendContext{QC: qc}
if word == nil {
rtn.Intention = WordTypeLit
return rtn
} else {
rtn.Intention = word.Type
rtn.Rtn = []*WordType{word}
rtn.CurWord = word
return rtn
}
}
func (ec *extendContext) appendWord(w *WordType) {
ec.Rtn = append(ec.Rtn, w)
ec.CurWord = w
}
func (ec *extendContext) ensureCurWord() {
if ec.CurWord == nil || ec.CurWord.Type != ec.Intention {
ec.CurWord = MakeEmptyWord(ec.Intention, ec.QC, 0, true)
ec.Rtn = append(ec.Rtn, ec.CurWord)
}
}
// grp, dq, ddq
func extendWithSubs(word *WordType, wordPos int, extStr string, complete bool) utilfn.StrWithPos {
wmeta := wordMetaMap[word.Type]
if word.Type == WordTypeGroup {
atEnd := (wordPos == len(word.Raw))
subWord := findCompletionWordAtPos(word.Subs, wordPos, true)
if subWord == nil {
strPos := Extend(MakeEmptyWord(WordTypeLit, word.QC, 0, true), 0, extStr, atEnd)
strPos = strPos.Prepend(string(word.Raw[0:wordPos]))
strPos = strPos.Append(string(word.Raw[wordPos:]))
return strPos
} else {
subComplete := complete && atEnd
strPos := Extend(subWord, wordPos-subWord.Offset, extStr, subComplete)
strPos = strPos.Prepend(string(word.Raw[0:subWord.Offset]))
strPos = strPos.Append(string(word.Raw[subWord.Offset+len(subWord.Raw):]))
return strPos
}
} else if word.Type == WordTypeDQ || word.Type == WordTypeDDQ {
if wordPos < word.contentStartPos() {
wordPos = word.contentStartPos()
}
atEnd := (wordPos >= len(word.Raw)-wmeta.SuffixLen)
subWord := findCompletionWordAtPos(word.Subs, wordPos-wmeta.PrefixLen, true)
quoteBalance := !atEnd
if subWord == nil {
realOffset := wordPos
strPos, wordOpen := extendInternal(MakeEmptyWord(WordTypeLit, word.QC.push(WordTypeDQ), 0, true), 0, extStr, false, quoteBalance)
strPos = strPos.Prepend(string(word.Raw[0:realOffset]))
var requiredSuffix string
if wordOpen {
requiredSuffix = wmeta.getSuffix()
}
if atEnd {
if complete {
return utilfn.StrWithPos{Str: strPos.Str + requiredSuffix + " ", Pos: strPos.Pos + len(requiredSuffix) + 1}
} else {
if word.Complete && requiredSuffix != "" {
return strPos.Append(requiredSuffix)
}
return strPos
}
}
strPos = strPos.Append(string(word.Raw[wordPos:]))
return strPos
} else {
realOffset := subWord.Offset + wmeta.PrefixLen
strPos, wordOpen := extendInternal(subWord, wordPos-realOffset, extStr, false, quoteBalance)
strPos = strPos.Prepend(string(word.Raw[0:realOffset]))
var requiredSuffix string
if wordOpen {
requiredSuffix = wmeta.getSuffix()
}
if atEnd {
if complete {
return utilfn.StrWithPos{Str: strPos.Str + requiredSuffix + " ", Pos: strPos.Pos + len(requiredSuffix) + 1}
} else {
if word.Complete && requiredSuffix != "" {
return strPos.Append(requiredSuffix)
}
return strPos
}
}
strPos = strPos.Append(string(word.Raw[realOffset+len(subWord.Raw):]))
return strPos
}
} else {
return utilfn.StrWithPos{Str: string(word.Raw), Pos: wordPos}
}
}
// lit, svar, varb, sq, dsq
func extendLeafCh(buf *bytes.Buffer, wordOpen *bool, wtype string, qc QuoteContext, ch rune) {
switch wtype {
case WordTypeSimpleVar, WordTypeVarBrace:
extendVar(buf, ch)
case WordTypeLit:
if qc.cur() == WordTypeDQ {
extendDQLit(buf, wordOpen, ch)
} else {
extendLit(buf, ch)
}
case WordTypeSQ:
extendSQ(buf, wordOpen, ch)
case WordTypeDSQ:
extendDSQ(buf, wordOpen, ch)
default:
return
}
}
func getWordOpenStr(wtype string, qc QuoteContext) string {
if wtype == WordTypeLit {
if qc.cur() == WordTypeDQ {
return "\""
} else {
return ""
}
}
wmeta := wordMetaMap[wtype]
return wmeta.getPrefix()
}
// lit, svar, varb sq, dsq
func extendLeaf(buf *bytes.Buffer, wordOpen *bool, word *WordType, wordPos int, extStr string) {
for _, ch := range extStr {
extendLeafCh(buf, wordOpen, word.Type, word.QC, ch)
}
}
// lit, grp, svar, dq, ddq, varb, sq, dsq
// returns (strwithpos, dq-closed)
func extendInternal(word *WordType, wordPos int, extStr string, complete bool, requiresQuoteBalance bool) (utilfn.StrWithPos, bool) {
if extStr == "" {
return utilfn.StrWithPos{Str: string(word.Raw), Pos: wordPos}, true
}
if word.canHaveSubs() {
return extendWithSubs(word, wordPos, extStr, complete), true
}
var buf bytes.Buffer
isEOW := wordPos >= word.contentEndPos()
if isEOW {
wordPos = word.contentEndPos()
}
if wordPos < word.contentStartPos() {
wordPos = word.contentStartPos()
}
if wordPos > 0 {
buf.WriteString(string(word.Raw[0:word.contentStartPos()])) // write the prefix
}
if wordPos > word.contentStartPos() {
buf.WriteString(string(word.Raw[word.contentStartPos():wordPos]))
}
wordOpen := true
extendLeaf(&buf, &wordOpen, word, wordPos, extStr)
if isEOW {
// end-of-word, write the suffix (and optional ' '). return the end of the string
wmeta := wordMetaMap[word.Type]
rtnPos := utf8.RuneCount(buf.Bytes())
buf.WriteString(wmeta.getSuffix())
if !wordOpen && requiresQuoteBalance {
buf.WriteString(getWordOpenStr(word.Type, word.QC))
wordOpen = true
}
if complete {
buf.WriteRune(' ')
return utilfn.StrWithPos{Str: buf.String(), Pos: utf8.RuneCount(buf.Bytes())}, wordOpen
} else {
return utilfn.StrWithPos{Str: buf.String(), Pos: rtnPos}, wordOpen
}
}
// completion in the middle of a word (no ' ')
rtnPos := utf8.RuneCount(buf.Bytes())
if !wordOpen {
// always required since there is a suffix
buf.WriteString(getWordOpenStr(word.Type, word.QC))
wordOpen = true
}
buf.WriteString(string(word.Raw[wordPos:])) // write the suffix
return utilfn.StrWithPos{Str: buf.String(), Pos: rtnPos}, wordOpen
}
// lit, grp, svar, dq, ddq, varb, sq, dsq
func Extend(word *WordType, wordPos int, extStr string, complete bool) utilfn.StrWithPos {
rtn, _ := extendInternal(word, wordPos, extStr, complete, false)
return rtn
}
func (ec *extendContext) extend(ch rune) {
if ch == 0 {
return
}
return
}
func isVarNameChar(ch rune) bool {
return ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9')
}
func extendVar(buf *bytes.Buffer, ch rune) {
if ch == 0 {
return
}
if !isVarNameChar(ch) {
return
}
buf.WriteRune(ch)
}
func getSpecialEscape(ch rune) string {
if ch > unicode.MaxASCII {
return ""
}
return specialEsc[byte(ch)]
}
func writeSpecial(buf *bytes.Buffer, ch rune, wrap bool) {
if wrap {
buf.WriteRune('$')
buf.WriteRune('\'')
}
sesc := getSpecialEscape(ch)
if sesc != "" {
buf.WriteString(sesc)
} else {
utf8Lit := getUtf8Literal(ch)
buf.WriteString(utf8Lit)
}
if wrap {
buf.WriteRune('\'')
}
}
func extendLit(buf *bytes.Buffer, ch rune) {
if ch == 0 {
return
}
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
writeSpecial(buf, ch, true)
return
}
var bch = byte(ch)
if noEscChars[bch] {
buf.WriteRune(ch)
return
}
buf.WriteRune('\\')
buf.WriteRune(ch)
return
}
func extendDSQ(buf *bytes.Buffer, wordOpen *bool, ch rune) {
if ch == 0 {
return
}
if !*wordOpen {
buf.WriteRune('$')
buf.WriteRune('\'')
*wordOpen = true
}
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
writeSpecial(buf, ch, false)
return
}
if ch == '\'' {
buf.WriteRune('\\')
buf.WriteRune(ch)
return
}
buf.WriteRune(ch)
return
}
func extendSQ(buf *bytes.Buffer, wordOpen *bool, ch rune) {
if ch == 0 {
return
}
if ch == '\'' {
if *wordOpen {
buf.WriteRune('\'')
*wordOpen = false
}
buf.WriteRune('\\')
buf.WriteRune('\'')
return
}
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
if *wordOpen {
buf.WriteRune('\'')
*wordOpen = false
}
writeSpecial(buf, ch, true)
return
}
if !*wordOpen {
buf.WriteRune('\'')
*wordOpen = true
}
buf.WriteRune(ch)
return
}
func extendDQLit(buf *bytes.Buffer, wordOpen *bool, ch rune) {
if ch == 0 {
return
}
if ch > unicode.MaxASCII || !unicode.IsPrint(ch) {
if *wordOpen {
buf.WriteRune('"')
*wordOpen = false
}
writeSpecial(buf, ch, true)
return
}
if !*wordOpen {
buf.WriteRune('"')
*wordOpen = true
}
if ch == '"' || ch == '\\' || ch == '$' || ch == '`' {
buf.WriteRune('\\')
buf.WriteRune(ch)
return
}
buf.WriteRune(ch)
return
}

View File

@ -0,0 +1,693 @@
package shparse
import (
"bytes"
"fmt"
"github.com/commandlinedev/prompt-server/pkg/utilfn"
)
//
// cmds := cmd (sep cmd)*
// sep := ';' | '&' | '&&' | '||' | '|' | '\n'
// cmd := simple-cmd | compound-command redirect-list?
// compound-command := brace-group | subshell | for-clause | case-clause | if-clause | while-clause | until-clause
// brace-group := '{' cmds '}'
// subshell := '(' cmds ')'
// simple-command := cmd-prefix cmd-word (io-redirect)*
// cmd-prefix := (io-redirect | assignment)*
// cmd-suffix := (io-redirect | word)*
// cmd-name := word
// cmd-word := word
// io-redirect := (io-number? io-file) | (io-number? io-here)
// io-file := ('<' | '<&' | '>' | '>&' | '>>' | '>|' ) filename
// io-here := ('<<' | '<<-') here_end
// here-end := word
// if-clause := 'if' compound-list 'then' compound-list else-part 'fi'
// else-part := 'elif' compound-list 'then' compound-list
// | 'elif' compount-list 'then' compound-list else-part
// | 'else' compound-list
// compound-list := linebreak term sep?
//
//
//
// A correctly-formed brace expansion must contain unquoted opening and closing braces, and at least one unquoted comma or a valid sequence expression
// Any incorrectly formed brace expansion is left unchanged.
//
// ambiguity between $((...)) and $((ls); ls)
// ambiguity between foo=([0]=hell) and foo=([abc)
// tokenization https://pubs.opengroup.org/onlinepubs/7908799/xcu/chap2.html#tag_001_003
// can-extend: WordTypeLit, WordTypeSimpleVar, WordTypeVarBrace, WordTypeDQ, WordTypeDDQ, WordTypeSQ, WordTypeDSQ
const (
WordTypeRaw = "raw"
WordTypeLit = "lit" // (can-extend)
WordTypeOp = "op" // single: & ; | ( ) < > \n multi(2): && || ;; << >> <& >& <> >| (( multi(3): <<- ('((' requires special processing)
WordTypeKey = "key" // if then else elif fi do done case esac while until for in { } ! (( [[
WordTypeGroup = "grp" // contains other words e.g. "hello"foo'bar'$x (has-subs) (can-extend)
WordTypeSimpleVar = "svar" // simplevar $ (can-extend)
WordTypeDQ = "dq" // " (quote-context) (can-extend) (has-subs)
WordTypeDDQ = "ddq" // $" (can-extend) (has-subs) (for quotecontext, uses WordTypeDQ)
WordTypeVarBrace = "varb" // ${ (quote-context) (can-extend) (internals not parsed)
WordTypeDP = "dp" // $( (quote-context) (has-subs)
WordTypeBQ = "bq" // ` (quote-context) (has-subs)
WordTypeSQ = "sq" // ' (can-extend)
WordTypeDSQ = "dsq" // $' (can-extend)
WordTypeDPP = "dpp" // $(( (internals not parsed)
WordTypePP = "pp" // (( (internals not parsed)
WordTypeDB = "db" // $[ (internals not parsed)
)
const (
CmdTypeNone = "none" // holds control structures: '(' ')' 'for' 'while' etc.
CmdTypeSimple = "simple" // holds real commands
)
type WordType struct {
Type string
Offset int
QC QuoteContext
Raw []rune
Complete bool
Prefix []rune
Subs []*WordType
}
type CmdType struct {
Type string
AssignmentWords []*WordType
Words []*WordType
NoneComplete bool // set to true when last-word is a "separator"
}
type QuoteContext []string
var wordMetaMap map[string]wordMeta
// same order as https://www.gnu.org/software/bash/manual/html_node/Reserved-Words.html
var bashReservedWords = []string{
"if", "then", "elif", "else", "fi", "time",
"for", "in", "until", "while", "do", "done",
"case", "esac", "coproc", "select", "function",
"{", "}", "[[", "]]", "!",
}
// special reserved words: "for", "in", "case", "select", "function", "[[", and "]]"
var bashNoneRW = []string{
"if", "then",
"elif", "else", "fi", "time",
"until", "while", "do", "done",
"esac", "coproc",
"{", "}", "!",
}
type wordMeta struct {
Type string
EmptyWord []rune
PrefixLen int
SuffixLen int
CanExtend bool
QuoteContext bool
}
func (m wordMeta) getSuffix() string {
if m.SuffixLen == 0 {
return ""
}
return string(m.EmptyWord[len(m.EmptyWord)-m.SuffixLen:])
}
func (m wordMeta) getPrefix() string {
if m.PrefixLen == 0 {
return ""
}
return string(m.EmptyWord[:m.PrefixLen])
}
func makeWordMeta(wtype string, emptyWord string, prefixLen int, suffixLen int, canExtend bool, quoteContext bool) {
if len(emptyWord) != prefixLen+suffixLen {
panic(fmt.Sprintf("invalid empty word %s %d %d", emptyWord, prefixLen, suffixLen))
}
wordMetaMap[wtype] = wordMeta{wtype, []rune(emptyWord), prefixLen, suffixLen, canExtend, quoteContext}
}
func init() {
wordMetaMap = make(map[string]wordMeta)
makeWordMeta(WordTypeRaw, "", 0, 0, false, false)
makeWordMeta(WordTypeLit, "", 0, 0, true, false)
makeWordMeta(WordTypeOp, "", 0, 0, false, false)
makeWordMeta(WordTypeKey, "", 0, 0, false, false)
makeWordMeta(WordTypeGroup, "", 0, 0, false, false)
makeWordMeta(WordTypeSimpleVar, "$", 1, 0, true, false)
makeWordMeta(WordTypeVarBrace, "${}", 2, 1, true, true)
makeWordMeta(WordTypeDQ, `""`, 1, 1, true, true)
makeWordMeta(WordTypeDDQ, `$""`, 2, 1, true, true)
makeWordMeta(WordTypeDP, "$()", 2, 1, false, false)
makeWordMeta(WordTypeBQ, "``", 1, 1, false, false)
makeWordMeta(WordTypeSQ, "''", 1, 1, true, false)
makeWordMeta(WordTypeDSQ, "$''", 2, 1, true, false)
makeWordMeta(WordTypeDPP, "$(())", 3, 2, false, false)
makeWordMeta(WordTypePP, "(())", 2, 2, false, false)
makeWordMeta(WordTypeDB, "$[]", 2, 1, false, false)
}
func MakeEmptyWord(wtype string, qc QuoteContext, offset int, complete bool) *WordType {
meta := wordMetaMap[wtype]
if meta.Type == "" {
meta = wordMetaMap[WordTypeRaw]
}
rtn := &WordType{Type: meta.Type, QC: qc, Offset: offset, Complete: complete}
if len(meta.EmptyWord) > 0 {
if complete {
rtn.Raw = append([]rune(nil), meta.EmptyWord...)
} else {
rtn.Raw = append([]rune(nil), []rune(meta.getPrefix())...)
}
}
return rtn
}
func (qc QuoteContext) push(q string) QuoteContext {
rtn := make([]string, 0, len(qc)+1)
rtn = append(rtn, qc...)
rtn = append(rtn, q)
return rtn
}
func (qc QuoteContext) cur() string {
if len(qc) == 0 {
return ""
}
return qc[len(qc)-1]
}
func (qc QuoteContext) clone() QuoteContext {
if len(qc) == 0 {
return nil
}
return append([]string(nil), qc...)
}
func makeRepeatStr(ch byte, slen int) string {
if slen == 0 {
return ""
}
rtn := make([]byte, slen)
for i := 0; i < slen; i++ {
rtn[i] = ch
}
return string(rtn)
}
func (w *WordType) isBlank() bool {
return w.Type == WordTypeLit && len(w.Raw) == 0
}
func (w *WordType) contentEndPos() int {
if !w.Complete {
return len(w.Raw)
}
wmeta := wordMetaMap[w.Type]
return len(w.Raw) - wmeta.SuffixLen
}
func (w *WordType) contentStartPos() int {
wmeta := wordMetaMap[w.Type]
return wmeta.PrefixLen
}
func (w *WordType) canHaveSubs() bool {
switch w.Type {
case WordTypeGroup, WordTypeDQ, WordTypeDDQ, WordTypeDP, WordTypeBQ:
return true
default:
return false
}
}
func (w *WordType) uncompletable() bool {
switch w.Type {
case WordTypeRaw, WordTypeOp, WordTypeKey, WordTypeDPP, WordTypePP, WordTypeDB, WordTypeBQ, WordTypeDP:
return true
default:
return false
}
}
func (w *WordType) stringWithPos(pos int) string {
notCompleteFlag := " "
if !w.Complete {
notCompleteFlag = "*"
}
str := string(w.Raw)
if pos != -1 {
str = utilfn.StrWithPos{Str: str, Pos: pos}.String()
}
return fmt.Sprintf("%-4s[%3d]%s %s%q", w.Type, w.Offset, notCompleteFlag, makeRepeatStr('_', len(w.Prefix)), str)
}
func (w *WordType) String() string {
notCompleteFlag := " "
if !w.Complete {
notCompleteFlag = "*"
}
return fmt.Sprintf("%-4s[%3d]%s %s%q", w.Type, w.Offset, notCompleteFlag, makeRepeatStr('_', len(w.Prefix)), string(w.Raw))
}
// offset = -1 for don't show
func dumpWords(words []*WordType, indentStr string, offset int) {
wrotePos := false
for _, word := range words {
posInWord := false
if !wrotePos && offset != -1 && offset <= word.Offset {
fmt.Printf("%s* [%3d] [*]\n", indentStr, offset)
wrotePos = true
}
if !wrotePos && offset != -1 && offset < word.Offset+len(word.Raw) {
fmt.Printf("%s%s\n", indentStr, word.stringWithPos(offset-word.Offset))
wrotePos = true
posInWord = true
} else {
fmt.Printf("%s%s\n", indentStr, word.String())
}
if len(word.Subs) > 0 {
if posInWord {
wmeta := wordMetaMap[word.Type]
dumpWords(word.Subs, indentStr+" ", offset-word.Offset-wmeta.PrefixLen)
} else {
dumpWords(word.Subs, indentStr+" ", -1)
}
}
}
}
func dumpCommands(cmds []*CmdType, indentStr string, pos int) {
for _, cmd := range cmds {
fmt.Printf("%sCMD: %s [%d] pos:%d\n", indentStr, cmd.Type, len(cmd.Words), pos)
dumpWords(cmd.AssignmentWords, indentStr+" *", pos)
dumpWords(cmd.Words, indentStr+" ", pos)
}
}
func wordsToStr(words []*WordType) string {
var buf bytes.Buffer
for _, word := range words {
if len(word.Prefix) > 0 {
buf.WriteString(string(word.Prefix))
}
buf.WriteString(string(word.Raw))
}
return buf.String()
}
// recognizes reserved words in first position
func convertToAnyReservedWord(w *WordType) bool {
if w == nil || w.Type != WordTypeLit {
return false
}
rawVal := string(w.Raw)
for _, rw := range bashReservedWords {
if rawVal == rw {
w.Type = WordTypeKey
return true
}
}
return false
}
// recognizes the specific reserved-word given only ('in' and 'do' in 'for', 'case', and 'select' commands)
func convertToReservedWord(w *WordType, reservedWord string) {
if w == nil || w.Type != WordTypeLit {
return
}
if string(w.Raw) == reservedWord {
w.Type = WordTypeKey
}
}
func isNoneReservedWord(w *WordType) bool {
if w.Type != WordTypeKey {
return false
}
rawVal := string(w.Raw)
for _, rw := range bashNoneRW {
if rawVal == rw {
return true
}
}
return false
}
type parseCmdState struct {
Input []*WordType
InputPos int
Rtn []*CmdType
Cur *CmdType
}
func (state *parseCmdState) isEof() bool {
return state.InputPos >= len(state.Input)
}
func (state *parseCmdState) curWord() *WordType {
if state.isEof() {
return nil
}
return state.Input[state.InputPos]
}
func (state *parseCmdState) lastCmd() *CmdType {
if len(state.Rtn) == 0 {
return nil
}
return state.Rtn[len(state.Rtn)-1]
}
func (state *parseCmdState) makeNoneCmd(sep bool) {
if state.Cur == nil || state.Cur.Type != CmdTypeNone {
state.Cur = &CmdType{Type: CmdTypeNone}
state.Rtn = append(state.Rtn, state.Cur)
}
state.Cur.Words = append(state.Cur.Words, state.curWord())
if sep {
state.Cur.NoneComplete = true
state.Cur = nil
}
state.InputPos++
}
func (state *parseCmdState) handleKeyword(word *WordType) bool {
if word.Type != WordTypeKey {
return false
}
if isNoneReservedWord(word) {
state.makeNoneCmd(true)
return true
}
rw := string(word.Raw)
if rw == "[[" {
// just ignore everything between [[ and ]]
for !state.isEof() {
curWord := state.curWord()
if curWord.Type == WordTypeLit && string(curWord.Raw) == "]]" {
convertToReservedWord(curWord, "]]")
state.makeNoneCmd(false)
break
}
state.makeNoneCmd(false)
}
return true
}
if rw == "case" {
// ignore everything between "case" and "esac"
for !state.isEof() {
curWord := state.curWord()
if curWord.Type == WordTypeKey && string(curWord.Raw) == "esac" {
state.makeNoneCmd(false)
break
}
state.makeNoneCmd(false)
}
return true
}
if rw == "for" || rw == "select" {
// ignore until a "do"
for !state.isEof() {
curWord := state.curWord()
if curWord.Type == WordTypeKey && string(curWord.Raw) == "do" {
state.makeNoneCmd(true)
break
}
state.makeNoneCmd(false)
}
return true
}
if rw == "in" {
// the "for" and "case" clauses should skip "in". so encountering an "in" here is a syntax error.
// just treat it as a none and allow a new command after.
state.makeNoneCmd(false)
return true
}
if rw == "function" {
// ignore until '{'
for !state.isEof() {
curWord := state.curWord()
if curWord.Type == WordTypeKey && string(curWord.Raw) == "{" {
state.makeNoneCmd(true)
break
}
state.makeNoneCmd(false)
}
return true
}
state.makeNoneCmd(true)
return true
}
func isCmdSeparatorOp(word *WordType) bool {
if word.Type != WordTypeOp {
return false
}
opVal := string(word.Raw)
return opVal == ";" || opVal == "\n" || opVal == "&" || opVal == "|" || opVal == "|&" || opVal == "&&" || opVal == "||" || opVal == "(" || opVal == ")"
}
func (state *parseCmdState) handleOp(word *WordType) bool {
opVal := string(word.Raw)
// sequential separators
if opVal == ";" || opVal == "\n" {
state.makeNoneCmd(true)
return true
}
// separator
if opVal == "&" {
state.makeNoneCmd(true)
return true
}
// pipelines
if opVal == "|" || opVal == "|&" {
state.makeNoneCmd(true)
return true
}
// lists
if opVal == "&&" || opVal == "||" {
state.makeNoneCmd(true)
return true
}
// subshell
if opVal == "(" || opVal == ")" {
state.makeNoneCmd(true)
return true
}
return false
}
func wordSliceBoundedIdx(words []*WordType, idx int) *WordType {
if idx >= len(words) {
return nil
}
return words[idx]
}
// note that a newline "op" can appear in the third position of "for" or "case". the "in" keyword is still converted because of wordNum == 0
func identifyReservedWords(words []*WordType) {
wordNum := 0
lastReserved := false
for idx, word := range words {
if wordNum == 0 || lastReserved {
convertToAnyReservedWord(word)
}
if word.Type == WordTypeKey {
rwVal := string(word.Raw)
switch rwVal {
case "for":
lastReserved = false
third := wordSliceBoundedIdx(words, idx+2)
convertToReservedWord(third, "in")
convertToReservedWord(third, "do")
case "case":
lastReserved = false
third := wordSliceBoundedIdx(words, idx+2)
convertToReservedWord(third, "in")
case "in":
lastReserved = false
default:
lastReserved = true
}
continue
}
lastReserved = false
if isCmdSeparatorOp(word) {
wordNum = 0
continue
}
wordNum++
}
}
func ResetWordOffsets(words []*WordType, startIdx int) {
pos := startIdx
for _, word := range words {
pos += len(word.Prefix)
word.Offset = pos
if len(word.Subs) > 0 {
ResetWordOffsets(word.Subs, 0)
}
pos += len(word.Raw)
}
}
func CommandsToWords(cmds []*CmdType) []*WordType {
var rtn []*WordType
for _, cmd := range cmds {
rtn = append(rtn, cmd.Words...)
}
return rtn
}
func (c *CmdType) stripPrefix() []rune {
if len(c.AssignmentWords) > 0 {
w := c.AssignmentWords[0]
prefix := w.Prefix
if len(prefix) == 0 {
return nil
}
newWord := *w
newWord.Prefix = nil
c.AssignmentWords[0] = &newWord
return prefix
}
if len(c.Words) > 0 {
w := c.Words[0]
prefix := w.Prefix
if len(prefix) == 0 {
return nil
}
newWord := *w
newWord.Prefix = nil
c.Words[0] = &newWord
return prefix
}
return nil
}
func (c *CmdType) isEmpty() bool {
return len(c.AssignmentWords) == 0 && len(c.Words) == 0
}
func (c *CmdType) lastWord() *WordType {
if len(c.Words) > 0 {
return c.Words[len(c.Words)-1]
}
if len(c.AssignmentWords) > 0 {
return c.AssignmentWords[len(c.AssignmentWords)-1]
}
return nil
}
func (c *CmdType) firstWord() *WordType {
if len(c.AssignmentWords) > 0 {
return c.AssignmentWords[0]
}
if len(c.Words) > 0 {
return c.Words[0]
}
return nil
}
func (c *CmdType) offset() int {
firstWord := c.firstWord()
if firstWord == nil {
return 0
}
return firstWord.Offset
}
func (c *CmdType) endOffset() int {
lastWord := c.lastWord()
if lastWord == nil {
return 0
}
return lastWord.Offset + len(lastWord.Raw)
}
func indexInRunes(arr []rune, ch rune) int {
for idx, r := range arr {
if r == ch {
return idx
}
}
return -1
}
func isAssignmentWord(w *WordType) bool {
if w.Type == WordTypeLit || w.Type == WordTypeGroup {
eqIdx := indexInRunes(w.Raw, '=')
if eqIdx == -1 {
return false
}
prefix := w.Raw[0:eqIdx]
return isSimpleVarName(prefix)
}
return false
}
// simple commands steal whitespace from subsequent commands
func cmdWhitespaceFixup(cmds []*CmdType) {
for idx := 0; idx < len(cmds)-1; idx++ {
cmd := cmds[idx]
if cmd.Type != CmdTypeSimple || cmd.isEmpty() {
continue
}
nextCmd := cmds[idx+1]
nextPrefix := nextCmd.stripPrefix()
if len(nextPrefix) > 0 {
blankWord := &WordType{Type: WordTypeLit, QC: cmd.lastWord().QC, Offset: cmd.endOffset() + len(nextPrefix), Prefix: nextPrefix, Complete: true}
cmd.Words = append(cmd.Words, blankWord)
}
}
}
func ParseCommands(words []*WordType) []*CmdType {
identifyReservedWords(words)
state := parseCmdState{Input: words}
for {
if state.isEof() {
break
}
word := state.curWord()
if word.Type == WordTypeKey {
done := state.handleKeyword(word)
if done {
continue
}
}
if word.Type == WordTypeOp {
done := state.handleOp(word)
if done {
continue
}
}
if state.Cur == nil || state.Cur.Type != CmdTypeSimple {
state.Cur = &CmdType{Type: CmdTypeSimple}
state.Rtn = append(state.Rtn, state.Cur)
}
if len(state.Cur.Words) == 0 && isAssignmentWord(word) {
state.Cur.AssignmentWords = append(state.Cur.AssignmentWords, word)
} else {
state.Cur.Words = append(state.Cur.Words, word)
}
state.InputPos++
}
cmdWhitespaceFixup(state.Rtn)
return state.Rtn
}

View File

@ -0,0 +1,219 @@
package shparse
import (
"fmt"
"testing"
"github.com/commandlinedev/prompt-server/pkg/utilfn"
)
// $(ls f[*]); ./x
// ls f => raw["ls f"] -> lit["ls f"] -> lit["ls"] lit["f"]
// w; ls foo; => raw["w; ls foo;"]
// ls&"ls" => raw["ls&ls"] => lit["ls&"] dq["ls"] => lit["ls"] key["&"] dq["ls"]
// ls $x; echo `ls f => raw["ls $x; echo `ls f"]
// > echo $foo{x,y}
func testParse(t *testing.T, s string) {
words := Tokenize(s)
fmt.Printf("parse <<\n%s\n>>\n", s)
dumpWords(words, " ", 8)
outStr := wordsToStr(words)
if outStr != s {
t.Errorf("tokenization output does not match input: %q => %q", s, outStr)
}
fmt.Printf("------\n\n")
}
func Test1(t *testing.T) {
testParse(t, "ls")
testParse(t, "ls 'foo'")
testParse(t, `ls "hello" $'\''`)
testParse(t, `ls "foo`)
testParse(t, `echo $11 $xyz $ `)
testParse(t, `echo $(ls ${x:"hello"} foo`)
testParse(t, `ls ${x:"hello"} $[2+2] $((5 * 10)) $(ls; ls&)`)
testParse(t, `ls;ls&./foo > out 2> "out2"`)
testParse(t, `(( x = 5)); ls& cd ~/work/"hello again"`)
testParse(t, `echo "hello"abc$(ls)$x${y:foo}`)
testParse(t, `echo $(ls; ./x "foo")`)
testParse(t, `echo $(ls; (cd foo; ls); (cd bar; ls))xyz`)
testParse(t, `echo "$x ${y:-foo}"`)
testParse(t, `command="$(echo "$input" | sed -e "s/^[ \t]*\([^ \t]*\)[ \t]*.*$/\1/g")"`)
testParse(t, `echo $(ls $)`)
testParse(t, `echo ${x:-hello\}"}"} 2nd`)
testParse(t, `echo "$(ls "foo") more $x"`)
testParse(t, "echo `ls $x \"hello $x\" \\`ls\\`; ./foo`")
testParse(t, `echo $"hello $x $(ls)"`)
testParse(t, "echo 'hello'\nls\n")
testParse(t, "echo 'hello'abc$'\a'")
}
func lastWord(words []*WordType) *WordType {
if len(words) == 0 {
return nil
}
return words[len(words)-1]
}
func testExtend(t *testing.T, startStr string, extendStr string, complete bool, expStr string) {
startSP := utilfn.ParseToSP(startStr)
words := Tokenize(startSP.Str)
word := findCompletionWordAtPos(words, startSP.Pos, true)
if word == nil {
word = MakeEmptyWord(WordTypeLit, nil, startSP.Pos, true)
}
outSP := Extend(word, startSP.Pos-word.Offset, extendStr, complete)
expSP := utilfn.ParseToSP(expStr)
fmt.Printf("extend: [%s] + %q => [%s]\n", startStr, extendStr, outSP)
if outSP != expSP {
t.Errorf("extension does not match: [%s] + %q => [%s] expected [%s]\n", startStr, extendStr, outSP, expSP)
}
}
func Test2(t *testing.T) {
testExtend(t, `he[*]`, "llo", false, "hello[*]")
testExtend(t, `he[*]`, "llo", true, "hello [*]")
testExtend(t, `'mi[*]e`, "k", false, "'mik[*]e")
testExtend(t, `'mi[*]e`, "k", true, "'mik[*]e")
testExtend(t, `'mi[*]'`, "ke", true, "'mike' [*]")
testExtend(t, `'mi'[*]`, "ke", true, "'mike' [*]")
testExtend(t, `'mi[*]'`, "ke", false, "'mike[*]'")
testExtend(t, `'mi'[*]`, "ke", false, "'mike[*]'")
testExtend(t, `$f[*]`, "oo", false, "$foo[*]")
testExtend(t, `${f}[*]`, "oo", false, "${foo[*]}")
testExtend(t, `${f[*]}`, "oo", true, "${foo} [*]")
testExtend(t, `[*]`, "more stuff", false, `more\ stuff[*]`)
testExtend(t, `[*]`, "hello\amike", false, `hello$'\a'mike[*]`)
testExtend(t, `$'he[*]'`, "\x01\x02\x0a", true, `$'he\x01\x02\n' [*]`)
testExtend(t, `${x}\ [*]ll$y`, "e", false, `${x}\ e[*]ll$y`)
testExtend(t, `"he[*]"`, "$$o", true, `"he\$\$o" [*]`)
testExtend(t, `"h[*]llo"`, "e", false, `"he[*]llo"`)
testExtend(t, `"h[*]llo"`, "e", true, `"he[*]llo"`)
testExtend(t, `"[*]${h}llo"`, "e\x01", true, `"e"$'\x01'[*]"${h}llo"`)
testExtend(t, `"${h}llo[*]"`, "e\x01", true, `"${h}lloe"$'\x01' [*]`)
testExtend(t, `"${h}llo[*]"`, "e\x01", false, `"${h}lloe"$'\x01'[*]`)
testExtend(t, `"${h}ll[*]o"`, "e\x01", false, `"${h}lle"$'\x01'[*]"o"`)
testExtend(t, `"ab[*]c${x}def"`, "\x01", false, `"ab"$'\x01'[*]"c${x}def"`)
testExtend(t, `'ab[*]ef'`, "\x01", false, `'ab'$'\x01'[*]'ef'`)
// testExtend(t, `'he'`, "llo", `'hello'`)
// testExtend(t, `'he'`, "'", `'he'\'''`)
// testExtend(t, `'he'`, "'\x01", `'he'\'$'\x01'''`)
// testExtend(t, `he`, "llo", `hello`)
// testExtend(t, `he`, "l*l'\x01\x07o", `hel\*l\'$'\x01'$'\a'o`)
// testExtend(t, `$x`, "fo|o", `$xfoo`)
// testExtend(t, `${x`, "fo|o", `${xfoo`)
// testExtend(t, `$'f`, "oo", `$'foo`)
// testExtend(t, `$'f`, "'\x01\x07o", `$'f\'\x01\ao`)
// testExtend(t, `"f"`, "oo", `"foo"`)
// testExtend(t, `"mi"`, "ke's \"hello\"", `"mike's \"hello\""`)
// testExtend(t, `"t"`, "t\x01\x07", `"tt"$'\x01'$'\a'""`)
}
func testParseCommands(t *testing.T, str string) {
fmt.Printf("parse: %q\n", str)
words := Tokenize(str)
cmds := ParseCommands(words)
dumpCommands(cmds, " ", -1)
fmt.Printf("\n")
}
func TestCmd(t *testing.T) {
testParseCommands(t, "ls foo")
testParseCommands(t, "function foo () { echo hello; }")
testParseCommands(t, "ls foo && ls bar; ./run $x hello | xargs foo; ")
testParseCommands(t, "if [[ 2 > 1 ]]; then echo hello\nelse echo world; echo next; done")
testParseCommands(t, "case lots of stuff; i don\\'t know how to parse; esac; ls foo")
testParseCommands(t, "(ls & ./x \n \n); for x in $vars 3; do { echo $x; ls foo ; } done")
testParseCommands(t, `ls f"oo" "${x:"hello$y"}"`)
testParseCommands(t, `x="foo $y" z=10 ls`)
}
func testCompPos(t *testing.T, cmdStr string, compType string, hasCommand bool, cmdWordPos int, hasWord bool, superOffset int) {
cmdSP := utilfn.ParseToSP(cmdStr)
words := Tokenize(cmdSP.Str)
cmds := ParseCommands(words)
cpos := FindCompletionPos(cmds, cmdSP.Pos)
fmt.Printf("testCompPos [%d] %q => [%s] %v\n", cmdSP.Pos, cmdStr, cpos.CompType, cpos)
if cpos.CompType != compType {
t.Errorf("testCompPos %q => invalid comp-type %q, expected %q", cmdStr, cpos.CompType, compType)
}
if cpos.CompWord != nil {
fmt.Printf(" found-word: %d %s\n", cpos.CompWordOffset, cpos.CompWord.stringWithPos(cpos.CompWordOffset))
}
if cpos.Cmd != nil {
fmt.Printf(" found-cmd: ")
dumpCommands([]*CmdType{cpos.Cmd}, " ", cpos.RawPos)
}
dumpCommands(cmds, " ", cmdSP.Pos)
fmt.Printf("\n")
if cpos.RawPos+cpos.SuperOffset != cmdSP.Pos {
t.Errorf("testCompPos %q => bad rawpos:%d superoffset:%d expected:%d", cmdStr, cpos.RawPos, cpos.SuperOffset, cmdSP.Pos)
}
if (cpos.Cmd != nil) != hasCommand {
t.Errorf("testCompPos %q => bad has-command exp:%v", cmdStr, hasCommand)
}
if (cpos.CompWord != nil) != hasWord {
t.Errorf("testCompPos %q => bad has-word exp:%v", cmdStr, hasWord)
}
if cpos.CmdWordPos != cmdWordPos {
t.Errorf("testCompPos %q => bad cmd-word-pos got:%d exp:%d", cmdStr, cpos.CmdWordPos, cmdWordPos)
}
if cpos.SuperOffset != superOffset {
t.Errorf("testCompPos %q => bad super-offset got:%d exp:%d", cmdStr, cpos.SuperOffset, superOffset)
}
}
func TestCompPos(t *testing.T) {
testCompPos(t, "ls [*]foo", CompTypeArg, true, 1, false, 0)
testCompPos(t, "ls foo [*];", CompTypeArg, true, 2, false, 0)
testCompPos(t, "ls foo ;[*]", CompTypeCommand, false, 0, false, 0)
testCompPos(t, "ls foo >[*]> ./bar", CompTypeInvalid, true, 2, true, 0)
testCompPos(t, "l[*]s", CompTypeCommand, true, 0, true, 0)
testCompPos(t, "ls[*]", CompTypeCommand, true, 0, true, 0)
testCompPos(t, "x=10 { (ls ./f[*] more); ls }", CompTypeArg, true, 1, true, 0)
testCompPos(t, "for x in 1[*] 2 3; do ", CompTypeBasic, false, 0, true, 0)
testCompPos(t, "for[*] x in 1 2 3;", CompTypeInvalid, false, 0, true, 0)
testCompPos(t, `ls "abc $(ls -l t[*])" && foo`, CompTypeArg, true, 2, true, 10)
testCompPos(t, "ls ${abc:$(ls -l [*])}", CompTypeVar, false, 0, true, 0) // we don't sub-parse inside of ${} (so this returns "var" right now)
testCompPos(t, `ls abc"$(ls $"echo $(ls ./[*]x) foo)" `, CompTypeArg, true, 1, true, 21)
testCompPos(t, `ls "abc$d[*]"`, CompTypeVar, false, 0, true, 4)
testCompPos(t, `ls "abc$d$'a[*]`, CompTypeArg, true, 1, true, 0)
testCompPos(t, `ls $[*]'foo`, CompTypeArg, true, 1, true, 0)
testCompPos(t, `echo $TE[*]`, CompTypeVar, false, 0, true, 0)
}
func testExpand(t *testing.T, str string, pos int, expStr string, expInfo *ExpandInfo) {
ectx := ExpandContext{HomeDir: "/Users/mike"}
words := Tokenize(str)
if len(words) == 0 {
t.Errorf("could not tokenize any words from %q", str)
return
}
word := words[0]
output, info := SimpleExpandPrefix(ectx, word, pos)
if output != expStr {
t.Errorf("error expanding %q, output:%q exp:%q", str, output, expStr)
} else {
fmt.Printf("expand: %q (%d) => %q\n", str, pos, output)
}
if expInfo != nil {
if info != *expInfo {
t.Errorf("error expanding %q, info:%v exp:%v", str, info, expInfo)
}
}
}
func TestExpand(t *testing.T) {
testExpand(t, "hello", 3, "hel", nil)
testExpand(t, "he\\$xabc", 6, "he$xa", nil)
testExpand(t, "he${x}abc", 6, "he${x}", nil)
testExpand(t, "'hello\"mike'", 8, "hello\"m", nil)
testExpand(t, `$'abc\x01def`, 10, "abc\x01d", nil)
testExpand(t, `$((2 + 2))`, 6, "$((2 +", &ExpandInfo{HasSpecial: true})
testExpand(t, `abc"def"`, 6, "abcde", nil)
testExpand(t, `"abc$x$'"'""`, 12, "abc$x\"", nil)
testExpand(t, `'he'\''s'`, 9, "he's", nil)
}

View File

@ -0,0 +1,601 @@
package shparse
import (
"fmt"
"unicode"
)
// from bash source
//
// shell_meta_chars "()<>;&|"
//
type tokenizeOutputState struct {
Rtn []*WordType
CurWord *WordType
SavedPrefix []rune
}
func copyRunes(rarr []rune) []rune {
if len(rarr) == 0 {
return nil
}
return append([]rune(nil), rarr...)
}
// does not set CurWord
func (state *tokenizeOutputState) appendStandaloneWord(word *WordType) {
state.delimitCurWord()
if len(state.SavedPrefix) > 0 {
word.Prefix = state.SavedPrefix
state.SavedPrefix = nil
}
state.Rtn = append(state.Rtn, word)
}
func (state *tokenizeOutputState) appendWord(word *WordType) {
if len(state.SavedPrefix) > 0 {
word.Prefix = state.SavedPrefix
state.SavedPrefix = nil
}
if state.CurWord == nil {
state.CurWord = word
return
}
state.ensureGroupWord()
word.Offset = word.Offset - state.CurWord.Offset
state.CurWord.Subs = append(state.CurWord.Subs, word)
state.CurWord.Raw = append(state.CurWord.Raw, word.Raw...)
}
func (state *tokenizeOutputState) ensureGroupWord() {
if state.CurWord == nil {
panic("invalid state, cannot make group word when CurWord is nil")
}
if state.CurWord.Type == WordTypeGroup {
return
}
// moves the prefix from CurWord to the new group word, resets offsets
groupWord := &WordType{
Type: WordTypeGroup,
Offset: state.CurWord.Offset,
QC: state.CurWord.QC,
Raw: copyRunes(state.CurWord.Raw),
Complete: true,
Prefix: state.CurWord.Prefix,
}
state.CurWord.Prefix = nil
state.CurWord.Offset = 0
groupWord.Subs = []*WordType{state.CurWord}
state.CurWord = groupWord
}
func ungroupWord(groupWord *WordType) []*WordType {
if groupWord.Type != WordTypeGroup {
return []*WordType{groupWord}
}
rtn := groupWord.Subs
if len(groupWord.Prefix) > 0 && len(rtn) > 0 {
newPrefix := append([]rune{}, groupWord.Prefix...)
newPrefix = append(newPrefix, rtn[0].Prefix...)
rtn[0].Prefix = newPrefix
}
for _, word := range rtn {
word.Offset = word.Offset + groupWord.Offset
}
return rtn
}
func (state *tokenizeOutputState) ensureLitCurWord(pc *parseContext) {
if state.CurWord == nil {
state.CurWord = pc.makeWord(WordTypeLit, 0, true)
state.CurWord.Prefix = state.SavedPrefix
state.SavedPrefix = nil
return
}
if state.CurWord.Type == WordTypeLit {
return
}
state.ensureGroupWord()
lastWord := state.CurWord.Subs[len(state.CurWord.Subs)-1]
if lastWord.Type != WordTypeLit {
if len(state.SavedPrefix) > 0 {
panic("invalid state, there can be no saved prefix")
}
litWord := pc.makeWord(WordTypeLit, 0, true)
litWord.Offset = litWord.Offset - state.CurWord.Offset
state.CurWord.Subs = append(state.CurWord.Subs, litWord)
}
}
func (state *tokenizeOutputState) delimitCurWord() {
if state.CurWord != nil {
state.Rtn = append(state.Rtn, state.CurWord)
state.CurWord = nil
}
}
func (state *tokenizeOutputState) delimitWithSpace(spaceCh rune) {
state.delimitCurWord()
state.SavedPrefix = append(state.SavedPrefix, spaceCh)
}
func (state *tokenizeOutputState) appendLiteral(pc *parseContext, ch rune) {
state.ensureLitCurWord(pc)
if state.CurWord.Type == WordTypeLit {
state.CurWord.Raw = append(state.CurWord.Raw, ch)
} else if state.CurWord.Type == WordTypeGroup {
lastWord := state.CurWord.Subs[len(state.CurWord.Subs)-1]
if lastWord.Type != WordTypeLit {
panic(fmt.Sprintf("invalid curword type (group) %q", state.CurWord.Type))
}
lastWord.Raw = append(lastWord.Raw, ch)
state.CurWord.Raw = append(state.CurWord.Raw, ch)
} else {
panic(fmt.Sprintf("invalid curword type %q", state.CurWord.Type))
}
}
func (state *tokenizeOutputState) finish(pc *parseContext) {
state.delimitCurWord()
if len(state.SavedPrefix) > 0 {
state.ensureLitCurWord(pc)
state.delimitCurWord()
}
}
func (c *parseContext) tokenizeVarBrace() ([]*WordType, bool) {
state := &tokenizeOutputState{}
eofExit := false
for {
ch := c.cur()
if ch == 0 {
eofExit = true
break
}
if ch == '}' {
c.Pos++
break
}
var quoteWord *WordType
if ch == '\'' {
quoteWord = c.parseStrSQ()
}
if quoteWord == nil && ch == '"' {
quoteWord = c.parseStrDQ()
}
isNextBrace := c.at(1) == '}'
if quoteWord == nil && ch == '$' && !isNextBrace {
quoteWord = c.parseStrANSI()
if quoteWord == nil {
quoteWord = c.parseStrDDQ()
}
if quoteWord == nil {
quoteWord = c.parseExpansion()
}
}
if quoteWord != nil {
state.appendWord(quoteWord)
continue
}
if ch == '\\' && c.at(1) != 0 {
state.appendLiteral(c, ch)
state.appendLiteral(c, c.at(1))
c.Pos += 2
continue
}
state.appendLiteral(c, ch)
c.Pos++
}
return state.Rtn, eofExit
}
func (c *parseContext) tokenizeDQ() ([]*WordType, bool) {
state := &tokenizeOutputState{}
eofExit := false
for {
ch := c.cur()
if ch == 0 {
eofExit = true
break
}
if ch == '"' {
c.Pos++
break
}
if ch == '$' && c.at(1) != 0 {
quoteWord := c.parseStrANSI()
if quoteWord == nil {
quoteWord = c.parseStrDDQ()
}
if quoteWord == nil {
quoteWord = c.parseExpansion()
}
if quoteWord != nil {
state.appendWord(quoteWord)
continue
}
}
if ch == '\\' && c.at(1) != 0 {
state.appendLiteral(c, ch)
state.appendLiteral(c, c.at(1))
c.Pos += 2
continue
}
state.appendLiteral(c, ch)
c.Pos++
}
state.finish(c)
if len(state.Rtn) == 0 {
return nil, eofExit
}
if len(state.Rtn) == 1 && state.Rtn[0].Type == WordTypeGroup {
return ungroupWord(state.Rtn[0]), eofExit
}
return state.Rtn, eofExit
}
// returns (words, eofexit)
// backticks (WordTypeBQ) handle backslash in a special way, but that seems to mainly effect execution (not completion)
// de_backslash => removes initial backslash in \`, \\, and \$ before execution
func (c *parseContext) tokenizeRaw() ([]*WordType, bool) {
state := &tokenizeOutputState{}
isExpSubShell := c.QC.cur() == WordTypeDP
isInBQ := c.QC.cur() == WordTypeBQ
parenLevel := 0
eofExit := false
for {
ch := c.cur()
if ch == 0 {
eofExit = true
break
}
if isExpSubShell && ch == ')' && parenLevel == 0 {
c.Pos++
break
}
if isInBQ && ch == '`' {
c.Pos++
break
}
// fmt.Printf("ch %d %q\n", c.Pos, string([]rune{ch}))
foundOp, newOffset := c.parseOp(0)
if foundOp {
opVal := string(c.Input[c.Pos : c.Pos+newOffset])
if opVal == "(" {
arithWord := c.parseArith(true)
if arithWord != nil {
state.appendStandaloneWord(arithWord)
continue
} else {
parenLevel++
}
}
if opVal == ")" {
parenLevel--
}
opWord := c.makeWord(WordTypeOp, newOffset, true)
state.appendStandaloneWord(opWord)
continue
}
var quoteWord *WordType
if ch == '\'' {
quoteWord = c.parseStrSQ()
}
if quoteWord == nil && ch == '"' {
quoteWord = c.parseStrDQ()
}
if quoteWord == nil && ch == '`' {
quoteWord = c.parseStrBQ()
}
isNextParen := isExpSubShell && c.at(1) == ')'
if quoteWord == nil && ch == '$' && !isNextParen {
quoteWord = c.parseStrANSI()
if quoteWord == nil {
quoteWord = c.parseStrDDQ()
}
if quoteWord == nil {
quoteWord = c.parseExpansion()
}
}
if quoteWord != nil {
state.appendWord(quoteWord)
continue
}
if ch == '\\' && c.at(1) != 0 {
state.appendLiteral(c, ch)
state.appendLiteral(c, c.at(1))
c.Pos += 2
continue
}
if ch == '\n' {
newlineWord := c.makeWord(WordTypeOp, 1, true)
state.appendStandaloneWord(newlineWord)
continue
}
if unicode.IsSpace(ch) {
state.delimitWithSpace(ch)
c.Pos++
continue
}
state.appendLiteral(c, ch)
c.Pos++
}
state.finish(c)
return state.Rtn, eofExit
}
type parseContext struct {
Input []rune
Pos int
QC QuoteContext
}
func (c *parseContext) clone(pos int, newQuote string) *parseContext {
rtn := parseContext{Input: c.Input[pos:], QC: c.QC}
if newQuote != "" {
rtn.QC = rtn.QC.push(newQuote)
}
return &rtn
}
func (c *parseContext) at(offset int) rune {
pos := c.Pos + offset
if pos < 0 || pos >= len(c.Input) {
return 0
}
return c.Input[pos]
}
func (c *parseContext) eof() bool {
return c.Pos >= len(c.Input)
}
func (c *parseContext) cur() rune {
return c.at(0)
}
func (c *parseContext) match(ch rune) bool {
return c.at(0) == ch
}
func (c *parseContext) match2(ch rune, ch2 rune) bool {
return c.at(0) == ch && c.at(1) == ch2
}
func (c *parseContext) match3(ch rune, ch2 rune, ch3 rune) bool {
return c.at(0) == ch && c.at(1) == ch2 && c.at(2) == ch3
}
func (c *parseContext) makeWord(t string, length int, complete bool) *WordType {
rtn := &WordType{Type: t}
rtn.Offset = c.Pos
rtn.QC = c.QC
rtn.Raw = copyRunes(c.Input[c.Pos : c.Pos+length])
rtn.Complete = complete
c.Pos += length
return rtn
}
// returns (found, newOffset)
// shell_meta_chars "()<>;&|"
// possible to maybe add ;;& &>> &> |& ;&
func (c *parseContext) parseOp(offset int) (bool, int) {
ch := c.at(offset)
if ch == '(' || ch == ')' || ch == '<' || ch == '>' || ch == ';' || ch == '&' || ch == '|' {
ch2 := c.at(offset + 1)
if ch2 == 0 {
return true, offset + 1
}
r2 := string([]rune{ch, ch2})
if r2 == "<<" {
ch3 := c.at(offset + 2)
if ch3 == '-' || ch3 == '<' {
return true, offset + 3 // "<<-" or "<<<"
}
return true, offset + 2 // "<<"
}
if r2 == ">>" || r2 == "&&" || r2 == "||" || r2 == ";;" || r2 == "<<" || r2 == "<&" || r2 == ">&" || r2 == "<>" || r2 == ">|" {
// we don't return '((' here (requires special processing)
return true, offset + 2
}
return true, offset + 1
}
return false, 0
}
// returns (new-offset, complete)
func (c *parseContext) skipToChar(offset int, endCh rune, allowEsc bool) (int, bool) {
for {
ch := c.at(offset)
if ch == 0 {
return offset, false
}
if allowEsc && ch == '\\' {
if c.at(offset+1) == 0 {
return offset + 1, false
}
offset += 2
continue
}
if ch == endCh {
return offset + 1, true
}
offset++
}
}
// returns (new-offset, complete)
func (c *parseContext) skipToChar2(offset int, endCh rune, endCh2 rune, allowEsc bool) (int, bool) {
for {
ch := c.at(offset)
ch2 := c.at(offset + 1)
if ch == 0 {
return offset, false
}
if ch2 == 0 {
return offset + 1, false
}
if allowEsc && ch == '\\' {
offset += 2
continue
}
if ch == endCh && ch2 == endCh2 {
return offset + 2, true
}
offset++
}
}
func (c *parseContext) parseStrSQ() *WordType {
if !c.match('\'') {
return nil
}
newOffset, complete := c.skipToChar(1, '\'', false)
w := c.makeWord(WordTypeSQ, newOffset, complete)
return w
}
func (c *parseContext) parseStrDQ() *WordType {
if !c.match('"') {
return nil
}
newContext := c.clone(c.Pos+1, WordTypeDQ)
subWords, eofExit := newContext.tokenizeDQ()
newOffset := newContext.Pos + 1
w := c.makeWord(WordTypeDQ, newOffset, !eofExit)
w.Subs = subWords
return w
}
func (c *parseContext) parseStrDDQ() *WordType {
if !c.match2('$', '"') {
return nil
}
newContext := c.clone(c.Pos+2, WordTypeDQ) // use WordTypeDQ (not DDQ)
subWords, eofExit := newContext.tokenizeDQ()
newOffset := newContext.Pos + 2
w := c.makeWord(WordTypeDDQ, newOffset, !eofExit)
w.Subs = subWords
return w
}
func (c *parseContext) parseStrBQ() *WordType {
if !c.match('`') {
return nil
}
newContext := c.clone(c.Pos+1, WordTypeBQ)
subWords, eofExit := newContext.tokenizeRaw()
newOffset := newContext.Pos + 1
w := c.makeWord(WordTypeBQ, newOffset, !eofExit)
w.Subs = subWords
return w
}
func (c *parseContext) parseStrANSI() *WordType {
if !c.match2('$', '\'') {
return nil
}
newOffset, complete := c.skipToChar(2, '\'', true)
w := c.makeWord(WordTypeDSQ, newOffset, complete)
return w
}
func (c *parseContext) parseArith(mustComplete bool) *WordType {
if !c.match2('(', '(') {
return nil
}
newOffset, complete := c.skipToChar2(2, ')', ')', false)
if mustComplete && !complete {
return nil
}
w := c.makeWord(WordTypePP, newOffset, complete)
return w
}
func (c *parseContext) parseExpansion() *WordType {
if !c.match('$') {
return nil
}
if c.match3('$', '(', '(') {
newOffset, complete := c.skipToChar2(3, ')', ')', false)
w := c.makeWord(WordTypeDPP, newOffset, complete)
return w
}
if c.match2('$', '(') {
// subshell
newContext := c.clone(c.Pos+2, WordTypeDP)
subWords, eofExit := newContext.tokenizeRaw()
newOffset := newContext.Pos + 2
w := c.makeWord(WordTypeDP, newOffset, !eofExit)
w.Subs = subWords
return w
}
if c.match2('$', '[') {
// deprecated arith expansion
newOffset, complete := c.skipToChar(2, ']', false)
w := c.makeWord(WordTypeDB, newOffset, complete)
return w
}
if c.match2('$', '{') {
// variable expansion
newContext := c.clone(c.Pos+2, WordTypeVarBrace)
_, eofExit := newContext.tokenizeVarBrace()
newOffset := newContext.Pos + 2
w := c.makeWord(WordTypeVarBrace, newOffset, !eofExit)
return w
}
ch2 := c.at(1)
if ch2 == 0 || unicode.IsSpace(ch2) {
// no expansion
return nil
}
newOffset := c.parseSimpleVarName(1)
if newOffset > 1 {
// simple variable name
w := c.makeWord(WordTypeSimpleVar, newOffset, true)
return w
}
if ch2 == '*' || ch2 == '@' || ch2 == '#' || ch2 == '?' || ch2 == '-' || ch2 == '$' || ch2 == '!' || (ch2 >= '0' && ch2 <= '9') {
// single character variable name, e.g. $@, $_, $1, etc.
w := c.makeWord(WordTypeSimpleVar, 2, true)
return w
}
return nil
}
// returns newOffset
func (c *parseContext) parseSimpleVarName(offset int) int {
first := true
for {
ch := c.at(offset)
if ch == 0 {
return offset
}
if (ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) || (!first && ch >= '0' && ch <= '9') {
first = false
offset++
continue
}
return offset
}
}
func isSimpleVarName(rstr []rune) bool {
if len(rstr) == 0 {
return false
}
for idx, ch := range rstr {
if (ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) || ((idx != 0) && ch >= '0' && ch <= '9') {
continue
}
return false
}
return true
}
func Tokenize(cmd string) []*WordType {
c := &parseContext{Input: []rune(cmd)}
rtn, _ := c.tokenizeRaw()
return rtn
}

2654
wavesrv/pkg/sstore/dbops.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,184 @@
package sstore
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io/fs"
"log"
"os"
"path"
"github.com/commandlinedev/apishell/pkg/cirfile"
"github.com/commandlinedev/prompt-server/pkg/scbase"
"github.com/google/uuid"
)
func CreateCmdPtyFile(ctx context.Context, screenId string, lineId string, maxSize int64) error {
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
if err != nil {
return err
}
f, err := cirfile.CreateCirFile(ptyOutFileName, maxSize)
if err != nil {
return err
}
return f.Close()
}
func StatCmdPtyFile(ctx context.Context, screenId string, lineId string) (*cirfile.Stat, error) {
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
if err != nil {
return nil, err
}
return cirfile.StatCirFile(ctx, ptyOutFileName)
}
func AppendToCmdPtyBlob(ctx context.Context, screenId string, lineId string, data []byte, pos int64) (*PtyDataUpdate, error) {
if screenId == "" {
return nil, fmt.Errorf("cannot append to PtyBlob, screenid is not set")
}
if pos < 0 {
return nil, fmt.Errorf("invalid seek pos '%d' in AppendToCmdPtyBlob", pos)
}
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
if err != nil {
return nil, err
}
f, err := cirfile.OpenCirFile(ptyOutFileName)
if err != nil {
return nil, err
}
defer f.Close()
err = f.WriteAt(ctx, data, pos)
if err != nil {
return nil, err
}
data64 := base64.StdEncoding.EncodeToString(data)
update := &PtyDataUpdate{
ScreenId: screenId,
LineId: lineId,
PtyPos: pos,
PtyData64: data64,
PtyDataLen: int64(len(data)),
}
err = MaybeInsertPtyPosUpdate(ctx, screenId, lineId)
if err != nil {
// just log
log.Printf("error inserting ptypos update %s/%s: %v\n", screenId, lineId, err)
}
return update, nil
}
// returns (real-offset, data, err)
func ReadFullPtyOutFile(ctx context.Context, screenId string, lineId string) (int64, []byte, error) {
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
if err != nil {
return 0, nil, err
}
f, err := cirfile.OpenCirFile(ptyOutFileName)
if err != nil {
return 0, nil, err
}
defer f.Close()
return f.ReadAll(ctx)
}
// returns (real-offset, data, err)
func ReadPtyOutFile(ctx context.Context, screenId string, lineId string, offset int64, maxSize int64) (int64, []byte, error) {
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
if err != nil {
return 0, nil, err
}
f, err := cirfile.OpenCirFile(ptyOutFileName)
if err != nil {
return 0, nil, err
}
defer f.Close()
return f.ReadAtWithMax(ctx, offset, maxSize)
}
type SessionDiskSizeType struct {
NumFiles int
TotalSize int64
ErrorCount int
Location string
}
func directorySize(dirName string) (SessionDiskSizeType, error) {
var rtn SessionDiskSizeType
rtn.Location = dirName
entries, err := os.ReadDir(dirName)
if err != nil {
return rtn, err
}
for _, entry := range entries {
if entry.IsDir() {
rtn.ErrorCount++
continue
}
finfo, err := entry.Info()
if err != nil {
rtn.ErrorCount++
continue
}
rtn.NumFiles++
rtn.TotalSize += finfo.Size()
}
return rtn, nil
}
func SessionDiskSize(sessionId string) (SessionDiskSizeType, error) {
sessionDir, err := scbase.EnsureSessionDir(sessionId)
if err != nil {
return SessionDiskSizeType{}, err
}
return directorySize(sessionDir)
}
func FullSessionDiskSize() (map[string]SessionDiskSizeType, error) {
sdir := scbase.GetSessionsDir()
entries, err := os.ReadDir(sdir)
if err != nil {
return nil, err
}
rtn := make(map[string]SessionDiskSizeType)
for _, entry := range entries {
if !entry.IsDir() {
continue
}
name := entry.Name()
_, err = uuid.Parse(name)
if err != nil {
continue
}
diskSize, err := directorySize(path.Join(sdir, name))
if err != nil {
continue
}
rtn[name] = diskSize
}
return rtn, nil
}
func DeletePtyOutFile(ctx context.Context, screenId string, lineId string) error {
ptyOutFileName, err := scbase.PtyOutFile(screenId, lineId)
if err != nil {
return err
}
err = os.Remove(ptyOutFileName)
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
func DeleteScreenDir(ctx context.Context, screenId string) error {
screenDir, err := scbase.EnsureScreenDir(screenId)
if err != nil {
return fmt.Errorf("error getting screendir: %w", err)
}
log.Printf("remove-all %s\n", screenDir)
return os.RemoveAll(screenDir)
}

33
wavesrv/pkg/sstore/map.go Normal file
View File

@ -0,0 +1,33 @@
package sstore
import (
"context"
)
func WithTxRtn[RT any](ctx context.Context, fn func(tx *TxWrap) (RT, error)) (RT, error) {
var rtn RT
txErr := WithTx(ctx, func(tx *TxWrap) error {
temp, err := fn(tx)
if err != nil {
return err
}
rtn = temp
return nil
})
return rtn, txErr
}
func WithTxRtn3[RT1 any, RT2 any](ctx context.Context, fn func(tx *TxWrap) (RT1, RT2, error)) (RT1, RT2, error) {
var rtn1 RT1
var rtn2 RT2
txErr := WithTx(ctx, func(tx *TxWrap) error {
temp1, temp2, err := fn(tx)
if err != nil {
return err
}
rtn1 = temp1
rtn2 = temp2
return nil
})
return rtn1, rtn2, txErr
}

View File

@ -0,0 +1,213 @@
package sstore
import (
"fmt"
"io"
"log"
"os"
"strconv"
"time"
sh2db "github.com/commandlinedev/prompt-server/db"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "github.com/mattn/go-sqlite3"
"github.com/golang-migrate/migrate/v4"
)
const MaxMigration = 22
const MigratePrimaryScreenVersion = 9
const CmdScreenSpecialMigration = 13
const CmdLineSpecialMigration = 20
func MakeMigrate() (*migrate.Migrate, error) {
fsVar, err := iofs.New(sh2db.MigrationFS, "migrations")
if err != nil {
return nil, fmt.Errorf("opening iofs: %w", err)
}
// migrationPathUrl := fmt.Sprintf("file://%s", path.Join(wd, "db", "migrations"))
dbUrl := fmt.Sprintf("sqlite3://%s", GetDBName())
m, err := migrate.NewWithSourceInstance("iofs", fsVar, dbUrl)
// m, err := migrate.New(migrationPathUrl, dbUrl)
if err != nil {
return nil, fmt.Errorf("making migration db[%s]: %w", GetDBName(), err)
}
return m, nil
}
func copyFile(srcFile string, dstFile string) error {
if srcFile == dstFile {
return fmt.Errorf("cannot copy %s to itself", srcFile)
}
srcFd, err := os.Open(srcFile)
if err != nil {
return fmt.Errorf("cannot open %s: %v", err)
}
defer srcFd.Close()
dstFd, err := os.OpenFile(dstFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("cannot open destination file %s: %v", err)
}
_, err = io.Copy(dstFd, srcFd)
if err != nil {
dstFd.Close()
return fmt.Errorf("error copying file: %v", err)
}
return dstFd.Close()
}
func MigrateUpStep(m *migrate.Migrate, newVersion uint) error {
startTime := time.Now()
err := m.Migrate(newVersion)
if err != nil {
return err
}
if newVersion == CmdScreenSpecialMigration {
mErr := RunMigration13()
if mErr != nil {
return fmt.Errorf("migrating to v%d: %w", newVersion, mErr)
}
}
if newVersion == CmdLineSpecialMigration {
mErr := RunMigration20()
if mErr != nil {
return fmt.Errorf("migrating to v%d: %w", newVersion, mErr)
}
}
log.Printf("[db] migration v%d, elapsed %v\n", newVersion, time.Since(startTime))
return nil
}
func MigrateUp(targetVersion uint) error {
m, err := MakeMigrate()
if err != nil {
return err
}
curVersion, dirty, err := MigrateVersion(m)
if dirty {
return fmt.Errorf("cannot migrate up, database is dirty")
}
if err != nil {
return fmt.Errorf("cannot get current migration version: %v", err)
}
if curVersion >= targetVersion {
return nil
}
log.Printf("[db] migrating from %d to %d\n", curVersion, targetVersion)
log.Printf("[db] backing up database %s to %s\n", DBFileName, DBFileNameBackup)
err = copyFile(GetDBName(), GetDBBackupName())
if err != nil {
return fmt.Errorf("error creating database backup: %v", err)
}
for newVersion := curVersion + 1; newVersion <= targetVersion; newVersion++ {
err = MigrateUpStep(m, newVersion)
if err != nil {
return fmt.Errorf("during migration v%d: %w", newVersion, err)
}
}
log.Printf("[db] migration done, new version = %d\n", targetVersion)
return nil
}
// returns curVersion, dirty, error
func MigrateVersion(m *migrate.Migrate) (uint, bool, error) {
if m == nil {
var err error
m, err = MakeMigrate()
if err != nil {
return 0, false, err
}
}
curVersion, dirty, err := m.Version()
if err == migrate.ErrNilVersion {
return 0, false, nil
}
return curVersion, dirty, err
}
func MigrateDown() error {
m, err := MakeMigrate()
if err != nil {
return err
}
err = m.Down()
if err != nil {
return err
}
return nil
}
func MigrateGoto(n uint) error {
curVersion, _, _ := MigrateVersion(nil)
if curVersion == n {
return nil
}
if curVersion < n {
return MigrateUp(n)
}
m, err := MakeMigrate()
if err != nil {
return err
}
err = m.Migrate(n)
if err != nil {
return err
}
return nil
}
func TryMigrateUp() error {
curVersion, _, _ := MigrateVersion(nil)
log.Printf("[db] db version = %d\n", curVersion)
if curVersion >= MaxMigration {
return nil
}
err := MigrateUp(MaxMigration)
if err != nil {
return err
}
return MigratePrintVersion()
}
func MigratePrintVersion() error {
version, dirty, err := MigrateVersion(nil)
if err != nil {
return fmt.Errorf("error getting db version: %v", err)
}
if dirty {
return fmt.Errorf("error db is dirty, version=%d", version)
}
log.Printf("[db] version=%d\n", version)
return nil
}
func MigrateCommandOpts(opts []string) error {
var err error
if opts[0] == "--migrate-up" {
fmt.Printf("migrate-up %v\n", GetDBName())
time.Sleep(3 * time.Second)
err = MigrateUp(MaxMigration)
} else if opts[0] == "--migrate-down" {
fmt.Printf("migrate-down %v\n", GetDBName())
time.Sleep(3 * time.Second)
err = MigrateDown()
} else if opts[0] == "--migrate-goto" {
n, err := strconv.Atoi(opts[1])
if err == nil {
fmt.Printf("migrate-goto %v => %d\n", GetDBName(), n)
time.Sleep(3 * time.Second)
err = MigrateGoto(uint(n))
}
} else {
err = fmt.Errorf("invalid migration command")
}
if err != nil && err.Error() == migrate.ErrNoChange.Error() {
err = nil
}
if err != nil {
return err
}
return MigratePrintVersion()
}

View File

@ -0,0 +1,19 @@
package sstore
import (
"github.com/commandlinedev/prompt-server/pkg/dbutil"
)
var quickSetStr = dbutil.QuickSetStr
var quickSetInt64 = dbutil.QuickSetInt64
var quickSetInt = dbutil.QuickSetInt
var quickSetBool = dbutil.QuickSetBool
var quickSetBytes = dbutil.QuickSetBytes
var quickSetJson = dbutil.QuickSetJson
var quickSetNullableJson = dbutil.QuickSetNullableJson
var quickSetJsonArr = dbutil.QuickSetJsonArr
var quickNullableJson = dbutil.QuickNullableJson
var quickJson = dbutil.QuickJson
var quickJsonArr = dbutil.QuickJsonArr
var quickScanJson = dbutil.QuickScanJson
var quickValueJson = dbutil.QuickValueJson

1295
wavesrv/pkg/sstore/sstore.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,154 @@
package sstore
import (
"context"
"fmt"
"log"
"os"
"time"
"github.com/commandlinedev/prompt-server/pkg/scbase"
)
const MigrationChunkSize = 10
type cmdMigration13Type struct {
SessionId string
ScreenId string
CmdId string
}
type cmdMigration20Type struct {
ScreenId string
LineId string
CmdId string
}
func getSliceChunk[T any](slice []T, chunkSize int) ([]T, []T) {
if chunkSize >= len(slice) {
return slice, nil
}
return slice[0:chunkSize], slice[chunkSize:]
}
func RunMigration20() error {
ctx := context.Background()
startTime := time.Now()
var migrations []cmdMigration20Type
txErr := WithTx(ctx, func(tx *TxWrap) error {
tx.Select(&migrations, `SELECT * FROM cmd_migrate20`)
return nil
})
if txErr != nil {
return fmt.Errorf("trying to get cmd20 migrations: %w", txErr)
}
log.Printf("[db] got %d cmd-line migrations\n", len(migrations))
for len(migrations) > 0 {
var mchunk []cmdMigration20Type
mchunk, migrations = getSliceChunk(migrations, MigrationChunkSize)
err := processMigration20Chunk(ctx, mchunk)
if err != nil {
return fmt.Errorf("cmd migration failed on chunk: %w", err)
}
}
log.Printf("[db] cmd line migration done: %v\n", time.Since(startTime))
return nil
}
func processMigration20Chunk(ctx context.Context, mchunk []cmdMigration20Type) error {
for _, mig := range mchunk {
newFile, err := scbase.PtyOutFile(mig.ScreenId, mig.LineId)
if err != nil {
log.Printf("ptyoutfile(lineid) error: %v\n", err)
continue
}
oldFile, err := scbase.PtyOutFile(mig.ScreenId, mig.CmdId)
if err != nil {
log.Printf("ptyoutfile(cmdid) error: %v\n", err)
continue
}
err = os.Rename(oldFile, newFile)
if err != nil {
log.Printf("error renaming %s => %s: %v\n", oldFile, newFile, err)
continue
}
}
txErr := WithTx(ctx, func(tx *TxWrap) error {
for _, mig := range mchunk {
query := `DELETE FROM cmd_migrate20 WHERE cmdid = ?`
tx.Exec(query, mig.CmdId)
}
return nil
})
if txErr != nil {
return txErr
}
return nil
}
func RunMigration13() error {
ctx := context.Background()
startTime := time.Now()
var migrations []cmdMigration13Type
txErr := WithTx(ctx, func(tx *TxWrap) error {
tx.Select(&migrations, `SELECT * FROM cmd_migrate`)
return nil
})
if txErr != nil {
return fmt.Errorf("trying to get cmd13 migrations: %w", txErr)
}
log.Printf("[db] got %d cmd-screen migrations\n", len(migrations))
for len(migrations) > 0 {
var mchunk []cmdMigration13Type
mchunk, migrations = getSliceChunk(migrations, MigrationChunkSize)
err := processMigration13Chunk(ctx, mchunk)
if err != nil {
return fmt.Errorf("cmd migration failed on chunk: %w", err)
}
}
err := os.RemoveAll(scbase.GetSessionsDir())
if err != nil {
return fmt.Errorf("cannot remove old sessions dir %s: %w\n", scbase.GetSessionsDir(), err)
}
txErr = WithTx(ctx, func(tx *TxWrap) error {
query := `UPDATE client SET cmdstoretype = 'screen'`
tx.Exec(query)
return nil
})
if txErr != nil {
return fmt.Errorf("cannot change client cmdstoretype: %w", err)
}
log.Printf("[db] cmd screen migration done: %v\n", time.Since(startTime))
return nil
}
func processMigration13Chunk(ctx context.Context, mchunk []cmdMigration13Type) error {
for _, mig := range mchunk {
newFile, err := scbase.PtyOutFile(mig.ScreenId, mig.CmdId)
if err != nil {
log.Printf("ptyoutfile error: %v\n", err)
continue
}
oldFile, err := scbase.PtyOutFile_Sessions(mig.SessionId, mig.CmdId)
if err != nil {
log.Printf("ptyoutfile_sessions error: %v\n", err)
continue
}
err = os.Rename(oldFile, newFile)
if err != nil {
log.Printf("error renaming %s => %s: %v\n", oldFile, newFile, err)
continue
}
}
txErr := WithTx(ctx, func(tx *TxWrap) error {
for _, mig := range mchunk {
query := `DELETE FROM cmd_migrate WHERE cmdid = ?`
tx.Exec(query, mig.CmdId)
}
return nil
})
if txErr != nil {
return txErr
}
return nil
}

View File

@ -0,0 +1,228 @@
package sstore
import (
"fmt"
"log"
"sync"
)
var MainBus *UpdateBus = MakeUpdateBus()
const PtyDataUpdateStr = "pty"
const ModelUpdateStr = "model"
const UpdateChSize = 100
type UpdatePacket interface {
UpdateType() string
Clean()
}
type PtyDataUpdate struct {
ScreenId string `json:"screenid,omitempty"`
LineId string `json:"lineid,omitempty"`
RemoteId string `json:"remoteid,omitempty"`
PtyPos int64 `json:"ptypos"`
PtyData64 string `json:"ptydata64"`
PtyDataLen int64 `json:"ptydatalen"`
}
func (*PtyDataUpdate) UpdateType() string {
return PtyDataUpdateStr
}
func (pdu *PtyDataUpdate) Clean() {}
type ModelUpdate struct {
Sessions []*SessionType `json:"sessions,omitempty"`
ActiveSessionId string `json:"activesessionid,omitempty"`
Screens []*ScreenType `json:"screens,omitempty"`
ScreenLines *ScreenLinesType `json:"screenlines,omitempty"`
Line *LineType `json:"line,omitempty"`
Lines []*LineType `json:"lines,omitempty"`
Cmd *CmdType `json:"cmd,omitempty"`
CmdLine *CmdLineType `json:"cmdline,omitempty"`
Info *InfoMsgType `json:"info,omitempty"`
ClearInfo bool `json:"clearinfo,omitempty"`
Remotes []interface{} `json:"remotes,omitempty"` // []*remote.RemoteState
History *HistoryInfoType `json:"history,omitempty"`
Interactive bool `json:"interactive"`
Connect bool `json:"connect,omitempty"`
MainView string `json:"mainview,omitempty"`
Bookmarks []*BookmarkType `json:"bookmarks,omitempty"`
SelectedBookmark string `json:"selectedbookmark,omitempty"`
HistoryViewData *HistoryViewData `json:"historyviewdata,omitempty"`
ClientData *ClientData `json:"clientdata,omitempty"`
RemoteView *RemoteViewType `json:"remoteview,omitempty"`
}
func (*ModelUpdate) UpdateType() string {
return ModelUpdateStr
}
func (update *ModelUpdate) Clean() {
if update == nil {
return
}
update.ClientData = update.ClientData.Clean()
}
type RemoteViewType struct {
RemoteShowAll bool `json:"remoteshowall,omitempty"`
PtyRemoteId string `json:"ptyremoteid,omitempty"`
RemoteEdit *RemoteEditType `json:"remoteedit,omitempty"`
}
func InfoMsgUpdate(infoMsgFmt string, args ...interface{}) *ModelUpdate {
msg := fmt.Sprintf(infoMsgFmt, args...)
return &ModelUpdate{
Info: &InfoMsgType{InfoMsg: msg},
}
}
type HistoryViewData struct {
Items []*HistoryItemType `json:"items"`
Offset int `json:"offset"`
RawOffset int `json:"rawoffset"`
NextRawOffset int `json:"nextrawoffset"`
HasMore bool `json:"hasmore"`
Lines []*LineType `json:"lines"`
Cmds []*CmdType `json:"cmds"`
}
type RemoteEditType struct {
RemoteEdit bool `json:"remoteedit"`
RemoteId string `json:"remoteid,omitempty"`
ErrorStr string `json:"errorstr,omitempty"`
InfoStr string `json:"infostr,omitempty"`
KeyStr string `json:"keystr,omitempty"`
HasPassword bool `json:"haspassword,omitempty"`
}
type InfoMsgType struct {
InfoTitle string `json:"infotitle"`
InfoError string `json:"infoerror,omitempty"`
InfoMsg string `json:"infomsg,omitempty"`
InfoMsgHtml bool `json:"infomsghtml,omitempty"`
WebShareLink bool `json:"websharelink,omitempty"`
InfoComps []string `json:"infocomps,omitempty"`
InfoCompsMore bool `json:"infocompssmore,omitempty"`
InfoLines []string `json:"infolines,omitempty"`
TimeoutMs int64 `json:"timeoutms,omitempty"`
}
type HistoryInfoType struct {
HistoryType string `json:"historytype"`
SessionId string `json:"sessionid,omitempty"`
ScreenId string `json:"screenid,omitempty"`
Items []*HistoryItemType `json:"items"`
Show bool `json:"show"`
}
type CmdLineType struct {
CmdLine string `json:"cmdline"`
CursorPos int `json:"cursorpos"`
}
type UpdateChannel struct {
ScreenId string
ClientId string
Ch chan interface{}
}
func (uch UpdateChannel) Match(screenId string) bool {
if screenId == "" {
return true
}
return screenId == uch.ScreenId
}
type UpdateBus struct {
Lock *sync.Mutex
Channels map[string]UpdateChannel
}
func MakeUpdateBus() *UpdateBus {
return &UpdateBus{
Lock: &sync.Mutex{},
Channels: make(map[string]UpdateChannel),
}
}
// always returns a new channel
func (bus *UpdateBus) RegisterChannel(clientId string, screenId string) chan interface{} {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uch, found := bus.Channels[clientId]
if found {
close(uch.Ch)
uch.ScreenId = screenId
uch.Ch = make(chan interface{}, UpdateChSize)
} else {
uch = UpdateChannel{
ClientId: clientId,
ScreenId: screenId,
Ch: make(chan interface{}, UpdateChSize),
}
}
bus.Channels[clientId] = uch
return uch.Ch
}
func (bus *UpdateBus) UnregisterChannel(clientId string) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uch, found := bus.Channels[clientId]
if found {
close(uch.Ch)
delete(bus.Channels, clientId)
}
}
func (bus *UpdateBus) SendUpdate(update UpdatePacket) {
if update == nil {
return
}
update.Clean()
bus.Lock.Lock()
defer bus.Lock.Unlock()
for _, uch := range bus.Channels {
select {
case uch.Ch <- update:
default:
log.Printf("[error] dropped update on updatebus uch clientid=%s\n", uch.ClientId)
}
}
}
func (bus *UpdateBus) SendScreenUpdate(screenId string, update UpdatePacket) {
if update == nil {
return
}
update.Clean()
bus.Lock.Lock()
defer bus.Lock.Unlock()
for _, uch := range bus.Channels {
if uch.Match(screenId) {
select {
case uch.Ch <- update:
default:
log.Printf("[error] dropped update on updatebus uch clientid=%s\n", uch.ClientId)
}
}
}
}
func MakeSessionsUpdateForRemote(sessionId string, ri *RemoteInstance) []*SessionType {
return []*SessionType{
&SessionType{
SessionId: sessionId,
Remotes: []*RemoteInstance{ri},
},
}
}
type BookmarksViewType struct {
Bookmarks []*BookmarkType `json:"bookmarks"`
}

View File

@ -0,0 +1,132 @@
package utilfn
import (
"bytes"
"encoding/binary"
"fmt"
"strings"
)
const LineDiffVersion = 0
type LineDiffType struct {
Lines []int
NewData []string
}
// simple encoding
// a 0 means read a line from NewData
// a non-zero number means read the 1-indexed line from OldData
func applyDiff(oldData []string, diff LineDiffType) ([]string, error) {
rtn := make([]string, 0, len(diff.Lines))
newDataPos := 0
for i := 0; i < len(diff.Lines); i++ {
if diff.Lines[i] == 0 {
if newDataPos >= len(diff.NewData) {
return nil, fmt.Errorf("not enough newdata for diff")
}
rtn = append(rtn, diff.NewData[newDataPos])
newDataPos++
} else {
idx := diff.Lines[i] - 1 // 1-indexed
if idx < 0 || idx >= len(oldData) {
return nil, fmt.Errorf("diff index out of bounds %d old-data-len:%d", idx, len(oldData))
}
rtn = append(rtn, oldData[idx])
}
}
return rtn, nil
}
func putUVarint(buf *bytes.Buffer, viBuf []byte, ival int) {
l := binary.PutUvarint(viBuf, uint64(ival))
buf.Write(viBuf[0:l])
}
// simple encoding
// write varints. first version, then len, then len-number-of-varints, then fill the rest with newdata
// [version] [len-varint] [varint]xlen... newdata (bytes)
func encodeDiff(diff LineDiffType) []byte {
var buf bytes.Buffer
viBuf := make([]byte, binary.MaxVarintLen64)
putUVarint(&buf, viBuf, 0)
putUVarint(&buf, viBuf, len(diff.Lines))
for _, val := range diff.Lines {
putUVarint(&buf, viBuf, val)
}
for _, str := range diff.NewData {
buf.WriteString(str)
buf.WriteByte('\n')
}
return buf.Bytes()
}
func decodeDiff(diffBytes []byte) (LineDiffType, error) {
var rtn LineDiffType
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
return rtn, fmt.Errorf("invalid diff, cannot read version: %v", err)
}
if version != LineDiffVersion {
return rtn, fmt.Errorf("invalid diff, bad version: %d", version)
}
linesLen64, err := binary.ReadUvarint(r)
if err != nil {
return rtn, fmt.Errorf("invalid diff, cannot read lines length: %v", err)
}
linesLen := int(linesLen64)
rtn.Lines = make([]int, linesLen)
for idx := 0; idx < linesLen; idx++ {
vi, err := binary.ReadUvarint(r)
if err != nil {
return rtn, fmt.Errorf("invalid diff, cannot read line %d: %v", idx, err)
}
rtn.Lines[idx] = int(vi)
}
restOfInput := string(r.Bytes())
rtn.NewData = strings.Split(restOfInput, "\n")
return rtn, nil
}
func makeDiff(oldData []string, newData []string) LineDiffType {
var rtn LineDiffType
oldDataMap := make(map[string]int) // 1-indexed
for idx, str := range oldData {
if _, found := oldDataMap[str]; found {
continue
}
oldDataMap[str] = idx + 1
}
rtn.Lines = make([]int, len(newData))
for idx, str := range newData {
oldIdx, found := oldDataMap[str]
if found {
rtn.Lines[idx] = oldIdx
} else {
rtn.Lines[idx] = 0
rtn.NewData = append(rtn.NewData, str)
}
}
return rtn
}
func MakeDiff(str1 string, str2 string) []byte {
str1Arr := strings.Split(str1, "\n")
str2Arr := strings.Split(str2, "\n")
diff := makeDiff(str1Arr, str2Arr)
return encodeDiff(diff)
}
func ApplyDiff(str1 string, diffBytes []byte) (string, error) {
diff, err := decodeDiff(diffBytes)
if err != nil {
return "", err
}
str1Arr := strings.Split(str1, "\n")
str2Arr, err := applyDiff(str1Arr, diff)
if err != nil {
return "", err
}
return strings.Join(str2Arr, "\n"), nil
}

View File

@ -0,0 +1,208 @@
package utilfn
import (
"crypto/sha1"
"encoding/base64"
"regexp"
"strings"
"unicode/utf8"
)
var HexDigits = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}
func GetStrArr(v interface{}, field string) []string {
if v == nil {
return nil
}
m, ok := v.(map[string]interface{})
if !ok {
return nil
}
fieldVal := m[field]
if fieldVal == nil {
return nil
}
iarr, ok := fieldVal.([]interface{})
if !ok {
return nil
}
var sarr []string
for _, iv := range iarr {
if sv, ok := iv.(string); ok {
sarr = append(sarr, sv)
}
}
return sarr
}
func GetBool(v interface{}, field string) bool {
if v == nil {
return false
}
m, ok := v.(map[string]interface{})
if !ok {
return false
}
fieldVal := m[field]
if fieldVal == nil {
return false
}
bval, ok := fieldVal.(bool)
if !ok {
return false
}
return bval
}
var needsQuoteRe = regexp.MustCompile(`[^\w@%:,./=+-]`)
// minimum maxlen=6
func ShellQuote(val string, forceQuote bool, maxLen int) string {
if maxLen < 6 {
maxLen = 6
}
rtn := val
if needsQuoteRe.MatchString(val) {
rtn = "'" + strings.ReplaceAll(val, "'", `'"'"'`) + "'"
}
if strings.HasPrefix(rtn, "\"") || strings.HasPrefix(rtn, "'") {
if len(rtn) > maxLen {
return rtn[0:maxLen-4] + "..." + rtn[0:1]
}
return rtn
}
if forceQuote {
if len(rtn) > maxLen-2 {
return "\"" + rtn[0:maxLen-5] + "...\""
}
return "\"" + rtn + "\""
} else {
if len(rtn) > maxLen {
return rtn[0:maxLen-3] + "..."
}
return rtn
}
}
func EllipsisStr(s string, maxLen int) string {
if maxLen < 4 {
maxLen = 4
}
if len(s) > maxLen {
return s[0:maxLen-3] + "..."
}
return s
}
func LongestPrefix(root string, strs []string) string {
if len(strs) == 0 {
return root
}
if len(strs) == 1 {
comp := strs[0]
if len(comp) >= len(root) && strings.HasPrefix(comp, root) {
if strings.HasSuffix(comp, "/") {
return strs[0]
}
return strs[0]
}
}
lcp := strs[0]
for i := 1; i < len(strs); i++ {
s := strs[i]
for j := 0; j < len(lcp); j++ {
if j >= len(s) || lcp[j] != s[j] {
lcp = lcp[0:j]
break
}
}
}
if len(lcp) < len(root) || !strings.HasPrefix(lcp, root) {
return root
}
return lcp
}
func ContainsStr(strs []string, test string) bool {
for _, s := range strs {
if s == test {
return true
}
}
return false
}
func IsPrefix(strs []string, test string) bool {
for _, s := range strs {
if len(s) > len(test) && strings.HasPrefix(s, test) {
return true
}
}
return false
}
type StrWithPos struct {
Str string
Pos int // this is a 'rune' position (not a byte position)
}
func (sp StrWithPos) String() string {
return strWithCursor(sp.Str, sp.Pos)
}
func ParseToSP(s string) StrWithPos {
idx := strings.Index(s, "[*]")
if idx == -1 {
return StrWithPos{Str: s}
}
return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: utf8.RuneCountInString(s[0:idx])}
}
func strWithCursor(str string, pos int) string {
if pos < 0 {
return "[*]_" + str
}
if pos >= len(str) {
if pos > len(str) {
return str + "_[*]"
}
return str + "[*]"
}
var rtn []rune
for _, ch := range str {
if len(rtn) == pos {
rtn = append(rtn, '[', '*', ']')
}
rtn = append(rtn, ch)
}
return string(rtn)
}
func (sp StrWithPos) Prepend(str string) StrWithPos {
return StrWithPos{Str: str + sp.Str, Pos: utf8.RuneCountInString(str) + sp.Pos}
}
func (sp StrWithPos) Append(str string) StrWithPos {
return StrWithPos{Str: sp.Str + str, Pos: sp.Pos}
}
// returns base64 hash of data
func Sha1Hash(data []byte) string {
hvalRaw := sha1.Sum(data)
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
return hval
}
func ChunkSlice[T any](s []T, chunkSize int) [][]T {
var rtn [][]T
for len(rtn) > 0 {
if len(s) <= chunkSize {
rtn = append(rtn, s)
break
}
rtn = append(rtn, s[:chunkSize])
s = s[chunkSize:]
}
return rtn
}

View File

@ -0,0 +1,48 @@
package utilfn
import (
"fmt"
"testing"
)
const Str1 = `
hello
line #2
more
stuff
apple
`
const Str2 = `
line #2
apple
grapes
banana
`
const Str3 = `
more
stuff
banana
coconut
`
func testDiff(t *testing.T, str1 string, str2 string) {
diffBytes := MakeDiff(str1, str2)
fmt.Printf("diff-len: %d\n", len(diffBytes))
out, err := ApplyDiff(str1, diffBytes)
if err != nil {
t.Errorf("error in diff: %v", err)
return
}
if out != str2 {
t.Errorf("bad diff output")
}
}
func TestDiff(t *testing.T) {
testDiff(t, Str1, Str2)
testDiff(t, Str2, Str3)
testDiff(t, Str1, Str3)
testDiff(t, Str3, Str1)
}

View File

@ -0,0 +1,189 @@
package wsshell
import (
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
const readWaitTimeout = 15 * time.Second
const writeWaitTimeout = 10 * time.Second
const pingPeriodTickTime = 10 * time.Second
const initialPingTime = 1 * time.Second
var upgrader = websocket.Upgrader{
ReadBufferSize: 4 * 1024,
WriteBufferSize: 32 * 1024,
HandshakeTimeout: 1 * time.Second,
CheckOrigin: func(r *http.Request) bool { return true },
}
type WSShell struct {
Conn *websocket.Conn
RemoteAddr string
ConnId string
Query url.Values
OpenTime time.Time
NumPings int
LastPing time.Time
LastRecv time.Time
Header http.Header
CloseChan chan bool
WriteChan chan []byte
ReadChan chan []byte
}
func (ws *WSShell) NonBlockingWrite(data []byte) bool {
select {
case ws.WriteChan <- data:
return true
default:
return false
}
}
func (ws *WSShell) WritePing() error {
now := time.Now()
pingMessage := map[string]interface{}{"type": "ping", "stime": now.Unix()}
jsonVal, _ := json.Marshal(pingMessage)
_ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error
err := ws.Conn.WriteMessage(websocket.TextMessage, jsonVal)
ws.NumPings++
ws.LastPing = now
if err != nil {
return err
}
return nil
}
func (ws *WSShell) WriteJson(val interface{}) error {
if ws.IsClosed() {
return fmt.Errorf("cannot write packet, empty or closed wsshell")
}
barr, err := json.Marshal(val)
if err != nil {
return err
}
ws.WriteChan <- barr
return nil
}
func (ws *WSShell) WritePump() {
ticker := time.NewTicker(initialPingTime)
defer func() {
ticker.Stop()
ws.Conn.Close()
}()
initialPing := true
for {
select {
case <-ticker.C:
err := ws.WritePing()
if err != nil {
log.Printf("WritePump %s err: %v\n", ws.RemoteAddr, err)
return
}
if initialPing {
initialPing = false
ticker.Reset(pingPeriodTickTime)
}
case msgBytes, ok := <-ws.WriteChan:
if !ok {
return
}
_ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error
err := ws.Conn.WriteMessage(websocket.TextMessage, msgBytes)
if err != nil {
log.Printf("WritePump %s err: %v\n", ws.RemoteAddr, err)
return
}
}
}
}
func (ws *WSShell) ReadPump() {
readWait := readWaitTimeout
defer func() {
ws.Conn.Close()
}()
ws.Conn.SetReadLimit(4096)
ws.Conn.SetReadDeadline(time.Now().Add(readWait))
for {
_, message, err := ws.Conn.ReadMessage()
if err != nil {
log.Printf("ReadPump %s Err: %v\n", ws.RemoteAddr, err)
break
}
jmsg := map[string]interface{}{}
err = json.Unmarshal(message, &jmsg)
if err != nil {
log.Printf("Error unmarshalling json: %v\n", err)
break
}
ws.Conn.SetReadDeadline(time.Now().Add(readWait))
ws.LastRecv = time.Now()
if str, ok := jmsg["type"].(string); ok && str == "pong" {
// nothing
continue
}
if str, ok := jmsg["type"].(string); ok && str == "ping" {
now := time.Now()
pongMessage := map[string]interface{}{"type": "pong", "stime": now.Unix()}
jsonVal, _ := json.Marshal(pongMessage)
ws.WriteChan <- jsonVal
continue
}
ws.ReadChan <- message
}
}
func (ws *WSShell) IsClosed() bool {
select {
case <-ws.CloseChan:
return true
default:
return false
}
}
func StartWS(w http.ResponseWriter, r *http.Request) (*WSShell, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, err
}
ws := WSShell{Conn: conn, ConnId: uuid.New().String(), OpenTime: time.Now()}
ws.CloseChan = make(chan bool)
ws.WriteChan = make(chan []byte, 10)
ws.ReadChan = make(chan []byte, 10)
ws.RemoteAddr = r.RemoteAddr
ws.Query = r.URL.Query()
ws.Header = r.Header
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ws.WritePump()
}()
wg.Add(1)
go func() {
defer wg.Done()
ws.ReadPump()
}()
go func() {
wg.Wait()
close(ws.CloseChan)
close(ws.ReadChan)
}()
return &ws, nil
}

16
wavesrv/scripthaus.md Normal file
View File

@ -0,0 +1,16 @@
# SH2 Server Commands
```bash
# @scripthaus command dump-schema-dev
sqlite3 /Users/mike/prompt-dev/prompt.db .schema > db/schema.sql
```
```bash
# @scripthaus command opendb-dev
sqlite3 /Users/mike/prompt-dev/prompt.db
```
```bash
# @scripthaus command build
go build -ldflags "-X main.BuildTime=$(date +'%Y%m%d%H%M')" -o bin/local-server ./cmd
```