diff --git a/wavesrv/pkg/promptenc/promptenc.go b/wavesrv/pkg/promptenc/promptenc.go index 23bfe1bde..64ff2e407 100644 --- a/wavesrv/pkg/promptenc/promptenc.go +++ b/wavesrv/pkg/promptenc/promptenc.go @@ -12,6 +12,7 @@ import ( "io" "reflect" + "github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn" ccp "golang.org/x/crypto/chacha20poly1305" ) @@ -65,9 +66,13 @@ func MakeEncryptorB64(key64 string) (*Encryptor, error) { } func (enc *Encryptor) EncryptData(plainText []byte, odata string) ([]byte, error) { - outputBuf := make([]byte, enc.AEAD.NonceSize()+enc.AEAD.Overhead()+len(plainText)) + bufSize, err := utilfn.AddIntSlice(enc.AEAD.NonceSize(), enc.AEAD.Overhead(), len(plainText)) + if err != nil { + return nil, err + } + outputBuf := make([]byte, bufSize) nonce := outputBuf[0:enc.AEAD.NonceSize()] - _, err := io.ReadFull(rand.Reader, nonce) + _, err = io.ReadFull(rand.Reader, nonce) if err != nil { return nil, err } diff --git a/wavesrv/pkg/utilfn/utilfn.go b/wavesrv/pkg/utilfn/utilfn.go index dbab5c29d..8df22fe7a 100644 --- a/wavesrv/pkg/utilfn/utilfn.go +++ b/wavesrv/pkg/utilfn/utilfn.go @@ -6,6 +6,8 @@ package utilfn import ( "crypto/sha1" "encoding/base64" + "errors" + "math" "regexp" "strings" "unicode/utf8" @@ -209,3 +211,32 @@ func ChunkSlice[T any](s []T, chunkSize int) [][]T { } return rtn } + +var ErrOverflow = errors.New("integer overflow") + +// Add two int values, returning an error if the result overflows. +func AddInt(left, right int) (int, error) { + if right > 0 { + if left > math.MaxInt-right { + return 0, ErrOverflow + } + } else { + if left < math.MinInt-right { + return 0, ErrOverflow + } + } + return left + right, nil +} + +// Add a slice of ints, returning an error if the result overflows. +func AddIntSlice(vals ...int) (int, error) { + var rtn int + for _, v := range vals { + var err error + rtn, err = AddInt(rtn, v) + if err != nil { + return 0, err + } + } + return rtn, nil +} diff --git a/wavesrv/pkg/utilfn/utilfn_test.go b/wavesrv/pkg/utilfn/utilfn_test.go index c6956ab9b..a65b46067 100644 --- a/wavesrv/pkg/utilfn/utilfn_test.go +++ b/wavesrv/pkg/utilfn/utilfn_test.go @@ -5,6 +5,7 @@ package utilfn import ( "fmt" + "math" "testing" ) @@ -49,3 +50,54 @@ func TestDiff(t *testing.T) { testDiff(t, Str1, Str3) testDiff(t, Str3, Str1) } + +func testArithmetic(t *testing.T, fn func() (int, error), shouldError bool, expected int) { + retVal, err := fn() + if err != nil { + if !shouldError { + t.Errorf("unexpected error") + } + return + } + if shouldError { + t.Errorf("expected error") + return + } + if retVal != expected { + t.Errorf("wrong return value") + } +} + +func testAddInt(t *testing.T, shouldError bool, expected int, a int, b int) { + testArithmetic(t, func() (int, error) { return AddInt(a, b) }, shouldError, expected) +} + +func TestAddInt(t *testing.T) { + testAddInt(t, false, 3, 1, 2) + testAddInt(t, true, 0, 1, math.MaxInt) + testAddInt(t, true, 0, math.MinInt, -1) + testAddInt(t, false, math.MaxInt-1, math.MaxInt, -1) + testAddInt(t, false, math.MinInt+1, math.MinInt, 1) + testAddInt(t, false, math.MaxInt, math.MaxInt, 0) + testAddInt(t, true, 0, math.MinInt, -1) +} + +func testAddIntSlice(t *testing.T, shouldError bool, expected int, vals ...int) { + testArithmetic(t, func() (int, error) { return AddIntSlice(vals...) }, shouldError, expected) +} + +func TestAddIntSlice(t *testing.T) { + testAddIntSlice(t, false, 0) + testAddIntSlice(t, false, 1, 1) + testAddIntSlice(t, false, 3, 1, 2) + testAddIntSlice(t, false, 6, 1, 2, 3) + testAddIntSlice(t, true, 0, 1, math.MaxInt) + testAddIntSlice(t, true, 0, 1, 2, math.MaxInt) + testAddIntSlice(t, true, 0, math.MaxInt, 2, 1) + testAddIntSlice(t, false, math.MaxInt, 0, 0, math.MaxInt) + testAddIntSlice(t, true, 0, math.MinInt, -1) + testAddIntSlice(t, false, math.MaxInt, math.MaxInt-3, 1, 2) + testAddIntSlice(t, true, 0, math.MaxInt-2, 1, 2) + testAddIntSlice(t, false, math.MinInt, math.MinInt+3, -1, -2) + testAddIntSlice(t, true, 0, math.MinInt+2, -1, -2) +}