waveterm/pkg/util/dbutil/dbmappable.go
2024-08-08 18:24:54 -07:00

238 lines
5.4 KiB
Go

// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package dbutil
import (
"fmt"
"reflect"
"strings"
"github.com/sawka/txwrap"
)
type DBMappable interface {
UseDBMap()
}
type MapEntry[T any] struct {
Key string
Val T
}
type MapConverter interface {
ToMap() map[string]interface{}
FromMap(map[string]interface{}) bool
}
type HasSimpleKey interface {
GetSimpleKey() string
}
type HasSimpleInt64Key interface {
GetSimpleKey() int64
}
type MapConverterPtr[T any] interface {
MapConverter
*T
}
type DBMappablePtr[T any] interface {
DBMappable
*T
}
func FromMap[PT MapConverterPtr[T], T any](m map[string]any) PT {
if len(m) == 0 {
return nil
}
rtn := PT(new(T))
ok := rtn.FromMap(m)
if !ok {
return nil
}
return rtn
}
func GetMapGen[PT MapConverterPtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) PT {
m := tx.GetMap(query, args...)
return FromMap[PT](m)
}
func GetMappable[PT DBMappablePtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) PT {
m := tx.GetMap(query, args...)
if len(m) == 0 {
return nil
}
rtn := PT(new(T))
FromDBMap(rtn, m)
return rtn
}
func SelectMappable[PT DBMappablePtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) []PT {
var rtn []PT
marr := tx.SelectMaps(query, args...)
for _, m := range marr {
if len(m) == 0 {
continue
}
val := PT(new(T))
FromDBMap(val, m)
rtn = append(rtn, val)
}
return rtn
}
func SelectMapsGen[PT MapConverterPtr[T], T any](tx *txwrap.TxWrap, query string, args ...interface{}) []PT {
var rtn []PT
marr := tx.SelectMaps(query, args...)
for _, m := range marr {
val := FromMap[PT](m)
if val != nil {
rtn = append(rtn, val)
}
}
return rtn
}
func SelectSimpleMap[T any](tx *txwrap.TxWrap, query string, args ...interface{}) map[string]T {
var rtn []MapEntry[T]
tx.Select(&rtn, query, args...)
if len(rtn) == 0 {
return nil
}
rtnMap := make(map[string]T)
for _, entry := range rtn {
rtnMap[entry.Key] = entry.Val
}
return rtnMap
}
func MakeGenMap[T HasSimpleKey](arr []T) map[string]T {
rtn := make(map[string]T)
for _, val := range arr {
rtn[val.GetSimpleKey()] = val
}
return rtn
}
func MakeGenMapInt64[T HasSimpleInt64Key](arr []T) map[int64]T {
rtn := make(map[int64]T)
for _, val := range arr {
rtn[val.GetSimpleKey()] = val
}
return rtn
}
func isStructType(rt reflect.Type) bool {
if rt.Kind() == reflect.Struct {
return true
}
if rt.Kind() == reflect.Pointer && rt.Elem().Kind() == reflect.Struct {
return true
}
return false
}
func isByteArrayType(t reflect.Type) bool {
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
}
func isStringMapType(t reflect.Type) bool {
return t.Kind() == reflect.Map && t.Key().Kind() == reflect.String
}
func ToDBMap(v DBMappable, useBytes bool) map[string]interface{} {
if CheckNil(v) {
return nil
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructToDBMap", v))
}
rt := rv.Type()
m := make(map[string]interface{})
numFields := rt.NumField()
for i := 0; i < numFields; i++ {
field := rt.Field(i)
fieldVal := rv.FieldByIndex(field.Index)
dbName := field.Tag.Get("dbmap")
if dbName == "" {
dbName = strings.ToLower(field.Name)
}
if dbName == "-" {
continue
}
if isByteArrayType(field.Type) {
m[dbName] = fieldVal.Interface()
} else if field.Type.Kind() == reflect.Slice {
if useBytes {
m[dbName] = QuickJsonArrBytes(fieldVal.Interface())
} else {
m[dbName] = QuickJsonArr(fieldVal.Interface())
}
} else if isStructType(field.Type) || isStringMapType(field.Type) {
if useBytes {
m[dbName] = QuickJsonBytes(fieldVal.Interface())
} else {
m[dbName] = QuickJson(fieldVal.Interface())
}
} else {
m[dbName] = fieldVal.Interface()
}
}
return m
}
func FromDBMap(v DBMappable, m map[string]interface{}) {
if CheckNil(v) {
panic("StructFromDBMap, v cannot be nil")
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
panic(fmt.Sprintf("invalid type %T (non-struct) passed to StructFromDBMap", v))
}
rt := rv.Type()
numFields := rt.NumField()
for i := 0; i < numFields; i++ {
field := rt.Field(i)
fieldVal := rv.FieldByIndex(field.Index)
dbName := field.Tag.Get("dbmap")
if dbName == "" {
dbName = strings.ToLower(field.Name)
}
if dbName == "-" {
continue
}
if isByteArrayType(field.Type) {
barrVal := fieldVal.Addr().Interface()
QuickSetBytes(barrVal.(*[]byte), m, dbName)
} else if field.Type.Kind() == reflect.Slice {
QuickSetJsonArr(fieldVal.Addr().Interface(), m, dbName)
} else if isStructType(field.Type) || isStringMapType(field.Type) {
QuickSetJson(fieldVal.Addr().Interface(), m, dbName)
} else if field.Type.Kind() == reflect.String {
strVal := fieldVal.Addr().Interface()
QuickSetStr(strVal.(*string), m, dbName)
} else if field.Type.Kind() == reflect.Int64 {
intVal := fieldVal.Addr().Interface()
QuickSetInt64(intVal.(*int64), m, dbName)
} else if field.Type.Kind() == reflect.Int {
intVal := fieldVal.Addr().Interface()
QuickSetInt(intVal.(*int), m, dbName)
} else if field.Type.Kind() == reflect.Bool {
boolVal := fieldVal.Addr().Interface()
QuickSetBool(boolVal.(*bool), m, dbName)
} else {
panic(fmt.Sprintf("StructFromDBMap invalid field type %v in %T", fieldVal.Type(), v))
}
}
}