waveterm/pkg/waveai/waveai.go
Sylvie Crowe 0e46b79c22
Open Ai Port (#154)
This brings over a simplified version of the open ai feature from the
previous app but in widget form. It still needs some work to reach
parity with that version, but this includes all of the basic building
blocks to get that working.
2024-07-25 02:30:49 -07:00

307 lines
9.8 KiB
Go

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package waveai
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"time"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/gorilla/websocket"
)
const OpenAIPacketStr = "openai"
const OpenAICloudReqStr = "openai-cloudreq"
const PacketEOFStr = "EOF"
type OpenAIUsageType struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type OpenAICmdInfoPacketOutputType struct {
Model string `json:"model,omitempty"`
Created int64 `json:"created,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
}
type OpenAIPacketType struct {
Type string `json:"type"`
Model string `json:"model,omitempty"`
Created int64 `json:"created,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Usage *OpenAIUsageType `json:"usage,omitempty"`
Index int `json:"index,omitempty"`
Text string `json:"text,omitempty"`
Error string `json:"error,omitempty"`
}
func MakeOpenAIPacket() *OpenAIPacketType {
return &OpenAIPacketType{Type: OpenAIPacketStr}
}
type OpenAICmdInfoChatMessage struct {
MessageID int `json:"messageid"`
IsAssistantResponse bool `json:"isassistantresponse,omitempty"`
AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"`
UserQuery string `json:"userquery,omitempty"`
UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
}
type OpenAIPromptMessageType struct {
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
type OpenAICloudReqPacketType struct {
Type string `json:"type"`
ClientId string `json:"clientid"`
Prompt []OpenAIPromptMessageType `json:"prompt"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
}
type OpenAIOptsType struct {
Model string `json:"model"`
APIToken string `json:"apitoken"`
BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"`
}
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{
Type: OpenAICloudReqStr,
}
}
type OpenAiStreamRequest struct {
ClientId string `json:"clientid,omitempty"`
Opts *OpenAIOptsType `json:"opts"`
Prompt []OpenAIPromptMessageType `json:"prompt"`
}
func GetWSEndpoint() string {
return PCloudWSEndpoint
if !wavebase.IsDevMode() {
return PCloudWSEndpoint
} else {
endpoint := os.Getenv(PCloudWSEndpointVarName)
if endpoint == "" {
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
}
return endpoint
}
}
const DefaultMaxTokens = 1000
const DefaultModel = "gpt-3.5-turbo"
const DefaultStreamChanSize = 10
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
const CloudWebsocketConnectTimeout = 1 * time.Minute
func convertUsage(resp openaiapi.ChatCompletionResponse) *OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
}
return &OpenAIUsageType{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
func ConvertPrompt(prompt []OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
rtn = append(rtn, msg)
}
return rtn
}
func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType])
go func() {
log.Printf("start: %v", request)
defer close(rtn)
if request.Opts == nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
return
}
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
defer dialCancelFn()
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, GetWSEndpoint(), nil)
defer func() {
err = conn.Close()
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("unable to close openai channel: %v", err)}
}
}()
if err == context.DeadlineExceeded {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err)}
return
} else if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket connect error: %v", err)}
return
}
reqPk := MakeOpenAICloudReqPacket()
reqPk.ClientId = request.ClientId
reqPk.Prompt = request.Prompt
reqPk.MaxTokens = request.Opts.MaxTokens
reqPk.MaxChoices = request.Opts.MaxChoices
configMessageBuf, err := json.Marshal(reqPk)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, packet marshal error: %v", err)}
return
}
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket write config error: %v", err)}
return
}
for {
log.Printf("loop")
_, socketMessage, err := conn.ReadMessage()
if err == io.EOF {
break
}
if err != nil {
log.Printf("err received: %v", err)
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
break
}
var streamResp *OpenAIPacketType
err = json.Unmarshal(socketMessage, &streamResp)
log.Printf("ai resp: %v", streamResp)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)}
break
}
if streamResp.Error == PacketEOFStr {
// got eof packet from socket
break
} else if streamResp.Error != "" {
// use error from server directly
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("%v", streamResp.Error)}
break
}
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *streamResp}
}
}()
return rtn
}
func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType])
go func() {
log.Printf("start2: %v", request)
defer close(rtn)
if request.Opts == nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
return
}
if request.Opts.Model == "" {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")}
return
}
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no api token")}
return
}
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model,
Messages: ConvertPrompt(request.Prompt),
MaxTokens: request.Opts.MaxTokens,
Stream: true,
}
if request.Opts.MaxChoices > 1 {
req.N = request.Opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)}
return
}
sentHeader := false
for {
log.Printf("loop2")
streamResp, err := apiResp.Recv()
if err == io.EOF {
break
}
if err != nil {
log.Printf("err received2: %v", err)
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
break
}
if streamResp.Model != "" && !sentHeader {
pk := MakeOpenAIPacket()
pk.Model = streamResp.Model
pk.Created = streamResp.Created
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk}
sentHeader = true
}
for _, choice := range streamResp.Choices {
pk := MakeOpenAIPacket()
pk.Index = choice.Index
pk.Text = choice.Delta.Content
pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk}
}
}
}()
return rtn
}
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*OpenAIPacketType {
var rtn []*OpenAIPacketType
headerPk := MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
headerPk.Usage = convertUsage(resp)
rtn = append(rtn, headerPk)
for _, choice := range resp.Choices {
choicePk := MakeOpenAIPacket()
choicePk.Index = choice.Index
choicePk.Text = choice.Message.Content
choicePk.FinishReason = string(choice.FinishReason)
rtn = append(rtn, choicePk)
}
return rtn
}
func CreateErrorPacket(errStr string) *OpenAIPacketType {
errPk := MakeOpenAIPacket()
errPk.FinishReason = "error"
errPk.Error = errStr
return errPk
}
func CreateTextPacket(text string) *OpenAIPacketType {
pk := MakeOpenAIPacket()
pk.Text = text
return pk
}