mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-22 16:48:23 +01:00
309 lines
10 KiB
Go
309 lines
10 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package waveai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
openaiapi "github.com/sashabaranov/go-openai"
|
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
|
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const OpenAIPacketStr = "openai"
|
|
const OpenAICloudReqStr = "openai-cloudreq"
|
|
const PacketEOFStr = "EOF"
|
|
const DefaultAzureAPIVersion = "2023-05-15"
|
|
|
|
type OpenAICmdInfoPacketOutputType struct {
|
|
Model string `json:"model,omitempty"`
|
|
Created int64 `json:"created,omitempty"`
|
|
FinishReason string `json:"finish_reason,omitempty"`
|
|
Message string `json:"message,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
func MakeOpenAIPacket() *wshrpc.OpenAIPacketType {
|
|
return &wshrpc.OpenAIPacketType{Type: OpenAIPacketStr}
|
|
}
|
|
|
|
type OpenAICmdInfoChatMessage struct {
|
|
MessageID int `json:"messageid"`
|
|
IsAssistantResponse bool `json:"isassistantresponse,omitempty"`
|
|
AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"`
|
|
UserQuery string `json:"userquery,omitempty"`
|
|
UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
|
|
}
|
|
|
|
type OpenAICloudReqPacketType struct {
|
|
Type string `json:"type"`
|
|
ClientId string `json:"clientid"`
|
|
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"`
|
|
MaxTokens int `json:"maxtokens,omitempty"`
|
|
MaxChoices int `json:"maxchoices,omitempty"`
|
|
}
|
|
|
|
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
|
return &OpenAICloudReqPacketType{
|
|
Type: OpenAICloudReqStr,
|
|
}
|
|
}
|
|
|
|
func GetWSEndpoint() string {
|
|
if !wavebase.IsDevMode() {
|
|
return WCloudWSEndpoint
|
|
} else {
|
|
endpoint := os.Getenv(WCloudWSEndpointVarName)
|
|
if endpoint == "" {
|
|
panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
|
|
}
|
|
return endpoint
|
|
}
|
|
}
|
|
|
|
const DefaultMaxTokens = 2048
|
|
const DefaultModel = "gpt-4o-mini"
|
|
const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
|
const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"
|
|
|
|
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
|
|
|
func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
|
|
if opts == nil {
|
|
return true
|
|
}
|
|
return opts.BaseURL == "" && opts.APIToken == ""
|
|
}
|
|
|
|
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
|
|
if resp.Usage.TotalTokens == 0 {
|
|
return nil
|
|
}
|
|
return &wshrpc.OpenAIUsageType{
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
TotalTokens: resp.Usage.TotalTokens,
|
|
}
|
|
}
|
|
|
|
func ConvertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
|
var rtn []openaiapi.ChatCompletionMessage
|
|
for _, p := range prompt {
|
|
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
|
rtn = append(rtn, msg)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
|
}
|
|
|
|
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
if IsCloudAIRequest(request.Opts) {
|
|
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
|
|
return RunCloudCompletionStream(ctx, request)
|
|
}
|
|
log.Printf("sending ai chat message to user-configured endpoint %s using model %s\n", request.Opts.BaseURL, request.Opts.Model)
|
|
return RunLocalCompletionStream(ctx, request)
|
|
}
|
|
|
|
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
|
wsEndpoint := wcloud.GetWSEndpoint()
|
|
go func() {
|
|
defer close(rtn)
|
|
if wsEndpoint == "" {
|
|
rtn <- makeAIError(fmt.Errorf("no cloud ws endpoint found"))
|
|
return
|
|
}
|
|
if request.Opts == nil {
|
|
rtn <- makeAIError(fmt.Errorf("no openai opts found"))
|
|
return
|
|
}
|
|
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
|
|
defer dialCancelFn()
|
|
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, wsEndpoint, nil)
|
|
if err == context.DeadlineExceeded {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err))
|
|
return
|
|
} else if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket connect error: %v", err))
|
|
return
|
|
}
|
|
defer func() {
|
|
err = conn.Close()
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
|
|
}
|
|
}()
|
|
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
|
|
for _, promptMsg := range request.Prompt {
|
|
if promptMsg.Role == "error" {
|
|
continue
|
|
}
|
|
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
|
|
}
|
|
reqPk := MakeOpenAICloudReqPacket()
|
|
reqPk.ClientId = request.ClientId
|
|
reqPk.Prompt = sendablePromptMsgs
|
|
reqPk.MaxTokens = request.Opts.MaxTokens
|
|
reqPk.MaxChoices = request.Opts.MaxChoices
|
|
configMessageBuf, err := json.Marshal(reqPk)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, packet marshal error: %v", err))
|
|
return
|
|
}
|
|
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket write config error: %v", err))
|
|
return
|
|
}
|
|
for {
|
|
_, socketMessage, err := conn.ReadMessage()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
log.Printf("err received: %v", err)
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
|
|
break
|
|
}
|
|
var streamResp *wshrpc.OpenAIPacketType
|
|
err = json.Unmarshal(socketMessage, &streamResp)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
|
|
break
|
|
}
|
|
if streamResp.Error == PacketEOFStr {
|
|
// got eof packet from socket
|
|
break
|
|
} else if streamResp.Error != "" {
|
|
// use error from server directly
|
|
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
|
|
break
|
|
}
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp}
|
|
}
|
|
}()
|
|
return rtn
|
|
}
|
|
|
|
// copied from go-openai/config.go
|
|
func defaultAzureMapperFn(model string) string {
|
|
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
|
}
|
|
|
|
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
|
|
ourApiType := strings.ToLower(opts.APIType)
|
|
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
|
|
clientConfig.APIType = openaiapi.APITypeOpenAI
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
|
|
clientConfig.APIType = openaiapi.APITypeAzure
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
|
|
clientConfig.APIType = openaiapi.APITypeAzureAD
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
|
|
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("invalid api type %q", opts.APIType)
|
|
}
|
|
}
|
|
|
|
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
|
go func() {
|
|
defer close(rtn)
|
|
if request.Opts == nil {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
|
|
return
|
|
}
|
|
if request.Opts.Model == "" {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")}
|
|
return
|
|
}
|
|
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no api token")}
|
|
return
|
|
}
|
|
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
|
if request.Opts.BaseURL != "" {
|
|
clientConfig.BaseURL = request.Opts.BaseURL
|
|
}
|
|
err := setApiType(request.Opts, &clientConfig)
|
|
if err != nil {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
|
return
|
|
}
|
|
if request.Opts.OrgID != "" {
|
|
clientConfig.OrgID = request.Opts.OrgID
|
|
}
|
|
if request.Opts.APIVersion != "" {
|
|
clientConfig.APIVersion = request.Opts.APIVersion
|
|
}
|
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
|
req := openaiapi.ChatCompletionRequest{
|
|
Model: request.Opts.Model,
|
|
Messages: ConvertPrompt(request.Prompt),
|
|
MaxTokens: request.Opts.MaxTokens,
|
|
Stream: true,
|
|
}
|
|
if request.Opts.MaxChoices > 1 {
|
|
req.N = request.Opts.MaxChoices
|
|
}
|
|
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)}
|
|
return
|
|
}
|
|
sentHeader := false
|
|
for {
|
|
streamResp, err := apiResp.Recv()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
log.Printf("err received2: %v", err)
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
|
|
break
|
|
}
|
|
if streamResp.Model != "" && !sentHeader {
|
|
pk := MakeOpenAIPacket()
|
|
pk.Model = streamResp.Model
|
|
pk.Created = streamResp.Created
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
|
sentHeader = true
|
|
}
|
|
for _, choice := range streamResp.Choices {
|
|
pk := MakeOpenAIPacket()
|
|
pk.Index = choice.Index
|
|
pk.Text = choice.Delta.Content
|
|
pk.FinishReason = string(choice.FinishReason)
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
|
}
|
|
}
|
|
}()
|
|
return rtn
|
|
}
|