Add support for proxying open AI chat completion through cloud (#148)

* wrote client code for communicating with lambda cloud

* Added timeout functionality, added check for telemetry enabled for clouod completion, added capability to unset token, other small fixes

* removed stale prints and comments, readded non stream completion for now

* changed json encode to json marshal, also testing my new commit author

* added no telemetry error message and removed check for model in cloud completion

* added defer conn.close() to doOpenAIStreamCompletion, so websocket is always closed

* made a constant for the long telemetry error message

* added endpoint getter, made errors better

* updated scripthaus file to include dev ws endpoint

* added error check for open ai errors

* changed bool condition for better readability

* update some error messages (use error message from server if returned)

* dont blow up the whole response if the server times out.  just write a timeout message

* render streaming errors with a new prompt in openai.tsx (show content and error).  render cmd status 'error' with red x as well.  show exitcode in tooltip of 'x'

* set hadError for errors.  update timeout error to work with new frontend code

* bump client timeout to 5 minutes (longer than server timeout)

---------

Co-authored-by: sawka
This commit is contained in:
Cole Lashley 2023-12-15 22:20:03 -08:00 committed by GitHub
parent b733724c7d
commit 4ccd62f12a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 186 additions and 28 deletions

View File

@ -27,7 +27,7 @@ node_modules/.bin/electron-rebuild
```bash
# @scripthaus command electron
# @scripthaus cd :playbook
WAVETERM_DEV=1 PCLOUD_ENDPOINT="https://ot2e112zx5.execute-api.us-west-2.amazonaws.com/dev" node_modules/.bin/electron dist-dev/emain.js
WAVETERM_DEV=1 PCLOUD_ENDPOINT="https://ot2e112zx5.execute-api.us-west-2.amazonaws.com/dev" PCLOUD_WS_ENDPOINT="wss://5lfzlg5crl.execute-api.us-west-2.amazonaws.com/dev/" node_modules/.bin/electron dist-dev/emain.js
```
```bash

View File

@ -76,11 +76,14 @@ class SmallLineAvatar extends React.Component<{ line: LineType; cmd: Cmd; onRigh
iconTitle = "success";
} else {
icon = <XmarkIcon className="fail" />;
iconTitle = "fail";
iconTitle = "exitcode " + exitcode;
}
} else if (status == "hangup" || status == "error") {
} else if (status == "hangup") {
icon = <WarningIcon className="warning" />;
iconTitle = status;
} else if (status == "error") {
icon = <XmarkIcon className="fail" />;
iconTitle = "error";
} else if (status == "running" || "detached") {
icon = <RotateIcon className="warning spin" />;
iconTitle = "running";

View File

@ -29,6 +29,7 @@ class OpenAIRendererModel {
savedHeight: number;
loading: OV<boolean>;
loadError: OV<string> = mobx.observable.box(null, { name: "renderer-loadError" });
chatError: OV<string> = mobx.observable.box(null, { name: "renderer-chatError" });
updateHeight_debounced: (newHeight: number) => void;
ptyDataSource: (termContext: T.TermContextUnion) => Promise<T.PtyDataType>;
packetData: PacketDataBuffer;
@ -64,7 +65,7 @@ class OpenAIRendererModel {
// console.log("got packet", packet);
if (packet.error != null) {
mobx.action(() => {
this.loadError.set(packet.error);
this.chatError.set(packet.error);
this.version.set(this.version.get() + 1);
})();
return;
@ -131,6 +132,7 @@ class OpenAIRendererModel {
mobx.action(() => {
this.loading.set(true);
this.loadError.set(null);
this.chatError.set(null);
})();
let rtnp = this.ptyDataSource(this.context);
if (rtnp == null) {
@ -186,13 +188,13 @@ class OpenAIRenderer extends React.Component<{ model: OpenAIRendererModel }> {
);
}
renderOutput(cmd: T.WebCmd) {
let output = this.props.model.output.get();
let message = "";
if (output != null) {
message = output.message ?? "";
}
renderOutput() {
let model = this.props.model;
let output = model.output.get();
if (output == null || output.message == null || output.message == "") {
return null;
}
let message = output.message;
let opts = model.opts;
let maxWidth = opts.maxSize.width;
let minWidth = opts.maxSize.width;
@ -219,6 +221,20 @@ class OpenAIRenderer extends React.Component<{ model: OpenAIRendererModel }> {
);
}
renderChatError() {
let model = this.props.model;
let chatError = model.chatError.get();
if (chatError == null) {
return null;
}
return (
<div className="openai-message">
<div className="openai-role openai-role-error">[error]</div>
<div className="openai-content-error">{chatError}</div>
</div>
);
}
render() {
let model: OpenAIRendererModel = this.props.model;
let cmd = model.rawCmd;
@ -239,7 +255,8 @@ class OpenAIRenderer extends React.Component<{ model: OpenAIRendererModel }> {
return (
<div className="openai-renderer" style={styleVal}>
{this.renderPrompt(cmd)}
{this.renderOutput(cmd)}
{this.renderOutput()}
{this.renderChatError()}
</div>
);
}

View File

@ -65,6 +65,8 @@ const (
const PacketSenderQueueSize = 20
const PacketEOFStr = "EOF"
var TypeStrToFactory map[string]reflect.Type
func init() {

View File

@ -22,6 +22,7 @@ import (
"unicode"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
@ -68,6 +69,11 @@ const TermFontSizeMax = 24
const TsFormatStr = "2006-01-02 15:04:05"
const OpenAIPacketTimeout = 10 * time.Second
const OpenAIStreamTimeout = 5 * time.Minute
const OpenAICloudCompletionTelemetryOffErrorMsg = "In order to protect against abuse, you must have telemetry turned on in order to use Wave's free AI features. If you do not want to turn telemetry on, you can still use Wave's AI features by adding your own OpenAI key in Settings. Note that when you use your own key, requests are not proxied through Wave's servers and will be sent directly to the OpenAI API."
const (
KwArgRenderer = "renderer"
KwArgView = "view"
@ -1490,7 +1496,10 @@ func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt
}
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
}()
respPks, err := openai.RunCompletion(ctx, opts, prompt)
var respPks []*packet.OpenAIPacketType
var err error
// run open ai completion locally
respPks, err = openai.RunCompletion(ctx, opts, prompt)
if err != nil {
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
return
@ -1509,7 +1518,7 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType,
var outputPos int64
var hadError bool
startTime := time.Now()
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancelFn := context.WithTimeout(context.Background(), OpenAIStreamTimeout)
defer cancelFn()
defer func() {
r := recover()
@ -1539,16 +1548,52 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType,
}
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
}()
ch, err := openai.RunCompletionStream(ctx, opts, prompt)
var ch chan *packet.OpenAIPacketType
var err error
if opts.APIToken == "" {
var conn *websocket.Conn
ch, conn, err = openai.RunCloudCompletionStream(ctx, opts, prompt)
if conn != nil {
defer conn.Close()
}
} else {
ch, err = openai.RunCompletionStream(ctx, opts, prompt)
}
if err != nil {
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
return
}
for pk := range ch {
err = writePacketToPty(ctx, cmd, pk, &outputPos)
if err != nil {
writeErrorToPty(cmd, fmt.Sprintf("error writing response to ptybuffer: %v", err), outputPos)
return
doneWaitingForPackets := false
for !doneWaitingForPackets {
select {
case <-time.After(OpenAIPacketTimeout):
// timeout reading from channel
hadError = true
pk := openai.CreateErrorPacket(fmt.Sprintf("timeout waiting for server response"))
err = writePacketToPty(ctx, cmd, pk, &outputPos)
if err != nil {
log.Printf("error writing response to ptybuffer: %v", err)
return
}
doneWaitingForPackets = true
break
case pk, ok := <-ch:
if ok {
// got a packet
if pk.Error != "" {
hadError = true
}
err = writePacketToPty(ctx, cmd, pk, &outputPos)
if err != nil {
hadError = true
log.Printf("error writing response to ptybuffer: %v", err)
return
}
} else {
// channel closed
doneWaitingForPackets = true
break
}
}
}
return
@ -1563,10 +1608,15 @@ func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstor
if err != nil {
return nil, fmt.Errorf("cannot retrieve client data: %v", err)
}
if clientData.OpenAIOpts == nil || clientData.OpenAIOpts.APIToken == "" {
return nil, fmt.Errorf("no openai API token found, configure in client settings")
if clientData.OpenAIOpts == nil {
return nil, fmt.Errorf("error retrieving client open ai options")
}
opts := clientData.OpenAIOpts
if opts.APIToken == "" {
if clientData.ClientOpts.NoTelemetry {
return nil, fmt.Errorf(OpenAICloudCompletionTelemetryOffErrorMsg)
}
}
if opts.Model == "" {
opts.Model = openai.DefaultModel
}
@ -3550,9 +3600,6 @@ func ClientAcceptTosCommand(ctx context.Context, pk *scpacket.FeCommandPacketTyp
}
func validateOpenAIAPIToken(key string) error {
if len(key) == 0 {
return fmt.Errorf("invalid openai token, zero length")
}
if len(key) > MaxOpenAIAPITokenLen {
return fmt.Errorf("invalid openai token, too long")
}

View File

@ -24,7 +24,7 @@ import (
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
)
const PCloudEndpoint = "https://api.getprompt.dev/central"
const PCloudEndpoint = "https://api.waveterm.dev/central"
const PCloudEndpointVarName = "PCLOUD_ENDPOINT"
const APIVersion = 1
const MaxPtyUpdateSize = (128 * 1024)
@ -34,6 +34,9 @@ const MaxUpdateWriterErrors = 3
const PCloudDefaultTimeout = 5 * time.Second
const PCloudWebShareUpdateTimeout = 15 * time.Second
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
// setting to 1M to be safe (max is 6M for API-GW + Lambda, but there is base64 encoding and upload time)
// we allow one extra update past this estimated size
const MaxUpdatePayloadSize = 1 * (1024 * 1024)
@ -63,6 +66,18 @@ func GetEndpoint() string {
return endpoint
}
func GetWSEndpoint() string {
if !scbase.IsDevMode() {
return PCloudWSEndpoint
} else {
endpoint := os.Getenv(PCloudWSEndpointVarName)
if endpoint == "" {
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
}
return endpoint
}
}
func makeAuthPostReq(ctx context.Context, apiUrl string, authInfo AuthInfo, data interface{}) (*http.Request, error) {
var dataReader io.Reader
if data != nil {

View File

@ -5,11 +5,16 @@ package openai
import (
"context"
"encoding/json"
"fmt"
"io"
"time"
"github.com/gorilla/websocket"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/wavesrv/pkg/pcloud"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
)
@ -19,6 +24,8 @@ const DefaultMaxTokens = 1000
const DefaultModel = "gpt-3.5-turbo"
const DefaultStreamChanSize = 10
const CloudWebsocketConnectTimeout = 5 * time.Second
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
@ -30,7 +37,7 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType
}
}
func convertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
func ConvertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
@ -56,7 +63,7 @@ func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []ss
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: opts.Model,
Messages: convertPrompt(prompt),
Messages: ConvertPrompt(prompt),
MaxTokens: opts.MaxTokens,
}
if opts.MaxChoices > 1 {
@ -72,6 +79,61 @@ func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []ss
return marshalResponse(apiResp), nil
}
func RunCloudCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, *websocket.Conn, error) {
if opts == nil {
return nil, nil, fmt.Errorf("no openai opts found")
}
websocketContext, _ := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, pcloud.GetWSEndpoint(), nil)
if err != nil {
return nil, nil, fmt.Errorf("OpenAI request, websocket connect error: %v", err)
}
cloudCompletionRequestConfig := sstore.OpenAICloudCompletionRequest{
Prompt: prompt,
MaxTokens: opts.MaxTokens,
MaxChoices: opts.MaxChoices,
}
configMessageBuf, err := json.Marshal(cloudCompletionRequestConfig)
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
if err != nil {
return nil, nil, fmt.Errorf("OpenAI request, websocket write config error: %v", err)
}
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
go func() {
defer close(rtn)
defer conn.Close()
for {
_, socketMessage, err := conn.ReadMessage()
if err == io.EOF {
break
}
if err != nil {
errPk := CreateErrorPacket(fmt.Sprintf("OpenAI request, websocket error reading message: %v", err))
rtn <- errPk
break
}
var streamResp *packet.OpenAIPacketType
err = json.Unmarshal(socketMessage, &streamResp)
if err != nil {
errPk := CreateErrorPacket(fmt.Sprintf("OpenAI request, websocket response json decode error: %v", err))
rtn <- errPk
break
}
if streamResp.Error == packet.PacketEOFStr {
// got eof packet from socket
break
} else if streamResp.Error != "" {
// use error from server directly
errPk := CreateErrorPacket(streamResp.Error)
rtn <- errPk
break
}
rtn <- streamResp
}
}()
return rtn, conn, err
}
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
if opts == nil {
return nil, fmt.Errorf("no openai opts found")
@ -89,7 +151,7 @@ func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, promp
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: opts.Model,
Messages: convertPrompt(prompt),
Messages: ConvertPrompt(prompt),
MaxTokens: opts.MaxTokens,
Stream: true,
}
@ -156,3 +218,9 @@ func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
errPk.Error = errStr
return errPk
}
func CreateTextPacket(text string) *packet.OpenAIPacketType {
pk := packet.MakeOpenAIPacket()
pk.Text = text
return pk
}

View File

@ -784,6 +784,12 @@ type OpenAIPromptMessageType struct {
Name string `json:"name,omitempty"`
}
type OpenAICloudCompletionRequest struct {
Prompt []OpenAIPromptMessageType `json:"prompt"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
}
type PlaybookType struct {
PlaybookId string `json:"playbookid"`
PlaybookName string `json:"playbookname"`