ai backend refactor + claude/anthropic API support (#1262)

This commit is contained in:
Mike Sawka 2024-11-11 11:39:08 -08:00 committed by GitHub
parent 38eeba5bd2
commit de902ec2b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 589 additions and 246 deletions

View File

@ -180,7 +180,15 @@ export class WaveAiModel implements ViewModel {
const presetKey = get(this.presetKey); const presetKey = get(this.presetKey);
const presetName = presets[presetKey]?.["display:name"] ?? ""; const presetName = presets[presetKey]?.["display:name"] ?? "";
const isCloud = isBlank(aiOpts.apitoken) && isBlank(aiOpts.baseurl); const isCloud = isBlank(aiOpts.apitoken) && isBlank(aiOpts.baseurl);
if (isCloud) { if (aiOpts?.apitype == "anthropic") {
const modelName = aiOpts.model;
viewTextChildren.push({
elemtype: "iconbutton",
icon: "globe",
title: "Using Remote Antropic API (" + modelName + ")",
disabled: true,
});
} else if (isCloud) {
viewTextChildren.push({ viewTextChildren.push({
elemtype: "iconbutton", elemtype: "iconbutton",
icon: "cloud", icon: "cloud",

View File

@ -0,0 +1,300 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package waveai
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"runtime/debug"
"strings"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type AnthropicBackend struct{}
var _ AIBackend = AnthropicBackend{}
// Claude API request types
type anthropicMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type anthropicRequest struct {
Model string `json:"model"`
Messages []anthropicMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
Temperature float32 `json:"temperature,omitempty"`
}
// Claude API response types for SSE events
type anthropicContentBlock struct {
Type string `json:"type"` // "text" or other content types
Text string `json:"text,omitempty"`
}
type anthropicUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type anthropicResponseMessage struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []anthropicContentBlock `json:"content"`
Model string `json:"model"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage *anthropicUsage `json:"usage,omitempty"`
}
type anthropicStreamEventError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type anthropicStreamEventDelta struct {
Text string `json:"text"`
}
type anthropicStreamEvent struct {
Type string `json:"type"`
Message *anthropicResponseMessage `json:"message,omitempty"`
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
Delta *anthropicStreamEventDelta `json:"delta,omitempty"`
Error *anthropicStreamEventError `json:"error,omitempty"`
Usage *anthropicUsage `json:"usage,omitempty"`
}
// SSE event represents a parsed Server-Sent Event
type sseEvent struct {
Event string // The event type field
Data string // The data field
}
// parseSSE reads and parses SSE format from a bufio.Reader
func parseSSE(reader *bufio.Reader) (*sseEvent, error) {
var event sseEvent
for {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimSpace(line)
if line == "" {
// Empty line signals end of event
if event.Event != "" || event.Data != "" {
return &event, nil
}
continue
}
if strings.HasPrefix(line, "event:") {
event.Event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
}
}
func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
go func() {
defer func() {
if r := recover(); r != nil {
// Convert panic to error and send it
log.Printf("panic: %v\n", r)
debug.PrintStack()
err, ok := r.(error)
if !ok {
err = fmt.Errorf("anthropic backend panic: %v", r)
}
rtn <- makeAIError(err)
}
// Always close the channel
close(rtn)
}()
if request.Opts == nil {
rtn <- makeAIError(errors.New("no anthropic opts found"))
return
}
model := request.Opts.Model
if model == "" {
model = "claude-3-sonnet-20240229" // default model
}
// Convert messages format
var messages []anthropicMessage
var systemPrompt string
for _, msg := range request.Prompt {
if msg.Role == "system" {
if systemPrompt != "" {
systemPrompt += "\n"
}
systemPrompt += msg.Content
continue
}
role := "user"
if msg.Role == "assistant" {
role = "assistant"
}
messages = append(messages, anthropicMessage{
Role: role,
Content: msg.Content,
})
}
anthropicReq := anthropicRequest{
Model: model,
Messages: messages,
System: systemPrompt,
Stream: true,
MaxTokens: request.Opts.MaxTokens,
}
reqBody, err := json.Marshal(anthropicReq)
if err != nil {
rtn <- makeAIError(fmt.Errorf("failed to marshal anthropic request: %v", err))
return
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.anthropic.com/v1/messages", strings.NewReader(string(reqBody)))
if err != nil {
rtn <- makeAIError(fmt.Errorf("failed to create anthropic request: %v", err))
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("x-api-key", request.Opts.APIToken)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
rtn <- makeAIError(fmt.Errorf("failed to send anthropic request: %v", err))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
rtn <- makeAIError(fmt.Errorf("Anthropic API error: %s - %s", resp.Status, string(bodyBytes)))
return
}
reader := bufio.NewReader(resp.Body)
for {
// Check for context cancellation
select {
case <-ctx.Done():
rtn <- makeAIError(fmt.Errorf("request cancelled: %v", ctx.Err()))
return
default:
}
sse, err := parseSSE(reader)
if err == io.EOF {
break
}
if err != nil {
rtn <- makeAIError(fmt.Errorf("error reading SSE stream: %v", err))
break
}
if sse.Event == "ping" {
continue // Ignore ping events
}
var event anthropicStreamEvent
if err := json.Unmarshal([]byte(sse.Data), &event); err != nil {
rtn <- makeAIError(fmt.Errorf("error unmarshaling event data: %v", err))
break
}
if event.Error != nil {
rtn <- makeAIError(fmt.Errorf("Anthropic API error: %s - %s", event.Error.Type, event.Error.Message))
break
}
switch sse.Event {
case "message_start":
if event.Message != nil {
pk := MakeOpenAIPacket()
pk.Model = event.Message.Model
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
}
case "content_block_start":
if event.ContentBlock != nil && event.ContentBlock.Text != "" {
pk := MakeOpenAIPacket()
pk.Text = event.ContentBlock.Text
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
}
case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" {
pk := MakeOpenAIPacket()
pk.Text = event.Delta.Text
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
}
case "content_block_stop":
// Note: According to the docs, this just signals the end of a content block
// We might want to use this for tracking block boundaries, but for now
// we don't need to send anything special to match OpenAI's format
case "message_delta":
// Update message metadata, usage stats
if event.Usage != nil {
pk := MakeOpenAIPacket()
pk.Usage = &wshrpc.OpenAIUsageType{
PromptTokens: event.Usage.InputTokens,
CompletionTokens: event.Usage.OutputTokens,
TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens,
}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
}
case "message_stop":
if event.Message != nil {
pk := MakeOpenAIPacket()
pk.FinishReason = event.Message.StopReason
if event.Message.Usage != nil {
pk.Usage = &wshrpc.OpenAIUsageType{
PromptTokens: event.Message.Usage.InputTokens,
CompletionTokens: event.Message.Usage.OutputTokens,
TotalTokens: event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens,
}
}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
}
default:
rtn <- makeAIError(fmt.Errorf("unknown Anthropic event type: %s", sse.Event))
return
}
}
}()
return rtn
}

115
pkg/waveai/cloudbackend.go Normal file
View File

@ -0,0 +1,115 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package waveai
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"github.com/gorilla/websocket"
"github.com/wavetermdev/waveterm/pkg/wcloud"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type WaveAICloudBackend struct{}
var _ AIBackend = WaveAICloudBackend{}
type OpenAICloudReqPacketType struct {
Type string `json:"type"`
ClientId string `json:"clientid"`
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
}
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{
Type: OpenAICloudReqStr,
}
}
func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
wsEndpoint := wcloud.GetWSEndpoint()
go func() {
defer close(rtn)
if wsEndpoint == "" {
rtn <- makeAIError(fmt.Errorf("no cloud ws endpoint found"))
return
}
if request.Opts == nil {
rtn <- makeAIError(fmt.Errorf("no openai opts found"))
return
}
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
defer dialCancelFn()
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, wsEndpoint, nil)
if err == context.DeadlineExceeded {
rtn <- makeAIError(fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err))
return
} else if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket connect error: %v", err))
return
}
defer func() {
err = conn.Close()
if err != nil {
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
}
}()
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
for _, promptMsg := range request.Prompt {
if promptMsg.Role == "error" {
continue
}
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
}
reqPk := MakeOpenAICloudReqPacket()
reqPk.ClientId = request.ClientId
reqPk.Prompt = sendablePromptMsgs
reqPk.MaxTokens = request.Opts.MaxTokens
reqPk.MaxChoices = request.Opts.MaxChoices
configMessageBuf, err := json.Marshal(reqPk)
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, packet marshal error: %v", err))
return
}
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket write config error: %v", err))
return
}
for {
_, socketMessage, err := conn.ReadMessage()
if err == io.EOF {
break
}
if err != nil {
log.Printf("err received: %v", err)
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
break
}
var streamResp *wshrpc.OpenAIPacketType
err = json.Unmarshal(socketMessage, &streamResp)
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
break
}
if streamResp.Error == PacketEOFStr {
// got eof packet from socket
break
} else if streamResp.Error != "" {
// use error from server directly
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
break
}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp}
}
}()
return rtn
}

146
pkg/waveai/openaibackend.go Normal file
View File

@ -0,0 +1,146 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package waveai
import (
"context"
"errors"
"fmt"
"io"
"regexp"
"strings"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type OpenAIBackend struct{}
var _ AIBackend = OpenAIBackend{}
// copied from go-openai/config.go
func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
clientConfig.APIType = openaiapi.APITypeOpenAI
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
clientConfig.APIType = openaiapi.APITypeAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
clientConfig.APIType = openaiapi.APITypeAzureAD
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else {
return fmt.Errorf("invalid api type %q", opts.APIType)
}
}
func convertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
rtn = append(rtn, msg)
}
return rtn
}
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
}
return &wshrpc.OpenAIUsageType{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
go func() {
defer close(rtn)
if request.Opts == nil {
rtn <- makeAIError(errors.New("no openai opts found"))
return
}
if request.Opts.Model == "" {
rtn <- makeAIError(errors.New("no openai model specified"))
return
}
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
rtn <- makeAIError(errors.New("no api token"))
return
}
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL
}
err := setApiType(request.Opts, &clientConfig)
if err != nil {
rtn <- makeAIError(err)
return
}
if request.Opts.OrgID != "" {
clientConfig.OrgID = request.Opts.OrgID
}
if request.Opts.APIVersion != "" {
clientConfig.APIVersion = request.Opts.APIVersion
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model,
Messages: convertPrompt(request.Prompt),
MaxTokens: request.Opts.MaxTokens,
MaxCompletionTokens: request.Opts.MaxTokens,
Stream: true,
}
if request.Opts.MaxChoices > 1 {
req.N = request.Opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
return
}
sentHeader := false
for {
streamResp, err := apiResp.Recv()
if err == io.EOF {
break
}
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, error reading message: %v", err))
break
}
if streamResp.Model != "" && !sentHeader {
pk := MakeOpenAIPacket()
pk.Model = streamResp.Model
pk.Created = streamResp.Created
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
sentHeader = true
}
for _, choice := range streamResp.Choices {
pk := MakeOpenAIPacket()
pk.Index = choice.Index
pk.Text = choice.Delta.Content
pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
}
}
}()
return rtn
}

View File

@ -5,21 +5,10 @@ package waveai
import ( import (
"context" "context"
"encoding/json"
"fmt"
"io"
"log" "log"
"os"
"regexp"
"strings"
"time" "time"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wcloud"
"github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/gorilla/websocket"
) )
const OpenAIPacketStr = "openai" const OpenAIPacketStr = "openai"
@ -47,30 +36,11 @@ type OpenAICmdInfoChatMessage struct {
UserEngineeredQuery string `json:"userengineeredquery,omitempty"` UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
} }
type OpenAICloudReqPacketType struct { type AIBackend interface {
Type string `json:"type"` StreamCompletion(
ClientId string `json:"clientid"` ctx context.Context,
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"` request wshrpc.OpenAiStreamRequest,
MaxTokens int `json:"maxtokens,omitempty"` ) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]
MaxChoices int `json:"maxchoices,omitempty"`
}
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{
Type: OpenAICloudReqStr,
}
}
func GetWSEndpoint() string {
if !wavebase.IsDevMode() {
return WCloudWSEndpoint
} else {
endpoint := os.Getenv(WCloudWSEndpointVarName)
if endpoint == "" {
panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
}
return endpoint
}
} }
const DefaultMaxTokens = 2048 const DefaultMaxTokens = 2048
@ -87,223 +57,27 @@ func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
return opts.BaseURL == "" && opts.APIToken == "" return opts.BaseURL == "" && opts.APIToken == ""
} }
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
}
return &wshrpc.OpenAIUsageType{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
func ConvertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
rtn = append(rtn, msg)
}
return rtn
}
func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err} return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
} }
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
if request.Opts.APIType == "anthropic" {
endpoint := request.Opts.BaseURL
if endpoint == "" {
endpoint = "default"
}
log.Printf("sending ai chat message to anthropic endpoint %q using model %s\n", endpoint, request.Opts.Model)
anthropicBackend := AnthropicBackend{}
return anthropicBackend.StreamCompletion(ctx, request)
}
if IsCloudAIRequest(request.Opts) { if IsCloudAIRequest(request.Opts) {
log.Print("sending ai chat message to default waveterm cloud endpoint\n") log.Print("sending ai chat message to default waveterm cloud endpoint\n")
return RunCloudCompletionStream(ctx, request) cloudBackend := WaveAICloudBackend{}
} return cloudBackend.StreamCompletion(ctx, request)
log.Printf("sending ai chat message to user-configured endpoint %s using model %s\n", request.Opts.BaseURL, request.Opts.Model)
return RunLocalCompletionStream(ctx, request)
}
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
wsEndpoint := wcloud.GetWSEndpoint()
go func() {
defer close(rtn)
if wsEndpoint == "" {
rtn <- makeAIError(fmt.Errorf("no cloud ws endpoint found"))
return
}
if request.Opts == nil {
rtn <- makeAIError(fmt.Errorf("no openai opts found"))
return
}
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
defer dialCancelFn()
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, wsEndpoint, nil)
if err == context.DeadlineExceeded {
rtn <- makeAIError(fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err))
return
} else if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket connect error: %v", err))
return
}
defer func() {
err = conn.Close()
if err != nil {
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
}
}()
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
for _, promptMsg := range request.Prompt {
if promptMsg.Role == "error" {
continue
}
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
}
reqPk := MakeOpenAICloudReqPacket()
reqPk.ClientId = request.ClientId
reqPk.Prompt = sendablePromptMsgs
reqPk.MaxTokens = request.Opts.MaxTokens
reqPk.MaxChoices = request.Opts.MaxChoices
configMessageBuf, err := json.Marshal(reqPk)
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, packet marshal error: %v", err))
return
}
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket write config error: %v", err))
return
}
for {
_, socketMessage, err := conn.ReadMessage()
if err == io.EOF {
break
}
if err != nil {
log.Printf("err received: %v", err)
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
break
}
var streamResp *wshrpc.OpenAIPacketType
err = json.Unmarshal(socketMessage, &streamResp)
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
break
}
if streamResp.Error == PacketEOFStr {
// got eof packet from socket
break
} else if streamResp.Error != "" {
// use error from server directly
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
break
}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp}
}
}()
return rtn
}
// copied from go-openai/config.go
func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
clientConfig.APIType = openaiapi.APITypeOpenAI
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
clientConfig.APIType = openaiapi.APITypeAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
clientConfig.APIType = openaiapi.APITypeAzureAD
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else { } else {
return fmt.Errorf("invalid api type %q", opts.APIType) log.Printf("sending ai chat message to user-configured endpoint %s using model %s\n", request.Opts.BaseURL, request.Opts.Model)
openAIBackend := OpenAIBackend{}
return openAIBackend.StreamCompletion(ctx, request)
} }
} }
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
go func() {
defer close(rtn)
if request.Opts == nil {
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
return
}
if request.Opts.Model == "" {
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")}
return
}
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no api token")}
return
}
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL
}
err := setApiType(request.Opts, &clientConfig)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
return
}
if request.Opts.OrgID != "" {
clientConfig.OrgID = request.Opts.OrgID
}
if request.Opts.APIVersion != "" {
clientConfig.APIVersion = request.Opts.APIVersion
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model,
Messages: ConvertPrompt(request.Prompt),
MaxTokens: request.Opts.MaxTokens,
MaxCompletionTokens: request.Opts.MaxTokens,
Stream: true,
}
if request.Opts.MaxChoices > 1 {
req.N = request.Opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)}
return
}
sentHeader := false
for {
streamResp, err := apiResp.Recv()
if err == io.EOF {
break
}
if err != nil {
log.Printf("err received2: %v", err)
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
break
}
if streamResp.Model != "" && !sentHeader {
pk := MakeOpenAIPacket()
pk.Model = streamResp.Model
pk.Created = streamResp.Created
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
sentHeader = true
}
for _, choice := range streamResp.Choices {
pk := MakeOpenAIPacket()
pk.Index = choice.Index
pk.Text = choice.Delta.Content
pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
}
}
}()
return rtn
}