diff --git a/db/migrations/000008_cloudsession.down.sql b/db/migrations/000008_cloudsession.down.sql index afb70c9b8..bf15963a7 100644 --- a/db/migrations/000008_cloudsession.down.sql +++ b/db/migrations/000008_cloudsession.down.sql @@ -2,3 +2,4 @@ ALTER TABLE session ADD COLUMN accesskey DEFAULT ''; ALTER TABLE session ADD COLUMN ownerid DEFAULT ''; DROP TABLE cloud_session; +DROP TABLE cloud_update; diff --git a/db/migrations/000008_cloudsession.up.sql b/db/migrations/000008_cloudsession.up.sql index 4fba64645..860e61b05 100644 --- a/db/migrations/000008_cloudsession.up.sql +++ b/db/migrations/000008_cloudsession.up.sql @@ -11,3 +11,10 @@ CREATE TABLE cloud_session ( acl json NOT NULL ); +CREATE TABLE cloud_update ( + updateid varchar(36) PRIMARY KEY, + ts bigint NOT NULL, + updatetype varchar(50) NOT NULL, + updatekeys json NOT NULL +); + diff --git a/pkg/cmdrunner/cmdrunner.go b/pkg/cmdrunner/cmdrunner.go index a63f6757c..68e4485cf 100644 --- a/pkg/cmdrunner/cmdrunner.go +++ b/pkg/cmdrunner/cmdrunner.go @@ -128,6 +128,7 @@ func init() { registerCmdFn("session:showall", SessionShowAllCommand) registerCmdFn("session:show", SessionShowCommand) registerCmdFn("session:openshared", SessionOpenSharedCommand) + registerCmdFn("session:opencloud", SessionOpenCloudCommand) registerCmdFn("screen", ScreenCommand) registerCmdFn("screen:archive", ScreenArchiveCommand) @@ -1546,6 +1547,22 @@ func SessionOpenSharedCommand(ctx context.Context, pk *scpacket.FeCommandPacketT return nil, fmt.Errorf("shared sessions are not available in this version of prompt (stay tuned)") } +func SessionOpenCloudCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) { + activate := resolveBool(pk.Kwargs["activate"], true) + newName := pk.Kwargs["name"] + if newName != "" { + err := validateName(newName, "session") + if err != nil { + return nil, err + } + } + update, err := sstore.InsertSessionWithName(ctx, newName, sstore.ShareModeShared, activate) + if err != nil { + return nil, err + } + return update, nil +} + func SessionOpenCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) { activate := resolveBool(pk.Kwargs["activate"], true) newName := pk.Kwargs["name"] @@ -1555,7 +1572,7 @@ func SessionOpenCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) ( return nil, err } } - update, err := sstore.InsertSessionWithName(ctx, newName, activate) + update, err := sstore.InsertSessionWithName(ctx, newName, sstore.ShareModeLocal, activate) if err != nil { return nil, err } diff --git a/pkg/pcloud/pcloud.go b/pkg/pcloud/pcloud.go index 3ca08b8fb..388433bd2 100644 --- a/pkg/pcloud/pcloud.go +++ b/pkg/pcloud/pcloud.go @@ -199,7 +199,6 @@ func CreateCloudSession(ctx context.Context) error { if err != nil { return err } - fmt.Printf("authinfo: %v\n", authInfo) req, err := makeAuthPostReq(ctx, CreateCloudSessionUrl, authInfo, nil) if err != nil { return err diff --git a/pkg/promptenc/promptenc.go b/pkg/promptenc/promptenc.go index bd74216da..e5ac25b3d 100644 --- a/pkg/promptenc/promptenc.go +++ b/pkg/promptenc/promptenc.go @@ -44,6 +44,7 @@ func MakeRandomEncryptor() (*Encryptor, error) { } func MakeEncryptor(key []byte) (*Encryptor, error) { + var err error rtn := &Encryptor{Key: key} rtn.AEAD, err = ccp.NewX(rtn.Key) if err != nil { diff --git a/pkg/sstore/dbops.go b/pkg/sstore/dbops.go index 30b55a4e3..2135b7f24 100644 --- a/pkg/sstore/dbops.go +++ b/pkg/sstore/dbops.go @@ -65,7 +65,7 @@ func GetAllRemotes(ctx context.Context) ([]*RemoteType, error) { query := `SELECT * FROM remote ORDER BY remoteidx` marr := tx.SelectMaps(query) for _, m := range marr { - rtn = append(rtn, RemoteFromMap(m)) + rtn = append(rtn, FromMap[*RemoteType](m)) } return nil }) @@ -80,7 +80,7 @@ func GetRemoteByAlias(ctx context.Context, alias string) (*RemoteType, error) { err := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM remote WHERE remotealias = ?` m := tx.GetMap(query, alias) - remote = RemoteFromMap(m) + remote = FromMap[*RemoteType](m) return nil }) if err != nil { @@ -94,7 +94,7 @@ func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) { err := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM remote WHERE remoteid = ?` m := tx.GetMap(query, remoteId) - remote = RemoteFromMap(m) + remote = FromMap[*RemoteType](m) return nil }) if err != nil { @@ -108,7 +108,7 @@ func GetLocalRemote(ctx context.Context) (*RemoteType, error) { err := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM remote WHERE local` m := tx.GetMap(query) - remote = RemoteFromMap(m) + remote = FromMap[*RemoteType](m) return nil }) if err != nil { @@ -121,8 +121,7 @@ func GetRemoteByCanonicalName(ctx context.Context, cname string) (*RemoteType, e var remote *RemoteType err := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM remote WHERE remotecanonicalname = ?` - m := tx.GetMap(query, cname) - remote = RemoteFromMap(m) + remote = GetMapGen[*RemoteType](tx, query, cname) return nil }) if err != nil { @@ -135,8 +134,7 @@ func GetRemoteByPhysicalId(ctx context.Context, physicalId string) (*RemoteType, var remote *RemoteType err := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM remote WHERE physicalid = ?` - m := tx.GetMap(query, physicalId) - remote = RemoteFromMap(m) + remote = GetMapGen[*RemoteType](tx, query, physicalId) return nil }) if err != nil { @@ -327,7 +325,7 @@ func runHistoryQuery(tx *TxWrap, opts HistoryQueryOpts, realOffset int, itemLimi marr := tx.SelectMaps(query, queryArgs...) rtn := make([]*HistoryItemType, len(marr)) for idx, m := range marr { - hitem := HistoryItemFromMap(m) + hitem := FromMap[*HistoryItemType](m) rtn[idx] = hitem } return rtn, nil @@ -434,9 +432,8 @@ func GetAllSessions(ctx context.Context) (*ModelUpdate, error) { screen.Windows = append(screen.Windows, sw) } query = `SELECT * FROM remote_instance` - riMaps := tx.SelectMaps(query) - for _, m := range riMaps { - ri := RIFromMap(m) + riArr := SelectMapsGen[*RemoteInstance](tx, query) + for _, ri := range riArr { s := sessionMap[ri.SessionId] if s != nil { s.Remotes = append(s.Remotes, ri) @@ -460,14 +457,11 @@ func GetWindowById(ctx context.Context, sessionId string, windowId string) (*Win if m == nil { return nil } - rtnWindow = WindowFromMap(m) + rtnWindow = FromMap[*WindowType](m) query = `SELECT * FROM line WHERE sessionid = ? AND windowid = ? ORDER BY linenum` tx.Select(&rtnWindow.Lines, query, sessionId, windowId) query = `SELECT * FROM cmd WHERE cmdid IN (SELECT cmdid FROM line WHERE sessionid = ? AND windowid = ?)` - cmdMaps := tx.SelectMaps(query, sessionId, windowId) - for _, m := range cmdMaps { - rtnWindow.Cmds = append(rtnWindow.Cmds, CmdFromMap(m)) - } + rtnWindow.Cmds = SelectMapsGen[*CmdType](tx, query, sessionId, windowId) return nil }) return rtnWindow, err @@ -519,9 +513,27 @@ func GetSessionByName(ctx context.Context, name string) (*SessionType, error) { return session, nil } +func InsertCloudSession(ctx context.Context, sessionName string, shareMode string, activate bool) (*ModelUpdate, error) { + var updateRtn *ModelUpdate + txErr := WithTx(ctx, func(tx *TxWrap) error { + var err error + updateRtn, err = InsertSessionWithName(tx.Context(), sessionName, shareMode, activate) + if err != nil { + return err + } + sessionId := updateRtn.Sessions[0].SessionId + fmt.Printf("sessionid: %v\n", sessionId) + return nil + }) + if txErr != nil { + return nil, txErr + } + return updateRtn, nil +} + // also creates default window, returns sessionId // if sessionName == "", it will be generated -func InsertSessionWithName(ctx context.Context, sessionName string, activate bool) (UpdatePacket, error) { +func InsertSessionWithName(ctx context.Context, sessionName string, shareMode string, activate bool) (*ModelUpdate, error) { newSessionId := scbase.GenPromptUUID() txErr := WithTx(ctx, func(tx *TxWrap) error { names := tx.SelectStrings(`SELECT name FROM session`) @@ -553,7 +565,7 @@ func InsertSessionWithName(ctx context.Context, sessionName string, activate boo if activate { update.ActiveSessionId = newSessionId } - return update, nil + return &update, nil } func SetActiveSessionId(ctx context.Context, sessionId string) error { @@ -757,8 +769,7 @@ func GetLineCmdByLineId(ctx context.Context, sessionId string, windowId string, lineRtn = &lineVal if lineVal.CmdId != "" { query = `SELECT * FROM cmd WHERE sessionid = ? AND cmdid = ?` - m := tx.GetMap(query, sessionId, lineVal.CmdId) - cmdRtn = CmdFromMap(m) + cmdRtn = GetMapGen[*CmdType](tx, query, sessionId, lineVal.CmdId) } return nil }) @@ -784,8 +795,7 @@ func GetLineCmdByCmdId(ctx context.Context, sessionId string, windowId string, c } lineRtn = &lineVal query = `SELECT * FROM cmd WHERE sessionid = ? AND cmdid = ?` - m := tx.GetMap(query, sessionId, cmdId) - cmdRtn = CmdFromMap(m) + cmdRtn = GetMapGen[*CmdType](tx, query, sessionId, cmdId) return nil }) if txErr != nil { @@ -834,8 +844,7 @@ func GetCmdById(ctx context.Context, sessionId string, cmdId string) (*CmdType, var cmd *CmdType err := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM cmd WHERE sessionid = ? AND cmdid = ?` - m := tx.GetMap(query, sessionId, cmdId) - cmd = CmdFromMap(m) + cmd = GetMapGen[*CmdType](tx, query, sessionId, cmdId) return nil }) if err != nil { @@ -1176,8 +1185,7 @@ func GetRemoteInstance(ctx context.Context, sessionId string, windowId string, r var ri *RemoteInstance txErr := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM remote_instance WHERE sessionid = ? AND windowid = ? AND remoteownerid = ? AND remoteid = ? AND name = ?` - m := tx.GetMap(query, sessionId, windowId, remotePtr.OwnerId, remotePtr.RemoteId, remotePtr.Name) - ri = RIFromMap(m) + ri = GetMapGen[*RemoteInstance](tx, query, sessionId, windowId, remotePtr.OwnerId, remotePtr.RemoteId, remotePtr.Name) return nil }) if txErr != nil { @@ -1223,8 +1231,7 @@ func UpdateRemoteState(ctx context.Context, sessionId string, windowId string, r return fmt.Errorf("cannot update remote instance state: %w", err) } query := `SELECT * FROM remote_instance WHERE sessionid = ? AND windowid = ? AND remoteownerid = ? AND remoteid = ? AND name = ?` - m := tx.GetMap(query, sessionId, windowId, remotePtr.OwnerId, remotePtr.RemoteId, remotePtr.Name) - ri = RIFromMap(m) + ri = GetMapGen[*RemoteInstance](tx, query, sessionId, windowId, remotePtr.OwnerId, remotePtr.RemoteId, remotePtr.Name) if ri == nil { ri = &RemoteInstance{ RIId: scbase.GenPromptUUID(), @@ -1423,10 +1430,7 @@ func GetRunningWindowCmds(ctx context.Context, sessionId string, windowId string var rtn []*CmdType txErr := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * from cmd WHERE cmdid IN (SELECT cmdid FROM line WHERE sessionid = ? AND windowid = ?) AND status = ?` - cmdMaps := tx.SelectMaps(query, sessionId, windowId, CmdStatusRunning) - for _, m := range cmdMaps { - rtn = append(rtn, CmdFromMap(m)) - } + rtn = SelectMapsGen[*CmdType](tx, query, sessionId, windowId, CmdStatusRunning) return nil }) if txErr != nil { @@ -1870,8 +1874,7 @@ func GetFullState(ctx context.Context, ssPtr ShellStatePtr) (*packet.ShellState, } for idx, diffHash := range ssPtr.DiffHashArr { query = `SELECT * FROM state_diff WHERE diffhash = ?` - m := tx.GetMap(query, diffHash) - stateDiff := StateDiffFromMap(m) + stateDiff := GetMapGen[*StateDiff](tx, query, diffHash) if stateDiff == nil { return fmt.Errorf("ShellStateDiff %s not found", diffHash) } @@ -1986,13 +1989,7 @@ func GetRIsForWindow(ctx context.Context, sessionId string, windowId string) ([] var rtn []*RemoteInstance txErr := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM remote_instance WHERE sessionid = ? AND (windowid = '' OR windowid = ?)` - riMaps := tx.SelectMaps(query, sessionId, windowId) - for _, m := range riMaps { - ri := RIFromMap(m) - if ri != nil { - rtn = append(rtn, ri) - } - } + rtn = SelectMapsGen[*RemoteInstance](tx, query, sessionId, windowId) return nil }) if txErr != nil { @@ -2155,20 +2152,14 @@ func GetBookmarks(ctx context.Context, tag string) ([]*BookmarkType, error) { var bms []*BookmarkType txErr := WithTx(ctx, func(tx *TxWrap) error { var query string - var marr []map[string]interface{} if tag == "" { query = `SELECT * FROM bookmark` - marr = tx.SelectMaps(query) + bms = SelectMapsGen[*BookmarkType](tx, query) } else { query = `SELECT * FROM bookmark WHERE EXISTS (SELECT 1 FROM json_each(tags) WHERE value = ?)` - marr = tx.SelectMaps(query, tag) - } - bmMap := make(map[string]*BookmarkType) - for _, m := range marr { - bm := BookmarkFromMap(m) - bms = append(bms, bm) - bmMap[bm.BookmarkId] = bm + bms = SelectMapsGen[*BookmarkType](tx, query, tag) } + bmMap := MakeGenMap(bms) var orders []bookmarkOrderType query = `SELECT bookmarkid, orderidx FROM bookmark_order WHERE tag = ?` tx.Select(&orders, query, tag) @@ -2199,8 +2190,7 @@ func GetBookmarkById(ctx context.Context, bookmarkId string, tag string) (*Bookm var rtn *BookmarkType txErr := WithTx(ctx, func(tx *TxWrap) error { query := `SELECT * FROM bookmark WHERE bookmarkid = ?` - m := tx.GetMap(query, bookmarkId) - rtn = BookmarkFromMap(m) + rtn = GetMapGen[*BookmarkType](tx, query, bookmarkId) if rtn == nil { return nil } @@ -2329,30 +2319,24 @@ func DeleteBookmark(ctx context.Context, bookmarkId string) error { } func CreatePlaybook(ctx context.Context, name string) (*PlaybookType, error) { - var rtn *PlaybookType - txErr := WithTx(ctx, func(tx *TxWrap) error { + return WithTxRtn(ctx, func(tx *TxWrap) (*PlaybookType, error) { query := `SELECT playbookid FROM playbook WHERE name = ?` if tx.Exists(query, name) { - return fmt.Errorf("playbook %q already exists", name) + return nil, fmt.Errorf("playbook %q already exists", name) } - rtn = &PlaybookType{} + rtn := &PlaybookType{} rtn.PlaybookId = uuid.New().String() rtn.PlaybookName = name query = `INSERT INTO playbook ( playbookid, playbookname, description, entryids) VALUES (:playbookid,:playbookname,:description,:entryids)` tx.Exec(query, rtn.ToMap()) - return nil + return rtn, nil }) - if txErr != nil { - return nil, txErr - } - return rtn, nil } func selectPlaybook(tx *TxWrap, playbookId string) *PlaybookType { query := `SELECT * FROM playbook where playbookid = ?` - m := tx.GetMap(query, playbookId) - playbook := PlaybookFromMap(m) + playbook := GetMapGen[*PlaybookType](tx, query, playbookId) return playbook } @@ -2360,7 +2344,7 @@ func AddPlaybookEntry(ctx context.Context, entry *PlaybookEntry) error { if entry.EntryId == "" { return fmt.Errorf("invalid entryid") } - txErr := WithTx(ctx, func(tx *TxWrap) error { + return WithTx(ctx, func(tx *TxWrap) error { playbook := selectPlaybook(tx, entry.PlaybookId) if playbook == nil { return fmt.Errorf("cannot add entry, playbook does not exist") @@ -2377,11 +2361,10 @@ func AddPlaybookEntry(ctx context.Context, entry *PlaybookEntry) error { tx.Exec(query, quickJsonArr(playbook.EntryIds), entry.PlaybookId) return nil }) - return txErr } func RemovePlaybookEntry(ctx context.Context, playbookId string, entryId string) error { - txErr := WithTx(ctx, func(tx *TxWrap) error { + return WithTx(ctx, func(tx *TxWrap) error { playbook := selectPlaybook(tx, playbookId) if playbook == nil { return fmt.Errorf("cannot remove playbook entry, playbook does not exist") @@ -2397,25 +2380,19 @@ func RemovePlaybookEntry(ctx context.Context, playbookId string, entryId string) tx.Exec(query, quickJsonArr(playbook.EntryIds), playbookId) return nil }) - return txErr } func GetPlaybookById(ctx context.Context, playbookId string) (*PlaybookType, error) { - var rtn *PlaybookType - txErr := WithTx(ctx, func(tx *TxWrap) error { - rtn = selectPlaybook(tx, playbookId) + return WithTxRtn(ctx, func(tx *TxWrap) (*PlaybookType, error) { + rtn := selectPlaybook(tx, playbookId) if rtn == nil { - return nil + return nil, nil } query := `SELECT * FROM playbook_entry WHERE playbookid = ?` tx.Select(&rtn.Entries, query, playbookId) rtn.OrderEntries() - return nil + return rtn, nil }) - if txErr != nil { - return nil, txErr - } - return rtn, nil } func getLineIdsFromHistoryItems(historyItems []*HistoryItemType) []string { @@ -2439,55 +2416,33 @@ func getCmdIdsFromHistoryItems(historyItems []*HistoryItemType) []string { } func GetLineCmdsFromHistoryItems(ctx context.Context, historyItems []*HistoryItemType) ([]*LineType, []*CmdType, error) { - var lineArr []*LineType - var cmdArr []*CmdType if len(historyItems) == 0 { return nil, nil, nil } - txErr := WithTx(ctx, func(tx *TxWrap) error { + return WithTxRtn3(ctx, func(tx *TxWrap) ([]*LineType, []*CmdType, error) { + var lineArr []*LineType query := `SELECT * FROM line WHERE lineid IN (SELECT value FROM json_each(?))` tx.Select(&lineArr, query, quickJsonArr(getLineIdsFromHistoryItems(historyItems))) query = `SELECT * FROM cmd WHERE cmdid IN (SELECT value FROM json_each(?))` - marr := tx.SelectMaps(query, quickJsonArr(getCmdIdsFromHistoryItems(historyItems))) - for _, m := range marr { - cmd := CmdFromMap(m) - if cmd != nil { - cmdArr = append(cmdArr, cmd) - } - } - return nil + cmdArr := SelectMapsGen[*CmdType](tx, query, quickJsonArr(getCmdIdsFromHistoryItems(historyItems))) + return lineArr, cmdArr, nil }) - if txErr != nil { - return nil, nil, txErr - } - return lineArr, cmdArr, nil } func PurgeHistoryByIds(ctx context.Context, historyIds []string) ([]*HistoryItemType, error) { - var rtn []*HistoryItemType - txErr := WithTx(ctx, func(tx *TxWrap) error { + return WithTxRtn(ctx, func(tx *TxWrap) ([]*HistoryItemType, error) { query := `SELECT * FROM history WHERE historyid IN (SELECT value FROM json_each(?))` - marr := tx.SelectMaps(query, quickJsonArr(historyIds)) - for _, m := range marr { - hitem := HistoryItemFromMap(m) - if hitem != nil { - rtn = append(rtn, hitem) - } - } + rtn := SelectMapsGen[*HistoryItemType](tx, query, quickJsonArr(historyIds)) query = `DELETE FROM history WHERE historyid IN (SELECT value FROM json_each(?))` tx.Exec(query, quickJsonArr(historyIds)) for _, hitem := range rtn { if hitem.LineId != "" { err := PurgeLinesByIds(tx.Context(), hitem.SessionId, []string{hitem.LineId}) if err != nil { - return err + return nil, err } } } - return nil + return rtn, nil }) - if txErr != nil { - return nil, txErr - } - return rtn, nil } diff --git a/pkg/sstore/map.go b/pkg/sstore/map.go new file mode 100644 index 000000000..136445dc4 --- /dev/null +++ b/pkg/sstore/map.go @@ -0,0 +1,200 @@ +package sstore + +import ( + "context" + "fmt" + "reflect" + "strings" +) + +type DBMappable interface { + UseDBMap() +} + +type MapConverter interface { + ToMap() map[string]interface{} + FromMap(map[string]interface{}) bool +} + +type HasSimpleKey interface { + GetSimpleKey() string +} + +type MapConverterPtr[T any] interface { + MapConverter + *T +} + +type DBMappablePtr[T any] interface { + DBMappable + *T +} + +func FromMap[PT MapConverterPtr[T], T any](m map[string]any) PT { + if len(m) == 0 { + return nil + } + rtn := PT(new(T)) + ok := rtn.FromMap(m) + if !ok { + return nil + } + return rtn +} + +func GetMapGen[PT MapConverterPtr[T], T any](tx *TxWrap, query string, args ...interface{}) PT { + m := tx.GetMap(query, args...) + return FromMap[PT](m) +} + +func GetMappable[PT DBMappablePtr[T], T any](tx *TxWrap, query string, args ...interface{}) PT { + rtn := PT(new(T)) + m := tx.GetMap(query, args...) + if len(m) == 0 { + return nil + } + FromDBMap(rtn, m) + return rtn +} + +func SelectMapsGen[PT MapConverterPtr[T], T any](tx *TxWrap, query string, args ...interface{}) []PT { + var rtn []PT + marr := tx.SelectMaps(query, args...) + for _, m := range marr { + val := FromMap[PT](m) + if val != nil { + rtn = append(rtn, val) + } + } + return rtn +} + +func MakeGenMap[T HasSimpleKey](arr []T) map[string]T { + rtn := make(map[string]T) + for _, val := range arr { + rtn[val.GetSimpleKey()] = val + } + return rtn +} + +func WithTxRtn[RT any](ctx context.Context, fn func(tx *TxWrap) (RT, error)) (RT, error) { + var rtn RT + txErr := WithTx(ctx, func(tx *TxWrap) error { + temp, err := fn(tx) + if err != nil { + return err + } + rtn = temp + return nil + }) + return rtn, txErr +} + +func WithTxRtn3[RT1 any, RT2 any](ctx context.Context, fn func(tx *TxWrap) (RT1, RT2, error)) (RT1, RT2, error) { + var rtn1 RT1 + var rtn2 RT2 + txErr := WithTx(ctx, func(tx *TxWrap) error { + temp1, temp2, err := fn(tx) + if err != nil { + return err + } + rtn1 = temp1 + rtn2 = temp2 + return nil + }) + return rtn1, rtn2, txErr +} + +func isStructType(rt reflect.Type) bool { + if rt.Kind() == reflect.Struct { + return true + } + if rt.Kind() == reflect.Pointer && rt.Elem().Kind() == reflect.Struct { + return true + } + return false +} + +func isByteArrayType(t reflect.Type) bool { + return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 +} + +func ToDBMap(v DBMappable) map[string]interface{} { + if v == nil { + return nil + } + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructToDBMap", v)) + } + rt := rv.Type() + m := make(map[string]interface{}) + numFields := rt.NumField() + for i := 0; i < numFields; i++ { + field := rt.Field(i) + fieldVal := rv.FieldByIndex(field.Index) + dbName := field.Tag.Get("dbmap") + if dbName == "" { + dbName = strings.ToLower(field.Name) + } + if dbName == "-" { + continue + } + if field.Type.Kind() == reflect.Slice { + m[dbName] = quickJsonArr(fieldVal.Interface()) + } else if isStructType(field.Type) { + m[dbName] = quickJson(fieldVal.Interface()) + } else { + m[dbName] = fieldVal.Interface() + } + } + return m +} + +func FromDBMap(v DBMappable, m map[string]interface{}) { + if v == nil { + panic("StructFromDBMap, v cannot be nil") + } + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructFromDBMap", v)) + } + rt := rv.Type() + numFields := rt.NumField() + for i := 0; i < numFields; i++ { + field := rt.Field(i) + fieldVal := rv.FieldByIndex(field.Index) + dbName := field.Tag.Get("dbmap") + if dbName == "" { + dbName = strings.ToLower(field.Name) + } + if dbName == "-" { + continue + } + if isByteArrayType(field.Type) { + barrVal := fieldVal.Addr().Interface() + quickSetBytes(barrVal.(*[]byte), m, dbName) + } else if field.Type.Kind() == reflect.Slice { + quickSetJsonArr(fieldVal.Addr().Interface(), m, dbName) + } else if isStructType(field.Type) { + quickSetJson(fieldVal.Addr().Interface(), m, dbName) + } else if field.Type.Kind() == reflect.String { + strVal := fieldVal.Addr().Interface() + quickSetStr(strVal.(*string), m, dbName) + } else if field.Type.Kind() == reflect.Int64 { + intVal := fieldVal.Addr().Interface() + quickSetInt64(intVal.(*int64), m, dbName) + } else if field.Type.Kind() == reflect.Bool { + boolVal := fieldVal.Addr().Interface() + quickSetBool(boolVal.(*bool), m, dbName) + } else { + panic(fmt.Sprintf("StructFromDBMap invalid field type %v in %T", fieldVal.Type(), v)) + } + } +} diff --git a/pkg/sstore/sstore.go b/pkg/sstore/sstore.go index 45e39db50..76d056633 100644 --- a/pkg/sstore/sstore.go +++ b/pkg/sstore/sstore.go @@ -175,42 +175,15 @@ type ClientData struct { UserId string `json:"userid"` UserPrivateKeyBytes []byte `json:"-"` UserPublicKeyBytes []byte `json:"-"` - UserPrivateKey *ecdsa.PrivateKey `json:"-"` - UserPublicKey *ecdsa.PublicKey `json:"-"` + UserPrivateKey *ecdsa.PrivateKey `json:"-" dbmap:"-"` + UserPublicKey *ecdsa.PublicKey `json:"-" dbmap:"-"` ActiveSessionId string `json:"activesessionid"` WinSize ClientWinSizeType `json:"winsize"` ClientOpts ClientOptsType `json:"clientopts"` FeOpts FeOptsType `json:"feopts"` } -func (c *ClientData) ToMap() map[string]interface{} { - rtn := make(map[string]interface{}) - rtn["clientid"] = c.ClientId - rtn["userid"] = c.UserId - rtn["userprivatekeybytes"] = c.UserPrivateKeyBytes - rtn["userpublickeybytes"] = c.UserPublicKeyBytes - rtn["activesessionid"] = c.ActiveSessionId - rtn["winsize"] = quickJson(c.WinSize) - rtn["clientopts"] = quickJson(c.ClientOpts) - rtn["feopts"] = quickJson(c.FeOpts) - return rtn -} - -func ClientDataFromMap(m map[string]interface{}) *ClientData { - if len(m) == 0 { - return nil - } - var c ClientData - quickSetStr(&c.ClientId, m, "clientid") - quickSetStr(&c.UserId, m, "userid") - quickSetBytes(&c.UserPrivateKeyBytes, m, "userprivatekeybytes") - quickSetBytes(&c.UserPublicKeyBytes, m, "userpublickeybytes") - quickSetStr(&c.ActiveSessionId, m, "activesessionid") - quickSetJson(&c.WinSize, m, "winsize") - quickSetJson(&c.ClientOpts, m, "clientopts") - quickSetJson(&c.FeOpts, m, "feopts") - return &c -} +func (c ClientData) UseDBMap() {} type CloudAclType struct { UserId string `json:"userid"` @@ -244,6 +217,36 @@ type CloudSessionType struct { Acl []*CloudAclType } +func (cs *CloudSessionType) ToMap() map[string]any { + m := make(map[string]any) + m["sessionid"] = cs.SessionId + m["viewkey"] = cs.ViewKey + m["writekey"] = cs.WriteKey + m["enckey"] = cs.EncKey + m["enctype"] = cs.EncType + m["vts"] = cs.Vts + m["acl"] = quickJsonArr(cs.Acl) + return m +} + +func (cs *CloudSessionType) FromMap(m map[string]interface{}) bool { + quickSetStr(&cs.SessionId, m, "sessionid") + quickSetStr(&cs.ViewKey, m, "viewkey") + quickSetStr(&cs.WriteKey, m, "writekey") + quickSetStr(&cs.EncKey, m, "enckey") + quickSetStr(&cs.EncType, m, "enctype") + quickSetInt64(&cs.Vts, m, "vts") + quickSetJsonArr(&cs.Acl, m, "acl") + return true +} + +type CloudUpdate struct { + UpdateId string + Ts int64 + UpdateType string + UpdateKeys []string +} + type SessionStatsType struct { SessionId string `json:"sessionid"` NumScreens int `json:"numscreens"` @@ -370,11 +373,7 @@ func (w *WindowType) ToMap() map[string]interface{} { return rtn } -func WindowFromMap(m map[string]interface{}) *WindowType { - if len(m) == 0 { - return nil - } - var w WindowType +func (w *WindowType) FromMap(m map[string]interface{}) bool { quickSetStr(&w.SessionId, m, "sessionid") quickSetStr(&w.WindowId, m, "windowid") quickSetStr(&w.CurRemote.OwnerId, m, "curremoteownerid") @@ -385,7 +384,7 @@ func WindowFromMap(m map[string]interface{}) *WindowType { quickSetStr(&w.OwnerId, m, "ownerid") quickSetStr(&w.ShareMode, m, "sharemode") quickSetJson(&w.ShareOpts, m, "shareopts") - return &w + return true } func (h *HistoryItemType) ToMap() map[string]interface{} { @@ -408,11 +407,7 @@ func (h *HistoryItemType) ToMap() map[string]interface{} { return rtn } -func HistoryItemFromMap(m map[string]interface{}) *HistoryItemType { - if len(m) == 0 { - return nil - } - var h HistoryItemType +func (h *HistoryItemType) FromMap(m map[string]interface{}) bool { quickSetStr(&h.HistoryId, m, "historyid") quickSetInt64(&h.Ts, m, "ts") quickSetStr(&h.UserId, m, "userid") @@ -429,7 +424,7 @@ func HistoryItemFromMap(m map[string]interface{}) *HistoryItemType { quickSetBool(&h.IsMetaCmd, m, "ismetacmd") quickSetStr(&h.HistoryNum, m, "historynum") quickSetBool(&h.Incognito, m, "incognito") - return &h + return true } type ScreenOptsType struct { @@ -624,17 +619,13 @@ type StateDiff struct { Data []byte } -func StateDiffFromMap(m map[string]interface{}) *StateDiff { - if len(m) == 0 { - return nil - } - var sd StateDiff +func (sd *StateDiff) FromMap(m map[string]interface{}) bool { quickSetStr(&sd.DiffHash, m, "diffhash") quickSetInt64(&sd.Ts, m, "ts") quickSetStr(&sd.BaseHash, m, "basehash") quickSetJsonArr(&sd.DiffHashArr, m, "diffhasharr") quickSetBytes(&sd.Data, m, "data") - return &sd + return true } func (sd *StateDiff) ToMap() map[string]interface{} { @@ -659,11 +650,7 @@ func FeStateFromShellState(state *packet.ShellState) *FeStateType { return &FeStateType{Cwd: state.Cwd} } -func RIFromMap(m map[string]interface{}) *RemoteInstance { - if len(m) == 0 { - return nil - } - var ri RemoteInstance +func (ri *RemoteInstance) FromMap(m map[string]interface{}) bool { quickSetStr(&ri.RIId, m, "riid") quickSetStr(&ri.Name, m, "name") quickSetStr(&ri.SessionId, m, "sessionid") @@ -673,7 +660,7 @@ func RIFromMap(m map[string]interface{}) *RemoteInstance { quickSetJson(&ri.FeState, m, "festate") quickSetStr(&ri.StateBaseHash, m, "statebasehash") quickSetJsonArr(&ri.StateDiffHashArr, m, "statediffhasharr") - return &ri + return true } func (ri *RemoteInstance) ToMap() map[string]interface{} { @@ -731,16 +718,12 @@ func (p *PlaybookType) ToMap() map[string]interface{} { return rtn } -func PlaybookFromMap(m map[string]interface{}) *PlaybookType { - if len(m) == 0 { - return nil - } - var p PlaybookType +func (p *PlaybookType) FromMap(m map[string]interface{}) bool { quickSetStr(&p.PlaybookId, m, "playbookid") quickSetStr(&p.PlaybookName, m, "playbookname") quickSetStr(&p.Description, m, "description") quickSetJsonArr(&p.Entries, m, "entries") - return &p + return true } // reorders p.Entries to match p.EntryIds @@ -800,6 +783,10 @@ type BookmarkType struct { Remove bool `json:"remove,omitempty"` } +func (bm *BookmarkType) GetSimpleKey() string { + return bm.BookmarkId +} + func (bm *BookmarkType) ToMap() map[string]interface{} { rtn := make(map[string]interface{}) rtn["bookmarkid"] = bm.BookmarkId @@ -811,18 +798,14 @@ func (bm *BookmarkType) ToMap() map[string]interface{} { return rtn } -func BookmarkFromMap(m map[string]interface{}) *BookmarkType { - if len(m) == 0 { - return nil - } - var bm BookmarkType +func (bm *BookmarkType) FromMap(m map[string]interface{}) bool { quickSetStr(&bm.BookmarkId, m, "bookmarkid") quickSetInt64(&bm.CreatedTs, m, "createdts") quickSetStr(&bm.Alias, m, "alias") quickSetStr(&bm.CmdStr, m, "cmdstr") quickSetStr(&bm.Description, m, "description") quickSetJsonArr(&bm.Tags, m, "tags") - return &bm + return true } type ResolveItem struct { @@ -925,11 +908,7 @@ func (r *RemoteType) ToMap() map[string]interface{} { return rtn } -func RemoteFromMap(m map[string]interface{}) *RemoteType { - if len(m) == 0 { - return nil - } - var r RemoteType +func (r *RemoteType) FromMap(m map[string]interface{}) bool { quickSetStr(&r.RemoteId, m, "remoteid") quickSetStr(&r.PhysicalId, m, "physicalid") quickSetStr(&r.RemoteType, m, "remotetype") @@ -946,7 +925,7 @@ func RemoteFromMap(m map[string]interface{}) *RemoteType { quickSetBool(&r.Archived, m, "archived") quickSetInt64(&r.RemoteIdx, m, "remoteidx") quickSetBool(&r.Local, m, "local") - return &r + return true } func (cmd *CmdType) ToMap() map[string]interface{} { @@ -972,11 +951,7 @@ func (cmd *CmdType) ToMap() map[string]interface{} { return rtn } -func CmdFromMap(m map[string]interface{}) *CmdType { - if len(m) == 0 { - return nil - } - var cmd CmdType +func (cmd *CmdType) FromMap(m map[string]interface{}) bool { quickSetStr(&cmd.SessionId, m, "sessionid") quickSetStr(&cmd.CmdId, m, "cmdid") quickSetStr(&cmd.Remote.OwnerId, m, "remoteownerid") @@ -995,7 +970,7 @@ func CmdFromMap(m map[string]interface{}) *CmdType { quickSetBool(&cmd.RtnState, m, "rtnstate") quickSetStr(&cmd.RtnStatePtr.BaseHash, m, "rtnbasehash") quickSetJsonArr(&cmd.RtnStatePtr.DiffHashArr, m, "rtndiffhasharr") - return &cmd + return true } func makeNewLineCmd(sessionId string, windowId string, userId string, cmdId string, renderer string) *LineType { @@ -1116,7 +1091,7 @@ func EnsureDefaultSession(ctx context.Context) (*SessionType, error) { if session != nil { return session, nil } - _, err = InsertSessionWithName(ctx, DefaultSessionName, true) + _, err = InsertSessionWithName(ctx, DefaultSessionName, ShareModeLocal, true) if err != nil { return nil, err } @@ -1147,32 +1122,29 @@ func createClientData(tx *TxWrap) error { } query := `INSERT INTO client ( clientid, userid, activesessionid, userpublickeybytes, userprivatekeybytes, winsize) VALUES (:clientid,:userid,:activesessionid,:userpublickeybytes,:userprivatekeybytes,:winsize)` - tx.NamedExec(query, c.ToMap()) + tx.NamedExec(query, ToDBMap(c)) log.Printf("create new clientid[%s] userid[%s] with public/private keypair\n", c.ClientId, c.UserId) return nil } func EnsureClientData(ctx context.Context) (*ClientData, error) { - var rtn ClientData - err := WithTx(ctx, func(tx *TxWrap) error { + rtn, err := WithTxRtn(ctx, func(tx *TxWrap) (*ClientData, error) { query := `SELECT count(*) FROM client` count := tx.GetInt(query) if count > 1 { - return fmt.Errorf("invalid client database, multiple (%d) rows in client table", count) + return nil, fmt.Errorf("invalid client database, multiple (%d) rows in client table", count) } if count == 0 { createErr := createClientData(tx) if createErr != nil { - return createErr + return nil, createErr } } - m := tx.GetMap(`SELECT * FROM client`) - cdata := ClientDataFromMap(m) + cdata := GetMappable[*ClientData](tx, `SELECT * FROM client`) if cdata == nil { - return fmt.Errorf("no client data found") + return nil, fmt.Errorf("no client data found") } - rtn = *cdata - return nil + return cdata, nil }) if err != nil { return nil, err @@ -1196,7 +1168,7 @@ func EnsureClientData(ctx context.Context) (*ClientData, error) { if !ok { return nil, fmt.Errorf("invalid client data, wrong public key type: %T", pubKey) } - return &rtn, nil + return rtn, nil } func SetClientOpts(ctx context.Context, clientOpts ClientOptsType) error {