mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
azure ai support (#997)
This commit is contained in:
parent
b81ab63ddc
commit
f33028af1d
@ -194,12 +194,15 @@ export class WaveAiModel implements ViewModel {
|
|||||||
};
|
};
|
||||||
addMessage(newMessage);
|
addMessage(newMessage);
|
||||||
// send message to backend and get response
|
// send message to backend and get response
|
||||||
const settings = globalStore.get(atoms.settingsAtom);
|
const settings = globalStore.get(atoms.settingsAtom) ?? {};
|
||||||
const opts: OpenAIOptsType = {
|
const opts: OpenAIOptsType = {
|
||||||
model: settings["ai:model"],
|
model: settings["ai:model"],
|
||||||
|
apitype: settings["ai:apitype"],
|
||||||
|
orgid: settings["ai:orgid"],
|
||||||
apitoken: settings["ai:apitoken"],
|
apitoken: settings["ai:apitoken"],
|
||||||
|
apiversion: settings["ai:apiversion"],
|
||||||
maxtokens: settings["ai:maxtokens"],
|
maxtokens: settings["ai:maxtokens"],
|
||||||
timeout: settings["ai:timeoutms"] / 1000,
|
timeoutms: settings["ai:timeoutms"] ?? 60000,
|
||||||
baseurl: settings["ai:baseurl"],
|
baseurl: settings["ai:baseurl"],
|
||||||
};
|
};
|
||||||
const newPrompt: OpenAIPromptMessageType = {
|
const newPrompt: OpenAIPromptMessageType = {
|
||||||
@ -222,7 +225,7 @@ export class WaveAiModel implements ViewModel {
|
|||||||
opts: opts,
|
opts: opts,
|
||||||
prompt: [...history, newPrompt],
|
prompt: [...history, newPrompt],
|
||||||
};
|
};
|
||||||
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: 60000 });
|
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
|
||||||
let fullMsg = "";
|
let fullMsg = "";
|
||||||
for await (const msg of aiGen) {
|
for await (const msg of aiGen) {
|
||||||
fullMsg += msg.text ?? "";
|
fullMsg += msg.text ?? "";
|
||||||
|
8
frontend/types/gotypes.d.ts
vendored
8
frontend/types/gotypes.d.ts
vendored
@ -329,11 +329,14 @@ declare global {
|
|||||||
// wshrpc.OpenAIOptsType
|
// wshrpc.OpenAIOptsType
|
||||||
type OpenAIOptsType = {
|
type OpenAIOptsType = {
|
||||||
model: string;
|
model: string;
|
||||||
|
apitype?: string;
|
||||||
apitoken: string;
|
apitoken: string;
|
||||||
|
orgid?: string;
|
||||||
|
apiversion?: string;
|
||||||
baseurl?: string;
|
baseurl?: string;
|
||||||
maxtokens?: number;
|
maxtokens?: number;
|
||||||
maxchoices?: number;
|
maxchoices?: number;
|
||||||
timeout?: number;
|
timeoutms?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
// wshrpc.OpenAIPacketType
|
// wshrpc.OpenAIPacketType
|
||||||
@ -413,10 +416,13 @@ declare global {
|
|||||||
// wconfig.SettingsType
|
// wconfig.SettingsType
|
||||||
type SettingsType = {
|
type SettingsType = {
|
||||||
"ai:*"?: boolean;
|
"ai:*"?: boolean;
|
||||||
|
"ai:apitype"?: string;
|
||||||
"ai:baseurl"?: string;
|
"ai:baseurl"?: string;
|
||||||
"ai:apitoken"?: string;
|
"ai:apitoken"?: string;
|
||||||
"ai:name"?: string;
|
"ai:name"?: string;
|
||||||
"ai:model"?: string;
|
"ai:model"?: string;
|
||||||
|
"ai:orgid"?: string;
|
||||||
|
"ai:apiversion"?: string;
|
||||||
"ai:maxtokens"?: number;
|
"ai:maxtokens"?: number;
|
||||||
"ai:timeoutms"?: number;
|
"ai:timeoutms"?: number;
|
||||||
"term:*"?: boolean;
|
"term:*"?: boolean;
|
||||||
|
@ -10,8 +10,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
openaiapi "github.com/sashabaranov/go-openai"
|
openaiapi "github.com/sashabaranov/go-openai"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
||||||
@ -23,6 +26,7 @@ import (
|
|||||||
const OpenAIPacketStr = "openai"
|
const OpenAIPacketStr = "openai"
|
||||||
const OpenAICloudReqStr = "openai-cloudreq"
|
const OpenAICloudReqStr = "openai-cloudreq"
|
||||||
const PacketEOFStr = "EOF"
|
const PacketEOFStr = "EOF"
|
||||||
|
const DefaultAzureAPIVersion = "2023-05-15"
|
||||||
|
|
||||||
type OpenAICmdInfoPacketOutputType struct {
|
type OpenAICmdInfoPacketOutputType struct {
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
@ -52,16 +56,6 @@ type OpenAICloudReqPacketType struct {
|
|||||||
MaxChoices int `json:"maxchoices,omitempty"`
|
MaxChoices int `json:"maxchoices,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIOptsType struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
APIToken string `json:"apitoken"`
|
|
||||||
BaseURL string `json:"baseurl,omitempty"`
|
|
||||||
MaxTokens int `json:"maxtokens,omitempty"`
|
|
||||||
MaxChoices int `json:"maxchoices,omitempty"`
|
|
||||||
Timeout int `json:"timeout,omitempty"`
|
|
||||||
BlockId string `json:"blockid"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
||||||
return &OpenAICloudReqPacketType{
|
return &OpenAICloudReqPacketType{
|
||||||
Type: OpenAICloudReqStr,
|
Type: OpenAICloudReqStr,
|
||||||
@ -69,26 +63,31 @@ func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetWSEndpoint() string {
|
func GetWSEndpoint() string {
|
||||||
return PCloudWSEndpoint
|
|
||||||
if !wavebase.IsDevMode() {
|
if !wavebase.IsDevMode() {
|
||||||
return PCloudWSEndpoint
|
return WCloudWSEndpoint
|
||||||
} else {
|
} else {
|
||||||
endpoint := os.Getenv(PCloudWSEndpointVarName)
|
endpoint := os.Getenv(WCloudWSEndpointVarName)
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
|
panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
|
||||||
}
|
}
|
||||||
return endpoint
|
return endpoint
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const DefaultMaxTokens = 1000
|
const DefaultMaxTokens = 2048
|
||||||
const DefaultModel = "gpt-4o-mini"
|
const DefaultModel = "gpt-4o-mini"
|
||||||
const DefaultStreamChanSize = 10
|
const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
||||||
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"
|
||||||
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
|
|
||||||
|
|
||||||
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
||||||
|
|
||||||
|
func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
|
||||||
|
if opts == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return opts.BaseURL == "" && opts.APIToken == ""
|
||||||
|
}
|
||||||
|
|
||||||
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
|
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
|
||||||
if resp.Usage.TotalTokens == 0 {
|
if resp.Usage.TotalTokens == 0 {
|
||||||
return nil
|
return nil
|
||||||
@ -113,6 +112,15 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
|||||||
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||||
|
if IsCloudAIRequest(request.Opts) {
|
||||||
|
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
|
||||||
|
return RunCloudCompletionStream(ctx, request)
|
||||||
|
}
|
||||||
|
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
|
||||||
|
return RunLocalCompletionStream(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||||
wsEndpoint := wcloud.GetWSEndpoint()
|
wsEndpoint := wcloud.GetWSEndpoint()
|
||||||
@ -187,6 +195,36 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
|
|||||||
return rtn
|
return rtn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copied from go-openai/config.go
|
||||||
|
func defaultAzureMapperFn(model string) string {
|
||||||
|
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error {
|
||||||
|
ourApiType := strings.ToLower(opts.APIType)
|
||||||
|
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) {
|
||||||
|
clientConfig.APIType = openai.APITypeOpenAI
|
||||||
|
return nil
|
||||||
|
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) {
|
||||||
|
clientConfig.APIType = openai.APITypeAzure
|
||||||
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||||
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||||
|
return nil
|
||||||
|
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) {
|
||||||
|
clientConfig.APIType = openai.APITypeAzureAD
|
||||||
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||||
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||||
|
return nil
|
||||||
|
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) {
|
||||||
|
clientConfig.APIType = openai.APITypeCloudflareAzure
|
||||||
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||||
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("invalid api type %q", opts.APIType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||||
go func() {
|
go func() {
|
||||||
@ -207,6 +245,17 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
|
|||||||
if request.Opts.BaseURL != "" {
|
if request.Opts.BaseURL != "" {
|
||||||
clientConfig.BaseURL = request.Opts.BaseURL
|
clientConfig.BaseURL = request.Opts.BaseURL
|
||||||
}
|
}
|
||||||
|
err := setApiType(request.Opts, &clientConfig)
|
||||||
|
if err != nil {
|
||||||
|
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.Opts.OrgID != "" {
|
||||||
|
clientConfig.OrgID = request.Opts.OrgID
|
||||||
|
}
|
||||||
|
if request.Opts.APIVersion != "" {
|
||||||
|
clientConfig.APIVersion = request.Opts.APIVersion
|
||||||
|
}
|
||||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||||
req := openaiapi.ChatCompletionRequest{
|
req := openaiapi.ChatCompletionRequest{
|
||||||
Model: request.Opts.Model,
|
Model: request.Opts.Model,
|
||||||
@ -251,33 +300,3 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
|
|||||||
}()
|
}()
|
||||||
return rtn
|
return rtn
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*wshrpc.OpenAIPacketType {
|
|
||||||
var rtn []*wshrpc.OpenAIPacketType
|
|
||||||
headerPk := MakeOpenAIPacket()
|
|
||||||
headerPk.Model = resp.Model
|
|
||||||
headerPk.Created = resp.Created
|
|
||||||
headerPk.Usage = convertUsage(resp)
|
|
||||||
rtn = append(rtn, headerPk)
|
|
||||||
for _, choice := range resp.Choices {
|
|
||||||
choicePk := MakeOpenAIPacket()
|
|
||||||
choicePk.Index = choice.Index
|
|
||||||
choicePk.Text = choice.Message.Content
|
|
||||||
choicePk.FinishReason = string(choice.FinishReason)
|
|
||||||
rtn = append(rtn, choicePk)
|
|
||||||
}
|
|
||||||
return rtn
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateErrorPacket(errStr string) *wshrpc.OpenAIPacketType {
|
|
||||||
errPk := MakeOpenAIPacket()
|
|
||||||
errPk.FinishReason = "error"
|
|
||||||
errPk.Error = errStr
|
|
||||||
return errPk
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateTextPacket(text string) *wshrpc.OpenAIPacketType {
|
|
||||||
pk := MakeOpenAIPacket()
|
|
||||||
pk.Text = text
|
|
||||||
return pk
|
|
||||||
}
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"ai:model": "gpt-4o-mini",
|
"ai:model": "gpt-4o-mini",
|
||||||
"ai:maxtokens": 1000,
|
"ai:maxtokens": 2048,
|
||||||
"ai:timeoutms": 10000,
|
"ai:timeoutms": 60000,
|
||||||
"autoupdate:enabled": true,
|
"autoupdate:enabled": true,
|
||||||
"autoupdate:installonquit": true,
|
"autoupdate:installonquit": true,
|
||||||
"autoupdate:intervalms": 3600000,
|
"autoupdate:intervalms": 3600000,
|
||||||
|
@ -7,10 +7,13 @@ package wconfig
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
ConfigKey_AiClear = "ai:*"
|
ConfigKey_AiClear = "ai:*"
|
||||||
|
ConfigKey_AiApiType = "ai:apitype"
|
||||||
ConfigKey_AiBaseURL = "ai:baseurl"
|
ConfigKey_AiBaseURL = "ai:baseurl"
|
||||||
ConfigKey_AiApiToken = "ai:apitoken"
|
ConfigKey_AiApiToken = "ai:apitoken"
|
||||||
ConfigKey_AiName = "ai:name"
|
ConfigKey_AiName = "ai:name"
|
||||||
ConfigKey_AiModel = "ai:model"
|
ConfigKey_AiModel = "ai:model"
|
||||||
|
ConfigKey_AiOrgID = "ai:orgid"
|
||||||
|
ConfigKey_AIApiVersion = "ai:apiversion"
|
||||||
ConfigKey_AiMaxTokens = "ai:maxtokens"
|
ConfigKey_AiMaxTokens = "ai:maxtokens"
|
||||||
ConfigKey_AiTimeoutMs = "ai:timeoutms"
|
ConfigKey_AiTimeoutMs = "ai:timeoutms"
|
||||||
|
|
||||||
|
@ -40,13 +40,16 @@ func (m MetaSettingsType) MarshalJSON() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SettingsType struct {
|
type SettingsType struct {
|
||||||
AiClear bool `json:"ai:*,omitempty"`
|
AiClear bool `json:"ai:*,omitempty"`
|
||||||
AiBaseURL string `json:"ai:baseurl,omitempty"`
|
AiApiType string `json:"ai:apitype,omitempty"`
|
||||||
AiApiToken string `json:"ai:apitoken,omitempty"`
|
AiBaseURL string `json:"ai:baseurl,omitempty"`
|
||||||
AiName string `json:"ai:name,omitempty"`
|
AiApiToken string `json:"ai:apitoken,omitempty"`
|
||||||
AiModel string `json:"ai:model,omitempty"`
|
AiName string `json:"ai:name,omitempty"`
|
||||||
AiMaxTokens float64 `json:"ai:maxtokens,omitempty"`
|
AiModel string `json:"ai:model,omitempty"`
|
||||||
AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"`
|
AiOrgID string `json:"ai:orgid,omitempty"`
|
||||||
|
AIApiVersion string `json:"ai:apiversion,omitempty"`
|
||||||
|
AiMaxTokens float64 `json:"ai:maxtokens,omitempty"`
|
||||||
|
AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"`
|
||||||
|
|
||||||
TermClear bool `json:"term:*,omitempty"`
|
TermClear bool `json:"term:*,omitempty"`
|
||||||
TermFontSize float64 `json:"term:fontsize,omitempty"`
|
TermFontSize float64 `json:"term:fontsize,omitempty"`
|
||||||
|
@ -275,11 +275,14 @@ type OpenAIPromptMessageType struct {
|
|||||||
|
|
||||||
type OpenAIOptsType struct {
|
type OpenAIOptsType struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
APIType string `json:"apitype,omitempty"`
|
||||||
APIToken string `json:"apitoken"`
|
APIToken string `json:"apitoken"`
|
||||||
|
OrgID string `json:"orgid,omitempty"`
|
||||||
|
APIVersion string `json:"apiversion,omitempty"`
|
||||||
BaseURL string `json:"baseurl,omitempty"`
|
BaseURL string `json:"baseurl,omitempty"`
|
||||||
MaxTokens int `json:"maxtokens,omitempty"`
|
MaxTokens int `json:"maxtokens,omitempty"`
|
||||||
MaxChoices int `json:"maxchoices,omitempty"`
|
MaxChoices int `json:"maxchoices,omitempty"`
|
||||||
Timeout int `json:"timeout,omitempty"`
|
TimeoutMs int `json:"timeoutms,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIPacketType struct {
|
type OpenAIPacketType struct {
|
||||||
|
@ -74,12 +74,7 @@ func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrEr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||||
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
return waveai.RunAICommand(ctx, request)
|
||||||
log.Print("sending ai chat message to waveterm default endpoint with openai\n")
|
|
||||||
return waveai.RunCloudCompletionStream(ctx, request)
|
|
||||||
}
|
|
||||||
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
|
|
||||||
return waveai.RunLocalCompletionStream(ctx, request)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakePlotData(ctx context.Context, blockId string) error {
|
func MakePlotData(ctx context.Context, blockId string) error {
|
||||||
|
Loading…
Reference in New Issue
Block a user