diff --git a/db/migrations-wstore/000001_init.up.sql b/db/migrations-wstore/000001_init.up.sql index 957bc803d..7fa0c3bd8 100644 --- a/db/migrations-wstore/000001_init.up.sql +++ b/db/migrations-wstore/000001_init.up.sql @@ -1,24 +1,29 @@ CREATE TABLE db_client ( - clientid varchar(36) PRIMARY KEY, -- unnecessary, but useful to have a PK + oid varchar(36) PRIMARY KEY, + version int NOT NULL, data json NOT NULL ); CREATE TABLE db_window ( - windowid varchar(36) PRIMARY KEY, + oid varchar(36) PRIMARY KEY, + version int NOT NULL, data json NOT NULL ); CREATE TABLE db_workspace ( - workspaceid varchar(36) PRIMARY KEY, + oid varchar(36) PRIMARY KEY, + version int NOT NULL, data json NOT NULL ); CREATE TABLE db_tab ( - tabid varchar(36) PRIMARY KEY, + oid varchar(36) PRIMARY KEY, + version int NOT NULL, data json NOT NULL ); CREATE TABLE db_block ( - blockid varchar(36) PRIMARY KEY, + oid varchar(36) PRIMARY KEY, + version int NOT NULL, data json NOT NULL ); diff --git a/main.go b/main.go index 22e10f59b..9c2532355 100644 --- a/main.go +++ b/main.go @@ -60,7 +60,7 @@ func createWindow(windowData *wstore.Window, app *application.App) { TitleBar: application.MacTitleBarHiddenInset, }, BackgroundColour: application.NewRGB(0, 0, 0), - URL: "/public/index.html?windowid=" + windowData.WindowId, + URL: "/public/index.html?windowid=" + windowData.OID, X: windowData.Pos.X, Y: windowData.Pos.Y, Width: windowData.WinSize.Width, @@ -146,12 +146,12 @@ func main() { setupCtx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) defer cancelFn() - client, err := wstore.DBGetSingleton[wstore.Client](setupCtx) + client, err := wstore.DBGetSingleton[*wstore.Client](setupCtx) if err != nil { log.Printf("error getting client data: %v\n", err) return } - mainWindow, err := wstore.DBGet[wstore.Window](setupCtx, client.MainWindowId) + mainWindow, err := wstore.DBGet[*wstore.Window](setupCtx, client.MainWindowId) if err != nil { log.Printf("error getting main window: %v\n", err) return diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 182e07ef8..bb8ade230 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -64,7 +64,7 @@ func jsonDeepCopy(val map[string]any) (map[string]any, error) { func CreateBlock(ctx context.Context, bdef *wstore.BlockDef, rtOpts *wstore.RuntimeOpts) (*wstore.Block, error) { blockId := uuid.New().String() blockData := &wstore.Block{ - BlockId: blockId, + OID: blockId, BlockDef: bdef, Controller: bdef.Controller, View: bdef.View, @@ -266,19 +266,19 @@ func ProcessStaticCommand(blockId string, cmdGen BlockCommand) error { return nil case *SetViewCommand: log.Printf("SETVIEW: %s | %q\n", blockId, cmd.View) - block, err := wstore.DBGet[wstore.Block](ctx, blockId) + block, err := wstore.DBGet[*wstore.Block](ctx, blockId) if err != nil { return fmt.Errorf("error getting block: %w", err) } block.View = cmd.View - err = wstore.DBUpdate[wstore.Block](ctx, block) + err = wstore.DBUpdate(ctx, block) if err != nil { return fmt.Errorf("error updating block: %w", err) } return nil case *SetMetaCommand: log.Printf("SETMETA: %s | %v\n", blockId, cmd.Meta) - block, err := wstore.DBGet[wstore.Block](ctx, blockId) + block, err := wstore.DBGet[*wstore.Block](ctx, blockId) if err != nil { return fmt.Errorf("error getting block: %w", err) } diff --git a/pkg/service/blockservice/blockservice.go b/pkg/service/blockservice/blockservice.go index 20e08fc54..615cd7f67 100644 --- a/pkg/service/blockservice/blockservice.go +++ b/pkg/service/blockservice/blockservice.go @@ -40,7 +40,7 @@ func (bs *BlockService) CloseBlock(blockId string) { func (bs *BlockService) GetBlockData(blockId string) (*wstore.Block, error) { ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) defer cancelFn() - blockData, err := wstore.DBGet[wstore.Block](ctx, blockId) + blockData, err := wstore.DBGet[*wstore.Block](ctx, blockId) if err != nil { return nil, fmt.Errorf("error getting block data: %w", err) } diff --git a/pkg/service/clientservice/clientservice.go b/pkg/service/clientservice/clientservice.go index e86c5fe47..5d0a0c6db 100644 --- a/pkg/service/clientservice/clientservice.go +++ b/pkg/service/clientservice/clientservice.go @@ -18,7 +18,7 @@ const DefaultTimeout = 2 * time.Second func (cs *ClientService) GetClientData() (*wstore.Client, error) { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() - clientData, err := wstore.DBGetSingleton[wstore.Client](ctx) + clientData, err := wstore.DBGetSingleton[*wstore.Client](ctx) if err != nil { return nil, fmt.Errorf("error getting client data: %w", err) } @@ -28,7 +28,7 @@ func (cs *ClientService) GetClientData() (*wstore.Client, error) { func (cs *ClientService) GetWorkspace(workspaceId string) (*wstore.Workspace, error) { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() - ws, err := wstore.DBGet[wstore.Workspace](ctx, workspaceId) + ws, err := wstore.DBGet[*wstore.Workspace](ctx, workspaceId) if err != nil { return nil, fmt.Errorf("error getting workspace: %w", err) } @@ -38,7 +38,7 @@ func (cs *ClientService) GetWorkspace(workspaceId string) (*wstore.Workspace, er func (cs *ClientService) GetTab(tabId string) (*wstore.Tab, error) { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() - tab, err := wstore.DBGet[wstore.Tab](ctx, tabId) + tab, err := wstore.DBGet[*wstore.Tab](ctx, tabId) if err != nil { return nil, fmt.Errorf("error getting tab: %w", err) } @@ -48,7 +48,7 @@ func (cs *ClientService) GetTab(tabId string) (*wstore.Tab, error) { func (cs *ClientService) GetWindow(windowId string) (*wstore.Window, error) { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() - window, err := wstore.DBGet[wstore.Window](ctx, windowId) + window, err := wstore.DBGet[*wstore.Window](ctx, windowId) if err != nil { return nil, fmt.Errorf("error getting window: %w", err) } diff --git a/pkg/waveobj/waveobj.go b/pkg/waveobj/waveobj.go index 9486000cc..fb773fdf6 100644 --- a/pkg/waveobj/waveobj.go +++ b/pkg/waveobj/waveobj.go @@ -15,78 +15,123 @@ import ( ) const ( - OTypeKeyName = "otype" - OIDKeyName = "oid" + OTypeKeyName = "otype" + OIDKeyName = "oid" + VersionKeyName = "version" + + OIDGoFieldName = "OID" + VersionGoFieldName = "Version" ) -type waveObjDesc struct { - RType reflect.Type - OIDField reflect.StructField -} - -var globalLock = &sync.Mutex{} -var waveObjMap = make(map[string]*waveObjDesc) -var waveObj WaveObj -var waveObjRType = reflect.TypeOf(&waveObj).Elem() - -func RegisterType(w WaveObj) { - globalLock.Lock() - defer globalLock.Unlock() - oidType := w.GetOType() - if waveObjMap[oidType] != nil { - panic(fmt.Sprintf("duplicate WaveObj registration: %T", w)) - } - rtype := reflect.TypeOf(w) - field := findOIDField(rtype) - if field == nil { - panic(fmt.Sprintf("cannot register WaveObj without OID field -- mark with tag `waveobj:\"oid\"`")) - } - waveObjMap[oidType] = &waveObjDesc{ - RType: rtype, - OIDField: *field, - } -} - -func findOIDField(rtype reflect.Type) *reflect.StructField { - for idx := 0; idx < rtype.NumField(); idx++ { - field := rtype.Field(idx) - if field.PkgPath != "" { - // private - continue - } - waveObjTag := field.Tag.Get("waveobj") - if waveObjTag == "oid" { - if field.Type.Kind() != reflect.String { - panic(fmt.Sprintf("in %v marked oid field is not type 'string'", rtype)) - } - return &field - } - } - return nil -} - -func getObjDescForOIDType(oidType string) *waveObjDesc { - globalLock.Lock() - defer globalLock.Unlock() - return waveObjMap[oidType] -} - type WaveObj interface { - GetOType() string + GetOType() string // should not depend on object state (should work with nil value) +} + +type waveObjDesc struct { + RType reflect.Type + OIDField reflect.StructField + VersionField reflect.StructField +} + +var waveObjMap = sync.Map{} +var waveObjRType = reflect.TypeOf((*WaveObj)(nil)).Elem() + +func RegisterType[T WaveObj]() { + var waveObj T + otype := waveObj.GetOType() + if otype == "" { + panic(fmt.Sprintf("otype is empty for %T", waveObj)) + } + rtype := reflect.TypeOf(waveObj) + if rtype.Kind() != reflect.Ptr { + panic(fmt.Sprintf("wave object must be a pointer for %T", waveObj)) + } + oidField, found := rtype.Elem().FieldByName(OIDGoFieldName) + if !found { + panic(fmt.Sprintf("missing OID field for %T", waveObj)) + } + if oidField.Type.Kind() != reflect.String { + panic(fmt.Sprintf("OID field must be string for %T", waveObj)) + } + if oidField.Tag.Get("json") != OIDKeyName { + panic(fmt.Sprintf("OID field json tag must be %q for %T", OIDKeyName, waveObj)) + } + versionField, found := rtype.Elem().FieldByName(VersionGoFieldName) + if !found { + panic(fmt.Sprintf("missing Version field for %T", waveObj)) + } + if versionField.Type.Kind() != reflect.Int { + panic(fmt.Sprintf("Version field must be int for %T", waveObj)) + } + if versionField.Tag.Get("json") != VersionKeyName { + panic(fmt.Sprintf("Version field json tag must be %q for %T", VersionKeyName, waveObj)) + } + _, found = waveObjMap.Load(otype) + if found { + panic(fmt.Sprintf("otype %q already registered", otype)) + } + waveObjMap.Store(otype, &waveObjDesc{ + RType: rtype, + OIDField: oidField, + VersionField: versionField, + }) +} + +func getWaveObjDesc(otype string) *waveObjDesc { + desc, _ := waveObjMap.Load(otype) + if desc == nil { + return nil + } + return desc.(*waveObjDesc) +} + +func GetOID(waveObj WaveObj) string { + desc := getWaveObjDesc(waveObj.GetOType()) + if desc == nil { + return "" + } + return reflect.ValueOf(waveObj).Elem().FieldByIndex(desc.OIDField.Index).String() +} + +func SetOID(waveObj WaveObj, oid string) { + desc := getWaveObjDesc(waveObj.GetOType()) + if desc == nil { + return + } + reflect.ValueOf(waveObj).Elem().FieldByIndex(desc.OIDField.Index).SetString(oid) +} + +func GetVersion(waveObj WaveObj) int { + desc := getWaveObjDesc(waveObj.GetOType()) + if desc == nil { + return 0 + } + return int(reflect.ValueOf(waveObj).Elem().FieldByIndex(desc.VersionField.Index).Int()) +} + +func SetVersion(waveObj WaveObj, version int) { + desc := getWaveObjDesc(waveObj.GetOType()) + if desc == nil { + return + } + reflect.ValueOf(waveObj).Elem().FieldByIndex(desc.VersionField.Index).SetInt(int64(version)) } func ToJson(w WaveObj) ([]byte, error) { m := make(map[string]any) - err := mapstructure.Decode(w, &m) + dconfig := &mapstructure.DecoderConfig{ + Result: &m, + TagName: "json", + } + decoder, err := mapstructure.NewDecoder(dconfig) if err != nil { return nil, err } - desc := getObjDescForOIDType(w.GetOType()) - if desc == nil { - return nil, fmt.Errorf("otype %q (%T) not registered", w.GetOType(), w) + err = decoder.Decode(w) + if err != nil { + return nil, err } m[OTypeKeyName] = w.GetOType() - m[OIDKeyName] = reflect.ValueOf(w).FieldByIndex(desc.OIDField.Index).String() return json.Marshal(m) } @@ -100,23 +145,24 @@ func FromJson(data []byte) (WaveObj, error) { if !ok { return nil, fmt.Errorf("missing otype") } - oid, ok := m[OIDKeyName].(string) - if !ok { - return nil, fmt.Errorf("missing oid") - } - desc := getObjDescForOIDType(otype) + desc := getWaveObjDesc(otype) if desc == nil { - return nil, fmt.Errorf("unknown oid type: %s", otype) + return nil, fmt.Errorf("unknown otype: %s", otype) } - objVal := reflect.New(desc.RType) - oidField := objVal.FieldByIndex(desc.OIDField.Index) - oidField.SetString(oid) - obj := objVal.Interface().(WaveObj) - err = mapstructure.Decode(m, obj) + wobj := reflect.Zero(desc.RType).Interface().(WaveObj) + dconfig := &mapstructure.DecoderConfig{ + Result: &wobj, + TagName: "json", + } + decoder, err := mapstructure.NewDecoder(dconfig) if err != nil { return nil, err } - return obj, nil + err = decoder.Decode(m) + if err != nil { + return nil, err + } + return wobj, nil } func FromJsonGen[T WaveObj](data []byte) (T, error) { @@ -204,9 +250,12 @@ func generateTSTypeInternal(rtype reflect.Type) (string, []reflect.Type) { var buf bytes.Buffer waveObjType := reflect.TypeOf((*WaveObj)(nil)).Elem() buf.WriteString(fmt.Sprintf("type %s = {\n", rtype.Name())) + var isWaveObj bool if rtype.Implements(waveObjType) || reflect.PointerTo(rtype).Implements(waveObjType) { + isWaveObj = true buf.WriteString(fmt.Sprintf(" %s: string;\n", OTypeKeyName)) buf.WriteString(fmt.Sprintf(" %s: string;\n", OIDKeyName)) + buf.WriteString(fmt.Sprintf(" %s: number;\n", VersionKeyName)) } var subTypes []reflect.Type for i := 0; i < rtype.NumField(); i++ { @@ -218,6 +267,9 @@ func generateTSTypeInternal(rtype reflect.Type) (string, []reflect.Type) { if fieldName == "" { continue } + if isWaveObj && (fieldName == OTypeKeyName || fieldName == OIDKeyName || fieldName == VersionKeyName) { + continue + } optMarker := "" if isFieldOmitEmpty(field) { optMarker = "?" diff --git a/pkg/waveobj/waveobj_test.go b/pkg/waveobj/waveobj_test.go index e3fd4a7b6..d7db8059e 100644 --- a/pkg/waveobj/waveobj_test.go +++ b/pkg/waveobj/waveobj_test.go @@ -7,16 +7,23 @@ import ( "log" "reflect" "testing" - - "github.com/wavetermdev/thenextwave/pkg/wstore" ) +type TestBlock struct { + BlockId string `json:"blockid" waveobj:"oid"` + Name string `json:"name"` +} + +func (TestBlock) GetOType() string { + return "block" +} + func TestGenerate(t *testing.T) { log.Printf("Testing Generate\n") tsMap := make(map[reflect.Type]string) var waveObj WaveObj GenerateTSType(reflect.TypeOf(&waveObj).Elem(), tsMap) - GenerateTSType(reflect.TypeOf(wstore.Block{}), tsMap) + GenerateTSType(reflect.TypeOf(TestBlock{}), tsMap) for k, v := range tsMap { log.Printf("Type: %v, TS:\n%s\n", k, v) } diff --git a/pkg/wstore/wstore.go b/pkg/wstore/wstore.go index 142104699..d71e92677 100644 --- a/pkg/wstore/wstore.go +++ b/pkg/wstore/wstore.go @@ -6,31 +6,41 @@ package wstore import ( "context" "fmt" - "sync" "time" "github.com/google/uuid" "github.com/wavetermdev/thenextwave/pkg/shellexec" "github.com/wavetermdev/thenextwave/pkg/util/ds" + "github.com/wavetermdev/thenextwave/pkg/waveobj" ) var WorkspaceMap = ds.NewSyncMap[*Workspace]() var TabMap = ds.NewSyncMap[*Tab]() var BlockMap = ds.NewSyncMap[*Block]() +func init() { + waveobj.RegisterType[*Client]() + waveobj.RegisterType[*Window]() + waveobj.RegisterType[*Workspace]() + waveobj.RegisterType[*Tab]() + waveobj.RegisterType[*Block]() +} + type Client struct { - ClientId string `json:"clientid"` + OID string `json:"oid"` + Version int `json:"version"` MainWindowId string `json:"mainwindowid"` } -func (c Client) GetId() string { - return c.ClientId +func (*Client) GetOType() string { + return "client" } // stores the ui-context of the window // workspaceid, active tab, active block within each tab, window size, etc. type Window struct { - WindowId string `json:"windowid"` + OID string `json:"oid"` + Version int `json:"version"` WorkspaceId string `json:"workspaceid"` ActiveTabId string `json:"activetabid"` ActiveBlockMap map[string]string `json:"activeblockmap"` // map from tabid to blockid @@ -39,42 +49,30 @@ type Window struct { LastFocusTs int64 `json:"lastfocusts"` } -func (w Window) GetId() string { - return w.WindowId +func (*Window) GetOType() string { + return "window" } type Workspace struct { - Lock *sync.Mutex `json:"-"` - WorkspaceId string `json:"workspaceid"` - Name string `json:"name"` - TabIds []string `json:"tabids"` + OID string `json:"oid"` + Version int `json:"version"` + Name string `json:"name"` + TabIds []string `json:"tabids"` } -func (ws Workspace) GetId() string { - return ws.WorkspaceId -} - -func (ws *Workspace) WithLock(f func()) { - ws.Lock.Lock() - defer ws.Lock.Unlock() - f() +func (*Workspace) GetOType() string { + return "workspace" } type Tab struct { - Lock *sync.Mutex `json:"-"` - TabId string `json:"tabid"` - Name string `json:"name"` - BlockIds []string `json:"blockids"` + OID string `json:"oid"` + Version int `json:"version"` + Name string `json:"name"` + BlockIds []string `json:"blockids"` } -func (tab Tab) GetId() string { - return tab.TabId -} - -func (tab *Tab) WithLock(f func()) { - tab.Lock.Lock() - defer tab.Lock.Unlock() - f() +func (*Tab) GetOType() string { + return "tab" } type FileDef struct { @@ -108,7 +106,8 @@ type WinSize struct { } type Block struct { - BlockId string `json:"blockid"` + OID string `json:"oid"` + Version int `json:"version"` BlockDef *BlockDef `json:"blockdef"` Controller string `json:"controller"` View string `json:"view"` @@ -116,45 +115,32 @@ type Block struct { RuntimeOpts *RuntimeOpts `json:"runtimeopts,omitempty"` } -func (b *Block) GetOType() string { +func (*Block) GetOType() string { return "block" } -func (b Block) GetId() string { - return b.BlockId -} - -// TODO remove -func (b *Block) WithLock(f func()) { - f() -} - func CreateTab(workspaceId string, name string) (*Tab, error) { tab := &Tab{ - Lock: &sync.Mutex{}, - TabId: uuid.New().String(), + OID: uuid.New().String(), Name: name, BlockIds: []string{}, } - TabMap.Set(tab.TabId, tab) + TabMap.Set(tab.OID, tab) ws := WorkspaceMap.Get(workspaceId) if ws == nil { return nil, fmt.Errorf("workspace not found: %q", workspaceId) } - ws.WithLock(func() { - ws.TabIds = append(ws.TabIds, tab.TabId) - }) + ws.TabIds = append(ws.TabIds, tab.OID) return tab, nil } func CreateWorkspace() (*Workspace, error) { ws := &Workspace{ - Lock: &sync.Mutex{}, - WorkspaceId: uuid.New().String(), - TabIds: []string{}, + OID: uuid.New().String(), + TabIds: []string{}, } - WorkspaceMap.Set(ws.WorkspaceId, ws) - _, err := CreateTab(ws.WorkspaceId, "Tab 1") + WorkspaceMap.Set(ws.OID, ws) + _, err := CreateTab(ws.OID, "Tab 1") if err != nil { return nil, err } @@ -164,7 +150,7 @@ func CreateWorkspace() (*Workspace, error) { func EnsureInitialData() error { ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) defer cancelFn() - clientCount, err := DBGetCount[Client](ctx) + clientCount, err := DBGetCount[*Client](ctx) if err != nil { return fmt.Errorf("error getting client count: %w", err) } @@ -175,7 +161,7 @@ func EnsureInitialData() error { workspaceId := uuid.New().String() tabId := uuid.New().String() client := &Client{ - ClientId: uuid.New().String(), + OID: uuid.New().String(), MainWindowId: windowId, } err = DBInsert(ctx, client) @@ -183,7 +169,7 @@ func EnsureInitialData() error { return fmt.Errorf("error inserting client: %w", err) } window := &Window{ - WindowId: windowId, + OID: windowId, WorkspaceId: workspaceId, ActiveTabId: tabId, ActiveBlockMap: make(map[string]string), @@ -201,16 +187,16 @@ func EnsureInitialData() error { return fmt.Errorf("error inserting window: %w", err) } ws := &Workspace{ - WorkspaceId: workspaceId, - Name: "default", - TabIds: []string{tabId}, + OID: workspaceId, + Name: "default", + TabIds: []string{tabId}, } err = DBInsert(ctx, ws) if err != nil { return fmt.Errorf("error inserting workspace: %w", err) } tab := &Tab{ - TabId: uuid.New().String(), + OID: uuid.New().String(), Name: "Tab 1", BlockIds: []string{}, } diff --git a/pkg/wstore/wstore_dbops.go b/pkg/wstore/wstore_dbops.go index 003e319a5..3ae151239 100644 --- a/pkg/wstore/wstore_dbops.go +++ b/pkg/wstore/wstore_dbops.go @@ -6,155 +6,124 @@ package wstore import ( "context" "fmt" - "reflect" + + "github.com/wavetermdev/thenextwave/pkg/waveobj" ) -const Table_Client = "db_client" -const Table_Workspace = "db_workspace" -const Table_Tab = "db_tab" -const Table_Block = "db_block" -const Table_Window = "db_window" - -// can replace with struct tags in the future -type ObjectWithId interface { - GetId() string +func waveObjTableName(w waveobj.WaveObj) string { + return "db_" + w.GetOType() } -// can replace these with struct tags in the future -var idColumnName = map[string]string{ - Table_Client: "clientid", - Table_Workspace: "workspaceid", - Table_Tab: "tabid", - Table_Block: "blockid", - Table_Window: "windowid", +func tableNameGen[T waveobj.WaveObj]() string { + var zeroObj T + return "db_" + zeroObj.GetOType() } -var tableToType = map[string]reflect.Type{ - Table_Client: reflect.TypeOf(Client{}), - Table_Workspace: reflect.TypeOf(Workspace{}), - Table_Tab: reflect.TypeOf(Tab{}), - Table_Block: reflect.TypeOf(Block{}), - Table_Window: reflect.TypeOf(Window{}), -} - -var typeToTable map[reflect.Type]string - -func init() { - typeToTable = make(map[reflect.Type]string) - for k, v := range tableToType { - typeToTable[v] = k - } -} - -func DBGetCount[T ObjectWithId](ctx context.Context) (int, error) { +func DBGetCount[T waveobj.WaveObj](ctx context.Context) (int, error) { return WithTxRtn(ctx, func(tx *TxWrap) (int, error) { - var valInstance T - table := typeToTable[reflect.TypeOf(valInstance)] - if table == "" { - return 0, fmt.Errorf("unknown table type: %T", valInstance) - } + table := tableNameGen[T]() query := fmt.Sprintf("SELECT count(*) FROM %s", table) return tx.GetInt(query), nil }) } -func DBGetSingleton[T ObjectWithId](ctx context.Context) (*T, error) { - return WithTxRtn(ctx, func(tx *TxWrap) (*T, error) { - var rtn T - query := fmt.Sprintf("SELECT data FROM %s LIMIT 1", typeToTable[reflect.TypeOf(rtn)]) - jsonData := tx.GetString(query) - return TxReadJson[T](tx, jsonData), nil - }) -} - -func DBGet[T ObjectWithId](ctx context.Context, id string) (*T, error) { - return WithTxRtn(ctx, func(tx *TxWrap) (*T, error) { - var rtn T - table := typeToTable[reflect.TypeOf(rtn)] - if table == "" { - return nil, fmt.Errorf("unknown table type: %T", rtn) - } - query := fmt.Sprintf("SELECT data FROM %s WHERE %s = ?", table, idColumnName[table]) - jsonData := tx.GetString(query, id) - return TxReadJson[T](tx, jsonData), nil - }) -} - type idDataType struct { - Id string - Data string + OId string + Version int + Data []byte } -func DBSelectMap[T ObjectWithId](ctx context.Context, ids []string) (map[string]*T, error) { - return WithTxRtn(ctx, func(tx *TxWrap) (map[string]*T, error) { - var valInstance T - table := typeToTable[reflect.TypeOf(valInstance)] - if table == "" { - return nil, fmt.Errorf("unknown table type: %T", &valInstance) +func DBGetSingleton[T waveobj.WaveObj](ctx context.Context) (T, error) { + return WithTxRtn(ctx, func(tx *TxWrap) (T, error) { + table := tableNameGen[T]() + query := fmt.Sprintf("SELECT oid, version, data FROM %s LIMIT 1", table) + var row idDataType + tx.Get(&row, query) + rtn, err := waveobj.FromJsonGen[T](row.Data) + if err != nil { + return rtn, err } + waveobj.SetVersion(rtn, row.Version) + return rtn, nil + }) +} + +func DBGet[T waveobj.WaveObj](ctx context.Context, id string) (T, error) { + return WithTxRtn(ctx, func(tx *TxWrap) (T, error) { + table := tableNameGen[T]() + query := fmt.Sprintf("SELECT oid, version, data FROM %s WHERE oid = ?", table) + var row idDataType + tx.Get(&row, query, id) + rtn, err := waveobj.FromJsonGen[T](row.Data) + if err != nil { + return rtn, err + } + waveobj.SetVersion(rtn, row.Version) + return rtn, nil + }) +} + +func DBSelectMap[T waveobj.WaveObj](ctx context.Context, ids []string) (map[string]T, error) { + return WithTxRtn(ctx, func(tx *TxWrap) (map[string]T, error) { + table := tableNameGen[T]() var rows []idDataType - query := fmt.Sprintf("SELECT %s, data FROM %s WHERE %s IN (SELECT value FROM json_each(?))", idColumnName[table], table, idColumnName[table]) + query := fmt.Sprintf("SELECT oid, version, data FROM %s WHERE oid IN (SELECT value FROM json_each(?))", table) tx.Select(&rows, query, ids) - rtnMap := make(map[string]*T) + rtnMap := make(map[string]T) for _, row := range rows { - if row.Id == "" || row.Data == "" { + if row.OId == "" || len(row.Data) == 0 { continue } - r := TxReadJson[T](tx, row.Data) - if r == nil { - continue + waveObj, err := waveobj.FromJsonGen[T](row.Data) + if err != nil { + return nil, err } - rtnMap[(*r).GetId()] = r + waveobj.SetVersion(waveObj, row.Version) + rtnMap[row.OId] = waveObj } return rtnMap, nil }) } -func DBDelete[T ObjectWithId](ctx context.Context, id string) error { +func DBDelete[T waveobj.WaveObj](ctx context.Context, id string) error { return WithTx(ctx, func(tx *TxWrap) error { - var rtn T - table := typeToTable[reflect.TypeOf(rtn)] - if table == "" { - return fmt.Errorf("unknown table type: %T", rtn) - } - query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, idColumnName[table]) + table := tableNameGen[T]() + query := fmt.Sprintf("DELETE FROM %s WHERE oid = ?", table) tx.Exec(query, id) return nil }) } -func DBUpdate[T ObjectWithId](ctx context.Context, val *T) error { - if val == nil { - return fmt.Errorf("cannot update nil value") - } - if (*val).GetId() == "" { +func DBUpdate(ctx context.Context, val waveobj.WaveObj) error { + oid := waveobj.GetOID(val) + if oid == "" { return fmt.Errorf("cannot update %T value with empty id", val) } + jsonData, err := waveobj.ToJson(val) + if err != nil { + return err + } return WithTx(ctx, func(tx *TxWrap) error { - table := typeToTable[reflect.TypeOf(*val)] - if table == "" { - return fmt.Errorf("unknown table type: %T", *val) - } - query := fmt.Sprintf("UPDATE %s SET data = ? WHERE %s = ?", table, idColumnName[table]) - tx.Exec(query, TxJson(tx, val), (*val).GetId()) + table := waveObjTableName(val) + query := fmt.Sprintf("UPDATE %s SET data = ?, version = version+1 WHERE oid = ?", table) + tx.Exec(query, jsonData, oid) return nil }) } -func DBInsert[T ObjectWithId](ctx context.Context, val *T) error { - if val == nil { - return fmt.Errorf("cannot insert nil value") - } - if (*val).GetId() == "" { +func DBInsert[T waveobj.WaveObj](ctx context.Context, val T) error { + oid := waveobj.GetOID(val) + if oid == "" { return fmt.Errorf("cannot insert %T value with empty id", val) } + jsonData, err := waveobj.ToJson(val) + if err != nil { + return err + } return WithTx(ctx, func(tx *TxWrap) error { - table := typeToTable[reflect.TypeOf(*val)] - if table == "" { - return fmt.Errorf("unknown table type: %T", *val) - } - query := fmt.Sprintf("INSERT INTO %s (%s, data) VALUES (?, ?)", table, idColumnName[table]) - tx.Exec(query, (*val).GetId(), TxJson(tx, val)) + table := waveObjTableName(val) + query := fmt.Sprintf("INSERT INTO %s (oid, version, data) VALUES (?, ?, ?)", table) + tx.Exec(query, oid, 1, jsonData) return nil }) } diff --git a/pkg/wstore/wstore_dbsetup.go b/pkg/wstore/wstore_dbsetup.go index 40ae85300..d11414909 100644 --- a/pkg/wstore/wstore_dbsetup.go +++ b/pkg/wstore/wstore_dbsetup.go @@ -5,7 +5,6 @@ package wstore import ( "context" - "encoding/json" "fmt" "log" "path" @@ -63,24 +62,3 @@ func WithTx(ctx context.Context, fn func(tx *TxWrap) error) error { func WithTxRtn[RT any](ctx context.Context, fn func(tx *TxWrap) (RT, error)) (RT, error) { return txwrap.WithTxRtn(ctx, globalDB, fn) } - -func TxJson(tx *TxWrap, v any) string { - barr, err := json.Marshal(v) - if err != nil { - tx.SetErr(fmt.Errorf("json marshal (%T): %w", v, err)) - return "" - } - return string(barr) -} - -func TxReadJson[T any](tx *TxWrap, jsonData string) *T { - if jsonData == "" { - return nil - } - var v T - err := json.Unmarshal([]byte(jsonData), &v) - if err != nil { - tx.SetErr(fmt.Errorf("json unmarshal (%T): %w", v, err)) - } - return &v -}