mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-22 21:42:49 +01:00
159 lines
4.3 KiB
Go
159 lines
4.3 KiB
Go
// Copyright 2023, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package openai
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
|
|
openaiapi "github.com/sashabaranov/go-openai"
|
|
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
|
)
|
|
|
|
// https://github.com/tiktoken-go/tokenizer
|
|
|
|
const DefaultMaxTokens = 1000
|
|
const DefaultModel = "gpt-3.5-turbo"
|
|
const DefaultStreamChanSize = 10
|
|
|
|
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
|
|
if resp.Usage.TotalTokens == 0 {
|
|
return nil
|
|
}
|
|
return &packet.OpenAIUsageType{
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
TotalTokens: resp.Usage.TotalTokens,
|
|
}
|
|
}
|
|
|
|
func convertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
|
var rtn []openaiapi.ChatCompletionMessage
|
|
for _, p := range prompt {
|
|
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
|
rtn = append(rtn, msg)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) {
|
|
if opts == nil {
|
|
return nil, fmt.Errorf("no openai opts found")
|
|
}
|
|
if opts.Model == "" {
|
|
return nil, fmt.Errorf("no openai model specified")
|
|
}
|
|
if opts.APIToken == "" {
|
|
return nil, fmt.Errorf("no api token")
|
|
}
|
|
clientConfig := openaiapi.DefaultConfig(opts.APIToken)
|
|
if opts.BaseURL != "" {
|
|
clientConfig.BaseURL = opts.BaseURL
|
|
}
|
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
|
req := openaiapi.ChatCompletionRequest{
|
|
Model: opts.Model,
|
|
Messages: convertPrompt(prompt),
|
|
MaxTokens: opts.MaxTokens,
|
|
}
|
|
if opts.MaxChoices > 1 {
|
|
req.N = opts.MaxChoices
|
|
}
|
|
apiResp, err := client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error calling openai API: %v", err)
|
|
}
|
|
if len(apiResp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no response received")
|
|
}
|
|
return marshalResponse(apiResp), nil
|
|
}
|
|
|
|
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
|
|
if opts == nil {
|
|
return nil, fmt.Errorf("no openai opts found")
|
|
}
|
|
if opts.Model == "" {
|
|
return nil, fmt.Errorf("no openai model specified")
|
|
}
|
|
if opts.APIToken == "" {
|
|
return nil, fmt.Errorf("no api token")
|
|
}
|
|
clientConfig := openaiapi.DefaultConfig(opts.APIToken)
|
|
if opts.BaseURL != "" {
|
|
clientConfig.BaseURL = opts.BaseURL
|
|
}
|
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
|
req := openaiapi.ChatCompletionRequest{
|
|
Model: opts.Model,
|
|
Messages: convertPrompt(prompt),
|
|
MaxTokens: opts.MaxTokens,
|
|
Stream: true,
|
|
}
|
|
if opts.MaxChoices > 1 {
|
|
req.N = opts.MaxChoices
|
|
}
|
|
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error calling openai API: %v", err)
|
|
}
|
|
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
|
|
go func() {
|
|
sentHeader := false
|
|
defer close(rtn)
|
|
for {
|
|
streamResp, err := apiResp.Recv()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
errPk := CreateErrorPacket(fmt.Sprintf("error in recv of streaming data: %v", err))
|
|
rtn <- errPk
|
|
break
|
|
}
|
|
if streamResp.Model != "" && !sentHeader {
|
|
pk := packet.MakeOpenAIPacket()
|
|
pk.Model = streamResp.Model
|
|
pk.Created = streamResp.Created
|
|
rtn <- pk
|
|
sentHeader = true
|
|
}
|
|
for _, choice := range streamResp.Choices {
|
|
pk := packet.MakeOpenAIPacket()
|
|
pk.Index = choice.Index
|
|
pk.Text = choice.Delta.Content
|
|
pk.FinishReason = choice.FinishReason
|
|
rtn <- pk
|
|
}
|
|
}
|
|
}()
|
|
return rtn, err
|
|
}
|
|
|
|
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*packet.OpenAIPacketType {
|
|
var rtn []*packet.OpenAIPacketType
|
|
headerPk := packet.MakeOpenAIPacket()
|
|
headerPk.Model = resp.Model
|
|
headerPk.Created = resp.Created
|
|
headerPk.Usage = convertUsage(resp)
|
|
rtn = append(rtn, headerPk)
|
|
for _, choice := range resp.Choices {
|
|
choicePk := packet.MakeOpenAIPacket()
|
|
choicePk.Index = choice.Index
|
|
choicePk.Text = choice.Message.Content
|
|
choicePk.FinishReason = choice.FinishReason
|
|
rtn = append(rtn, choicePk)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
|
|
errPk := packet.MakeOpenAIPacket()
|
|
errPk.FinishReason = "error"
|
|
errPk.Error = errStr
|
|
return errPk
|
|
}
|