mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-08 19:38:51 +01:00
230 lines
6.6 KiB
Go
230 lines
6.6 KiB
Go
// Copyright 2023, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package openai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
openaiapi "github.com/sashabaranov/go-openai"
|
|
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/pcloud"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
|
)
|
|
|
|
// https://github.com/tiktoken-go/tokenizer
|
|
|
|
const DefaultMaxTokens = 1000
|
|
const DefaultModel = "gpt-3.5-turbo"
|
|
const DefaultStreamChanSize = 10
|
|
|
|
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
|
|
|
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
|
|
if resp.Usage.TotalTokens == 0 {
|
|
return nil
|
|
}
|
|
return &packet.OpenAIUsageType{
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
TotalTokens: resp.Usage.TotalTokens,
|
|
}
|
|
}
|
|
|
|
func ConvertPrompt(prompt []packet.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
|
var rtn []openaiapi.ChatCompletionMessage
|
|
for _, p := range prompt {
|
|
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
|
rtn = append(rtn, msg)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) {
|
|
if opts == nil {
|
|
return nil, fmt.Errorf("no openai opts found")
|
|
}
|
|
if opts.Model == "" {
|
|
return nil, fmt.Errorf("no openai model specified")
|
|
}
|
|
if opts.APIToken == "" {
|
|
return nil, fmt.Errorf("no api token")
|
|
}
|
|
clientConfig := openaiapi.DefaultConfig(opts.APIToken)
|
|
if opts.BaseURL != "" {
|
|
clientConfig.BaseURL = opts.BaseURL
|
|
}
|
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
|
req := openaiapi.ChatCompletionRequest{
|
|
Model: opts.Model,
|
|
Messages: ConvertPrompt(prompt),
|
|
MaxTokens: opts.MaxTokens,
|
|
}
|
|
if opts.MaxChoices > 1 {
|
|
req.N = opts.MaxChoices
|
|
}
|
|
apiResp, err := client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error calling openai API: %v", err)
|
|
}
|
|
if len(apiResp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no response received")
|
|
}
|
|
return marshalResponse(apiResp), nil
|
|
}
|
|
|
|
func RunCloudCompletionStream(ctx context.Context, clientId string, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, *websocket.Conn, error) {
|
|
if opts == nil {
|
|
return nil, nil, fmt.Errorf("no openai opts found")
|
|
}
|
|
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
|
|
defer dialCancelFn()
|
|
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, pcloud.GetWSEndpoint(), nil)
|
|
if err == context.DeadlineExceeded {
|
|
return nil, nil, fmt.Errorf("OpenAI request, timed out connected to cloud server: %v", err)
|
|
} else if err != nil {
|
|
return nil, nil, fmt.Errorf("OpenAI request, websocket connect error: %v", err)
|
|
}
|
|
reqPk := packet.MakeOpenAICloudReqPacket()
|
|
reqPk.ClientId = clientId
|
|
reqPk.Prompt = prompt
|
|
reqPk.MaxTokens = opts.MaxTokens
|
|
reqPk.MaxChoices = opts.MaxChoices
|
|
configMessageBuf, err := json.Marshal(reqPk)
|
|
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("OpenAI request, websocket write config error: %v", err)
|
|
}
|
|
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
|
|
go func() {
|
|
defer close(rtn)
|
|
defer conn.Close()
|
|
for {
|
|
_, socketMessage, err := conn.ReadMessage()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
errPk := CreateErrorPacket(fmt.Sprintf("OpenAI request, websocket error reading message: %v", err))
|
|
rtn <- errPk
|
|
break
|
|
}
|
|
var streamResp *packet.OpenAIPacketType
|
|
err = json.Unmarshal(socketMessage, &streamResp)
|
|
if err != nil {
|
|
errPk := CreateErrorPacket(fmt.Sprintf("OpenAI request, websocket response json decode error: %v", err))
|
|
rtn <- errPk
|
|
break
|
|
}
|
|
if streamResp.Error == packet.PacketEOFStr {
|
|
// got eof packet from socket
|
|
break
|
|
} else if streamResp.Error != "" {
|
|
// use error from server directly
|
|
errPk := CreateErrorPacket(streamResp.Error)
|
|
rtn <- errPk
|
|
break
|
|
}
|
|
rtn <- streamResp
|
|
}
|
|
}()
|
|
return rtn, conn, err
|
|
}
|
|
|
|
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
|
|
if opts == nil {
|
|
return nil, fmt.Errorf("no openai opts found")
|
|
}
|
|
if opts.Model == "" {
|
|
return nil, fmt.Errorf("no openai model specified")
|
|
}
|
|
if opts.BaseURL == "" && opts.APIToken == "" {
|
|
return nil, fmt.Errorf("no api token")
|
|
}
|
|
clientConfig := openaiapi.DefaultConfig(opts.APIToken)
|
|
if opts.BaseURL != "" {
|
|
clientConfig.BaseURL = opts.BaseURL
|
|
}
|
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
|
req := openaiapi.ChatCompletionRequest{
|
|
Model: opts.Model,
|
|
Messages: ConvertPrompt(prompt),
|
|
MaxTokens: opts.MaxTokens,
|
|
Stream: true,
|
|
}
|
|
if opts.MaxChoices > 1 {
|
|
req.N = opts.MaxChoices
|
|
}
|
|
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error calling openai API: %v", err)
|
|
}
|
|
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
|
|
go func() {
|
|
sentHeader := false
|
|
defer close(rtn)
|
|
for {
|
|
streamResp, err := apiResp.Recv()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
errPk := CreateErrorPacket(fmt.Sprintf("error in recv of streaming data: %v", err))
|
|
rtn <- errPk
|
|
break
|
|
}
|
|
if streamResp.Model != "" && !sentHeader {
|
|
pk := packet.MakeOpenAIPacket()
|
|
pk.Model = streamResp.Model
|
|
pk.Created = streamResp.Created
|
|
rtn <- pk
|
|
sentHeader = true
|
|
}
|
|
for _, choice := range streamResp.Choices {
|
|
pk := packet.MakeOpenAIPacket()
|
|
pk.Index = choice.Index
|
|
pk.Text = choice.Delta.Content
|
|
pk.FinishReason = choice.FinishReason
|
|
rtn <- pk
|
|
}
|
|
}
|
|
}()
|
|
return rtn, err
|
|
}
|
|
|
|
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*packet.OpenAIPacketType {
|
|
var rtn []*packet.OpenAIPacketType
|
|
headerPk := packet.MakeOpenAIPacket()
|
|
headerPk.Model = resp.Model
|
|
headerPk.Created = resp.Created
|
|
headerPk.Usage = convertUsage(resp)
|
|
rtn = append(rtn, headerPk)
|
|
for _, choice := range resp.Choices {
|
|
choicePk := packet.MakeOpenAIPacket()
|
|
choicePk.Index = choice.Index
|
|
choicePk.Text = choice.Message.Content
|
|
choicePk.FinishReason = choice.FinishReason
|
|
rtn = append(rtn, choicePk)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
|
|
errPk := packet.MakeOpenAIPacket()
|
|
errPk.FinishReason = "error"
|
|
errPk.Error = errStr
|
|
return errPk
|
|
}
|
|
|
|
func CreateTextPacket(text string) *packet.OpenAIPacketType {
|
|
pk := packet.MakeOpenAIPacket()
|
|
pk.Text = text
|
|
return pk
|
|
}
|