mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-22 16:48:23 +01:00
fixes for o1 models (#1269)
This commit is contained in:
parent
29e54c8263
commit
95fd00617e
@ -8,7 +8,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
openaiapi "github.com/sashabaranov/go-openai"
|
openaiapi "github.com/sashabaranov/go-openai"
|
||||||
@ -72,7 +74,20 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType
|
|||||||
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||||
go func() {
|
go func() {
|
||||||
defer close(rtn)
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// Convert panic to error and send it
|
||||||
|
log.Printf("panic: %v\n", r)
|
||||||
|
debug.PrintStack()
|
||||||
|
err, ok := r.(error)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("openai backend panic: %v", r)
|
||||||
|
}
|
||||||
|
rtn <- makeAIError(err)
|
||||||
|
}
|
||||||
|
// Always close the channel
|
||||||
|
close(rtn)
|
||||||
|
}()
|
||||||
if request.Opts == nil {
|
if request.Opts == nil {
|
||||||
rtn <- makeAIError(errors.New("no openai opts found"))
|
rtn <- makeAIError(errors.New("no openai opts found"))
|
||||||
return
|
return
|
||||||
@ -85,6 +100,7 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
|
|||||||
rtn <- makeAIError(errors.New("no api token"))
|
rtn <- makeAIError(errors.New("no api token"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
||||||
if request.Opts.BaseURL != "" {
|
if request.Opts.BaseURL != "" {
|
||||||
clientConfig.BaseURL = request.Opts.BaseURL
|
clientConfig.BaseURL = request.Opts.BaseURL
|
||||||
@ -100,17 +116,49 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
|
|||||||
if request.Opts.APIVersion != "" {
|
if request.Opts.APIVersion != "" {
|
||||||
clientConfig.APIVersion = request.Opts.APIVersion
|
clientConfig.APIVersion = request.Opts.APIVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||||
req := openaiapi.ChatCompletionRequest{
|
req := openaiapi.ChatCompletionRequest{
|
||||||
Model: request.Opts.Model,
|
Model: request.Opts.Model,
|
||||||
Messages: convertPrompt(request.Prompt),
|
Messages: convertPrompt(request.Prompt),
|
||||||
MaxTokens: request.Opts.MaxTokens,
|
|
||||||
MaxCompletionTokens: request.Opts.MaxTokens,
|
|
||||||
Stream: true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle o1 models differently - use non-streaming API
|
||||||
|
if strings.HasPrefix(request.Opts.Model, "o1-") {
|
||||||
|
req.MaxCompletionTokens = request.Opts.MaxTokens
|
||||||
|
req.Stream = false
|
||||||
|
|
||||||
|
// Make non-streaming API call
|
||||||
|
resp, err := client.CreateChatCompletion(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send header packet
|
||||||
|
headerPk := MakeOpenAIPacket()
|
||||||
|
headerPk.Model = resp.Model
|
||||||
|
headerPk.Created = resp.Created
|
||||||
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *headerPk}
|
||||||
|
|
||||||
|
// Send content packet(s)
|
||||||
|
for i, choice := range resp.Choices {
|
||||||
|
pk := MakeOpenAIPacket()
|
||||||
|
pk.Index = i
|
||||||
|
pk.Text = choice.Message.Content
|
||||||
|
pk.FinishReason = string(choice.FinishReason)
|
||||||
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original streaming implementation for non-o1 models
|
||||||
|
req.Stream = true
|
||||||
|
req.MaxTokens = request.Opts.MaxTokens
|
||||||
if request.Opts.MaxChoices > 1 {
|
if request.Opts.MaxChoices > 1 {
|
||||||
req.N = request.Opts.MaxChoices
|
req.N = request.Opts.MaxChoices
|
||||||
}
|
}
|
||||||
|
|
||||||
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
||||||
|
@ -15,6 +15,7 @@ const OpenAIPacketStr = "openai"
|
|||||||
const OpenAICloudReqStr = "openai-cloudreq"
|
const OpenAICloudReqStr = "openai-cloudreq"
|
||||||
const PacketEOFStr = "EOF"
|
const PacketEOFStr = "EOF"
|
||||||
const DefaultAzureAPIVersion = "2023-05-15"
|
const DefaultAzureAPIVersion = "2023-05-15"
|
||||||
|
const ApiType_Anthropic = "anthropic"
|
||||||
|
|
||||||
type OpenAICmdInfoPacketOutputType struct {
|
type OpenAICmdInfoPacketOutputType struct {
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
@ -62,7 +63,7 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||||
if request.Opts.APIType == "anthropic" {
|
if request.Opts.APIType == ApiType_Anthropic {
|
||||||
endpoint := request.Opts.BaseURL
|
endpoint := request.Opts.BaseURL
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
endpoint = "default"
|
endpoint = "default"
|
||||||
|
Loading…
Reference in New Issue
Block a user