2024-11-11 20:39:08 +01:00
|
|
|
// Copyright 2024, Command Line Inc.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
|
|
package waveai
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"log"
|
|
|
|
|
|
|
|
"github.com/gorilla/websocket"
|
2024-11-21 03:05:13 +01:00
|
|
|
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
2024-11-11 20:39:08 +01:00
|
|
|
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
|
|
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
|
|
|
)
|
|
|
|
|
|
|
|
type WaveAICloudBackend struct{}
|
|
|
|
|
|
|
|
var _ AIBackend = WaveAICloudBackend{}
|
|
|
|
|
|
|
|
type OpenAICloudReqPacketType struct {
|
|
|
|
Type string `json:"type"`
|
|
|
|
ClientId string `json:"clientid"`
|
|
|
|
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"`
|
|
|
|
MaxTokens int `json:"maxtokens,omitempty"`
|
|
|
|
MaxChoices int `json:"maxchoices,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
|
|
|
return &OpenAICloudReqPacketType{
|
|
|
|
Type: OpenAICloudReqStr,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
|
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
|
|
|
wsEndpoint := wcloud.GetWSEndpoint()
|
|
|
|
go func() {
|
2024-11-21 03:05:13 +01:00
|
|
|
defer func() {
|
|
|
|
panicErr := panichandler.PanicHandler("WaveAICloudBackend.StreamCompletion")
|
|
|
|
if panicErr != nil {
|
|
|
|
rtn <- makeAIError(panicErr)
|
|
|
|
}
|
|
|
|
close(rtn)
|
|
|
|
}()
|
2024-11-11 20:39:08 +01:00
|
|
|
if wsEndpoint == "" {
|
|
|
|
rtn <- makeAIError(fmt.Errorf("no cloud ws endpoint found"))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if request.Opts == nil {
|
|
|
|
rtn <- makeAIError(fmt.Errorf("no openai opts found"))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
|
|
|
|
defer dialCancelFn()
|
|
|
|
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, wsEndpoint, nil)
|
|
|
|
if err == context.DeadlineExceeded {
|
|
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err))
|
|
|
|
return
|
|
|
|
} else if err != nil {
|
|
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket connect error: %v", err))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer func() {
|
|
|
|
err = conn.Close()
|
|
|
|
if err != nil {
|
|
|
|
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
|
|
|
|
for _, promptMsg := range request.Prompt {
|
|
|
|
if promptMsg.Role == "error" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
|
|
|
|
}
|
|
|
|
reqPk := MakeOpenAICloudReqPacket()
|
|
|
|
reqPk.ClientId = request.ClientId
|
|
|
|
reqPk.Prompt = sendablePromptMsgs
|
|
|
|
reqPk.MaxTokens = request.Opts.MaxTokens
|
|
|
|
reqPk.MaxChoices = request.Opts.MaxChoices
|
|
|
|
configMessageBuf, err := json.Marshal(reqPk)
|
|
|
|
if err != nil {
|
|
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, packet marshal error: %v", err))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
|
|
|
|
if err != nil {
|
|
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket write config error: %v", err))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
for {
|
|
|
|
_, socketMessage, err := conn.ReadMessage()
|
|
|
|
if err == io.EOF {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
log.Printf("err received: %v", err)
|
|
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
|
|
|
|
break
|
|
|
|
}
|
|
|
|
var streamResp *wshrpc.OpenAIPacketType
|
|
|
|
err = json.Unmarshal(socketMessage, &streamResp)
|
|
|
|
if err != nil {
|
|
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
|
|
|
|
break
|
|
|
|
}
|
|
|
|
if streamResp.Error == PacketEOFStr {
|
|
|
|
// got eof packet from socket
|
|
|
|
break
|
|
|
|
} else if streamResp.Error != "" {
|
|
|
|
// use error from server directly
|
|
|
|
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
|
|
|
|
break
|
|
|
|
}
|
|
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
return rtn
|
|
|
|
}
|