mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-09 19:48:45 +01:00
dbacae8a99
A bunch of the Wave AI types still mentioned OpenAI. Now that most of them are being used for multiple AI backends, we need to update the names to be more generic.
176 lines
5.2 KiB
Go
176 lines
5.2 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package waveai
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"regexp"
|
|
"strings"
|
|
|
|
openaiapi "github.com/sashabaranov/go-openai"
|
|
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
|
)
|
|
|
|
type OpenAIBackend struct{}
|
|
|
|
var _ AIBackend = OpenAIBackend{}
|
|
|
|
// copied from go-openai/config.go
|
|
func defaultAzureMapperFn(model string) string {
|
|
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
|
}
|
|
|
|
func setApiType(opts *wshrpc.WaveAIOptsType, clientConfig *openaiapi.ClientConfig) error {
|
|
ourApiType := strings.ToLower(opts.APIType)
|
|
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
|
|
clientConfig.APIType = openaiapi.APITypeOpenAI
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
|
|
clientConfig.APIType = openaiapi.APITypeAzure
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
|
|
clientConfig.APIType = openaiapi.APITypeAzureAD
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
|
|
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("invalid api type %q", opts.APIType)
|
|
}
|
|
}
|
|
|
|
func convertPrompt(prompt []wshrpc.WaveAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
|
var rtn []openaiapi.ChatCompletionMessage
|
|
for _, p := range prompt {
|
|
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
|
rtn = append(rtn, msg)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
|
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
|
|
go func() {
|
|
defer func() {
|
|
panicErr := panichandler.PanicHandler("OpenAIBackend.StreamCompletion")
|
|
if panicErr != nil {
|
|
rtn <- makeAIError(panicErr)
|
|
}
|
|
close(rtn)
|
|
}()
|
|
if request.Opts == nil {
|
|
rtn <- makeAIError(errors.New("no openai opts found"))
|
|
return
|
|
}
|
|
if request.Opts.Model == "" {
|
|
rtn <- makeAIError(errors.New("no openai model specified"))
|
|
return
|
|
}
|
|
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
|
rtn <- makeAIError(errors.New("no api token"))
|
|
return
|
|
}
|
|
|
|
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
|
if request.Opts.BaseURL != "" {
|
|
clientConfig.BaseURL = request.Opts.BaseURL
|
|
}
|
|
err := setApiType(request.Opts, &clientConfig)
|
|
if err != nil {
|
|
rtn <- makeAIError(err)
|
|
return
|
|
}
|
|
if request.Opts.OrgID != "" {
|
|
clientConfig.OrgID = request.Opts.OrgID
|
|
}
|
|
if request.Opts.APIVersion != "" {
|
|
clientConfig.APIVersion = request.Opts.APIVersion
|
|
}
|
|
|
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
|
req := openaiapi.ChatCompletionRequest{
|
|
Model: request.Opts.Model,
|
|
Messages: convertPrompt(request.Prompt),
|
|
}
|
|
|
|
// Handle o1 models differently - use non-streaming API
|
|
if strings.HasPrefix(request.Opts.Model, "o1-") {
|
|
req.MaxCompletionTokens = request.Opts.MaxTokens
|
|
req.Stream = false
|
|
|
|
// Make non-streaming API call
|
|
resp, err := client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
|
return
|
|
}
|
|
|
|
// Send header packet
|
|
headerPk := MakeWaveAIPacket()
|
|
headerPk.Model = resp.Model
|
|
headerPk.Created = resp.Created
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *headerPk}
|
|
|
|
// Send content packet(s)
|
|
for i, choice := range resp.Choices {
|
|
pk := MakeWaveAIPacket()
|
|
pk.Index = i
|
|
pk.Text = choice.Message.Content
|
|
pk.FinishReason = string(choice.FinishReason)
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Original streaming implementation for non-o1 models
|
|
req.Stream = true
|
|
req.MaxTokens = request.Opts.MaxTokens
|
|
if request.Opts.MaxChoices > 1 {
|
|
req.N = request.Opts.MaxChoices
|
|
}
|
|
|
|
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
|
return
|
|
}
|
|
sentHeader := false
|
|
for {
|
|
streamResp, err := apiResp.Recv()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, error reading message: %v", err))
|
|
break
|
|
}
|
|
if streamResp.Model != "" && !sentHeader {
|
|
pk := MakeWaveAIPacket()
|
|
pk.Model = streamResp.Model
|
|
pk.Created = streamResp.Created
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
|
|
sentHeader = true
|
|
}
|
|
for _, choice := range streamResp.Choices {
|
|
pk := MakeWaveAIPacket()
|
|
pk.Index = choice.Index
|
|
pk.Text = choice.Delta.Content
|
|
pk.FinishReason = string(choice.FinishReason)
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
|
|
}
|
|
}
|
|
}()
|
|
return rtn
|
|
}
|