mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-22 16:48:23 +01:00
195 lines
5.6 KiB
Go
195 lines
5.6 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package waveai
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"regexp"
|
|
"runtime/debug"
|
|
"strings"
|
|
|
|
openaiapi "github.com/sashabaranov/go-openai"
|
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
|
)
|
|
|
|
type OpenAIBackend struct{}
|
|
|
|
var _ AIBackend = OpenAIBackend{}
|
|
|
|
// copied from go-openai/config.go
|
|
func defaultAzureMapperFn(model string) string {
|
|
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
|
}
|
|
|
|
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
|
|
ourApiType := strings.ToLower(opts.APIType)
|
|
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
|
|
clientConfig.APIType = openaiapi.APITypeOpenAI
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
|
|
clientConfig.APIType = openaiapi.APITypeAzure
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
|
|
clientConfig.APIType = openaiapi.APITypeAzureAD
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
|
|
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
|
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("invalid api type %q", opts.APIType)
|
|
}
|
|
}
|
|
|
|
func convertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
|
var rtn []openaiapi.ChatCompletionMessage
|
|
for _, p := range prompt {
|
|
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
|
rtn = append(rtn, msg)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
|
|
if resp.Usage.TotalTokens == 0 {
|
|
return nil
|
|
}
|
|
return &wshrpc.OpenAIUsageType{
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
TotalTokens: resp.Usage.TotalTokens,
|
|
}
|
|
}
|
|
|
|
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
// Convert panic to error and send it
|
|
log.Printf("panic: %v\n", r)
|
|
debug.PrintStack()
|
|
err, ok := r.(error)
|
|
if !ok {
|
|
err = fmt.Errorf("openai backend panic: %v", r)
|
|
}
|
|
rtn <- makeAIError(err)
|
|
}
|
|
// Always close the channel
|
|
close(rtn)
|
|
}()
|
|
if request.Opts == nil {
|
|
rtn <- makeAIError(errors.New("no openai opts found"))
|
|
return
|
|
}
|
|
if request.Opts.Model == "" {
|
|
rtn <- makeAIError(errors.New("no openai model specified"))
|
|
return
|
|
}
|
|
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
|
rtn <- makeAIError(errors.New("no api token"))
|
|
return
|
|
}
|
|
|
|
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
|
if request.Opts.BaseURL != "" {
|
|
clientConfig.BaseURL = request.Opts.BaseURL
|
|
}
|
|
err := setApiType(request.Opts, &clientConfig)
|
|
if err != nil {
|
|
rtn <- makeAIError(err)
|
|
return
|
|
}
|
|
if request.Opts.OrgID != "" {
|
|
clientConfig.OrgID = request.Opts.OrgID
|
|
}
|
|
if request.Opts.APIVersion != "" {
|
|
clientConfig.APIVersion = request.Opts.APIVersion
|
|
}
|
|
|
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
|
req := openaiapi.ChatCompletionRequest{
|
|
Model: request.Opts.Model,
|
|
Messages: convertPrompt(request.Prompt),
|
|
}
|
|
|
|
// Handle o1 models differently - use non-streaming API
|
|
if strings.HasPrefix(request.Opts.Model, "o1-") {
|
|
req.MaxCompletionTokens = request.Opts.MaxTokens
|
|
req.Stream = false
|
|
|
|
// Make non-streaming API call
|
|
resp, err := client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
|
return
|
|
}
|
|
|
|
// Send header packet
|
|
headerPk := MakeOpenAIPacket()
|
|
headerPk.Model = resp.Model
|
|
headerPk.Created = resp.Created
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *headerPk}
|
|
|
|
// Send content packet(s)
|
|
for i, choice := range resp.Choices {
|
|
pk := MakeOpenAIPacket()
|
|
pk.Index = i
|
|
pk.Text = choice.Message.Content
|
|
pk.FinishReason = string(choice.FinishReason)
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Original streaming implementation for non-o1 models
|
|
req.Stream = true
|
|
req.MaxTokens = request.Opts.MaxTokens
|
|
if request.Opts.MaxChoices > 1 {
|
|
req.N = request.Opts.MaxChoices
|
|
}
|
|
|
|
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
|
return
|
|
}
|
|
sentHeader := false
|
|
for {
|
|
streamResp, err := apiResp.Recv()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
rtn <- makeAIError(fmt.Errorf("OpenAI request, error reading message: %v", err))
|
|
break
|
|
}
|
|
if streamResp.Model != "" && !sentHeader {
|
|
pk := MakeOpenAIPacket()
|
|
pk.Model = streamResp.Model
|
|
pk.Created = streamResp.Created
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
|
sentHeader = true
|
|
}
|
|
for _, choice := range streamResp.Choices {
|
|
pk := MakeOpenAIPacket()
|
|
pk.Index = choice.Index
|
|
pk.Text = choice.Delta.Content
|
|
pk.FinishReason = string(choice.FinishReason)
|
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
|
}
|
|
}
|
|
}()
|
|
return rtn
|
|
}
|