2023-03-08 23:38:46 +01:00
|
|
|
package promptenc
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/cipher"
|
|
|
|
"crypto/rand"
|
2023-03-09 02:16:06 +01:00
|
|
|
"encoding/base64"
|
2023-03-08 23:38:46 +01:00
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"reflect"
|
|
|
|
|
|
|
|
ccp "golang.org/x/crypto/chacha20poly1305"
|
|
|
|
)
|
|
|
|
|
|
|
|
const EncTagName = "enc"
|
|
|
|
const EncFieldIndicator = "*"
|
|
|
|
|
|
|
|
type Encryptor struct {
|
|
|
|
Key []byte
|
|
|
|
AEAD cipher.AEAD
|
|
|
|
}
|
|
|
|
|
|
|
|
type HasOData interface {
|
|
|
|
GetOData() string
|
|
|
|
}
|
|
|
|
|
|
|
|
func readRandBytes(n int) ([]byte, error) {
|
|
|
|
rtn := make([]byte, n)
|
|
|
|
_, err := io.ReadFull(rand.Reader, rtn)
|
|
|
|
return rtn, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func MakeRandomEncryptor() (*Encryptor, error) {
|
|
|
|
key, err := readRandBytes(ccp.KeySize)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
rtn := &Encryptor{Key: key}
|
|
|
|
rtn.AEAD, err = ccp.NewX(rtn.Key)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return rtn, nil
|
|
|
|
}
|
|
|
|
|
2023-03-09 02:16:06 +01:00
|
|
|
func MakeEncryptor(key []byte) (*Encryptor, error) {
|
2023-03-10 03:44:01 +01:00
|
|
|
var err error
|
2023-03-09 02:16:06 +01:00
|
|
|
rtn := &Encryptor{Key: key}
|
|
|
|
rtn.AEAD, err = ccp.NewX(rtn.Key)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return rtn, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func MakeEncryptorB64(key64 string) (*Encryptor, error) {
|
|
|
|
keyBytes, err := base64.RawURLEncoding.DecodeString(key64)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return MakeEncryptor(keyBytes)
|
|
|
|
}
|
|
|
|
|
2023-03-08 23:38:46 +01:00
|
|
|
func (enc *Encryptor) EncryptData(plainText []byte, odata string) ([]byte, error) {
|
|
|
|
outputBuf := make([]byte, enc.AEAD.NonceSize()+enc.AEAD.Overhead()+len(plainText))
|
|
|
|
nonce := outputBuf[0:enc.AEAD.NonceSize()]
|
|
|
|
_, err := io.ReadFull(rand.Reader, nonce)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
// we're going to append the cipherText to nonce. so the encrypted data is [nonce][ciphertext]
|
|
|
|
// note that outputbuf should be the correct size to hold the rtn value
|
|
|
|
rtn := enc.AEAD.Seal(nonce, nonce, plainText, []byte(odata))
|
|
|
|
return rtn, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (enc *Encryptor) DecryptData(encData []byte, odata string) (map[string]interface{}, error) {
|
|
|
|
minLen := enc.AEAD.NonceSize() + enc.AEAD.Overhead()
|
|
|
|
if len(encData) < minLen {
|
|
|
|
return nil, fmt.Errorf("invalid encdata, len:%d is less than minimum len:%d", len(encData), minLen)
|
|
|
|
}
|
|
|
|
m := make(map[string]interface{})
|
|
|
|
nonce := encData[0:enc.AEAD.NonceSize()]
|
|
|
|
cipherText := encData[enc.AEAD.NonceSize():]
|
|
|
|
plainText, err := enc.AEAD.Open(nil, nonce, cipherText, []byte(odata))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
err = json.Unmarshal(plainText, &m)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return m, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type EncryptMeta struct {
|
|
|
|
EncField *reflect.StructField
|
|
|
|
PlainFields map[string]reflect.StructField
|
|
|
|
}
|
|
|
|
|
|
|
|
func isByteArrayType(t reflect.Type) bool {
|
|
|
|
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
|
|
|
|
}
|
|
|
|
|
|
|
|
func metaFromType(v interface{}) (*EncryptMeta, error) {
|
|
|
|
if v == nil {
|
|
|
|
return nil, fmt.Errorf("Encryptor cannot encrypt nil")
|
|
|
|
}
|
|
|
|
rt := reflect.TypeOf(v)
|
|
|
|
if rt.Kind() != reflect.Pointer {
|
|
|
|
return nil, fmt.Errorf("Encryptor invalid type %T, not a pointer type", v)
|
|
|
|
}
|
|
|
|
rtElem := rt.Elem()
|
|
|
|
if rtElem.Kind() != reflect.Struct {
|
|
|
|
return nil, fmt.Errorf("Encryptor invalid type %T, not a pointer to struct type", v)
|
|
|
|
}
|
|
|
|
meta := &EncryptMeta{}
|
|
|
|
meta.PlainFields = make(map[string]reflect.StructField)
|
|
|
|
numFields := rtElem.NumField()
|
|
|
|
for i := 0; i < numFields; i++ {
|
|
|
|
field := rtElem.Field(i)
|
|
|
|
encTag := field.Tag.Get(EncTagName)
|
|
|
|
if encTag == "" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if encTag == EncFieldIndicator {
|
|
|
|
if meta.EncField != nil {
|
|
|
|
return nil, fmt.Errorf("Encryptor, type %T has two enc fields set (*)", v)
|
|
|
|
}
|
|
|
|
if !isByteArrayType(field.Type) {
|
|
|
|
return nil, fmt.Errorf("Encryptor, type %T enc field %q is not []byte", v, field.Name)
|
|
|
|
}
|
|
|
|
meta.EncField = &field
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if _, found := meta.PlainFields[encTag]; found {
|
|
|
|
return nil, fmt.Errorf("Encryptor, type %T has two enc fields with tag %q", v, encTag)
|
|
|
|
}
|
|
|
|
meta.PlainFields[encTag] = field
|
|
|
|
}
|
|
|
|
if meta.EncField == nil {
|
|
|
|
return nil, fmt.Errorf("Encryptor, type %T has no enc (*) field", v)
|
|
|
|
}
|
|
|
|
return meta, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (enc *Encryptor) EncryptODS(v HasOData) error {
|
|
|
|
odata := v.GetOData()
|
|
|
|
return enc.EncryptStructFields(v, odata)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (enc *Encryptor) DecryptODS(v HasOData) error {
|
|
|
|
odata := v.GetOData()
|
|
|
|
return enc.DecryptStructFields(v, odata)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (enc *Encryptor) EncryptStructFields(v interface{}, odata string) error {
|
|
|
|
encMeta, err := metaFromType(v)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
rvPtr := reflect.ValueOf(v)
|
|
|
|
rv := rvPtr.Elem()
|
|
|
|
m := make(map[string]interface{})
|
|
|
|
for jsonKey, field := range encMeta.PlainFields {
|
|
|
|
fieldVal := rv.FieldByIndex(field.Index)
|
|
|
|
m[jsonKey] = fieldVal.Interface()
|
|
|
|
}
|
|
|
|
barr, err := json.Marshal(m)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
cipherText, err := enc.EncryptData(barr, odata)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
encFieldValue := rv.FieldByIndex(encMeta.EncField.Index)
|
|
|
|
encFieldValue.SetBytes(cipherText)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (enc *Encryptor) DecryptStructFields(v interface{}, odata string) error {
|
|
|
|
encMeta, err := metaFromType(v)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
rvPtr := reflect.ValueOf(v)
|
|
|
|
rv := rvPtr.Elem()
|
|
|
|
cipherText := rv.FieldByIndex(encMeta.EncField.Index).Bytes()
|
|
|
|
m, err := enc.DecryptData(cipherText, odata)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
for jsonKey, field := range encMeta.PlainFields {
|
|
|
|
val := m[jsonKey]
|
|
|
|
rv.FieldByIndex(field.Index).Set(reflect.ValueOf(val))
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|