mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-21 21:32:13 +01:00
save/restore activesessionid, set session name, much more sophisticated session switching logic
This commit is contained in:
parent
00b88f7f13
commit
46ba21030b
@ -423,6 +423,12 @@ func main() {
|
||||
fmt.Printf("[error] migrate up: %v\n", err)
|
||||
return
|
||||
}
|
||||
userData, err := sstore.EnsureClientData(context.Background())
|
||||
if err != nil {
|
||||
fmt.Printf("[error] ensuring user data: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("userid = %s\n", userData.UserId)
|
||||
err = sstore.EnsureLocalRemote(context.Background())
|
||||
if err != nil {
|
||||
fmt.Printf("[error] ensuring local remote: %v\n", err)
|
||||
@ -443,12 +449,6 @@ func main() {
|
||||
fmt.Printf("[error] ensuring default session: %v\n", err)
|
||||
return
|
||||
}
|
||||
userData, err := sstore.EnsureUserData(context.Background())
|
||||
if err != nil {
|
||||
fmt.Printf("[error] ensuring user data: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("userid = %s\n", userData.UserId)
|
||||
err = remote.LoadRemotes(context.Background())
|
||||
if err != nil {
|
||||
fmt.Printf("[error] loading remotes: %v\n", err)
|
||||
|
@ -1,5 +1,6 @@
|
||||
CREATE TABLE client (
|
||||
userid varchar(36) NOT NULL,
|
||||
activesessionid varchar(36) NOT NULL,
|
||||
userpublickeybytes blob NOT NULL,
|
||||
userprivatekeybytes blob NOT NULL
|
||||
);
|
||||
|
@ -2,6 +2,7 @@ CREATE TABLE schema_migrations (version uint64,dirty bool);
|
||||
CREATE UNIQUE INDEX version_unique ON schema_migrations (version);
|
||||
CREATE TABLE client (
|
||||
userid varchar(36) NOT NULL,
|
||||
activesessionid varchar(36) NOT NULL,
|
||||
userpublickeybytes blob NOT NULL,
|
||||
userprivatekeybytes blob NOT NULL
|
||||
);
|
||||
|
@ -35,6 +35,10 @@ const (
|
||||
R_RemoteOpt = 128
|
||||
)
|
||||
|
||||
const MaxNameLen = 50
|
||||
|
||||
var genericNameRe = regexp.MustCompile("^[a-zA-Z][a-zA-Z0-9_ .()<>,/\"'\\[\\]{}=+$@!*-]*$")
|
||||
|
||||
type resolvedIds struct {
|
||||
SessionId string
|
||||
ScreenId string
|
||||
@ -57,6 +61,52 @@ var ValidCommands = []string{
|
||||
"/remote:show",
|
||||
}
|
||||
|
||||
var positionRe = regexp.MustCompile("^((\\+|-)?[0-9]+|(\\+|-))$")
|
||||
|
||||
func resolveByPosition(ids []string, curId string, posStr string) string {
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
if !positionRe.MatchString(posStr) {
|
||||
return ""
|
||||
}
|
||||
curIdx := 1 // if no match, curIdx will be first item
|
||||
for idx, id := range ids {
|
||||
if id == curId {
|
||||
curIdx = idx + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
isRelative := strings.HasPrefix(posStr, "+") || strings.HasPrefix(posStr, "-")
|
||||
isWrap := posStr == "+" || posStr == "-"
|
||||
var pos int
|
||||
if isWrap && posStr == "+" {
|
||||
pos = 1
|
||||
} else if isWrap && posStr == "-" {
|
||||
pos = -1
|
||||
} else {
|
||||
pos, _ = strconv.Atoi(posStr)
|
||||
}
|
||||
if isRelative {
|
||||
pos = curIdx + pos
|
||||
}
|
||||
if pos < 1 {
|
||||
if isWrap {
|
||||
pos = len(ids)
|
||||
} else {
|
||||
pos = 1
|
||||
}
|
||||
}
|
||||
if pos > len(ids) {
|
||||
if isWrap {
|
||||
pos = 1
|
||||
} else {
|
||||
pos = len(ids)
|
||||
}
|
||||
}
|
||||
return ids[pos-1]
|
||||
}
|
||||
|
||||
func HandleCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
|
||||
switch SubMetaCmd(pk.MetaCmd) {
|
||||
case "run":
|
||||
@ -137,21 +187,61 @@ func resolveSessionScreen(ctx context.Context, sessionId string, screenArg strin
|
||||
if screen.ScreenId == screenArg || screen.Name == screenArg {
|
||||
return screen.ScreenId, nil
|
||||
}
|
||||
|
||||
}
|
||||
return "", fmt.Errorf("could not resolve screen '%s' (name/id not found)", screenArg)
|
||||
}
|
||||
|
||||
func resolveSession(ctx context.Context, sessionArg string) (string, error) {
|
||||
sessions, err := sstore.GetBareSessions(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not retrive bare sessions")
|
||||
func getSessionIds(sarr []*sstore.SessionType) []string {
|
||||
rtn := make([]string, len(sarr))
|
||||
for idx, s := range sarr {
|
||||
rtn[idx] = s.SessionId
|
||||
}
|
||||
for _, session := range sessions {
|
||||
if session.SessionId == sessionArg || session.Name == sessionArg {
|
||||
return session.SessionId, nil
|
||||
return rtn
|
||||
}
|
||||
|
||||
var partialUUIDRe = regexp.MustCompile("^[0-9a-f]{8}$")
|
||||
|
||||
func isPartialUUID(s string) bool {
|
||||
return partialUUIDRe.MatchString(s)
|
||||
}
|
||||
|
||||
func resolveSession(ctx context.Context, sessionArg string, curSession string, bareSessions []*sstore.SessionType) (string, error) {
|
||||
if bareSessions == nil {
|
||||
var err error
|
||||
bareSessions, err = sstore.GetBareSessions(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not retrive bare sessions")
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("could not resolve sesssion '%s' (name/id not found)", sessionArg)
|
||||
var curSessionId string
|
||||
if curSession != "" {
|
||||
curSessionId, _ = resolveSession(ctx, curSession, "", bareSessions)
|
||||
}
|
||||
sids := getSessionIds(bareSessions)
|
||||
rtnId := resolveByPosition(sids, curSessionId, sessionArg)
|
||||
if rtnId != "" {
|
||||
return rtnId, nil
|
||||
}
|
||||
tryPuid := isPartialUUID(sessionArg)
|
||||
var prefixMatches []string
|
||||
var lastPrefixMatchId string
|
||||
for _, session := range bareSessions {
|
||||
if session.SessionId == sessionArg || session.Name == sessionArg || (tryPuid && strings.HasPrefix(session.SessionId, sessionArg)) {
|
||||
return session.SessionId, nil
|
||||
}
|
||||
if strings.HasPrefix(session.Name, sessionArg) {
|
||||
prefixMatches = append(prefixMatches, session.Name)
|
||||
lastPrefixMatchId = session.SessionId
|
||||
}
|
||||
}
|
||||
if len(prefixMatches) == 1 {
|
||||
return lastPrefixMatchId, nil
|
||||
}
|
||||
if len(prefixMatches) > 1 {
|
||||
return "", fmt.Errorf("could not resolve session '%s', ambiguious prefix matched multiple sessions: %s", sessionArg, formatStrs(prefixMatches, "and", true))
|
||||
}
|
||||
return "", fmt.Errorf("could not resolve sesssion '%s' (name/id/pos not found)", sessionArg)
|
||||
}
|
||||
|
||||
func resolveSessionId(pk *scpacket.FeCommandPacketType) (string, error) {
|
||||
@ -880,10 +970,55 @@ func CommentCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ssto
|
||||
return sstore.ModelUpdate{Line: rtnLine}, nil
|
||||
}
|
||||
|
||||
func maybeQuote(s string, quote bool) string {
|
||||
if quote {
|
||||
return fmt.Sprintf("%q", s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func formatStrs(strs []string, conj string, quote bool) string {
|
||||
if len(strs) == 0 {
|
||||
return "(none)"
|
||||
}
|
||||
if len(strs) == 1 {
|
||||
return maybeQuote(strs[0], quote)
|
||||
}
|
||||
if len(strs) == 2 {
|
||||
return fmt.Sprintf("%s %s %s", maybeQuote(strs[0], quote), conj, maybeQuote(strs[1], quote))
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
for idx := 0; idx < len(strs)-1; idx++ {
|
||||
buf.WriteString(maybeQuote(strs[idx], quote))
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
buf.WriteString(conj)
|
||||
buf.WriteString(" ")
|
||||
buf.WriteString(maybeQuote(strs[len(strs)-1], quote))
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func validateName(name string, typeStr string) error {
|
||||
if len(name) > MaxNameLen {
|
||||
return fmt.Errorf("%s name too long, max length is %d", typeStr, MaxNameLen)
|
||||
}
|
||||
if !genericNameRe.MatchString(name) {
|
||||
return fmt.Errorf("invalid %s name", typeStr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
|
||||
if pk.MetaSubCmd == "open" || pk.MetaSubCmd == "new" {
|
||||
activate := resolveBool(pk.Kwargs["activate"], true)
|
||||
update, err := sstore.InsertSessionWithName(ctx, pk.Kwargs["name"], activate)
|
||||
newName := pk.Kwargs["name"]
|
||||
if newName != "" {
|
||||
err := validateName(newName, "session")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
update, err := sstore.InsertSessionWithName(ctx, newName, activate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -901,10 +1036,30 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ssto
|
||||
if bareSession == nil {
|
||||
return nil, fmt.Errorf("session '%s' not found", ids.SessionId)
|
||||
}
|
||||
var varsUpdated []string
|
||||
if pk.Kwargs["name"] != "" {
|
||||
newName := pk.Kwargs["name"]
|
||||
err = validateName(newName, "session")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = sstore.SetSessionName(ctx, ids.SessionId, newName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting session name: %v", err)
|
||||
}
|
||||
varsUpdated = append(varsUpdated, "name")
|
||||
}
|
||||
if pk.Kwargs["pos"] != "" {
|
||||
|
||||
}
|
||||
if len(varsUpdated) == 0 {
|
||||
return nil, fmt.Errorf("/session:set no updates, can set %s", formatStrs([]string{"name", "pos"}, "or", false))
|
||||
}
|
||||
bareSession, err = sstore.GetBareSessionById(ctx, ids.SessionId)
|
||||
update := sstore.ModelUpdate{
|
||||
Sessions: nil,
|
||||
Sessions: []*sstore.SessionType{bareSession},
|
||||
Info: &sstore.InfoMsgType{
|
||||
InfoMsg: fmt.Sprintf("[%s]: update", bareSession.Name),
|
||||
InfoMsg: fmt.Sprintf("[%s]: updated %s", bareSession.Name, formatStrs(varsUpdated, "and", false)),
|
||||
TimeoutMs: 2000,
|
||||
},
|
||||
}
|
||||
@ -915,13 +1070,24 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ssto
|
||||
}
|
||||
firstArg := firstArg(pk)
|
||||
if firstArg == "" {
|
||||
return nil, fmt.Errorf("usage /session [session-name|session-id], no param specified")
|
||||
return nil, fmt.Errorf("usage /session [name|id|pos], no param specified")
|
||||
}
|
||||
sessionId, err := resolveSession(ctx, firstArg)
|
||||
sessionId, err := resolveSession(ctx, firstArg, pk.Kwargs["session"], nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sstore.ModelUpdate{ActiveSessionId: sessionId}, nil
|
||||
bareSession, err := sstore.GetSessionById(ctx, sessionId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find session '%s': %v", sessionId, err)
|
||||
}
|
||||
update := sstore.ModelUpdate{
|
||||
ActiveSessionId: sessionId,
|
||||
Info: &sstore.InfoMsgType{
|
||||
InfoMsg: fmt.Sprintf("switched to session %q", bareSession.Name),
|
||||
TimeoutMs: 2000,
|
||||
},
|
||||
}
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func splitLinesForInfo(str string) []string {
|
||||
|
@ -132,7 +132,7 @@ func GetSessionHistoryItems(ctx context.Context, sessionId string, maxItems int)
|
||||
func GetBareSessions(ctx context.Context) ([]*SessionType, error) {
|
||||
var rtn []*SessionType
|
||||
err := WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `SELECT * FROM session`
|
||||
query := `SELECT * FROM session ORDER BY sessionidx`
|
||||
tx.SelectWrap(&rtn, query)
|
||||
return nil
|
||||
})
|
||||
@ -142,6 +142,19 @@ func GetBareSessions(ctx context.Context) ([]*SessionType, error) {
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func GetAllSessionIds(ctx context.Context) ([]string, error) {
|
||||
var rtn []string
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `SELECT sessionid from session ORDER by sessionidx`
|
||||
rtn = tx.SelectStrings(query)
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return nil, txErr
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func GetBareSessionById(ctx context.Context, sessionId string) (*SessionType, error) {
|
||||
var rtn SessionType
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
@ -158,9 +171,10 @@ func GetBareSessionById(ctx context.Context, sessionId string) (*SessionType, er
|
||||
return &rtn, nil
|
||||
}
|
||||
|
||||
func GetAllSessions(ctx context.Context) ([]*SessionType, error) {
|
||||
func GetAllSessions(ctx context.Context) (*ModelUpdate, error) {
|
||||
var rtn []*SessionType
|
||||
err := WithTx(ctx, func(tx *TxWrap) error {
|
||||
var activeSessionId string
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `SELECT * FROM session`
|
||||
tx.SelectWrap(&rtn, query)
|
||||
sessionMap := make(map[string]*SessionType)
|
||||
@ -203,9 +217,14 @@ func GetAllSessions(ctx context.Context) ([]*SessionType, error) {
|
||||
s.Remotes = append(s.Remotes, ri)
|
||||
}
|
||||
}
|
||||
query = `SELECT activesessionid FROM client`
|
||||
activeSessionId = tx.GetString(query)
|
||||
return nil
|
||||
})
|
||||
return rtn, err
|
||||
if txErr != nil {
|
||||
return nil, txErr
|
||||
}
|
||||
return &ModelUpdate{Sessions: rtn, ActiveSessionId: activeSessionId}, nil
|
||||
}
|
||||
|
||||
func GetWindowById(ctx context.Context, sessionId string, windowId string) (*WindowType, error) {
|
||||
@ -240,10 +259,11 @@ func GetSessionScreens(ctx context.Context, sessionId string) ([]*ScreenType, er
|
||||
}
|
||||
|
||||
func GetSessionById(ctx context.Context, id string) (*SessionType, error) {
|
||||
allSessions, err := GetAllSessions(ctx)
|
||||
allSessionsUpdate, err := GetAllSessions(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allSessions := allSessionsUpdate.Sessions
|
||||
for _, session := range allSessions {
|
||||
if session.SessionId == id {
|
||||
return session, nil
|
||||
@ -283,6 +303,10 @@ func InsertSessionWithName(ctx context.Context, sessionName string, activate boo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if activate {
|
||||
query = `UPDATE client SET activesessionid = ?`
|
||||
tx.ExecWrap(query, newSessionId)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
@ -677,12 +701,12 @@ func reorderStrings(strs []string, toMove string, newIndex int) []string {
|
||||
|
||||
func ReIndexSessions(ctx context.Context, sessionId string, newIndex int) error {
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `SELECT sessionid FROM sessions ORDER BY sessionidx, name, sessionid`
|
||||
query := `SELECT sessionid FROM session ORDER BY sessionidx, name, sessionid`
|
||||
ids := tx.SelectStrings(query)
|
||||
if sessionId != "" {
|
||||
ids = reorderStrings(ids, sessionId, newIndex)
|
||||
}
|
||||
query = `UPDATE sessions SET sessionid = ? WHERE sessionid = ?`
|
||||
query = `UPDATE session SET sessionid = ? WHERE sessionid = ?`
|
||||
for idx, id := range ids {
|
||||
tx.ExecWrap(query, id, idx+1)
|
||||
}
|
||||
@ -693,11 +717,11 @@ func ReIndexSessions(ctx context.Context, sessionId string, newIndex int) error
|
||||
|
||||
func SetSessionName(ctx context.Context, sessionId string, name string) error {
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `SELECT sessionid FROM sessions WHERE sessionid = ?`
|
||||
query := `SELECT sessionid FROM session WHERE sessionid = ?`
|
||||
if !tx.Exists(query, sessionId) {
|
||||
return fmt.Errorf("session does not exist")
|
||||
}
|
||||
query = `UPDATE sessions SET name = ? WHERE sessionid = ?`
|
||||
query = `UPDATE session SET name = ? WHERE sessionid = ?`
|
||||
tx.ExecWrap(query, name, sessionId)
|
||||
return nil
|
||||
})
|
||||
|
@ -86,12 +86,13 @@ func GetDB(ctx context.Context) (*sqlx.DB, error) {
|
||||
return globalDB, globalDBErr
|
||||
}
|
||||
|
||||
type UserData struct {
|
||||
type ClientData struct {
|
||||
UserId string `json:"userid"`
|
||||
UserPrivateKeyBytes []byte `json:"-"`
|
||||
UserPublicKeyBytes []byte `json:"-"`
|
||||
UserPrivateKey *ecdsa.PrivateKey
|
||||
UserPublicKey *ecdsa.PublicKey
|
||||
ActiveSessionId string `json:"activesessionid"`
|
||||
}
|
||||
|
||||
type SessionType struct {
|
||||
@ -639,7 +640,7 @@ func EnsureDefaultSession(ctx context.Context) (*SessionType, error) {
|
||||
return GetSessionByName(ctx, DefaultSessionName)
|
||||
}
|
||||
|
||||
func createUserData(tx *TxWrap) error {
|
||||
func createClientData(tx *TxWrap) error {
|
||||
userId := uuid.New().String()
|
||||
curve := elliptic.P384()
|
||||
pkey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
@ -654,14 +655,14 @@ func createUserData(tx *TxWrap) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling (pkix) public key bytes: %w", err)
|
||||
}
|
||||
query := `INSERT INTO client (userid, userpublickeybytes, userprivatekeybytes) VALUES (?, ?, ?)`
|
||||
query := `INSERT INTO client (userid, activesessionid, userpublickeybytes, userprivatekeybytes) VALUES (?, '', ?, ?)`
|
||||
tx.ExecWrap(query, userId, pubBytes, pkBytes)
|
||||
fmt.Printf("create new userid[%s] with public/private keypair\n", userId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnsureUserData(ctx context.Context) (*UserData, error) {
|
||||
var rtn UserData
|
||||
func EnsureClientData(ctx context.Context) (*ClientData, error) {
|
||||
var rtn ClientData
|
||||
err := WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `SELECT count(*) FROM client`
|
||||
count := tx.GetInt(query)
|
||||
@ -669,7 +670,7 @@ func EnsureUserData(ctx context.Context) (*UserData, error) {
|
||||
return fmt.Errorf("invalid client database, multiple (%d) rows in client table", count)
|
||||
}
|
||||
if count == 0 {
|
||||
createErr := createUserData(tx)
|
||||
createErr := createClientData(tx)
|
||||
if createErr != nil {
|
||||
return createErr
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user