fixes for o1 models (#1269)

This commit is contained in:
Mike Sawka 2024-11-11 17:11:09 -08:00 committed by GitHub
parent 29e54c8263
commit 95fd00617e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 7 deletions

View File

@ -8,7 +8,9 @@ import (
"errors"
"fmt"
"io"
"log"
"regexp"
"runtime/debug"
"strings"
openaiapi "github.com/sashabaranov/go-openai"
@ -72,7 +74,20 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
go func() {
defer close(rtn)
defer func() {
if r := recover(); r != nil {
// Convert panic to error and send it
log.Printf("panic: %v\n", r)
debug.PrintStack()
err, ok := r.(error)
if !ok {
err = fmt.Errorf("openai backend panic: %v", r)
}
rtn <- makeAIError(err)
}
// Always close the channel
close(rtn)
}()
if request.Opts == nil {
rtn <- makeAIError(errors.New("no openai opts found"))
return
@ -85,6 +100,7 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
rtn <- makeAIError(errors.New("no api token"))
return
}
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL
@ -100,17 +116,49 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
if request.Opts.APIVersion != "" {
clientConfig.APIVersion = request.Opts.APIVersion
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model,
Messages: convertPrompt(request.Prompt),
MaxTokens: request.Opts.MaxTokens,
MaxCompletionTokens: request.Opts.MaxTokens,
Stream: true,
Model: request.Opts.Model,
Messages: convertPrompt(request.Prompt),
}
// Handle o1 models differently - use non-streaming API
if strings.HasPrefix(request.Opts.Model, "o1-") {
req.MaxCompletionTokens = request.Opts.MaxTokens
req.Stream = false
// Make non-streaming API call
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
return
}
// Send header packet
headerPk := MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *headerPk}
// Send content packet(s)
for i, choice := range resp.Choices {
pk := MakeOpenAIPacket()
pk.Index = i
pk.Text = choice.Message.Content
pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
}
return
}
// Original streaming implementation for non-o1 models
req.Stream = true
req.MaxTokens = request.Opts.MaxTokens
if request.Opts.MaxChoices > 1 {
req.N = request.Opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))

View File

@ -15,6 +15,7 @@ const OpenAIPacketStr = "openai"
const OpenAICloudReqStr = "openai-cloudreq"
const PacketEOFStr = "EOF"
const DefaultAzureAPIVersion = "2023-05-15"
const ApiType_Anthropic = "anthropic"
type OpenAICmdInfoPacketOutputType struct {
Model string `json:"model,omitempty"`
@ -62,7 +63,7 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
}
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
if request.Opts.APIType == "anthropic" {
if request.Opts.APIType == ApiType_Anthropic {
endpoint := request.Opts.BaseURL
if endpoint == "" {
endpoint = "default"