waveterm/pkg/waveobj/ctxupdate.go
2024-08-29 18:23:12 -07:00

154 lines
3.9 KiB
Go

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package waveobj
import (
"bytes"
"context"
"fmt"
"log"
)
var waveObjUpdateKey = struct{}{}
type contextUpdatesType struct {
UpdatesStack []map[ORef]WaveObjUpdate
}
func dumpUpdateStack(updates *contextUpdatesType) {
log.Printf("dumpUpdateStack len:%d\n", len(updates.UpdatesStack))
for idx, update := range updates.UpdatesStack {
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf(" [%d]:", idx))
for k := range update {
buf.WriteString(fmt.Sprintf(" %s:%s", k.OType, k.OID))
}
buf.WriteString("\n")
log.Print(buf.String())
}
}
func ContextWithUpdates(ctx context.Context) context.Context {
updatesVal := ctx.Value(waveObjUpdateKey)
if updatesVal != nil {
return ctx
}
return context.WithValue(ctx, waveObjUpdateKey, &contextUpdatesType{
UpdatesStack: []map[ORef]WaveObjUpdate{make(map[ORef]WaveObjUpdate)},
})
}
func ContextGetUpdates(ctx context.Context) map[ORef]WaveObjUpdate {
updatesVal := ctx.Value(waveObjUpdateKey)
if updatesVal == nil {
return nil
}
updates := updatesVal.(*contextUpdatesType)
if len(updates.UpdatesStack) == 1 {
return updates.UpdatesStack[0]
}
rtn := make(map[ORef]WaveObjUpdate)
for _, update := range updates.UpdatesStack {
for k, v := range update {
rtn[k] = v
}
}
return rtn
}
func ContextGetUpdate(ctx context.Context, oref ORef) *WaveObjUpdate {
updatesVal := ctx.Value(waveObjUpdateKey)
if updatesVal == nil {
return nil
}
updates := updatesVal.(*contextUpdatesType)
for idx := len(updates.UpdatesStack) - 1; idx >= 0; idx-- {
if obj, ok := updates.UpdatesStack[idx][oref]; ok {
return &obj
}
}
return nil
}
func ContextAddUpdate(ctx context.Context, update WaveObjUpdate) {
updatesVal := ctx.Value(waveObjUpdateKey)
if updatesVal == nil {
return
}
updates := updatesVal.(*contextUpdatesType)
oref := ORef{
OType: update.OType,
OID: update.OID,
}
updates.UpdatesStack[len(updates.UpdatesStack)-1][oref] = update
}
func ContextUpdatesBeginTx(ctx context.Context) context.Context {
updatesVal := ctx.Value(waveObjUpdateKey)
if updatesVal == nil {
return ctx
}
updates := updatesVal.(*contextUpdatesType)
updates.UpdatesStack = append(updates.UpdatesStack, make(map[ORef]WaveObjUpdate))
return ctx
}
func ContextUpdatesCommitTx(ctx context.Context) {
updatesVal := ctx.Value(waveObjUpdateKey)
if updatesVal == nil {
return
}
updates := updatesVal.(*contextUpdatesType)
if len(updates.UpdatesStack) <= 1 {
panic(fmt.Errorf("no updates transaction to commit"))
}
// merge the last two updates
curUpdateMap := updates.UpdatesStack[len(updates.UpdatesStack)-1]
prevUpdateMap := updates.UpdatesStack[len(updates.UpdatesStack)-2]
for k, v := range curUpdateMap {
prevUpdateMap[k] = v
}
updates.UpdatesStack = updates.UpdatesStack[:len(updates.UpdatesStack)-1]
}
func ContextUpdatesRollbackTx(ctx context.Context) {
updatesVal := ctx.Value(waveObjUpdateKey)
if updatesVal == nil {
return
}
updates := updatesVal.(*contextUpdatesType)
if len(updates.UpdatesStack) <= 1 {
panic(fmt.Errorf("no updates transaction to rollback"))
}
updates.UpdatesStack = updates.UpdatesStack[:len(updates.UpdatesStack)-1]
}
func ContextGetUpdatesRtn(ctx context.Context) UpdatesRtnType {
updatesMap := ContextGetUpdates(ctx)
if updatesMap == nil {
return nil
}
rtn := make(UpdatesRtnType, 0, len(updatesMap))
for _, v := range updatesMap {
rtn = append(rtn, v)
}
return rtn
}
func ContextPrintUpdates(ctx context.Context) {
updatesVal := ctx.Value(waveObjUpdateKey)
if updatesVal == nil {
log.Print("no updates\n")
return
}
updates := updatesVal.(*contextUpdatesType)
log.Printf("updates len:%d\n", len(updates.UpdatesStack))
for idx, update := range updates.UpdatesStack {
log.Printf(" update[%d]:\n", idx)
for k, v := range update {
log.Printf(" %s:%s %s\n", k.OType, k.OID, v.UpdateType)
}
}
}