From f33028af1d0d02359ae52f982fb7917ca32c4e37 Mon Sep 17 00:00:00 2001 From: Mike Sawka Date: Wed, 9 Oct 2024 13:36:02 -0700 Subject: [PATCH] azure ai support (#997) --- frontend/app/view/waveai/waveai.tsx | 9 +- frontend/types/gotypes.d.ts | 8 +- pkg/waveai/waveai.go | 115 ++++++++++++++---------- pkg/wconfig/defaultconfig/settings.json | 4 +- pkg/wconfig/metaconsts.go | 3 + pkg/wconfig/settingsconfig.go | 17 ++-- pkg/wshrpc/wshrpctypes.go | 5 +- pkg/wshrpc/wshserver/wshserver.go | 7 +- 8 files changed, 100 insertions(+), 68 deletions(-) diff --git a/frontend/app/view/waveai/waveai.tsx b/frontend/app/view/waveai/waveai.tsx index 401995486..c29458bc2 100644 --- a/frontend/app/view/waveai/waveai.tsx +++ b/frontend/app/view/waveai/waveai.tsx @@ -194,12 +194,15 @@ export class WaveAiModel implements ViewModel { }; addMessage(newMessage); // send message to backend and get response - const settings = globalStore.get(atoms.settingsAtom); + const settings = globalStore.get(atoms.settingsAtom) ?? {}; const opts: OpenAIOptsType = { model: settings["ai:model"], + apitype: settings["ai:apitype"], + orgid: settings["ai:orgid"], apitoken: settings["ai:apitoken"], + apiversion: settings["ai:apiversion"], maxtokens: settings["ai:maxtokens"], - timeout: settings["ai:timeoutms"] / 1000, + timeoutms: settings["ai:timeoutms"] ?? 60000, baseurl: settings["ai:baseurl"], }; const newPrompt: OpenAIPromptMessageType = { @@ -222,7 +225,7 @@ export class WaveAiModel implements ViewModel { opts: opts, prompt: [...history, newPrompt], }; - const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: 60000 }); + const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms }); let fullMsg = ""; for await (const msg of aiGen) { fullMsg += msg.text ?? ""; diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 5076253a3..afe523925 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -329,11 +329,14 @@ declare global { // wshrpc.OpenAIOptsType type OpenAIOptsType = { model: string; + apitype?: string; apitoken: string; + orgid?: string; + apiversion?: string; baseurl?: string; maxtokens?: number; maxchoices?: number; - timeout?: number; + timeoutms?: number; }; // wshrpc.OpenAIPacketType @@ -413,10 +416,13 @@ declare global { // wconfig.SettingsType type SettingsType = { "ai:*"?: boolean; + "ai:apitype"?: string; "ai:baseurl"?: string; "ai:apitoken"?: string; "ai:name"?: string; "ai:model"?: string; + "ai:orgid"?: string; + "ai:apiversion"?: string; "ai:maxtokens"?: number; "ai:timeoutms"?: number; "term:*"?: boolean; diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go index 6a9e39e3d..e9be4de4f 100644 --- a/pkg/waveai/waveai.go +++ b/pkg/waveai/waveai.go @@ -10,8 +10,11 @@ import ( "io" "log" "os" + "regexp" + "strings" "time" + "github.com/sashabaranov/go-openai" openaiapi "github.com/sashabaranov/go-openai" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wcloud" @@ -23,6 +26,7 @@ import ( const OpenAIPacketStr = "openai" const OpenAICloudReqStr = "openai-cloudreq" const PacketEOFStr = "EOF" +const DefaultAzureAPIVersion = "2023-05-15" type OpenAICmdInfoPacketOutputType struct { Model string `json:"model,omitempty"` @@ -52,16 +56,6 @@ type OpenAICloudReqPacketType struct { MaxChoices int `json:"maxchoices,omitempty"` } -type OpenAIOptsType struct { - Model string `json:"model"` - APIToken string `json:"apitoken"` - BaseURL string `json:"baseurl,omitempty"` - MaxTokens int `json:"maxtokens,omitempty"` - MaxChoices int `json:"maxchoices,omitempty"` - Timeout int `json:"timeout,omitempty"` - BlockId string `json:"blockid"` -} - func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { return &OpenAICloudReqPacketType{ Type: OpenAICloudReqStr, @@ -69,26 +63,31 @@ func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { } func GetWSEndpoint() string { - return PCloudWSEndpoint if !wavebase.IsDevMode() { - return PCloudWSEndpoint + return WCloudWSEndpoint } else { - endpoint := os.Getenv(PCloudWSEndpointVarName) + endpoint := os.Getenv(WCloudWSEndpointVarName) if endpoint == "" { - panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid") + panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid") } return endpoint } } -const DefaultMaxTokens = 1000 +const DefaultMaxTokens = 2048 const DefaultModel = "gpt-4o-mini" -const DefaultStreamChanSize = 10 -const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/" -const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT" +const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/" +const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT" const CloudWebsocketConnectTimeout = 1 * time.Minute +func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool { + if opts == nil { + return true + } + return opts.BaseURL == "" && opts.APIToken == "" +} + func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType { if resp.Usage.TotalTokens == 0 { return nil @@ -113,6 +112,15 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err} } +func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { + if IsCloudAIRequest(request.Opts) { + log.Print("sending ai chat message to default waveterm cloud endpoint\n") + return RunCloudCompletionStream(ctx, request) + } + log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL) + return RunLocalCompletionStream(ctx, request) +} + func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) wsEndpoint := wcloud.GetWSEndpoint() @@ -187,6 +195,36 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe return rtn } +// copied from go-openai/config.go +func defaultAzureMapperFn(model string) string { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") +} + +func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error { + ourApiType := strings.ToLower(opts.APIType) + if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) { + clientConfig.APIType = openai.APITypeOpenAI + return nil + } else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) { + clientConfig.APIType = openai.APITypeAzure + clientConfig.APIVersion = DefaultAzureAPIVersion + clientConfig.AzureModelMapperFunc = defaultAzureMapperFn + return nil + } else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) { + clientConfig.APIType = openai.APITypeAzureAD + clientConfig.APIVersion = DefaultAzureAPIVersion + clientConfig.AzureModelMapperFunc = defaultAzureMapperFn + return nil + } else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) { + clientConfig.APIType = openai.APITypeCloudflareAzure + clientConfig.APIVersion = DefaultAzureAPIVersion + clientConfig.AzureModelMapperFunc = defaultAzureMapperFn + return nil + } else { + return fmt.Errorf("invalid api type %q", opts.APIType) + } +} + func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) go func() { @@ -207,6 +245,17 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe if request.Opts.BaseURL != "" { clientConfig.BaseURL = request.Opts.BaseURL } + err := setApiType(request.Opts, &clientConfig) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err} + return + } + if request.Opts.OrgID != "" { + clientConfig.OrgID = request.Opts.OrgID + } + if request.Opts.APIVersion != "" { + clientConfig.APIVersion = request.Opts.APIVersion + } client := openaiapi.NewClientWithConfig(clientConfig) req := openaiapi.ChatCompletionRequest{ Model: request.Opts.Model, @@ -251,33 +300,3 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe }() return rtn } - -func marshalResponse(resp openaiapi.ChatCompletionResponse) []*wshrpc.OpenAIPacketType { - var rtn []*wshrpc.OpenAIPacketType - headerPk := MakeOpenAIPacket() - headerPk.Model = resp.Model - headerPk.Created = resp.Created - headerPk.Usage = convertUsage(resp) - rtn = append(rtn, headerPk) - for _, choice := range resp.Choices { - choicePk := MakeOpenAIPacket() - choicePk.Index = choice.Index - choicePk.Text = choice.Message.Content - choicePk.FinishReason = string(choice.FinishReason) - rtn = append(rtn, choicePk) - } - return rtn -} - -func CreateErrorPacket(errStr string) *wshrpc.OpenAIPacketType { - errPk := MakeOpenAIPacket() - errPk.FinishReason = "error" - errPk.Error = errStr - return errPk -} - -func CreateTextPacket(text string) *wshrpc.OpenAIPacketType { - pk := MakeOpenAIPacket() - pk.Text = text - return pk -} diff --git a/pkg/wconfig/defaultconfig/settings.json b/pkg/wconfig/defaultconfig/settings.json index c23d78731..3cc4ca932 100644 --- a/pkg/wconfig/defaultconfig/settings.json +++ b/pkg/wconfig/defaultconfig/settings.json @@ -1,7 +1,7 @@ { "ai:model": "gpt-4o-mini", - "ai:maxtokens": 1000, - "ai:timeoutms": 10000, + "ai:maxtokens": 2048, + "ai:timeoutms": 60000, "autoupdate:enabled": true, "autoupdate:installonquit": true, "autoupdate:intervalms": 3600000, diff --git a/pkg/wconfig/metaconsts.go b/pkg/wconfig/metaconsts.go index 02a65f526..efd577f12 100644 --- a/pkg/wconfig/metaconsts.go +++ b/pkg/wconfig/metaconsts.go @@ -7,10 +7,13 @@ package wconfig const ( ConfigKey_AiClear = "ai:*" + ConfigKey_AiApiType = "ai:apitype" ConfigKey_AiBaseURL = "ai:baseurl" ConfigKey_AiApiToken = "ai:apitoken" ConfigKey_AiName = "ai:name" ConfigKey_AiModel = "ai:model" + ConfigKey_AiOrgID = "ai:orgid" + ConfigKey_AIApiVersion = "ai:apiversion" ConfigKey_AiMaxTokens = "ai:maxtokens" ConfigKey_AiTimeoutMs = "ai:timeoutms" diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go index 3091425de..a856bd753 100644 --- a/pkg/wconfig/settingsconfig.go +++ b/pkg/wconfig/settingsconfig.go @@ -40,13 +40,16 @@ func (m MetaSettingsType) MarshalJSON() ([]byte, error) { } type SettingsType struct { - AiClear bool `json:"ai:*,omitempty"` - AiBaseURL string `json:"ai:baseurl,omitempty"` - AiApiToken string `json:"ai:apitoken,omitempty"` - AiName string `json:"ai:name,omitempty"` - AiModel string `json:"ai:model,omitempty"` - AiMaxTokens float64 `json:"ai:maxtokens,omitempty"` - AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"` + AiClear bool `json:"ai:*,omitempty"` + AiApiType string `json:"ai:apitype,omitempty"` + AiBaseURL string `json:"ai:baseurl,omitempty"` + AiApiToken string `json:"ai:apitoken,omitempty"` + AiName string `json:"ai:name,omitempty"` + AiModel string `json:"ai:model,omitempty"` + AiOrgID string `json:"ai:orgid,omitempty"` + AIApiVersion string `json:"ai:apiversion,omitempty"` + AiMaxTokens float64 `json:"ai:maxtokens,omitempty"` + AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"` TermClear bool `json:"term:*,omitempty"` TermFontSize float64 `json:"term:fontsize,omitempty"` diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 3c24e1d2b..fbd06a27d 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -275,11 +275,14 @@ type OpenAIPromptMessageType struct { type OpenAIOptsType struct { Model string `json:"model"` + APIType string `json:"apitype,omitempty"` APIToken string `json:"apitoken"` + OrgID string `json:"orgid,omitempty"` + APIVersion string `json:"apiversion,omitempty"` BaseURL string `json:"baseurl,omitempty"` MaxTokens int `json:"maxtokens,omitempty"` MaxChoices int `json:"maxchoices,omitempty"` - Timeout int `json:"timeout,omitempty"` + TimeoutMs int `json:"timeoutms,omitempty"` } type OpenAIPacketType struct { diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 6cac789cd..baed1bd3e 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -74,12 +74,7 @@ func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrEr } func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { - if request.Opts.BaseURL == "" && request.Opts.APIToken == "" { - log.Print("sending ai chat message to waveterm default endpoint with openai\n") - return waveai.RunCloudCompletionStream(ctx, request) - } - log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL) - return waveai.RunLocalCompletionStream(ctx, request) + return waveai.RunAICommand(ctx, request) } func MakePlotData(ctx context.Context, blockId string) error {