waveterm/wavesrv/pkg/remote/openai/openai.go
2024-02-26 22:21:28 -08:00

230 lines
6.6 KiB
Go

// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package openai
import (
"context"
"encoding/json"
"fmt"
"io"
"time"
"github.com/gorilla/websocket"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/wavesrv/pkg/pcloud"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
)
// https://github.com/tiktoken-go/tokenizer
const DefaultMaxTokens = 1000
const DefaultModel = "gpt-3.5-turbo"
const DefaultStreamChanSize = 10
const CloudWebsocketConnectTimeout = 1 * time.Minute
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
}
return &packet.OpenAIUsageType{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
func ConvertPrompt(prompt []packet.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
rtn = append(rtn, msg)
}
return rtn
}
func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) {
if opts == nil {
return nil, fmt.Errorf("no openai opts found")
}
if opts.Model == "" {
return nil, fmt.Errorf("no openai model specified")
}
if opts.APIToken == "" {
return nil, fmt.Errorf("no api token")
}
clientConfig := openaiapi.DefaultConfig(opts.APIToken)
if opts.BaseURL != "" {
clientConfig.BaseURL = opts.BaseURL
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: opts.Model,
Messages: ConvertPrompt(prompt),
MaxTokens: opts.MaxTokens,
}
if opts.MaxChoices > 1 {
req.N = opts.MaxChoices
}
apiResp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
return nil, fmt.Errorf("error calling openai API: %v", err)
}
if len(apiResp.Choices) == 0 {
return nil, fmt.Errorf("no response received")
}
return marshalResponse(apiResp), nil
}
func RunCloudCompletionStream(ctx context.Context, clientId string, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, *websocket.Conn, error) {
if opts == nil {
return nil, nil, fmt.Errorf("no openai opts found")
}
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
defer dialCancelFn()
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, pcloud.GetWSEndpoint(), nil)
if err == context.DeadlineExceeded {
return nil, nil, fmt.Errorf("OpenAI request, timed out connected to cloud server: %v", err)
} else if err != nil {
return nil, nil, fmt.Errorf("OpenAI request, websocket connect error: %v", err)
}
reqPk := packet.MakeOpenAICloudReqPacket()
reqPk.ClientId = clientId
reqPk.Prompt = prompt
reqPk.MaxTokens = opts.MaxTokens
reqPk.MaxChoices = opts.MaxChoices
configMessageBuf, err := json.Marshal(reqPk)
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
if err != nil {
return nil, nil, fmt.Errorf("OpenAI request, websocket write config error: %v", err)
}
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
go func() {
defer close(rtn)
defer conn.Close()
for {
_, socketMessage, err := conn.ReadMessage()
if err == io.EOF {
break
}
if err != nil {
errPk := CreateErrorPacket(fmt.Sprintf("OpenAI request, websocket error reading message: %v", err))
rtn <- errPk
break
}
var streamResp *packet.OpenAIPacketType
err = json.Unmarshal(socketMessage, &streamResp)
if err != nil {
errPk := CreateErrorPacket(fmt.Sprintf("OpenAI request, websocket response json decode error: %v", err))
rtn <- errPk
break
}
if streamResp.Error == packet.PacketEOFStr {
// got eof packet from socket
break
} else if streamResp.Error != "" {
// use error from server directly
errPk := CreateErrorPacket(streamResp.Error)
rtn <- errPk
break
}
rtn <- streamResp
}
}()
return rtn, conn, err
}
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
if opts == nil {
return nil, fmt.Errorf("no openai opts found")
}
if opts.Model == "" {
return nil, fmt.Errorf("no openai model specified")
}
if opts.APIToken == "" {
return nil, fmt.Errorf("no api token")
}
clientConfig := openaiapi.DefaultConfig(opts.APIToken)
if opts.BaseURL != "" {
clientConfig.BaseURL = opts.BaseURL
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: opts.Model,
Messages: ConvertPrompt(prompt),
MaxTokens: opts.MaxTokens,
Stream: true,
}
if opts.MaxChoices > 1 {
req.N = opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
return nil, fmt.Errorf("error calling openai API: %v", err)
}
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
go func() {
sentHeader := false
defer close(rtn)
for {
streamResp, err := apiResp.Recv()
if err == io.EOF {
break
}
if err != nil {
errPk := CreateErrorPacket(fmt.Sprintf("error in recv of streaming data: %v", err))
rtn <- errPk
break
}
if streamResp.Model != "" && !sentHeader {
pk := packet.MakeOpenAIPacket()
pk.Model = streamResp.Model
pk.Created = streamResp.Created
rtn <- pk
sentHeader = true
}
for _, choice := range streamResp.Choices {
pk := packet.MakeOpenAIPacket()
pk.Index = choice.Index
pk.Text = choice.Delta.Content
pk.FinishReason = choice.FinishReason
rtn <- pk
}
}
}()
return rtn, err
}
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*packet.OpenAIPacketType {
var rtn []*packet.OpenAIPacketType
headerPk := packet.MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
headerPk.Usage = convertUsage(resp)
rtn = append(rtn, headerPk)
for _, choice := range resp.Choices {
choicePk := packet.MakeOpenAIPacket()
choicePk.Index = choice.Index
choicePk.Text = choice.Message.Content
choicePk.FinishReason = choice.FinishReason
rtn = append(rtn, choicePk)
}
return rtn
}
func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
errPk := packet.MakeOpenAIPacket()
errPk.FinishReason = "error"
errPk.Error = errStr
return errPk
}
func CreateTextPacket(text string) *packet.OpenAIPacketType {
pk := packet.MakeOpenAIPacket()
pk.Text = text
return pk
}