waveterm/wavesrv/pkg/promptenc/promptenc.go
2023-12-12 16:51:19 -08:00

208 lines
5.2 KiB
Go

// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package promptenc
import (
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"reflect"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
ccp "golang.org/x/crypto/chacha20poly1305"
)
const EncTagName = "enc"
const EncFieldIndicator = "*"
type Encryptor struct {
Key []byte
AEAD cipher.AEAD
}
type HasOData interface {
GetOData() string
}
func readRandBytes(n int) ([]byte, error) {
rtn := make([]byte, n)
_, err := io.ReadFull(rand.Reader, rtn)
return rtn, err
}
func MakeRandomEncryptor() (*Encryptor, error) {
key, err := readRandBytes(ccp.KeySize)
if err != nil {
return nil, err
}
rtn := &Encryptor{Key: key}
rtn.AEAD, err = ccp.NewX(rtn.Key)
if err != nil {
return nil, err
}
return rtn, nil
}
func MakeEncryptor(key []byte) (*Encryptor, error) {
var err error
rtn := &Encryptor{Key: key}
rtn.AEAD, err = ccp.NewX(rtn.Key)
if err != nil {
return nil, err
}
return rtn, nil
}
func MakeEncryptorB64(key64 string) (*Encryptor, error) {
keyBytes, err := base64.RawURLEncoding.DecodeString(key64)
if err != nil {
return nil, err
}
return MakeEncryptor(keyBytes)
}
func (enc *Encryptor) EncryptData(plainText []byte, odata string) ([]byte, error) {
bufSize, err := utilfn.AddIntSlice(enc.AEAD.NonceSize(), enc.AEAD.Overhead(), len(plainText))
if err != nil {
return nil, err
}
outputBuf := make([]byte, bufSize)
nonce := outputBuf[0:enc.AEAD.NonceSize()]
_, err = io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, err
}
// we're going to append the cipherText to nonce. so the encrypted data is [nonce][ciphertext]
// note that outputbuf should be the correct size to hold the rtn value
rtn := enc.AEAD.Seal(nonce, nonce, plainText, []byte(odata))
return rtn, nil
}
func (enc *Encryptor) DecryptData(encData []byte, odata string) (map[string]interface{}, error) {
minLen := enc.AEAD.NonceSize() + enc.AEAD.Overhead()
if len(encData) < minLen {
return nil, fmt.Errorf("invalid encdata, len:%d is less than minimum len:%d", len(encData), minLen)
}
m := make(map[string]interface{})
nonce := encData[0:enc.AEAD.NonceSize()]
cipherText := encData[enc.AEAD.NonceSize():]
plainText, err := enc.AEAD.Open(nil, nonce, cipherText, []byte(odata))
if err != nil {
return nil, err
}
err = json.Unmarshal(plainText, &m)
if err != nil {
return nil, err
}
return m, nil
}
type EncryptMeta struct {
EncField *reflect.StructField
PlainFields map[string]reflect.StructField
}
func isByteArrayType(t reflect.Type) bool {
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
}
func metaFromType(v interface{}) (*EncryptMeta, error) {
if v == nil {
return nil, fmt.Errorf("Encryptor cannot encrypt nil")
}
rt := reflect.TypeOf(v)
if rt.Kind() != reflect.Pointer {
return nil, fmt.Errorf("Encryptor invalid type %T, not a pointer type", v)
}
rtElem := rt.Elem()
if rtElem.Kind() != reflect.Struct {
return nil, fmt.Errorf("Encryptor invalid type %T, not a pointer to struct type", v)
}
meta := &EncryptMeta{}
meta.PlainFields = make(map[string]reflect.StructField)
numFields := rtElem.NumField()
for i := 0; i < numFields; i++ {
field := rtElem.Field(i)
encTag := field.Tag.Get(EncTagName)
if encTag == "" {
continue
}
if encTag == EncFieldIndicator {
if meta.EncField != nil {
return nil, fmt.Errorf("Encryptor, type %T has two enc fields set (*)", v)
}
if !isByteArrayType(field.Type) {
return nil, fmt.Errorf("Encryptor, type %T enc field %q is not []byte", v, field.Name)
}
meta.EncField = &field
continue
}
if _, found := meta.PlainFields[encTag]; found {
return nil, fmt.Errorf("Encryptor, type %T has two enc fields with tag %q", v, encTag)
}
meta.PlainFields[encTag] = field
}
if meta.EncField == nil {
return nil, fmt.Errorf("Encryptor, type %T has no enc (*) field", v)
}
return meta, nil
}
func (enc *Encryptor) EncryptODS(v HasOData) error {
odata := v.GetOData()
return enc.EncryptStructFields(v, odata)
}
func (enc *Encryptor) DecryptODS(v HasOData) error {
odata := v.GetOData()
return enc.DecryptStructFields(v, odata)
}
func (enc *Encryptor) EncryptStructFields(v interface{}, odata string) error {
encMeta, err := metaFromType(v)
if err != nil {
return err
}
rvPtr := reflect.ValueOf(v)
rv := rvPtr.Elem()
m := make(map[string]interface{})
for jsonKey, field := range encMeta.PlainFields {
fieldVal := rv.FieldByIndex(field.Index)
m[jsonKey] = fieldVal.Interface()
}
barr, err := json.Marshal(m)
if err != nil {
return err
}
cipherText, err := enc.EncryptData(barr, odata)
if err != nil {
return err
}
encFieldValue := rv.FieldByIndex(encMeta.EncField.Index)
encFieldValue.SetBytes(cipherText)
return nil
}
func (enc *Encryptor) DecryptStructFields(v interface{}, odata string) error {
encMeta, err := metaFromType(v)
if err != nil {
return err
}
rvPtr := reflect.ValueOf(v)
rv := rvPtr.Elem()
cipherText := rv.FieldByIndex(encMeta.EncField.Index).Bytes()
m, err := enc.DecryptData(cipherText, odata)
if err != nil {
return err
}
for jsonKey, field := range encMeta.PlainFields {
val := m[jsonKey]
rv.FieldByIndex(field.Index).Set(reflect.ValueOf(val))
}
return nil
}