waveterm/pkg/sstore/map.go

203 lines
4.5 KiB
Go

package sstore
import (
"context"
"fmt"
"reflect"
"strings"
)
type DBMappable interface {
UseDBMap()
}
type MapConverter interface {
ToMap() map[string]interface{}
FromMap(map[string]interface{}) bool
}
type HasSimpleKey interface {
GetSimpleKey() string
}
type MapConverterPtr[T any] interface {
MapConverter
*T
}
type DBMappablePtr[T any] interface {
DBMappable
*T
}
func FromMap[PT MapConverterPtr[T], T any](m map[string]any) PT {
if len(m) == 0 {
return nil
}
rtn := PT(new(T))
ok := rtn.FromMap(m)
if !ok {
return nil
}
return rtn
}
func GetMapGen[PT MapConverterPtr[T], T any](tx *TxWrap, query string, args ...interface{}) PT {
m := tx.GetMap(query, args...)
return FromMap[PT](m)
}
func GetMappable[PT DBMappablePtr[T], T any](tx *TxWrap, query string, args ...interface{}) PT {
rtn := PT(new(T))
m := tx.GetMap(query, args...)
if len(m) == 0 {
return nil
}
FromDBMap(rtn, m)
return rtn
}
func SelectMapsGen[PT MapConverterPtr[T], T any](tx *TxWrap, query string, args ...interface{}) []PT {
var rtn []PT
marr := tx.SelectMaps(query, args...)
for _, m := range marr {
val := FromMap[PT](m)
if val != nil {
rtn = append(rtn, val)
}
}
return rtn
}
func MakeGenMap[T HasSimpleKey](arr []T) map[string]T {
rtn := make(map[string]T)
for _, val := range arr {
rtn[val.GetSimpleKey()] = val
}
return rtn
}
func WithTxRtn[RT any](ctx context.Context, fn func(tx *TxWrap) (RT, error)) (RT, error) {
var rtn RT
txErr := WithTx(ctx, func(tx *TxWrap) error {
temp, err := fn(tx)
if err != nil {
return err
}
rtn = temp
return nil
})
return rtn, txErr
}
func WithTxRtn3[RT1 any, RT2 any](ctx context.Context, fn func(tx *TxWrap) (RT1, RT2, error)) (RT1, RT2, error) {
var rtn1 RT1
var rtn2 RT2
txErr := WithTx(ctx, func(tx *TxWrap) error {
temp1, temp2, err := fn(tx)
if err != nil {
return err
}
rtn1 = temp1
rtn2 = temp2
return nil
})
return rtn1, rtn2, txErr
}
func isStructType(rt reflect.Type) bool {
if rt.Kind() == reflect.Struct {
return true
}
if rt.Kind() == reflect.Pointer && rt.Elem().Kind() == reflect.Struct {
return true
}
return false
}
func isByteArrayType(t reflect.Type) bool {
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
}
func ToDBMap(v DBMappable) map[string]interface{} {
if v == nil {
return nil
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructToDBMap", v))
}
rt := rv.Type()
m := make(map[string]interface{})
numFields := rt.NumField()
for i := 0; i < numFields; i++ {
field := rt.Field(i)
fieldVal := rv.FieldByIndex(field.Index)
dbName := field.Tag.Get("dbmap")
if dbName == "" {
dbName = strings.ToLower(field.Name)
}
if dbName == "-" {
continue
}
if isByteArrayType(field.Type) {
m[dbName] = fieldVal.Interface()
} else if field.Type.Kind() == reflect.Slice {
m[dbName] = quickJsonArr(fieldVal.Interface())
} else if isStructType(field.Type) {
m[dbName] = quickJson(fieldVal.Interface())
} else {
m[dbName] = fieldVal.Interface()
}
}
return m
}
func FromDBMap(v DBMappable, m map[string]interface{}) {
if v == nil {
panic("StructFromDBMap, v cannot be nil")
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructFromDBMap", v))
}
rt := rv.Type()
numFields := rt.NumField()
for i := 0; i < numFields; i++ {
field := rt.Field(i)
fieldVal := rv.FieldByIndex(field.Index)
dbName := field.Tag.Get("dbmap")
if dbName == "" {
dbName = strings.ToLower(field.Name)
}
if dbName == "-" {
continue
}
if isByteArrayType(field.Type) {
barrVal := fieldVal.Addr().Interface()
quickSetBytes(barrVal.(*[]byte), m, dbName)
} else if field.Type.Kind() == reflect.Slice {
quickSetJsonArr(fieldVal.Addr().Interface(), m, dbName)
} else if isStructType(field.Type) {
quickSetJson(fieldVal.Addr().Interface(), m, dbName)
} else if field.Type.Kind() == reflect.String {
strVal := fieldVal.Addr().Interface()
quickSetStr(strVal.(*string), m, dbName)
} else if field.Type.Kind() == reflect.Int64 {
intVal := fieldVal.Addr().Interface()
quickSetInt64(intVal.(*int64), m, dbName)
} else if field.Type.Kind() == reflect.Bool {
boolVal := fieldVal.Addr().Interface()
quickSetBool(boolVal.(*bool), m, dbName)
} else {
panic(fmt.Sprintf("StructFromDBMap invalid field type %v in %T", fieldVal.Type(), v))
}
}
}