mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-02 18:39:05 +01:00
ai backend refactor + claude/anthropic API support (#1262)
This commit is contained in:
parent
38eeba5bd2
commit
de902ec2b7
@ -180,7 +180,15 @@ export class WaveAiModel implements ViewModel {
|
||||
const presetKey = get(this.presetKey);
|
||||
const presetName = presets[presetKey]?.["display:name"] ?? "";
|
||||
const isCloud = isBlank(aiOpts.apitoken) && isBlank(aiOpts.baseurl);
|
||||
if (isCloud) {
|
||||
if (aiOpts?.apitype == "anthropic") {
|
||||
const modelName = aiOpts.model;
|
||||
viewTextChildren.push({
|
||||
elemtype: "iconbutton",
|
||||
icon: "globe",
|
||||
title: "Using Remote Antropic API (" + modelName + ")",
|
||||
disabled: true,
|
||||
});
|
||||
} else if (isCloud) {
|
||||
viewTextChildren.push({
|
||||
elemtype: "iconbutton",
|
||||
icon: "cloud",
|
||||
|
300
pkg/waveai/anthropicbackend.go
Normal file
300
pkg/waveai/anthropicbackend.go
Normal file
@ -0,0 +1,300 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package waveai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
type AnthropicBackend struct{}
|
||||
|
||||
var _ AIBackend = AnthropicBackend{}
|
||||
|
||||
// Claude API request types
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
// Claude API response types for SSE events
|
||||
type anthropicContentBlock struct {
|
||||
Type string `json:"type"` // "text" or other content types
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type anthropicResponseMessage struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []anthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicStreamEventError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type anthropicStreamEventDelta struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type anthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message *anthropicResponseMessage `json:"message,omitempty"`
|
||||
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta *anthropicStreamEventDelta `json:"delta,omitempty"`
|
||||
Error *anthropicStreamEventError `json:"error,omitempty"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// SSE event represents a parsed Server-Sent Event
|
||||
type sseEvent struct {
|
||||
Event string // The event type field
|
||||
Data string // The data field
|
||||
}
|
||||
|
||||
// parseSSE reads and parses SSE format from a bufio.Reader
|
||||
func parseSSE(reader *bufio.Reader) (*sseEvent, error) {
|
||||
var event sseEvent
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
// Empty line signals end of event
|
||||
if event.Event != "" || event.Data != "" {
|
||||
return &event, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
event.Event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Convert panic to error and send it
|
||||
log.Printf("panic: %v\n", r)
|
||||
debug.PrintStack()
|
||||
err, ok := r.(error)
|
||||
if !ok {
|
||||
err = fmt.Errorf("anthropic backend panic: %v", r)
|
||||
}
|
||||
rtn <- makeAIError(err)
|
||||
}
|
||||
// Always close the channel
|
||||
close(rtn)
|
||||
}()
|
||||
|
||||
if request.Opts == nil {
|
||||
rtn <- makeAIError(errors.New("no anthropic opts found"))
|
||||
return
|
||||
}
|
||||
|
||||
model := request.Opts.Model
|
||||
if model == "" {
|
||||
model = "claude-3-sonnet-20240229" // default model
|
||||
}
|
||||
|
||||
// Convert messages format
|
||||
var messages []anthropicMessage
|
||||
var systemPrompt string
|
||||
|
||||
for _, msg := range request.Prompt {
|
||||
if msg.Role == "system" {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += msg.Content
|
||||
continue
|
||||
}
|
||||
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "assistant"
|
||||
}
|
||||
|
||||
messages = append(messages, anthropicMessage{
|
||||
Role: role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
anthropicReq := anthropicRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
System: systemPrompt,
|
||||
Stream: true,
|
||||
MaxTokens: request.Opts.MaxTokens,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(anthropicReq)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("failed to marshal anthropic request: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.anthropic.com/v1/messages", strings.NewReader(string(reqBody)))
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("failed to create anthropic request: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("x-api-key", request.Opts.APIToken)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("failed to send anthropic request: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
rtn <- makeAIError(fmt.Errorf("Anthropic API error: %s - %s", resp.Status, string(bodyBytes)))
|
||||
return
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rtn <- makeAIError(fmt.Errorf("request cancelled: %v", ctx.Err()))
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
sse, err := parseSSE(reader)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("error reading SSE stream: %v", err))
|
||||
break
|
||||
}
|
||||
|
||||
if sse.Event == "ping" {
|
||||
continue // Ignore ping events
|
||||
}
|
||||
|
||||
var event anthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(sse.Data), &event); err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("error unmarshaling event data: %v", err))
|
||||
break
|
||||
}
|
||||
|
||||
if event.Error != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("Anthropic API error: %s - %s", event.Error.Type, event.Error.Message))
|
||||
break
|
||||
}
|
||||
|
||||
switch sse.Event {
|
||||
case "message_start":
|
||||
if event.Message != nil {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Model = event.Message.Model
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
if event.ContentBlock != nil && event.ContentBlock.Text != "" {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Text = event.ContentBlock.Text
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
if event.Delta != nil && event.Delta.Text != "" {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Text = event.Delta.Text
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Note: According to the docs, this just signals the end of a content block
|
||||
// We might want to use this for tracking block boundaries, but for now
|
||||
// we don't need to send anything special to match OpenAI's format
|
||||
|
||||
case "message_delta":
|
||||
// Update message metadata, usage stats
|
||||
if event.Usage != nil {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Usage = &wshrpc.OpenAIUsageType{
|
||||
PromptTokens: event.Usage.InputTokens,
|
||||
CompletionTokens: event.Usage.OutputTokens,
|
||||
TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens,
|
||||
}
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
if event.Message != nil {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.FinishReason = event.Message.StopReason
|
||||
if event.Message.Usage != nil {
|
||||
pk.Usage = &wshrpc.OpenAIUsageType{
|
||||
PromptTokens: event.Message.Usage.InputTokens,
|
||||
CompletionTokens: event.Message.Usage.OutputTokens,
|
||||
TotalTokens: event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
|
||||
default:
|
||||
rtn <- makeAIError(fmt.Errorf("unknown Anthropic event type: %s", sse.Event))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return rtn
|
||||
}
|
115
pkg/waveai/cloudbackend.go
Normal file
115
pkg/waveai/cloudbackend.go
Normal file
@ -0,0 +1,115 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package waveai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
type WaveAICloudBackend struct{}
|
||||
|
||||
var _ AIBackend = WaveAICloudBackend{}
|
||||
|
||||
type OpenAICloudReqPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ClientId string `json:"clientid"`
|
||||
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"`
|
||||
MaxTokens int `json:"maxtokens,omitempty"`
|
||||
MaxChoices int `json:"maxchoices,omitempty"`
|
||||
}
|
||||
|
||||
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
||||
return &OpenAICloudReqPacketType{
|
||||
Type: OpenAICloudReqStr,
|
||||
}
|
||||
}
|
||||
|
||||
func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||
wsEndpoint := wcloud.GetWSEndpoint()
|
||||
go func() {
|
||||
defer close(rtn)
|
||||
if wsEndpoint == "" {
|
||||
rtn <- makeAIError(fmt.Errorf("no cloud ws endpoint found"))
|
||||
return
|
||||
}
|
||||
if request.Opts == nil {
|
||||
rtn <- makeAIError(fmt.Errorf("no openai opts found"))
|
||||
return
|
||||
}
|
||||
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
|
||||
defer dialCancelFn()
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, wsEndpoint, nil)
|
||||
if err == context.DeadlineExceeded {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err))
|
||||
return
|
||||
} else if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket connect error: %v", err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
|
||||
}
|
||||
}()
|
||||
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
|
||||
for _, promptMsg := range request.Prompt {
|
||||
if promptMsg.Role == "error" {
|
||||
continue
|
||||
}
|
||||
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
|
||||
}
|
||||
reqPk := MakeOpenAICloudReqPacket()
|
||||
reqPk.ClientId = request.ClientId
|
||||
reqPk.Prompt = sendablePromptMsgs
|
||||
reqPk.MaxTokens = request.Opts.MaxTokens
|
||||
reqPk.MaxChoices = request.Opts.MaxChoices
|
||||
configMessageBuf, err := json.Marshal(reqPk)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, packet marshal error: %v", err))
|
||||
return
|
||||
}
|
||||
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket write config error: %v", err))
|
||||
return
|
||||
}
|
||||
for {
|
||||
_, socketMessage, err := conn.ReadMessage()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("err received: %v", err)
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
|
||||
break
|
||||
}
|
||||
var streamResp *wshrpc.OpenAIPacketType
|
||||
err = json.Unmarshal(socketMessage, &streamResp)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
|
||||
break
|
||||
}
|
||||
if streamResp.Error == PacketEOFStr {
|
||||
// got eof packet from socket
|
||||
break
|
||||
} else if streamResp.Error != "" {
|
||||
// use error from server directly
|
||||
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
|
||||
break
|
||||
}
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp}
|
||||
}
|
||||
}()
|
||||
return rtn
|
||||
}
|
146
pkg/waveai/openaibackend.go
Normal file
146
pkg/waveai/openaibackend.go
Normal file
@ -0,0 +1,146 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package waveai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
openaiapi "github.com/sashabaranov/go-openai"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
type OpenAIBackend struct{}
|
||||
|
||||
var _ AIBackend = OpenAIBackend{}
|
||||
|
||||
// copied from go-openai/config.go
|
||||
func defaultAzureMapperFn(model string) string {
|
||||
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
||||
}
|
||||
|
||||
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
|
||||
ourApiType := strings.ToLower(opts.APIType)
|
||||
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
|
||||
clientConfig.APIType = openaiapi.APITypeOpenAI
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
|
||||
clientConfig.APIType = openaiapi.APITypeAzure
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
|
||||
clientConfig.APIType = openaiapi.APITypeAzureAD
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
|
||||
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("invalid api type %q", opts.APIType)
|
||||
}
|
||||
}
|
||||
|
||||
func convertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
||||
var rtn []openaiapi.ChatCompletionMessage
|
||||
for _, p := range prompt {
|
||||
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
||||
rtn = append(rtn, msg)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
|
||||
if resp.Usage.TotalTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
return &wshrpc.OpenAIUsageType{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||
go func() {
|
||||
defer close(rtn)
|
||||
if request.Opts == nil {
|
||||
rtn <- makeAIError(errors.New("no openai opts found"))
|
||||
return
|
||||
}
|
||||
if request.Opts.Model == "" {
|
||||
rtn <- makeAIError(errors.New("no openai model specified"))
|
||||
return
|
||||
}
|
||||
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
||||
rtn <- makeAIError(errors.New("no api token"))
|
||||
return
|
||||
}
|
||||
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
||||
if request.Opts.BaseURL != "" {
|
||||
clientConfig.BaseURL = request.Opts.BaseURL
|
||||
}
|
||||
err := setApiType(request.Opts, &clientConfig)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(err)
|
||||
return
|
||||
}
|
||||
if request.Opts.OrgID != "" {
|
||||
clientConfig.OrgID = request.Opts.OrgID
|
||||
}
|
||||
if request.Opts.APIVersion != "" {
|
||||
clientConfig.APIVersion = request.Opts.APIVersion
|
||||
}
|
||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: request.Opts.Model,
|
||||
Messages: convertPrompt(request.Prompt),
|
||||
MaxTokens: request.Opts.MaxTokens,
|
||||
MaxCompletionTokens: request.Opts.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
if request.Opts.MaxChoices > 1 {
|
||||
req.N = request.Opts.MaxChoices
|
||||
}
|
||||
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
|
||||
return
|
||||
}
|
||||
sentHeader := false
|
||||
for {
|
||||
streamResp, err := apiResp.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, error reading message: %v", err))
|
||||
break
|
||||
}
|
||||
if streamResp.Model != "" && !sentHeader {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Model = streamResp.Model
|
||||
pk.Created = streamResp.Created
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
sentHeader = true
|
||||
}
|
||||
for _, choice := range streamResp.Choices {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Index = choice.Index
|
||||
pk.Text = choice.Delta.Content
|
||||
pk.FinishReason = string(choice.FinishReason)
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return rtn
|
||||
}
|
@ -5,21 +5,10 @@ package waveai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
openaiapi "github.com/sashabaranov/go-openai"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const OpenAIPacketStr = "openai"
|
||||
@ -47,30 +36,11 @@ type OpenAICmdInfoChatMessage struct {
|
||||
UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAICloudReqPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ClientId string `json:"clientid"`
|
||||
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"`
|
||||
MaxTokens int `json:"maxtokens,omitempty"`
|
||||
MaxChoices int `json:"maxchoices,omitempty"`
|
||||
}
|
||||
|
||||
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
||||
return &OpenAICloudReqPacketType{
|
||||
Type: OpenAICloudReqStr,
|
||||
}
|
||||
}
|
||||
|
||||
func GetWSEndpoint() string {
|
||||
if !wavebase.IsDevMode() {
|
||||
return WCloudWSEndpoint
|
||||
} else {
|
||||
endpoint := os.Getenv(WCloudWSEndpointVarName)
|
||||
if endpoint == "" {
|
||||
panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
type AIBackend interface {
|
||||
StreamCompletion(
|
||||
ctx context.Context,
|
||||
request wshrpc.OpenAiStreamRequest,
|
||||
) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]
|
||||
}
|
||||
|
||||
const DefaultMaxTokens = 2048
|
||||
@ -87,223 +57,27 @@ func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
|
||||
return opts.BaseURL == "" && opts.APIToken == ""
|
||||
}
|
||||
|
||||
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
|
||||
if resp.Usage.TotalTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
return &wshrpc.OpenAIUsageType{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
||||
var rtn []openaiapi.ChatCompletionMessage
|
||||
for _, p := range prompt {
|
||||
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
||||
rtn = append(rtn, msg)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
||||
}
|
||||
|
||||
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
if request.Opts.APIType == "anthropic" {
|
||||
endpoint := request.Opts.BaseURL
|
||||
if endpoint == "" {
|
||||
endpoint = "default"
|
||||
}
|
||||
log.Printf("sending ai chat message to anthropic endpoint %q using model %s\n", endpoint, request.Opts.Model)
|
||||
anthropicBackend := AnthropicBackend{}
|
||||
return anthropicBackend.StreamCompletion(ctx, request)
|
||||
}
|
||||
if IsCloudAIRequest(request.Opts) {
|
||||
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
|
||||
return RunCloudCompletionStream(ctx, request)
|
||||
}
|
||||
log.Printf("sending ai chat message to user-configured endpoint %s using model %s\n", request.Opts.BaseURL, request.Opts.Model)
|
||||
return RunLocalCompletionStream(ctx, request)
|
||||
}
|
||||
|
||||
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||
wsEndpoint := wcloud.GetWSEndpoint()
|
||||
go func() {
|
||||
defer close(rtn)
|
||||
if wsEndpoint == "" {
|
||||
rtn <- makeAIError(fmt.Errorf("no cloud ws endpoint found"))
|
||||
return
|
||||
}
|
||||
if request.Opts == nil {
|
||||
rtn <- makeAIError(fmt.Errorf("no openai opts found"))
|
||||
return
|
||||
}
|
||||
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
|
||||
defer dialCancelFn()
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, wsEndpoint, nil)
|
||||
if err == context.DeadlineExceeded {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err))
|
||||
return
|
||||
} else if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket connect error: %v", err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
|
||||
}
|
||||
}()
|
||||
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
|
||||
for _, promptMsg := range request.Prompt {
|
||||
if promptMsg.Role == "error" {
|
||||
continue
|
||||
}
|
||||
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
|
||||
}
|
||||
reqPk := MakeOpenAICloudReqPacket()
|
||||
reqPk.ClientId = request.ClientId
|
||||
reqPk.Prompt = sendablePromptMsgs
|
||||
reqPk.MaxTokens = request.Opts.MaxTokens
|
||||
reqPk.MaxChoices = request.Opts.MaxChoices
|
||||
configMessageBuf, err := json.Marshal(reqPk)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, packet marshal error: %v", err))
|
||||
return
|
||||
}
|
||||
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket write config error: %v", err))
|
||||
return
|
||||
}
|
||||
for {
|
||||
_, socketMessage, err := conn.ReadMessage()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("err received: %v", err)
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
|
||||
break
|
||||
}
|
||||
var streamResp *wshrpc.OpenAIPacketType
|
||||
err = json.Unmarshal(socketMessage, &streamResp)
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
|
||||
break
|
||||
}
|
||||
if streamResp.Error == PacketEOFStr {
|
||||
// got eof packet from socket
|
||||
break
|
||||
} else if streamResp.Error != "" {
|
||||
// use error from server directly
|
||||
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
|
||||
break
|
||||
}
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp}
|
||||
}
|
||||
}()
|
||||
return rtn
|
||||
}
|
||||
|
||||
// copied from go-openai/config.go
|
||||
func defaultAzureMapperFn(model string) string {
|
||||
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
||||
}
|
||||
|
||||
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
|
||||
ourApiType := strings.ToLower(opts.APIType)
|
||||
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
|
||||
clientConfig.APIType = openaiapi.APITypeOpenAI
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
|
||||
clientConfig.APIType = openaiapi.APITypeAzure
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
|
||||
clientConfig.APIType = openaiapi.APITypeAzureAD
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
|
||||
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
cloudBackend := WaveAICloudBackend{}
|
||||
return cloudBackend.StreamCompletion(ctx, request)
|
||||
} else {
|
||||
return fmt.Errorf("invalid api type %q", opts.APIType)
|
||||
log.Printf("sending ai chat message to user-configured endpoint %s using model %s\n", request.Opts.BaseURL, request.Opts.Model)
|
||||
openAIBackend := OpenAIBackend{}
|
||||
return openAIBackend.StreamCompletion(ctx, request)
|
||||
}
|
||||
}
|
||||
|
||||
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||
go func() {
|
||||
defer close(rtn)
|
||||
if request.Opts == nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
|
||||
return
|
||||
}
|
||||
if request.Opts.Model == "" {
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")}
|
||||
return
|
||||
}
|
||||
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no api token")}
|
||||
return
|
||||
}
|
||||
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
||||
if request.Opts.BaseURL != "" {
|
||||
clientConfig.BaseURL = request.Opts.BaseURL
|
||||
}
|
||||
err := setApiType(request.Opts, &clientConfig)
|
||||
if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
||||
return
|
||||
}
|
||||
if request.Opts.OrgID != "" {
|
||||
clientConfig.OrgID = request.Opts.OrgID
|
||||
}
|
||||
if request.Opts.APIVersion != "" {
|
||||
clientConfig.APIVersion = request.Opts.APIVersion
|
||||
}
|
||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: request.Opts.Model,
|
||||
Messages: ConvertPrompt(request.Prompt),
|
||||
MaxTokens: request.Opts.MaxTokens,
|
||||
MaxCompletionTokens: request.Opts.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
if request.Opts.MaxChoices > 1 {
|
||||
req.N = request.Opts.MaxChoices
|
||||
}
|
||||
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)}
|
||||
return
|
||||
}
|
||||
sentHeader := false
|
||||
for {
|
||||
streamResp, err := apiResp.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("err received2: %v", err)
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
|
||||
break
|
||||
}
|
||||
if streamResp.Model != "" && !sentHeader {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Model = streamResp.Model
|
||||
pk.Created = streamResp.Created
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
sentHeader = true
|
||||
}
|
||||
for _, choice := range streamResp.Choices {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Index = choice.Index
|
||||
pk.Text = choice.Delta.Content
|
||||
pk.FinishReason = string(choice.FinishReason)
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return rtn
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user