mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-22 02:41:23 +01:00
Add support for proxying open AI chat completion through cloud (#148)
* wrote client code for communicating with lambda cloud * Added timeout functionality, added check for telemetry enabled for clouod completion, added capability to unset token, other small fixes * removed stale prints and comments, readded non stream completion for now * changed json encode to json marshal, also testing my new commit author * added no telemetry error message and removed check for model in cloud completion * added defer conn.close() to doOpenAIStreamCompletion, so websocket is always closed * made a constant for the long telemetry error message * added endpoint getter, made errors better * updated scripthaus file to include dev ws endpoint * added error check for open ai errors * changed bool condition for better readability * update some error messages (use error message from server if returned) * dont blow up the whole response if the server times out. just write a timeout message * render streaming errors with a new prompt in openai.tsx (show content and error). render cmd status 'error' with red x as well. show exitcode in tooltip of 'x' * set hadError for errors. update timeout error to work with new frontend code * bump client timeout to 5 minutes (longer than server timeout) --------- Co-authored-by: sawka
This commit is contained in:
parent
b733724c7d
commit
4ccd62f12a
@ -27,7 +27,7 @@ node_modules/.bin/electron-rebuild
|
||||
```bash
|
||||
# @scripthaus command electron
|
||||
# @scripthaus cd :playbook
|
||||
WAVETERM_DEV=1 PCLOUD_ENDPOINT="https://ot2e112zx5.execute-api.us-west-2.amazonaws.com/dev" node_modules/.bin/electron dist-dev/emain.js
|
||||
WAVETERM_DEV=1 PCLOUD_ENDPOINT="https://ot2e112zx5.execute-api.us-west-2.amazonaws.com/dev" PCLOUD_WS_ENDPOINT="wss://5lfzlg5crl.execute-api.us-west-2.amazonaws.com/dev/" node_modules/.bin/electron dist-dev/emain.js
|
||||
```
|
||||
|
||||
```bash
|
||||
|
@ -76,11 +76,14 @@ class SmallLineAvatar extends React.Component<{ line: LineType; cmd: Cmd; onRigh
|
||||
iconTitle = "success";
|
||||
} else {
|
||||
icon = <XmarkIcon className="fail" />;
|
||||
iconTitle = "fail";
|
||||
iconTitle = "exitcode " + exitcode;
|
||||
}
|
||||
} else if (status == "hangup" || status == "error") {
|
||||
} else if (status == "hangup") {
|
||||
icon = <WarningIcon className="warning" />;
|
||||
iconTitle = status;
|
||||
} else if (status == "error") {
|
||||
icon = <XmarkIcon className="fail" />;
|
||||
iconTitle = "error";
|
||||
} else if (status == "running" || "detached") {
|
||||
icon = <RotateIcon className="warning spin" />;
|
||||
iconTitle = "running";
|
||||
|
@ -29,6 +29,7 @@ class OpenAIRendererModel {
|
||||
savedHeight: number;
|
||||
loading: OV<boolean>;
|
||||
loadError: OV<string> = mobx.observable.box(null, { name: "renderer-loadError" });
|
||||
chatError: OV<string> = mobx.observable.box(null, { name: "renderer-chatError" });
|
||||
updateHeight_debounced: (newHeight: number) => void;
|
||||
ptyDataSource: (termContext: T.TermContextUnion) => Promise<T.PtyDataType>;
|
||||
packetData: PacketDataBuffer;
|
||||
@ -64,7 +65,7 @@ class OpenAIRendererModel {
|
||||
// console.log("got packet", packet);
|
||||
if (packet.error != null) {
|
||||
mobx.action(() => {
|
||||
this.loadError.set(packet.error);
|
||||
this.chatError.set(packet.error);
|
||||
this.version.set(this.version.get() + 1);
|
||||
})();
|
||||
return;
|
||||
@ -131,6 +132,7 @@ class OpenAIRendererModel {
|
||||
mobx.action(() => {
|
||||
this.loading.set(true);
|
||||
this.loadError.set(null);
|
||||
this.chatError.set(null);
|
||||
})();
|
||||
let rtnp = this.ptyDataSource(this.context);
|
||||
if (rtnp == null) {
|
||||
@ -186,13 +188,13 @@ class OpenAIRenderer extends React.Component<{ model: OpenAIRendererModel }> {
|
||||
);
|
||||
}
|
||||
|
||||
renderOutput(cmd: T.WebCmd) {
|
||||
let output = this.props.model.output.get();
|
||||
let message = "";
|
||||
if (output != null) {
|
||||
message = output.message ?? "";
|
||||
}
|
||||
renderOutput() {
|
||||
let model = this.props.model;
|
||||
let output = model.output.get();
|
||||
if (output == null || output.message == null || output.message == "") {
|
||||
return null;
|
||||
}
|
||||
let message = output.message;
|
||||
let opts = model.opts;
|
||||
let maxWidth = opts.maxSize.width;
|
||||
let minWidth = opts.maxSize.width;
|
||||
@ -219,6 +221,20 @@ class OpenAIRenderer extends React.Component<{ model: OpenAIRendererModel }> {
|
||||
);
|
||||
}
|
||||
|
||||
renderChatError() {
|
||||
let model = this.props.model;
|
||||
let chatError = model.chatError.get();
|
||||
if (chatError == null) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<div className="openai-message">
|
||||
<div className="openai-role openai-role-error">[error]</div>
|
||||
<div className="openai-content-error">{chatError}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
render() {
|
||||
let model: OpenAIRendererModel = this.props.model;
|
||||
let cmd = model.rawCmd;
|
||||
@ -239,7 +255,8 @@ class OpenAIRenderer extends React.Component<{ model: OpenAIRendererModel }> {
|
||||
return (
|
||||
<div className="openai-renderer" style={styleVal}>
|
||||
{this.renderPrompt(cmd)}
|
||||
{this.renderOutput(cmd)}
|
||||
{this.renderOutput()}
|
||||
{this.renderChatError()}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -65,6 +65,8 @@ const (
|
||||
|
||||
const PacketSenderQueueSize = 20
|
||||
|
||||
const PacketEOFStr = "EOF"
|
||||
|
||||
var TypeStrToFactory map[string]reflect.Type
|
||||
|
||||
func init() {
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"unicode"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
@ -68,6 +69,11 @@ const TermFontSizeMax = 24
|
||||
|
||||
const TsFormatStr = "2006-01-02 15:04:05"
|
||||
|
||||
const OpenAIPacketTimeout = 10 * time.Second
|
||||
const OpenAIStreamTimeout = 5 * time.Minute
|
||||
|
||||
const OpenAICloudCompletionTelemetryOffErrorMsg = "In order to protect against abuse, you must have telemetry turned on in order to use Wave's free AI features. If you do not want to turn telemetry on, you can still use Wave's AI features by adding your own OpenAI key in Settings. Note that when you use your own key, requests are not proxied through Wave's servers and will be sent directly to the OpenAI API."
|
||||
|
||||
const (
|
||||
KwArgRenderer = "renderer"
|
||||
KwArgView = "view"
|
||||
@ -1490,7 +1496,10 @@ func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt
|
||||
}
|
||||
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
|
||||
}()
|
||||
respPks, err := openai.RunCompletion(ctx, opts, prompt)
|
||||
var respPks []*packet.OpenAIPacketType
|
||||
var err error
|
||||
// run open ai completion locally
|
||||
respPks, err = openai.RunCompletion(ctx, opts, prompt)
|
||||
if err != nil {
|
||||
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
|
||||
return
|
||||
@ -1509,7 +1518,7 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType,
|
||||
var outputPos int64
|
||||
var hadError bool
|
||||
startTime := time.Now()
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), OpenAIStreamTimeout)
|
||||
defer cancelFn()
|
||||
defer func() {
|
||||
r := recover()
|
||||
@ -1539,16 +1548,52 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType,
|
||||
}
|
||||
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
|
||||
}()
|
||||
ch, err := openai.RunCompletionStream(ctx, opts, prompt)
|
||||
var ch chan *packet.OpenAIPacketType
|
||||
var err error
|
||||
if opts.APIToken == "" {
|
||||
var conn *websocket.Conn
|
||||
ch, conn, err = openai.RunCloudCompletionStream(ctx, opts, prompt)
|
||||
if conn != nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
} else {
|
||||
ch, err = openai.RunCompletionStream(ctx, opts, prompt)
|
||||
}
|
||||
if err != nil {
|
||||
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
|
||||
return
|
||||
}
|
||||
for pk := range ch {
|
||||
err = writePacketToPty(ctx, cmd, pk, &outputPos)
|
||||
if err != nil {
|
||||
writeErrorToPty(cmd, fmt.Sprintf("error writing response to ptybuffer: %v", err), outputPos)
|
||||
return
|
||||
doneWaitingForPackets := false
|
||||
for !doneWaitingForPackets {
|
||||
select {
|
||||
case <-time.After(OpenAIPacketTimeout):
|
||||
// timeout reading from channel
|
||||
hadError = true
|
||||
pk := openai.CreateErrorPacket(fmt.Sprintf("timeout waiting for server response"))
|
||||
err = writePacketToPty(ctx, cmd, pk, &outputPos)
|
||||
if err != nil {
|
||||
log.Printf("error writing response to ptybuffer: %v", err)
|
||||
return
|
||||
}
|
||||
doneWaitingForPackets = true
|
||||
break
|
||||
case pk, ok := <-ch:
|
||||
if ok {
|
||||
// got a packet
|
||||
if pk.Error != "" {
|
||||
hadError = true
|
||||
}
|
||||
err = writePacketToPty(ctx, cmd, pk, &outputPos)
|
||||
if err != nil {
|
||||
hadError = true
|
||||
log.Printf("error writing response to ptybuffer: %v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// channel closed
|
||||
doneWaitingForPackets = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
@ -1563,10 +1608,15 @@ func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstor
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot retrieve client data: %v", err)
|
||||
}
|
||||
if clientData.OpenAIOpts == nil || clientData.OpenAIOpts.APIToken == "" {
|
||||
return nil, fmt.Errorf("no openai API token found, configure in client settings")
|
||||
if clientData.OpenAIOpts == nil {
|
||||
return nil, fmt.Errorf("error retrieving client open ai options")
|
||||
}
|
||||
opts := clientData.OpenAIOpts
|
||||
if opts.APIToken == "" {
|
||||
if clientData.ClientOpts.NoTelemetry {
|
||||
return nil, fmt.Errorf(OpenAICloudCompletionTelemetryOffErrorMsg)
|
||||
}
|
||||
}
|
||||
if opts.Model == "" {
|
||||
opts.Model = openai.DefaultModel
|
||||
}
|
||||
@ -3550,9 +3600,6 @@ func ClientAcceptTosCommand(ctx context.Context, pk *scpacket.FeCommandPacketTyp
|
||||
}
|
||||
|
||||
func validateOpenAIAPIToken(key string) error {
|
||||
if len(key) == 0 {
|
||||
return fmt.Errorf("invalid openai token, zero length")
|
||||
}
|
||||
if len(key) > MaxOpenAIAPITokenLen {
|
||||
return fmt.Errorf("invalid openai token, too long")
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
)
|
||||
|
||||
const PCloudEndpoint = "https://api.getprompt.dev/central"
|
||||
const PCloudEndpoint = "https://api.waveterm.dev/central"
|
||||
const PCloudEndpointVarName = "PCLOUD_ENDPOINT"
|
||||
const APIVersion = 1
|
||||
const MaxPtyUpdateSize = (128 * 1024)
|
||||
@ -34,6 +34,9 @@ const MaxUpdateWriterErrors = 3
|
||||
const PCloudDefaultTimeout = 5 * time.Second
|
||||
const PCloudWebShareUpdateTimeout = 15 * time.Second
|
||||
|
||||
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
||||
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
|
||||
|
||||
// setting to 1M to be safe (max is 6M for API-GW + Lambda, but there is base64 encoding and upload time)
|
||||
// we allow one extra update past this estimated size
|
||||
const MaxUpdatePayloadSize = 1 * (1024 * 1024)
|
||||
@ -63,6 +66,18 @@ func GetEndpoint() string {
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func GetWSEndpoint() string {
|
||||
if !scbase.IsDevMode() {
|
||||
return PCloudWSEndpoint
|
||||
} else {
|
||||
endpoint := os.Getenv(PCloudWSEndpointVarName)
|
||||
if endpoint == "" {
|
||||
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
}
|
||||
|
||||
func makeAuthPostReq(ctx context.Context, apiUrl string, authInfo AuthInfo, data interface{}) (*http.Request, error) {
|
||||
var dataReader io.Reader
|
||||
if data != nil {
|
||||
|
@ -5,11 +5,16 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
openaiapi "github.com/sashabaranov/go-openai"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/pcloud"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
)
|
||||
|
||||
@ -19,6 +24,8 @@ const DefaultMaxTokens = 1000
|
||||
const DefaultModel = "gpt-3.5-turbo"
|
||||
const DefaultStreamChanSize = 10
|
||||
|
||||
const CloudWebsocketConnectTimeout = 5 * time.Second
|
||||
|
||||
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
|
||||
if resp.Usage.TotalTokens == 0 {
|
||||
return nil
|
||||
@ -30,7 +37,7 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType
|
||||
}
|
||||
}
|
||||
|
||||
func convertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
||||
func ConvertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
||||
var rtn []openaiapi.ChatCompletionMessage
|
||||
for _, p := range prompt {
|
||||
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
||||
@ -56,7 +63,7 @@ func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []ss
|
||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Messages: convertPrompt(prompt),
|
||||
Messages: ConvertPrompt(prompt),
|
||||
MaxTokens: opts.MaxTokens,
|
||||
}
|
||||
if opts.MaxChoices > 1 {
|
||||
@ -72,6 +79,61 @@ func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []ss
|
||||
return marshalResponse(apiResp), nil
|
||||
}
|
||||
|
||||
func RunCloudCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, *websocket.Conn, error) {
|
||||
if opts == nil {
|
||||
return nil, nil, fmt.Errorf("no openai opts found")
|
||||
}
|
||||
websocketContext, _ := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, pcloud.GetWSEndpoint(), nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("OpenAI request, websocket connect error: %v", err)
|
||||
}
|
||||
cloudCompletionRequestConfig := sstore.OpenAICloudCompletionRequest{
|
||||
Prompt: prompt,
|
||||
MaxTokens: opts.MaxTokens,
|
||||
MaxChoices: opts.MaxChoices,
|
||||
}
|
||||
configMessageBuf, err := json.Marshal(cloudCompletionRequestConfig)
|
||||
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("OpenAI request, websocket write config error: %v", err)
|
||||
}
|
||||
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
|
||||
go func() {
|
||||
defer close(rtn)
|
||||
defer conn.Close()
|
||||
for {
|
||||
_, socketMessage, err := conn.ReadMessage()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
errPk := CreateErrorPacket(fmt.Sprintf("OpenAI request, websocket error reading message: %v", err))
|
||||
rtn <- errPk
|
||||
break
|
||||
}
|
||||
var streamResp *packet.OpenAIPacketType
|
||||
err = json.Unmarshal(socketMessage, &streamResp)
|
||||
if err != nil {
|
||||
errPk := CreateErrorPacket(fmt.Sprintf("OpenAI request, websocket response json decode error: %v", err))
|
||||
rtn <- errPk
|
||||
break
|
||||
}
|
||||
if streamResp.Error == packet.PacketEOFStr {
|
||||
// got eof packet from socket
|
||||
break
|
||||
} else if streamResp.Error != "" {
|
||||
// use error from server directly
|
||||
errPk := CreateErrorPacket(streamResp.Error)
|
||||
rtn <- errPk
|
||||
break
|
||||
}
|
||||
rtn <- streamResp
|
||||
}
|
||||
}()
|
||||
return rtn, conn, err
|
||||
}
|
||||
|
||||
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("no openai opts found")
|
||||
@ -89,7 +151,7 @@ func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, promp
|
||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Messages: convertPrompt(prompt),
|
||||
Messages: ConvertPrompt(prompt),
|
||||
MaxTokens: opts.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
@ -156,3 +218,9 @@ func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
|
||||
errPk.Error = errStr
|
||||
return errPk
|
||||
}
|
||||
|
||||
func CreateTextPacket(text string) *packet.OpenAIPacketType {
|
||||
pk := packet.MakeOpenAIPacket()
|
||||
pk.Text = text
|
||||
return pk
|
||||
}
|
||||
|
@ -784,6 +784,12 @@ type OpenAIPromptMessageType struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAICloudCompletionRequest struct {
|
||||
Prompt []OpenAIPromptMessageType `json:"prompt"`
|
||||
MaxTokens int `json:"maxtokens,omitempty"`
|
||||
MaxChoices int `json:"maxchoices,omitempty"`
|
||||
}
|
||||
|
||||
type PlaybookType struct {
|
||||
PlaybookId string `json:"playbookid"`
|
||||
PlaybookName string `json:"playbookname"`
|
||||
|
Loading…
Reference in New Issue
Block a user