azure ai support (#997)

This commit is contained in:
Mike Sawka 2024-10-09 13:36:02 -07:00 committed by GitHub
parent b81ab63ddc
commit f33028af1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 100 additions and 68 deletions

View File

@ -194,12 +194,15 @@ export class WaveAiModel implements ViewModel {
}; };
addMessage(newMessage); addMessage(newMessage);
// send message to backend and get response // send message to backend and get response
const settings = globalStore.get(atoms.settingsAtom); const settings = globalStore.get(atoms.settingsAtom) ?? {};
const opts: OpenAIOptsType = { const opts: OpenAIOptsType = {
model: settings["ai:model"], model: settings["ai:model"],
apitype: settings["ai:apitype"],
orgid: settings["ai:orgid"],
apitoken: settings["ai:apitoken"], apitoken: settings["ai:apitoken"],
apiversion: settings["ai:apiversion"],
maxtokens: settings["ai:maxtokens"], maxtokens: settings["ai:maxtokens"],
timeout: settings["ai:timeoutms"] / 1000, timeoutms: settings["ai:timeoutms"] ?? 60000,
baseurl: settings["ai:baseurl"], baseurl: settings["ai:baseurl"],
}; };
const newPrompt: OpenAIPromptMessageType = { const newPrompt: OpenAIPromptMessageType = {
@ -222,7 +225,7 @@ export class WaveAiModel implements ViewModel {
opts: opts, opts: opts,
prompt: [...history, newPrompt], prompt: [...history, newPrompt],
}; };
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: 60000 }); const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
let fullMsg = ""; let fullMsg = "";
for await (const msg of aiGen) { for await (const msg of aiGen) {
fullMsg += msg.text ?? ""; fullMsg += msg.text ?? "";

View File

@ -329,11 +329,14 @@ declare global {
// wshrpc.OpenAIOptsType // wshrpc.OpenAIOptsType
type OpenAIOptsType = { type OpenAIOptsType = {
model: string; model: string;
apitype?: string;
apitoken: string; apitoken: string;
orgid?: string;
apiversion?: string;
baseurl?: string; baseurl?: string;
maxtokens?: number; maxtokens?: number;
maxchoices?: number; maxchoices?: number;
timeout?: number; timeoutms?: number;
}; };
// wshrpc.OpenAIPacketType // wshrpc.OpenAIPacketType
@ -413,10 +416,13 @@ declare global {
// wconfig.SettingsType // wconfig.SettingsType
type SettingsType = { type SettingsType = {
"ai:*"?: boolean; "ai:*"?: boolean;
"ai:apitype"?: string;
"ai:baseurl"?: string; "ai:baseurl"?: string;
"ai:apitoken"?: string; "ai:apitoken"?: string;
"ai:name"?: string; "ai:name"?: string;
"ai:model"?: string; "ai:model"?: string;
"ai:orgid"?: string;
"ai:apiversion"?: string;
"ai:maxtokens"?: number; "ai:maxtokens"?: number;
"ai:timeoutms"?: number; "ai:timeoutms"?: number;
"term:*"?: boolean; "term:*"?: boolean;

View File

@ -10,8 +10,11 @@ import (
"io" "io"
"log" "log"
"os" "os"
"regexp"
"strings"
"time" "time"
"github.com/sashabaranov/go-openai"
openaiapi "github.com/sashabaranov/go-openai" openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wcloud" "github.com/wavetermdev/waveterm/pkg/wcloud"
@ -23,6 +26,7 @@ import (
const OpenAIPacketStr = "openai" const OpenAIPacketStr = "openai"
const OpenAICloudReqStr = "openai-cloudreq" const OpenAICloudReqStr = "openai-cloudreq"
const PacketEOFStr = "EOF" const PacketEOFStr = "EOF"
const DefaultAzureAPIVersion = "2023-05-15"
type OpenAICmdInfoPacketOutputType struct { type OpenAICmdInfoPacketOutputType struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
@ -52,16 +56,6 @@ type OpenAICloudReqPacketType struct {
MaxChoices int `json:"maxchoices,omitempty"` MaxChoices int `json:"maxchoices,omitempty"`
} }
type OpenAIOptsType struct {
Model string `json:"model"`
APIToken string `json:"apitoken"`
BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"`
BlockId string `json:"blockid"`
}
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{ return &OpenAICloudReqPacketType{
Type: OpenAICloudReqStr, Type: OpenAICloudReqStr,
@ -69,26 +63,31 @@ func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
} }
func GetWSEndpoint() string { func GetWSEndpoint() string {
return PCloudWSEndpoint
if !wavebase.IsDevMode() { if !wavebase.IsDevMode() {
return PCloudWSEndpoint return WCloudWSEndpoint
} else { } else {
endpoint := os.Getenv(PCloudWSEndpointVarName) endpoint := os.Getenv(WCloudWSEndpointVarName)
if endpoint == "" { if endpoint == "" {
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid") panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
} }
return endpoint return endpoint
} }
} }
const DefaultMaxTokens = 1000 const DefaultMaxTokens = 2048
const DefaultModel = "gpt-4o-mini" const DefaultModel = "gpt-4o-mini"
const DefaultStreamChanSize = 10 const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/" const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
const CloudWebsocketConnectTimeout = 1 * time.Minute const CloudWebsocketConnectTimeout = 1 * time.Minute
func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
if opts == nil {
return true
}
return opts.BaseURL == "" && opts.APIToken == ""
}
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType { func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 { if resp.Usage.TotalTokens == 0 {
return nil return nil
@ -113,6 +112,15 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err} return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
} }
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
if IsCloudAIRequest(request.Opts) {
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
return RunCloudCompletionStream(ctx, request)
}
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
return RunLocalCompletionStream(ctx, request)
}
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
wsEndpoint := wcloud.GetWSEndpoint() wsEndpoint := wcloud.GetWSEndpoint()
@ -187,6 +195,36 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
return rtn return rtn
} }
// copied from go-openai/config.go
func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) {
clientConfig.APIType = openai.APITypeOpenAI
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) {
clientConfig.APIType = openai.APITypeAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) {
clientConfig.APIType = openai.APITypeAzureAD
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) {
clientConfig.APIType = openai.APITypeCloudflareAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else {
return fmt.Errorf("invalid api type %q", opts.APIType)
}
}
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
go func() { go func() {
@ -207,6 +245,17 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
if request.Opts.BaseURL != "" { if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL clientConfig.BaseURL = request.Opts.BaseURL
} }
err := setApiType(request.Opts, &clientConfig)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
return
}
if request.Opts.OrgID != "" {
clientConfig.OrgID = request.Opts.OrgID
}
if request.Opts.APIVersion != "" {
clientConfig.APIVersion = request.Opts.APIVersion
}
client := openaiapi.NewClientWithConfig(clientConfig) client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{ req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model, Model: request.Opts.Model,
@ -251,33 +300,3 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
}() }()
return rtn return rtn
} }
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*wshrpc.OpenAIPacketType {
var rtn []*wshrpc.OpenAIPacketType
headerPk := MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
headerPk.Usage = convertUsage(resp)
rtn = append(rtn, headerPk)
for _, choice := range resp.Choices {
choicePk := MakeOpenAIPacket()
choicePk.Index = choice.Index
choicePk.Text = choice.Message.Content
choicePk.FinishReason = string(choice.FinishReason)
rtn = append(rtn, choicePk)
}
return rtn
}
func CreateErrorPacket(errStr string) *wshrpc.OpenAIPacketType {
errPk := MakeOpenAIPacket()
errPk.FinishReason = "error"
errPk.Error = errStr
return errPk
}
func CreateTextPacket(text string) *wshrpc.OpenAIPacketType {
pk := MakeOpenAIPacket()
pk.Text = text
return pk
}

View File

@ -1,7 +1,7 @@
{ {
"ai:model": "gpt-4o-mini", "ai:model": "gpt-4o-mini",
"ai:maxtokens": 1000, "ai:maxtokens": 2048,
"ai:timeoutms": 10000, "ai:timeoutms": 60000,
"autoupdate:enabled": true, "autoupdate:enabled": true,
"autoupdate:installonquit": true, "autoupdate:installonquit": true,
"autoupdate:intervalms": 3600000, "autoupdate:intervalms": 3600000,

View File

@ -7,10 +7,13 @@ package wconfig
const ( const (
ConfigKey_AiClear = "ai:*" ConfigKey_AiClear = "ai:*"
ConfigKey_AiApiType = "ai:apitype"
ConfigKey_AiBaseURL = "ai:baseurl" ConfigKey_AiBaseURL = "ai:baseurl"
ConfigKey_AiApiToken = "ai:apitoken" ConfigKey_AiApiToken = "ai:apitoken"
ConfigKey_AiName = "ai:name" ConfigKey_AiName = "ai:name"
ConfigKey_AiModel = "ai:model" ConfigKey_AiModel = "ai:model"
ConfigKey_AiOrgID = "ai:orgid"
ConfigKey_AIApiVersion = "ai:apiversion"
ConfigKey_AiMaxTokens = "ai:maxtokens" ConfigKey_AiMaxTokens = "ai:maxtokens"
ConfigKey_AiTimeoutMs = "ai:timeoutms" ConfigKey_AiTimeoutMs = "ai:timeoutms"

View File

@ -41,10 +41,13 @@ func (m MetaSettingsType) MarshalJSON() ([]byte, error) {
type SettingsType struct { type SettingsType struct {
AiClear bool `json:"ai:*,omitempty"` AiClear bool `json:"ai:*,omitempty"`
AiApiType string `json:"ai:apitype,omitempty"`
AiBaseURL string `json:"ai:baseurl,omitempty"` AiBaseURL string `json:"ai:baseurl,omitempty"`
AiApiToken string `json:"ai:apitoken,omitempty"` AiApiToken string `json:"ai:apitoken,omitempty"`
AiName string `json:"ai:name,omitempty"` AiName string `json:"ai:name,omitempty"`
AiModel string `json:"ai:model,omitempty"` AiModel string `json:"ai:model,omitempty"`
AiOrgID string `json:"ai:orgid,omitempty"`
AIApiVersion string `json:"ai:apiversion,omitempty"`
AiMaxTokens float64 `json:"ai:maxtokens,omitempty"` AiMaxTokens float64 `json:"ai:maxtokens,omitempty"`
AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"` AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"`

View File

@ -275,11 +275,14 @@ type OpenAIPromptMessageType struct {
type OpenAIOptsType struct { type OpenAIOptsType struct {
Model string `json:"model"` Model string `json:"model"`
APIType string `json:"apitype,omitempty"`
APIToken string `json:"apitoken"` APIToken string `json:"apitoken"`
OrgID string `json:"orgid,omitempty"`
APIVersion string `json:"apiversion,omitempty"`
BaseURL string `json:"baseurl,omitempty"` BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"` MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"` MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"` TimeoutMs int `json:"timeoutms,omitempty"`
} }
type OpenAIPacketType struct { type OpenAIPacketType struct {

View File

@ -74,12 +74,7 @@ func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrEr
} }
func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" { return waveai.RunAICommand(ctx, request)
log.Print("sending ai chat message to waveterm default endpoint with openai\n")
return waveai.RunCloudCompletionStream(ctx, request)
}
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
return waveai.RunLocalCompletionStream(ctx, request)
} }
func MakePlotData(ctx context.Context, blockId string) error { func MakePlotData(ctx context.Context, blockId string) error {