fix openai packets to be PacketType. better compatibility with backend (#154)

also pass clientid to openapi cloud service
This commit is contained in:
Mike Sawka 2023-12-15 22:43:59 -08:00 committed by GitHub
parent 4ccd62f12a
commit 633cb8fbd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 29 deletions

View File

@ -61,6 +61,7 @@ const (
FileDataPacketStr = "filedata" FileDataPacketStr = "filedata"
OpenAIPacketStr = "openai" // other OpenAIPacketStr = "openai" // other
OpenAICloudReqStr = "openai-cloudreq"
) )
const PacketSenderQueueSize = 20 const PacketSenderQueueSize = 20
@ -842,6 +843,30 @@ func MakeWriteFileDonePacket(reqId string) *WriteFileDonePacketType {
} }
} }
type OpenAIPromptMessageType struct {
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
type OpenAICloudReqPacketType struct {
Type string `json:"type"`
ClientId string `json:"clientid"`
Prompt []OpenAIPromptMessageType `json:"prompt"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
}
func (*OpenAICloudReqPacketType) GetType() string {
return OpenAICloudReqStr
}
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{
Type: OpenAICloudReqStr,
}
}
type PacketType interface { type PacketType interface {
GetType() string GetType() string
} }

View File

@ -1462,7 +1462,7 @@ func writePacketToPty(ctx context.Context, cmd *sstore.CmdType, pk packet.Packet
return nil return nil
} }
func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) { func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) {
var outputPos int64 var outputPos int64
var hadError bool var hadError bool
startTime := time.Now() startTime := time.Now()
@ -1514,7 +1514,7 @@ func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt
return return
} }
func doOpenAIStreamCompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) { func doOpenAIStreamCompletion(cmd *sstore.CmdType, clientId string, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) {
var outputPos int64 var outputPos int64
var hadError bool var hadError bool
startTime := time.Now() startTime := time.Now()
@ -1552,7 +1552,7 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType,
var err error var err error
if opts.APIToken == "" { if opts.APIToken == "" {
var conn *websocket.Conn var conn *websocket.Conn
ch, conn, err = openai.RunCloudCompletionStream(ctx, opts, prompt) ch, conn, err = openai.RunCloudCompletionStream(ctx, clientId, opts, prompt)
if conn != nil { if conn != nil {
defer conn.Close() defer conn.Close()
} }
@ -1641,9 +1641,9 @@ func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstor
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot add new line: %v", err) return nil, fmt.Errorf("cannot add new line: %v", err)
} }
prompt := []sstore.OpenAIPromptMessageType{{Role: sstore.OpenAIRoleUser, Content: promptStr}} prompt := []packet.OpenAIPromptMessageType{{Role: sstore.OpenAIRoleUser, Content: promptStr}}
if resolveBool(pk.Kwargs["stream"], true) { if resolveBool(pk.Kwargs["stream"], true) {
go doOpenAIStreamCompletion(cmd, opts, prompt) go doOpenAIStreamCompletion(cmd, clientData.ClientId, opts, prompt)
} else { } else {
go doOpenAICompletion(cmd, opts, prompt) go doOpenAICompletion(cmd, opts, prompt)
} }

View File

@ -37,7 +37,7 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType
} }
} }
func ConvertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage { func ConvertPrompt(prompt []packet.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt { for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name} msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
@ -46,7 +46,7 @@ func ConvertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatComp
return rtn return rtn
} }
func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) { func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) {
if opts == nil { if opts == nil {
return nil, fmt.Errorf("no openai opts found") return nil, fmt.Errorf("no openai opts found")
} }
@ -79,21 +79,22 @@ func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []ss
return marshalResponse(apiResp), nil return marshalResponse(apiResp), nil
} }
func RunCloudCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, *websocket.Conn, error) { func RunCloudCompletionStream(ctx context.Context, clientId string, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, *websocket.Conn, error) {
if opts == nil { if opts == nil {
return nil, nil, fmt.Errorf("no openai opts found") return nil, nil, fmt.Errorf("no openai opts found")
} }
websocketContext, _ := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout) websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
defer dialCancelFn()
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, pcloud.GetWSEndpoint(), nil) conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, pcloud.GetWSEndpoint(), nil)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("OpenAI request, websocket connect error: %v", err) return nil, nil, fmt.Errorf("OpenAI request, websocket connect error: %v", err)
} }
cloudCompletionRequestConfig := sstore.OpenAICloudCompletionRequest{ reqPk := packet.MakeOpenAICloudReqPacket()
Prompt: prompt, reqPk.ClientId = clientId
MaxTokens: opts.MaxTokens, reqPk.Prompt = prompt
MaxChoices: opts.MaxChoices, reqPk.MaxTokens = opts.MaxTokens
} reqPk.MaxChoices = opts.MaxChoices
configMessageBuf, err := json.Marshal(cloudCompletionRequestConfig) configMessageBuf, err := json.Marshal(reqPk)
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf) err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("OpenAI request, websocket write config error: %v", err) return nil, nil, fmt.Errorf("OpenAI request, websocket write config error: %v", err)
@ -134,7 +135,7 @@ func RunCloudCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType,
return rtn, conn, err return rtn, conn, err
} }
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) { func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
if opts == nil { if opts == nil {
return nil, fmt.Errorf("no openai opts found") return nil, fmt.Errorf("no openai opts found")
} }

View File

@ -778,18 +778,6 @@ type OpenAIResponse struct {
Choices []OpenAIChoiceType `json:"choices,omitempty"` Choices []OpenAIChoiceType `json:"choices,omitempty"`
} }
type OpenAIPromptMessageType struct {
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
type OpenAICloudCompletionRequest struct {
Prompt []OpenAIPromptMessageType `json:"prompt"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
}
type PlaybookType struct { type PlaybookType struct {
PlaybookId string `json:"playbookid"` PlaybookId string `json:"playbookid"`
PlaybookName string `json:"playbookname"` PlaybookName string `json:"playbookname"`