mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-02 18:39:05 +01:00
01b5d71709
lots of changes. new wshrpc implementation. unify websocket, web, blockcontroller, domain sockets, and terminal inputs to all use the new rpc system. lots of moving files around to deal with circular dependencies use new wshrpc as a client in wsh cmd
270 lines
7.6 KiB
Go
270 lines
7.6 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package wstore
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/wavetermdev/thenextwave/pkg/filestore"
|
|
"github.com/wavetermdev/thenextwave/pkg/waveobj"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
|
|
)
|
|
|
|
var ErrNotFound = fmt.Errorf("not found")
|
|
|
|
func waveObjTableName(w waveobj.WaveObj) string {
|
|
return "db_" + w.GetOType()
|
|
}
|
|
|
|
func tableNameFromOType(otype string) string {
|
|
return "db_" + otype
|
|
}
|
|
|
|
func tableNameGen[T waveobj.WaveObj]() string {
|
|
var zeroObj T
|
|
return tableNameFromOType(zeroObj.GetOType())
|
|
}
|
|
|
|
func getOTypeGen[T waveobj.WaveObj]() string {
|
|
var zeroObj T
|
|
return zeroObj.GetOType()
|
|
}
|
|
|
|
func DBGetCount[T waveobj.WaveObj](ctx context.Context) (int, error) {
|
|
return WithTxRtn(ctx, func(tx *TxWrap) (int, error) {
|
|
table := tableNameGen[T]()
|
|
query := fmt.Sprintf("SELECT count(*) FROM %s", table)
|
|
return tx.GetInt(query), nil
|
|
})
|
|
}
|
|
|
|
type idDataType struct {
|
|
OId string
|
|
Version int
|
|
Data []byte
|
|
}
|
|
|
|
func genericCastWithErr[T any](v any, err error) (T, error) {
|
|
if err != nil {
|
|
var zeroVal T
|
|
return zeroVal, err
|
|
}
|
|
if v == nil {
|
|
var zeroVal T
|
|
return zeroVal, nil
|
|
}
|
|
return v.(T), err
|
|
}
|
|
|
|
func DBGetSingleton[T waveobj.WaveObj](ctx context.Context) (T, error) {
|
|
rtn, err := DBGetSingletonByType(ctx, getOTypeGen[T]())
|
|
return genericCastWithErr[T](rtn, err)
|
|
}
|
|
|
|
func DBGetSingletonByType(ctx context.Context, otype string) (waveobj.WaveObj, error) {
|
|
return WithTxRtn(ctx, func(tx *TxWrap) (waveobj.WaveObj, error) {
|
|
table := tableNameFromOType(otype)
|
|
query := fmt.Sprintf("SELECT oid, version, data FROM %s LIMIT 1", table)
|
|
var row idDataType
|
|
found := tx.Get(&row, query)
|
|
if !found {
|
|
return nil, ErrNotFound
|
|
}
|
|
rtn, err := waveobj.FromJson(row.Data)
|
|
if err != nil {
|
|
return rtn, err
|
|
}
|
|
waveobj.SetVersion(rtn, row.Version)
|
|
return rtn, nil
|
|
})
|
|
}
|
|
|
|
func DBExistsORef(ctx context.Context, oref waveobj.ORef) (bool, error) {
|
|
return WithTxRtn(ctx, func(tx *TxWrap) (bool, error) {
|
|
table := tableNameFromOType(oref.OType)
|
|
query := fmt.Sprintf("SELECT oid FROM %s WHERE oid = ?", table)
|
|
return tx.Exists(query, oref.OID), nil
|
|
})
|
|
}
|
|
|
|
func DBGet[T waveobj.WaveObj](ctx context.Context, id string) (T, error) {
|
|
rtn, err := DBGetORef(ctx, waveobj.ORef{OType: getOTypeGen[T](), OID: id})
|
|
return genericCastWithErr[T](rtn, err)
|
|
}
|
|
|
|
func DBMustGet[T waveobj.WaveObj](ctx context.Context, id string) (T, error) {
|
|
rtn, err := DBGetORef(ctx, waveobj.ORef{OType: getOTypeGen[T](), OID: id})
|
|
if err != nil {
|
|
var zeroVal T
|
|
return zeroVal, err
|
|
}
|
|
if rtn == nil {
|
|
var zeroVal T
|
|
return zeroVal, ErrNotFound
|
|
}
|
|
return rtn.(T), nil
|
|
}
|
|
|
|
func DBGetORef(ctx context.Context, oref waveobj.ORef) (waveobj.WaveObj, error) {
|
|
return WithTxRtn(ctx, func(tx *TxWrap) (waveobj.WaveObj, error) {
|
|
table := tableNameFromOType(oref.OType)
|
|
query := fmt.Sprintf("SELECT oid, version, data FROM %s WHERE oid = ?", table)
|
|
var row idDataType
|
|
found := tx.Get(&row, query, oref.OID)
|
|
if !found {
|
|
return nil, nil
|
|
}
|
|
rtn, err := waveobj.FromJson(row.Data)
|
|
if err != nil {
|
|
return rtn, err
|
|
}
|
|
waveobj.SetVersion(rtn, row.Version)
|
|
return rtn, nil
|
|
})
|
|
}
|
|
|
|
func dbSelectOIDs(ctx context.Context, otype string, oids []string) ([]waveobj.WaveObj, error) {
|
|
return WithTxRtn(ctx, func(tx *TxWrap) ([]waveobj.WaveObj, error) {
|
|
table := tableNameFromOType(otype)
|
|
query := fmt.Sprintf("SELECT oid, version, data FROM %s WHERE oid IN (SELECT value FROM json_each(?))", table)
|
|
var rows []idDataType
|
|
tx.Select(&rows, query, dbutil.QuickJson(oids))
|
|
rtn := make([]waveobj.WaveObj, 0, len(rows))
|
|
for _, row := range rows {
|
|
waveObj, err := waveobj.FromJson(row.Data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
waveobj.SetVersion(waveObj, row.Version)
|
|
rtn = append(rtn, waveObj)
|
|
}
|
|
return rtn, nil
|
|
})
|
|
}
|
|
|
|
func DBSelectORefs(ctx context.Context, orefs []waveobj.ORef) ([]waveobj.WaveObj, error) {
|
|
oidsByType := make(map[string][]string)
|
|
for _, oref := range orefs {
|
|
oidsByType[oref.OType] = append(oidsByType[oref.OType], oref.OID)
|
|
}
|
|
return WithTxRtn(ctx, func(tx *TxWrap) ([]waveobj.WaveObj, error) {
|
|
rtn := make([]waveobj.WaveObj, 0, len(orefs))
|
|
for otype, oids := range oidsByType {
|
|
rtnArr, err := dbSelectOIDs(tx.Context(), otype, oids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rtn = append(rtn, rtnArr...)
|
|
}
|
|
return rtn, nil
|
|
})
|
|
}
|
|
|
|
func DBResolveEasyOID(ctx context.Context, oid string) (*waveobj.ORef, error) {
|
|
return WithTxRtn(ctx, func(tx *TxWrap) (*waveobj.ORef, error) {
|
|
for _, rtype := range AllWaveObjTypes() {
|
|
otype := reflect.Zero(rtype).Interface().(waveobj.WaveObj).GetOType()
|
|
table := tableNameFromOType(otype)
|
|
var fullOID string
|
|
if len(oid) == 8 {
|
|
query := fmt.Sprintf("SELECT oid FROM %s WHERE oid LIKE ?", table)
|
|
fullOID = tx.GetString(query, oid+"%")
|
|
} else {
|
|
query := fmt.Sprintf("SELECT oid FROM %s WHERE oid = ?", table)
|
|
fullOID = tx.GetString(query, oid)
|
|
}
|
|
if fullOID != "" {
|
|
oref := waveobj.MakeORef(otype, fullOID)
|
|
return &oref, nil
|
|
}
|
|
}
|
|
return nil, ErrNotFound
|
|
})
|
|
}
|
|
|
|
func DBSelectMap[T waveobj.WaveObj](ctx context.Context, ids []string) (map[string]T, error) {
|
|
rtnArr, err := dbSelectOIDs(ctx, getOTypeGen[T](), ids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rtnMap := make(map[string]T)
|
|
for _, obj := range rtnArr {
|
|
rtnMap[waveobj.GetOID(obj)] = obj.(T)
|
|
}
|
|
return rtnMap, nil
|
|
}
|
|
|
|
func DBDelete(ctx context.Context, otype string, id string) error {
|
|
err := WithTx(ctx, func(tx *TxWrap) error {
|
|
table := tableNameFromOType(otype)
|
|
query := fmt.Sprintf("DELETE FROM %s WHERE oid = ?", table)
|
|
tx.Exec(query, id)
|
|
ContextAddUpdate(ctx, WaveObjUpdate{UpdateType: UpdateType_Delete, OType: otype, OID: id})
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go func() {
|
|
// we spawn a go routine here because we don't want to reuse the DB connection
|
|
// since DBDelete is called in a transaction from DeleteTab
|
|
deleteCtx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancelFn()
|
|
err := filestore.WFS.DeleteZone(deleteCtx, id)
|
|
if err != nil {
|
|
log.Printf("error deleting filestore zone (after deleting block): %v", err)
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func DBUpdate(ctx context.Context, val waveobj.WaveObj) error {
|
|
oid := waveobj.GetOID(val)
|
|
if oid == "" {
|
|
return fmt.Errorf("cannot update %T value with empty id", val)
|
|
}
|
|
jsonData, err := waveobj.ToJson(val)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return WithTx(ctx, func(tx *TxWrap) error {
|
|
table := waveObjTableName(val)
|
|
query := fmt.Sprintf("UPDATE %s SET data = ?, version = version+1 WHERE oid = ? RETURNING version", table)
|
|
newVersion := tx.GetInt(query, jsonData, oid)
|
|
waveobj.SetVersion(val, newVersion)
|
|
ContextAddUpdate(ctx, WaveObjUpdate{UpdateType: UpdateType_Update, OType: val.GetOType(), OID: oid, Obj: val})
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func DBInsert(ctx context.Context, val waveobj.WaveObj) error {
|
|
oid := waveobj.GetOID(val)
|
|
if oid == "" {
|
|
return fmt.Errorf("cannot insert %T value with empty id", val)
|
|
}
|
|
jsonData, err := waveobj.ToJson(val)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return WithTx(ctx, func(tx *TxWrap) error {
|
|
table := waveObjTableName(val)
|
|
waveobj.SetVersion(val, 1)
|
|
query := fmt.Sprintf("INSERT INTO %s (oid, version, data) VALUES (?, ?, ?)", table)
|
|
tx.Exec(query, oid, 1, jsonData)
|
|
ContextAddUpdate(ctx, WaveObjUpdate{UpdateType: UpdateType_Update, OType: val.GetOType(), OID: oid, Obj: val})
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func DBFindWindowForTabId(ctx context.Context, tabId string) (string, error) {
|
|
return WithTxRtn(ctx, func(tx *TxWrap) (string, error) {
|
|
query := "SELECT oid FROM db_window WHERE data->>'activetabid' = ?"
|
|
return tx.GetString(query, tabId), nil
|
|
})
|
|
}
|