mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-31 18:18:02 +01:00
9dd4188810
This will print error messages to the chat when there is an error getting an ai response. The actual content of the responses are not forwarded to the models in future requests. <img width="389" alt="Screenshot 2024-10-09 at 2 36 13 PM" src="https://github.com/user-attachments/assets/e6c6b1c1-fa19-4456-be3b-596feaeaafed">
309 lines
10 KiB
Go
309 lines
10 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package waveai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
openaiapi "github.com/sashabaranov/go-openai"
|
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
|
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const OpenAIPacketStr = "openai"
|
|
const OpenAICloudReqStr = "openai-cloudreq"
|
|
const PacketEOFStr = "EOF"
|
|
const DefaultAzureAPIVersion = "2023-05-15"
|
|
|
|
type OpenAICmdInfoPacketOutputType struct {
|
|
Model string `json:"model,omitempty"`
|
|
Created int64 `json:"created,omitempty"`
|
|
FinishReason string `json:"finish_reason,omitempty"`
|
|
Message string `json:"message,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
func MakeOpenAIPacket() *wshrpc.OpenAIPacketType {
|
|
return &wshrpc.OpenAIPacketType{Type: OpenAIPacketStr}
|
|
}
|
|
|
|
type OpenAICmdInfoChatMessage struct {
|
|
MessageID int `json:"messageid"`
|
|
IsAssistantResponse bool `json:"isassistantresponse,omitempty"`
|
|
AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"`
|
|
UserQuery string `json:"userquery,omitempty"`
|
|
UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
|
|
}
|
|
|
|
type OpenAICloudReqPacketType struct {
|
|
Type string `json:"type"`
|
|
ClientId string `json:"clientid"`
|
|
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"`
|
|
MaxTokens int `json:"maxtokens,omitempty"`
|
|
MaxChoices int `json:"maxchoices,omitempty"`
|
|
}
|
|
|
|
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
|
return &OpenAICloudReqPacketType{
|
|
Type: OpenAICloudReqStr,
|
|
}
|
|
}
|
|
|
|
func GetWSEndpoint() string {
|
|
if !wavebase.IsDevMode() {
|
|
return WCloudWSEndpoint
|
|
} else {
|
|
endpoint := os.Getenv(WCloudWSEndpointVarName)
|
|
if endpoint == "" {
|
|
panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
|
|
}
|
|
return endpoint
|
|
}
|
|
}
|
|
|
|
const DefaultMaxTokens = 2048
|
|
const DefaultModel = "gpt-4o-mini"
|
|
const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
|
const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"
|
|
|
|
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
|
|
|
func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
|
|
if opts == nil {
|
|
return true
|
|
}
|
|
return opts.BaseURL == "" && opts.APIToken == ""
|
|
}
|
|
|
|
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
|
|
if resp.Usage.TotalTokens == 0 {
|
|
return nil
|
|
}
|
|
return &wshrpc.OpenAIUsageType{
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
TotalTokens: resp.Usage.TotalTokens,
|
|
}
|
|
}
|
|
|
|
func ConvertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
|
var rtn []openaiapi.ChatCompletionMessage
|
|
for _, p := range prompt {
|
|
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
|
rtn = append(rtn, msg)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
|
}
|
|
|
|
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
if IsCloudAIRequest(request.Opts) {
|
|
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
|
|
return RunCloudCompletionStream(ctx, request)
|
|
}
|
|
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
|
|
return RunLocalCompletionStream(ctx, request)
|
|
}
|
|
|
|
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
|
wsEndpoint := wcloud.GetWSEndpoint()
|
|
go func() {
|
|
defer close(rtn)
|
|
if wsEndpoint == "" {
|
|
rtn <- makeAIError(fmt.Errorf("no cloud ws endpoint found"))
|
|
return
|
|
}
|
|
if request.Opts == nil {
|
|
rtn <- makeAIError(fmt.Errorf("no openai opts found"))
|
|
return
|
|
}
|
|
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
|
|
defer dialCancelFn()
|
|
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, wsEndpoint, nil)
|
|
if err == context.DeadlineExceeded {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err))
|
|
return
|
|
} else if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket connect error: %v", err))
|
|
return
|
|
}
|
|
defer func() {
|
|
err = conn.Close()
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
|
|
}
|
|
}()
|
|
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
|
|
for _, promptMsg := range request.Prompt {
|
|
if promptMsg.Role == "error" {
|
|
continue
|
|
}
|
|
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
|
|
}
|
|
reqPk := MakeOpenAICloudReqPacket()
|
|
reqPk.ClientId = request.ClientId
|
|
reqPk.Prompt = sendablePromptMsgs
|
|
reqPk.MaxTokens = request.Opts.MaxTokens
|
|
reqPk.MaxChoices = request.Opts.MaxChoices
|
|
configMessageBuf, err := json.Marshal(reqPk)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, packet marshal error: %v", err))
|
|
return
|
|
}
|
|
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket write config error: %v", err))
|
|
return
|
|
}
|
|
for {
|
|
_, socketMessage, err := conn.ReadMessage()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
log.Printf("err received: %v", err)
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
|
|
break
|
|
}
|
|
var streamResp *wshrpc.OpenAIPacketType
|
|
err = json.Unmarshal(socketMessage, &streamResp)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
|
|
break
|
|
}
|
|
if streamResp.Error == PacketEOFStr {
|
|
// got eof packet from socket
|
|
break
|
|
} else if streamResp.Error != "" {
|
|
// use error from server directly
|
|
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
|
|
break
|
|
}
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp}
|
|
}
|
|
}()
|
|
return rtn
|
|
}
|
|
|
|
// copied from go-openai/config.go
|
|
func defaultAzureMapperFn(model string) string {
|
|
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
|
}
|
|
|
|
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
|
|
ourApiType := strings.ToLower(opts.APIType)
|
|
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
|
|
clientConfig.APIType = openaiapi.APITypeOpenAI
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
|
|
clientConfig.APIType = openaiapi.APITypeAzure
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
|
|
clientConfig.APIType = openaiapi.APITypeAzureAD
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
|
|
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("invalid api type %q", opts.APIType)
|
|
}
|
|
}
|
|
|
|
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
|
go func() {
|
|
defer close(rtn)
|
|
if request.Opts == nil {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
|
|
return
|
|
}
|
|
if request.Opts.Model == "" {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")}
|
|
return
|
|
}
|
|
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no api token")}
|
|
return
|
|
}
|
|
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
|
if request.Opts.BaseURL != "" {
|
|
clientConfig.BaseURL = request.Opts.BaseURL
|
|
}
|
|
err := setApiType(request.Opts, &clientConfig)
|
|
if err != nil {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
|
return
|
|
}
|
|
if request.Opts.OrgID != "" {
|
|
clientConfig.OrgID = request.Opts.OrgID
|
|
}
|
|
if request.Opts.APIVersion != "" {
|
|
clientConfig.APIVersion = request.Opts.APIVersion
|
|
}
|
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
|
req := openaiapi.ChatCompletionRequest{
|
|
Model: request.Opts.Model,
|
|
Messages: ConvertPrompt(request.Prompt),
|
|
MaxTokens: request.Opts.MaxTokens,
|
|
Stream: true,
|
|
}
|
|
if request.Opts.MaxChoices > 1 {
|
|
req.N = request.Opts.MaxChoices
|
|
}
|
|
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)}
|
|
return
|
|
}
|
|
sentHeader := false
|
|
for {
|
|
streamResp, err := apiResp.Recv()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
log.Printf("err received2: %v", err)
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
|
|
break
|
|
}
|
|
if streamResp.Model != "" && !sentHeader {
|
|
pk := MakeOpenAIPacket()
|
|
pk.Model = streamResp.Model
|
|
pk.Created = streamResp.Created
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
|
sentHeader = true
|
|
}
|
|
for _, choice := range streamResp.Choices {
|
|
pk := MakeOpenAIPacket()
|
|
pk.Index = choice.Index
|
|
pk.Text = choice.Delta.Content
|
|
pk.FinishReason = string(choice.FinishReason)
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
|
}
|
|
}
|
|
}()
|
|
return rtn
|
|
}
|