From 95fd00617eb388e0042eae12aac9b17ddfa54a98 Mon Sep 17 00:00:00 2001 From: Mike Sawka Date: Mon, 11 Nov 2024 17:11:09 -0800 Subject: [PATCH] fixes for o1 models (#1269) --- pkg/waveai/openaibackend.go | 60 +++++++++++++++++++++++++++++++++---- pkg/waveai/waveai.go | 3 +- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/pkg/waveai/openaibackend.go b/pkg/waveai/openaibackend.go index 12a0f6790..a7c7b4388 100644 --- a/pkg/waveai/openaibackend.go +++ b/pkg/waveai/openaibackend.go @@ -8,7 +8,9 @@ import ( "errors" "fmt" "io" + "log" "regexp" + "runtime/debug" "strings" openaiapi "github.com/sashabaranov/go-openai" @@ -72,7 +74,20 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) go func() { - defer close(rtn) + defer func() { + if r := recover(); r != nil { + // Convert panic to error and send it + log.Printf("panic: %v\n", r) + debug.PrintStack() + err, ok := r.(error) + if !ok { + err = fmt.Errorf("openai backend panic: %v", r) + } + rtn <- makeAIError(err) + } + // Always close the channel + close(rtn) + }() if request.Opts == nil { rtn <- makeAIError(errors.New("no openai opts found")) return @@ -85,6 +100,7 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi rtn <- makeAIError(errors.New("no api token")) return } + clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken) if request.Opts.BaseURL != "" { clientConfig.BaseURL = request.Opts.BaseURL @@ -100,17 +116,49 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi if request.Opts.APIVersion != "" { clientConfig.APIVersion = request.Opts.APIVersion } + client := openaiapi.NewClientWithConfig(clientConfig) req := openaiapi.ChatCompletionRequest{ - Model: request.Opts.Model, - Messages: convertPrompt(request.Prompt), - MaxTokens: request.Opts.MaxTokens, - MaxCompletionTokens: request.Opts.MaxTokens, - Stream: true, + Model: request.Opts.Model, + Messages: convertPrompt(request.Prompt), } + + // Handle o1 models differently - use non-streaming API + if strings.HasPrefix(request.Opts.Model, "o1-") { + req.MaxCompletionTokens = request.Opts.MaxTokens + req.Stream = false + + // Make non-streaming API call + resp, err := client.CreateChatCompletion(ctx, req) + if err != nil { + rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err)) + return + } + + // Send header packet + headerPk := MakeOpenAIPacket() + headerPk.Model = resp.Model + headerPk.Created = resp.Created + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *headerPk} + + // Send content packet(s) + for i, choice := range resp.Choices { + pk := MakeOpenAIPacket() + pk.Index = i + pk.Text = choice.Message.Content + pk.FinishReason = string(choice.FinishReason) + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + } + return + } + + // Original streaming implementation for non-o1 models + req.Stream = true + req.MaxTokens = request.Opts.MaxTokens if request.Opts.MaxChoices > 1 { req.N = request.Opts.MaxChoices } + apiResp, err := client.CreateChatCompletionStream(ctx, req) if err != nil { rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err)) diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go index 3fb1d31a7..94f0a4ca1 100644 --- a/pkg/waveai/waveai.go +++ b/pkg/waveai/waveai.go @@ -15,6 +15,7 @@ const OpenAIPacketStr = "openai" const OpenAICloudReqStr = "openai-cloudreq" const PacketEOFStr = "EOF" const DefaultAzureAPIVersion = "2023-05-15" +const ApiType_Anthropic = "anthropic" type OpenAICmdInfoPacketOutputType struct { Model string `json:"model,omitempty"` @@ -62,7 +63,7 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { } func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { - if request.Opts.APIType == "anthropic" { + if request.Opts.APIType == ApiType_Anthropic { endpoint := request.Opts.BaseURL if endpoint == "" { endpoint = "default"