Fix potential overflow in promptenc arithmetic

This commit is contained in:
Evan Simkowitz 2023-12-12 16:51:19 -08:00
parent 112d002c2a
commit 96f636c2da
No known key found for this signature in database
3 changed files with 92 additions and 2 deletions

View File

@ -12,6 +12,7 @@ import (
"io"
"reflect"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
ccp "golang.org/x/crypto/chacha20poly1305"
)
@ -65,9 +66,13 @@ func MakeEncryptorB64(key64 string) (*Encryptor, error) {
}
func (enc *Encryptor) EncryptData(plainText []byte, odata string) ([]byte, error) {
outputBuf := make([]byte, enc.AEAD.NonceSize()+enc.AEAD.Overhead()+len(plainText))
bufSize, err := utilfn.AddIntSlice(enc.AEAD.NonceSize(), enc.AEAD.Overhead(), len(plainText))
if err != nil {
return nil, err
}
outputBuf := make([]byte, bufSize)
nonce := outputBuf[0:enc.AEAD.NonceSize()]
_, err := io.ReadFull(rand.Reader, nonce)
_, err = io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, err
}

View File

@ -6,6 +6,8 @@ package utilfn
import (
"crypto/sha1"
"encoding/base64"
"errors"
"math"
"regexp"
"strings"
"unicode/utf8"
@ -209,3 +211,32 @@ func ChunkSlice[T any](s []T, chunkSize int) [][]T {
}
return rtn
}
var ErrOverflow = errors.New("integer overflow")
// Add two int values, returning an error if the result overflows.
func AddInt(left, right int) (int, error) {
if right > 0 {
if left > math.MaxInt-right {
return 0, ErrOverflow
}
} else {
if left < math.MaxInt-right {
return 0, ErrOverflow
}
}
return left + right, nil
}
// Add a slice of ints, returning an error if the result overflows.
func AddIntSlice(vals ...int) (int, error) {
var rtn int
for _, v := range vals {
var err error
rtn, err = AddInt(rtn, v)
if err != nil {
return 0, err
}
}
return rtn, nil
}

View File

@ -5,6 +5,7 @@ package utilfn
import (
"fmt"
"math"
"testing"
)
@ -49,3 +50,56 @@ func TestDiff(t *testing.T) {
testDiff(t, Str1, Str3)
testDiff(t, Str3, Str1)
}
const unexpectedError = "unexpected error"
const expectedError = "expected error"
const wrongRetVal = "wrong return value"
func testAddInt(t *testing.T, a int, b int, shouldError bool, expected int) {
retVal, err := AddInt(a, b)
if err != nil {
if !shouldError {
t.Errorf(unexpectedError)
}
return
}
if shouldError {
t.Errorf(expectedError)
return
}
if retVal != expected {
t.Errorf(wrongRetVal)
}
}
func TestAddInt(t *testing.T) {
testAddInt(t, 1, 2, false, 3)
testAddInt(t, 1, math.MaxInt, true, 0)
}
func testAddIntSlice(t *testing.T, shouldError bool, expected int, vals ...int) {
retVal, err := AddIntSlice(vals...)
if err != nil {
if !shouldError {
t.Errorf(unexpectedError)
}
return
}
if shouldError {
t.Errorf(expectedError)
return
}
if retVal != expected {
t.Errorf(wrongRetVal)
}
}
func TestAddIntSlice(t *testing.T) {
testAddIntSlice(t, false, 0)
testAddIntSlice(t, false, 1, 1)
testAddIntSlice(t, false, 3, 1, 2)
testAddIntSlice(t, false, 6, 1, 2, 3)
testAddIntSlice(t, true, 0, 1, math.MaxInt)
testAddIntSlice(t, true, 0, 1, 2, math.MaxInt)
testAddIntSlice(t, true, 0, math.MaxInt, 2, 1)
}