save/restore activesessionid, set session name, much more sophisticated session switching logic

This commit is contained in:
sawka 2022-08-26 16:21:19 -07:00
parent 00b88f7f13
commit 46ba21030b
6 changed files with 228 additions and 35 deletions

View File

@ -423,6 +423,12 @@ func main() {
fmt.Printf("[error] migrate up: %v\n", err)
return
}
userData, err := sstore.EnsureClientData(context.Background())
if err != nil {
fmt.Printf("[error] ensuring user data: %v\n", err)
return
}
fmt.Printf("userid = %s\n", userData.UserId)
err = sstore.EnsureLocalRemote(context.Background())
if err != nil {
fmt.Printf("[error] ensuring local remote: %v\n", err)
@ -443,12 +449,6 @@ func main() {
fmt.Printf("[error] ensuring default session: %v\n", err)
return
}
userData, err := sstore.EnsureUserData(context.Background())
if err != nil {
fmt.Printf("[error] ensuring user data: %v\n", err)
return
}
fmt.Printf("userid = %s\n", userData.UserId)
err = remote.LoadRemotes(context.Background())
if err != nil {
fmt.Printf("[error] loading remotes: %v\n", err)

View File

@ -1,5 +1,6 @@
CREATE TABLE client (
userid varchar(36) NOT NULL,
activesessionid varchar(36) NOT NULL,
userpublickeybytes blob NOT NULL,
userprivatekeybytes blob NOT NULL
);

View File

@ -2,6 +2,7 @@ CREATE TABLE schema_migrations (version uint64,dirty bool);
CREATE UNIQUE INDEX version_unique ON schema_migrations (version);
CREATE TABLE client (
userid varchar(36) NOT NULL,
activesessionid varchar(36) NOT NULL,
userpublickeybytes blob NOT NULL,
userprivatekeybytes blob NOT NULL
);

View File

@ -35,6 +35,10 @@ const (
R_RemoteOpt = 128
)
const MaxNameLen = 50
var genericNameRe = regexp.MustCompile("^[a-zA-Z][a-zA-Z0-9_ .()<>,/\"'\\[\\]{}=+$@!*-]*$")
type resolvedIds struct {
SessionId string
ScreenId string
@ -57,6 +61,52 @@ var ValidCommands = []string{
"/remote:show",
}
var positionRe = regexp.MustCompile("^((\\+|-)?[0-9]+|(\\+|-))$")
func resolveByPosition(ids []string, curId string, posStr string) string {
if len(ids) == 0 {
return ""
}
if !positionRe.MatchString(posStr) {
return ""
}
curIdx := 1 // if no match, curIdx will be first item
for idx, id := range ids {
if id == curId {
curIdx = idx + 1
break
}
}
isRelative := strings.HasPrefix(posStr, "+") || strings.HasPrefix(posStr, "-")
isWrap := posStr == "+" || posStr == "-"
var pos int
if isWrap && posStr == "+" {
pos = 1
} else if isWrap && posStr == "-" {
pos = -1
} else {
pos, _ = strconv.Atoi(posStr)
}
if isRelative {
pos = curIdx + pos
}
if pos < 1 {
if isWrap {
pos = len(ids)
} else {
pos = 1
}
}
if pos > len(ids) {
if isWrap {
pos = 1
} else {
pos = len(ids)
}
}
return ids[pos-1]
}
func HandleCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
switch SubMetaCmd(pk.MetaCmd) {
case "run":
@ -137,21 +187,61 @@ func resolveSessionScreen(ctx context.Context, sessionId string, screenArg strin
if screen.ScreenId == screenArg || screen.Name == screenArg {
return screen.ScreenId, nil
}
}
return "", fmt.Errorf("could not resolve screen '%s' (name/id not found)", screenArg)
}
func resolveSession(ctx context.Context, sessionArg string) (string, error) {
sessions, err := sstore.GetBareSessions(ctx)
func getSessionIds(sarr []*sstore.SessionType) []string {
rtn := make([]string, len(sarr))
for idx, s := range sarr {
rtn[idx] = s.SessionId
}
return rtn
}
var partialUUIDRe = regexp.MustCompile("^[0-9a-f]{8}$")
func isPartialUUID(s string) bool {
return partialUUIDRe.MatchString(s)
}
func resolveSession(ctx context.Context, sessionArg string, curSession string, bareSessions []*sstore.SessionType) (string, error) {
if bareSessions == nil {
var err error
bareSessions, err = sstore.GetBareSessions(ctx)
if err != nil {
return "", fmt.Errorf("could not retrive bare sessions")
}
for _, session := range sessions {
if session.SessionId == sessionArg || session.Name == sessionArg {
}
var curSessionId string
if curSession != "" {
curSessionId, _ = resolveSession(ctx, curSession, "", bareSessions)
}
sids := getSessionIds(bareSessions)
rtnId := resolveByPosition(sids, curSessionId, sessionArg)
if rtnId != "" {
return rtnId, nil
}
tryPuid := isPartialUUID(sessionArg)
var prefixMatches []string
var lastPrefixMatchId string
for _, session := range bareSessions {
if session.SessionId == sessionArg || session.Name == sessionArg || (tryPuid && strings.HasPrefix(session.SessionId, sessionArg)) {
return session.SessionId, nil
}
if strings.HasPrefix(session.Name, sessionArg) {
prefixMatches = append(prefixMatches, session.Name)
lastPrefixMatchId = session.SessionId
}
return "", fmt.Errorf("could not resolve sesssion '%s' (name/id not found)", sessionArg)
}
if len(prefixMatches) == 1 {
return lastPrefixMatchId, nil
}
if len(prefixMatches) > 1 {
return "", fmt.Errorf("could not resolve session '%s', ambiguious prefix matched multiple sessions: %s", sessionArg, formatStrs(prefixMatches, "and", true))
}
return "", fmt.Errorf("could not resolve sesssion '%s' (name/id/pos not found)", sessionArg)
}
func resolveSessionId(pk *scpacket.FeCommandPacketType) (string, error) {
@ -880,10 +970,55 @@ func CommentCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ssto
return sstore.ModelUpdate{Line: rtnLine}, nil
}
func maybeQuote(s string, quote bool) string {
if quote {
return fmt.Sprintf("%q", s)
}
return s
}
func formatStrs(strs []string, conj string, quote bool) string {
if len(strs) == 0 {
return "(none)"
}
if len(strs) == 1 {
return maybeQuote(strs[0], quote)
}
if len(strs) == 2 {
return fmt.Sprintf("%s %s %s", maybeQuote(strs[0], quote), conj, maybeQuote(strs[1], quote))
}
var buf bytes.Buffer
for idx := 0; idx < len(strs)-1; idx++ {
buf.WriteString(maybeQuote(strs[idx], quote))
buf.WriteString(", ")
}
buf.WriteString(conj)
buf.WriteString(" ")
buf.WriteString(maybeQuote(strs[len(strs)-1], quote))
return buf.String()
}
func validateName(name string, typeStr string) error {
if len(name) > MaxNameLen {
return fmt.Errorf("%s name too long, max length is %d", typeStr, MaxNameLen)
}
if !genericNameRe.MatchString(name) {
return fmt.Errorf("invalid %s name", typeStr)
}
return nil
}
func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
if pk.MetaSubCmd == "open" || pk.MetaSubCmd == "new" {
activate := resolveBool(pk.Kwargs["activate"], true)
update, err := sstore.InsertSessionWithName(ctx, pk.Kwargs["name"], activate)
newName := pk.Kwargs["name"]
if newName != "" {
err := validateName(newName, "session")
if err != nil {
return nil, err
}
}
update, err := sstore.InsertSessionWithName(ctx, newName, activate)
if err != nil {
return nil, err
}
@ -901,10 +1036,30 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ssto
if bareSession == nil {
return nil, fmt.Errorf("session '%s' not found", ids.SessionId)
}
var varsUpdated []string
if pk.Kwargs["name"] != "" {
newName := pk.Kwargs["name"]
err = validateName(newName, "session")
if err != nil {
return nil, err
}
err = sstore.SetSessionName(ctx, ids.SessionId, newName)
if err != nil {
return nil, fmt.Errorf("setting session name: %v", err)
}
varsUpdated = append(varsUpdated, "name")
}
if pk.Kwargs["pos"] != "" {
}
if len(varsUpdated) == 0 {
return nil, fmt.Errorf("/session:set no updates, can set %s", formatStrs([]string{"name", "pos"}, "or", false))
}
bareSession, err = sstore.GetBareSessionById(ctx, ids.SessionId)
update := sstore.ModelUpdate{
Sessions: nil,
Sessions: []*sstore.SessionType{bareSession},
Info: &sstore.InfoMsgType{
InfoMsg: fmt.Sprintf("[%s]: update", bareSession.Name),
InfoMsg: fmt.Sprintf("[%s]: updated %s", bareSession.Name, formatStrs(varsUpdated, "and", false)),
TimeoutMs: 2000,
},
}
@ -915,13 +1070,24 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ssto
}
firstArg := firstArg(pk)
if firstArg == "" {
return nil, fmt.Errorf("usage /session [session-name|session-id], no param specified")
return nil, fmt.Errorf("usage /session [name|id|pos], no param specified")
}
sessionId, err := resolveSession(ctx, firstArg)
sessionId, err := resolveSession(ctx, firstArg, pk.Kwargs["session"], nil)
if err != nil {
return nil, err
}
return sstore.ModelUpdate{ActiveSessionId: sessionId}, nil
bareSession, err := sstore.GetSessionById(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("could not find session '%s': %v", sessionId, err)
}
update := sstore.ModelUpdate{
ActiveSessionId: sessionId,
Info: &sstore.InfoMsgType{
InfoMsg: fmt.Sprintf("switched to session %q", bareSession.Name),
TimeoutMs: 2000,
},
}
return update, nil
}
func splitLinesForInfo(str string) []string {

View File

@ -132,7 +132,7 @@ func GetSessionHistoryItems(ctx context.Context, sessionId string, maxItems int)
func GetBareSessions(ctx context.Context) ([]*SessionType, error) {
var rtn []*SessionType
err := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT * FROM session`
query := `SELECT * FROM session ORDER BY sessionidx`
tx.SelectWrap(&rtn, query)
return nil
})
@ -142,6 +142,19 @@ func GetBareSessions(ctx context.Context) ([]*SessionType, error) {
return rtn, nil
}
func GetAllSessionIds(ctx context.Context) ([]string, error) {
var rtn []string
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT sessionid from session ORDER by sessionidx`
rtn = tx.SelectStrings(query)
return nil
})
if txErr != nil {
return nil, txErr
}
return rtn, nil
}
func GetBareSessionById(ctx context.Context, sessionId string) (*SessionType, error) {
var rtn SessionType
txErr := WithTx(ctx, func(tx *TxWrap) error {
@ -158,9 +171,10 @@ func GetBareSessionById(ctx context.Context, sessionId string) (*SessionType, er
return &rtn, nil
}
func GetAllSessions(ctx context.Context) ([]*SessionType, error) {
func GetAllSessions(ctx context.Context) (*ModelUpdate, error) {
var rtn []*SessionType
err := WithTx(ctx, func(tx *TxWrap) error {
var activeSessionId string
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT * FROM session`
tx.SelectWrap(&rtn, query)
sessionMap := make(map[string]*SessionType)
@ -203,9 +217,14 @@ func GetAllSessions(ctx context.Context) ([]*SessionType, error) {
s.Remotes = append(s.Remotes, ri)
}
}
query = `SELECT activesessionid FROM client`
activeSessionId = tx.GetString(query)
return nil
})
return rtn, err
if txErr != nil {
return nil, txErr
}
return &ModelUpdate{Sessions: rtn, ActiveSessionId: activeSessionId}, nil
}
func GetWindowById(ctx context.Context, sessionId string, windowId string) (*WindowType, error) {
@ -240,10 +259,11 @@ func GetSessionScreens(ctx context.Context, sessionId string) ([]*ScreenType, er
}
func GetSessionById(ctx context.Context, id string) (*SessionType, error) {
allSessions, err := GetAllSessions(ctx)
allSessionsUpdate, err := GetAllSessions(ctx)
if err != nil {
return nil, err
}
allSessions := allSessionsUpdate.Sessions
for _, session := range allSessions {
if session.SessionId == id {
return session, nil
@ -283,6 +303,10 @@ func InsertSessionWithName(ctx context.Context, sessionName string, activate boo
if err != nil {
return err
}
if activate {
query = `UPDATE client SET activesessionid = ?`
tx.ExecWrap(query, newSessionId)
}
return nil
})
if txErr != nil {
@ -677,12 +701,12 @@ func reorderStrings(strs []string, toMove string, newIndex int) []string {
func ReIndexSessions(ctx context.Context, sessionId string, newIndex int) error {
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT sessionid FROM sessions ORDER BY sessionidx, name, sessionid`
query := `SELECT sessionid FROM session ORDER BY sessionidx, name, sessionid`
ids := tx.SelectStrings(query)
if sessionId != "" {
ids = reorderStrings(ids, sessionId, newIndex)
}
query = `UPDATE sessions SET sessionid = ? WHERE sessionid = ?`
query = `UPDATE session SET sessionid = ? WHERE sessionid = ?`
for idx, id := range ids {
tx.ExecWrap(query, id, idx+1)
}
@ -693,11 +717,11 @@ func ReIndexSessions(ctx context.Context, sessionId string, newIndex int) error
func SetSessionName(ctx context.Context, sessionId string, name string) error {
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT sessionid FROM sessions WHERE sessionid = ?`
query := `SELECT sessionid FROM session WHERE sessionid = ?`
if !tx.Exists(query, sessionId) {
return fmt.Errorf("session does not exist")
}
query = `UPDATE sessions SET name = ? WHERE sessionid = ?`
query = `UPDATE session SET name = ? WHERE sessionid = ?`
tx.ExecWrap(query, name, sessionId)
return nil
})

View File

@ -86,12 +86,13 @@ func GetDB(ctx context.Context) (*sqlx.DB, error) {
return globalDB, globalDBErr
}
type UserData struct {
type ClientData struct {
UserId string `json:"userid"`
UserPrivateKeyBytes []byte `json:"-"`
UserPublicKeyBytes []byte `json:"-"`
UserPrivateKey *ecdsa.PrivateKey
UserPublicKey *ecdsa.PublicKey
ActiveSessionId string `json:"activesessionid"`
}
type SessionType struct {
@ -639,7 +640,7 @@ func EnsureDefaultSession(ctx context.Context) (*SessionType, error) {
return GetSessionByName(ctx, DefaultSessionName)
}
func createUserData(tx *TxWrap) error {
func createClientData(tx *TxWrap) error {
userId := uuid.New().String()
curve := elliptic.P384()
pkey, err := ecdsa.GenerateKey(curve, rand.Reader)
@ -654,14 +655,14 @@ func createUserData(tx *TxWrap) error {
if err != nil {
return fmt.Errorf("marshaling (pkix) public key bytes: %w", err)
}
query := `INSERT INTO client (userid, userpublickeybytes, userprivatekeybytes) VALUES (?, ?, ?)`
query := `INSERT INTO client (userid, activesessionid, userpublickeybytes, userprivatekeybytes) VALUES (?, '', ?, ?)`
tx.ExecWrap(query, userId, pubBytes, pkBytes)
fmt.Printf("create new userid[%s] with public/private keypair\n", userId)
return nil
}
func EnsureUserData(ctx context.Context) (*UserData, error) {
var rtn UserData
func EnsureClientData(ctx context.Context) (*ClientData, error) {
var rtn ClientData
err := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT count(*) FROM client`
count := tx.GetInt(query)
@ -669,7 +670,7 @@ func EnsureUserData(ctx context.Context) (*UserData, error) {
return fmt.Errorf("invalid client database, multiple (%d) rows in client table", count)
}
if count == 0 {
createErr := createUserData(tx)
createErr := createClientData(tx)
if createErr != nil {
return createErr
}