mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
fixes for o1 models (#1269)
This commit is contained in:
parent
29e54c8263
commit
95fd00617e
@ -8,7 +8,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
openaiapi "github.com/sashabaranov/go-openai"
|
||||
@ -72,7 +74,20 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType
|
||||
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||
go func() {
|
||||
defer close(rtn)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Convert panic to error and send it
|
||||
log.Printf("panic: %v\n", r)
|
||||
debug.PrintStack()
|
||||
err, ok := r.(error)
|
||||
if !ok {
|
||||
err = fmt.Errorf("openai backend panic: %v", r)
|
||||
}
|
||||
rtn <- makeAIError(err)
|
||||
}
|
||||
// Always close the channel
|
||||
close(rtn)
|
||||
}()
|
||||
if request.Opts == nil {
|
||||
rtn <- makeAIError(errors.New("no openai opts found"))
|
||||
return
|
||||
@ -85,6 +100,7 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
|
||||
rtn <- makeAIError(errors.New("no api token"))
|
||||
return
|
||||
}
|
||||
|
||||
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
||||
if request.Opts.BaseURL != "" {
|
||||
clientConfig.BaseURL = request.Opts.BaseURL
|
||||
@ -100,17 +116,49 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
|
||||
if request.Opts.APIVersion != "" {
|
||||
clientConfig.APIVersion = request.Opts.APIVersion
|
||||
}
|
||||
|
||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: request.Opts.Model,
|
||||
Messages: convertPrompt(request.Prompt),
|
||||
MaxTokens: request.Opts.MaxTokens,
|
||||
MaxCompletionTokens: request.Opts.MaxTokens,
|
||||
Stream: true,
|
||||
Model: request.Opts.Model,
|
||||
Messages: convertPrompt(request.Prompt),
|
||||
}
|
||||
|
||||
// Handle o1 models differently - use non-streaming API
|
||||
if strings.HasPrefix(request.Opts.Model, "o1-") {
|
||||
req.MaxCompletionTokens = request.Opts.MaxTokens
|
||||
req.Stream = false
|
||||
|
||||
// Make non-streaming API call
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Send header packet
|
||||
headerPk := MakeOpenAIPacket()
|
||||
headerPk.Model = resp.Model
|
||||
headerPk.Created = resp.Created
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *headerPk}
|
||||
|
||||
// Send content packet(s)
|
||||
for i, choice := range resp.Choices {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Index = i
|
||||
pk.Text = choice.Message.Content
|
||||
pk.FinishReason = string(choice.FinishReason)
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Original streaming implementation for non-o1 models
|
||||
req.Stream = true
|
||||
req.MaxTokens = request.Opts.MaxTokens
|
||||
if request.Opts.MaxChoices > 1 {
|
||||
req.N = request.Opts.MaxChoices
|
||||
}
|
||||
|
||||
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
||||
|
@ -15,6 +15,7 @@ const OpenAIPacketStr = "openai"
|
||||
const OpenAICloudReqStr = "openai-cloudreq"
|
||||
const PacketEOFStr = "EOF"
|
||||
const DefaultAzureAPIVersion = "2023-05-15"
|
||||
const ApiType_Anthropic = "anthropic"
|
||||
|
||||
type OpenAICmdInfoPacketOutputType struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
@ -62,7 +63,7 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
}
|
||||
|
||||
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
if request.Opts.APIType == "anthropic" {
|
||||
if request.Opts.APIType == ApiType_Anthropic {
|
||||
endpoint := request.Opts.BaseURL
|
||||
if endpoint == "" {
|
||||
endpoint = "default"
|
||||
|
Loading…
Reference in New Issue
Block a user