waveterm/pkg/waveai/openaibackend.go
Evan Simkowitz dbacae8a99
Rename outdated WaveAI types (#1609)
A bunch of the Wave AI types still mentioned OpenAI. Now that most of
them are being used for multiple AI backends, we need to update the
names to be more generic.
2024-12-23 10:55:04 -08:00

176 lines
5.2 KiB
Go

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package waveai
import (
"context"
"errors"
"fmt"
"io"
"regexp"
"strings"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type OpenAIBackend struct{}
var _ AIBackend = OpenAIBackend{}
// copied from go-openai/config.go
func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}
func setApiType(opts *wshrpc.WaveAIOptsType, clientConfig *openaiapi.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
clientConfig.APIType = openaiapi.APITypeOpenAI
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
clientConfig.APIType = openaiapi.APITypeAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
clientConfig.APIType = openaiapi.APITypeAzureAD
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else {
return fmt.Errorf("invalid api type %q", opts.APIType)
}
}
func convertPrompt(prompt []wshrpc.WaveAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
rtn = append(rtn, msg)
}
return rtn
}
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
go func() {
defer func() {
panicErr := panichandler.PanicHandler("OpenAIBackend.StreamCompletion")
if panicErr != nil {
rtn <- makeAIError(panicErr)
}
close(rtn)
}()
if request.Opts == nil {
rtn <- makeAIError(errors.New("no openai opts found"))
return
}
if request.Opts.Model == "" {
rtn <- makeAIError(errors.New("no openai model specified"))
return
}
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
rtn <- makeAIError(errors.New("no api token"))
return
}
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL
}
err := setApiType(request.Opts, &clientConfig)
if err != nil {
rtn <- makeAIError(err)
return
}
if request.Opts.OrgID != "" {
clientConfig.OrgID = request.Opts.OrgID
}
if request.Opts.APIVersion != "" {
clientConfig.APIVersion = request.Opts.APIVersion
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model,
Messages: convertPrompt(request.Prompt),
}
// Handle o1 models differently - use non-streaming API
if strings.HasPrefix(request.Opts.Model, "o1-") {
req.MaxCompletionTokens = request.Opts.MaxTokens
req.Stream = false
// Make non-streaming API call
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
return
}
// Send header packet
headerPk := MakeWaveAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *headerPk}
// Send content packet(s)
for i, choice := range resp.Choices {
pk := MakeWaveAIPacket()
pk.Index = i
pk.Text = choice.Message.Content
pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
}
return
}
// Original streaming implementation for non-o1 models
req.Stream = true
req.MaxTokens = request.Opts.MaxTokens
if request.Opts.MaxChoices > 1 {
req.N = request.Opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
return
}
sentHeader := false
for {
streamResp, err := apiResp.Recv()
if err == io.EOF {
break
}
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, error reading message: %v", err))
break
}
if streamResp.Model != "" && !sentHeader {
pk := MakeWaveAIPacket()
pk.Model = streamResp.Model
pk.Created = streamResp.Created
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
sentHeader = true
}
for _, choice := range streamResp.Choices {
pk := MakeWaveAIPacket()
pk.Index = choice.Index
pk.Text = choice.Delta.Content
pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
}
}
}()
return rtn
}