From 93b9432f9907c850477da0c9e792f845de71c7d5 Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Fri, 27 Dec 2024 14:52:38 -0500 Subject: [PATCH] Add support for Google AI models (#1612) --- docs/docs/ai-presets.mdx | 17 ++++++ frontend/app/view/waveai/waveai.tsx | 2 +- go.mod | 26 +++++++++ go.sum | 53 +++++++++++++++++ pkg/waveai/cloudbackend.go | 5 ++ pkg/waveai/googlebackend.go | 91 +++++++++++++++++++++++++++++ pkg/waveai/openaibackend.go | 2 + pkg/waveai/waveai.go | 66 +++++++++------------ 8 files changed, 224 insertions(+), 38 deletions(-) create mode 100644 pkg/waveai/googlebackend.go diff --git a/docs/docs/ai-presets.mdx b/docs/docs/ai-presets.mdx index e822553b7..de117a86c 100644 --- a/docs/docs/ai-presets.mdx +++ b/docs/docs/ai-presets.mdx @@ -127,6 +127,23 @@ To use Perplexity's models: } ``` +### Google (Gemini) + +To use Google's Gemini models from [Google AI Studio](https://aistudio.google.com): + +```json +{ + "ai@gemini-2.0": { + "display:name": "Gemini 2.0", + "display:order": 5, + "ai:*": true, + "ai:apitype": "google", + "ai:model": "gemini-2.0-flash-exp", + "ai:apitoken": "" + } +} +``` + ## Multiple Presets Example You can define multiple presets in your `ai.json` file: diff --git a/frontend/app/view/waveai/waveai.tsx b/frontend/app/view/waveai/waveai.tsx index 6dbd3a03b..d20a2c920 100644 --- a/frontend/app/view/waveai/waveai.tsx +++ b/frontend/app/view/waveai/waveai.tsx @@ -347,7 +347,7 @@ export class WaveAiModel implements ViewModel { // Add a typing indicator globalStore.set(this.addMessageAtom, typingMessage); const history = await this.fetchAiData(); - const beMsg: OpenAiStreamRequest = { + const beMsg: WaveAIStreamRequest = { clientid: clientId, opts: opts, prompt: [...history, newPrompt], diff --git a/go.mod b/go.mod index 5ffce7682..de542ae69 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/fsnotify/fsnotify v1.8.0 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-migrate/migrate/v4 v4.18.1 + github.com/google/generative-ai-go v0.19.0 github.com/google/uuid v1.6.0 github.com/gorilla/handlers v1.5.2 github.com/gorilla/mux v1.8.1 @@ -27,12 +28,24 @@ require ( golang.org/x/crypto v0.31.0 golang.org/x/sys v0.28.0 golang.org/x/term v0.27.0 + google.golang.org/api v0.214.0 ) require ( + cloud.google.com/go v0.115.0 // indirect + cloud.google.com/go/ai v0.8.0 // indirect + cloud.google.com/go/auth v0.13.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect + cloud.google.com/go/longrunning v0.5.7 // indirect github.com/ebitengine/purego v0.8.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect + github.com/google/s2a-go v0.1.8 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect + github.com/googleapis/gax-go/v2 v2.14.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -44,8 +57,21 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect + go.opentelemetry.io/otel v1.29.0 // indirect + go.opentelemetry.io/otel/metric v1.29.0 // indirect + go.opentelemetry.io/otel/trace v1.29.0 // indirect go.uber.org/atomic v1.7.0 // indirect golang.org/x/net v0.33.0 // indirect + golang.org/x/oauth2 v0.24.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/text v0.21.0 // indirect + golang.org/x/time v0.8.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 // indirect + google.golang.org/grpc v1.67.1 // indirect + google.golang.org/protobuf v1.35.2 // indirect ) replace github.com/kevinburke/ssh_config => github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34 diff --git a/go.sum b/go.sum index 8490d1292..7e137603e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,15 @@ +cloud.google.com/go v0.115.0 h1:CnFSK6Xo3lDYRoBKEcAtia6VSC837/ZkJuRduSFnr14= +cloud.google.com/go v0.115.0/go.mod h1:8jIM5vVgoAEoiVxQ/O4BFTfHqulPZgs/ufEzMcFMdWU= +cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w= +cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE= +cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs= +cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q= +cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU= +cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= +cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg= @@ -14,6 +26,11 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= @@ -22,11 +39,19 @@ github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17w github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y= github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= +github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= +github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= +github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= +github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= +github.com/googleapis/gax-go/v2 v2.14.0 h1:f+jMrjBPl+DL9nI4IQzLUxMq7XrAqFYB7hBPqMNIe8o= +github.com/googleapis/gax-go/v2 v2.14.0/go.mod h1:lhBCnjdLrWRaPvLWhmc8IS24m9mr07qSYnHncrgo+zk= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= @@ -94,12 +119,26 @@ github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34 h1:I8VZVTZE github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= +golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -111,7 +150,21 @@ golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.214.0 h1:h2Gkq07OYi6kusGOaT/9rnNljuXmqPnaig7WGPmKbwA= +google.golang.org/api v0.214.0/go.mod h1:bYPpLG8AyeMWwDU6NXoB00xC0DFkikVvd5MfwoxjLqE= +google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 h1:M0KvPgPmDZHPlbRbaNU1APr28TvwvvdUPlSv7PUvy8g= +google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 h1:8ZmaLZE4XWrtU3MyClkYqqtl6Oegr3235h7jxsDyqCY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= +google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= +google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= +google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= +google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/waveai/cloudbackend.go b/pkg/waveai/cloudbackend.go index 710730590..fc5245f79 100644 --- a/pkg/waveai/cloudbackend.go +++ b/pkg/waveai/cloudbackend.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log" + "time" "github.com/gorilla/websocket" "github.com/wavetermdev/waveterm/pkg/panichandler" @@ -20,6 +21,10 @@ type WaveAICloudBackend struct{} var _ AIBackend = WaveAICloudBackend{} +const CloudWebsocketConnectTimeout = 1 * time.Minute +const OpenAICloudReqStr = "openai-cloudreq" +const PacketEOFStr = "EOF" + type WaveAICloudReqPacketType struct { Type string `json:"type"` ClientId string `json:"clientid"` diff --git a/pkg/waveai/googlebackend.go b/pkg/waveai/googlebackend.go new file mode 100644 index 000000000..7f19237c1 --- /dev/null +++ b/pkg/waveai/googlebackend.go @@ -0,0 +1,91 @@ +package waveai + +import ( + "context" + "fmt" + "log" + + "github.com/google/generative-ai-go/genai" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "google.golang.org/api/iterator" + "google.golang.org/api/option" +) + +type GoogleBackend struct{} + +var _ AIBackend = GoogleBackend{} + +func (GoogleBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { + client, err := genai.NewClient(ctx, option.WithAPIKey(request.Opts.APIToken)) + if err != nil { + log.Fatalf("failed to create client: %v", err) + return nil + } + + model := client.GenerativeModel(request.Opts.Model) + if model == nil { + log.Fatal("model not found") + client.Close() + return nil + } + + cs := model.StartChat() + cs.History = extractHistory(request.Prompt) + iter := cs.SendMessageStream(ctx, extractPrompt(request.Prompt)) + + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]) + + go func() { + defer client.Close() + defer close(rtn) + for { + // Check for context cancellation + select { + case <-ctx.Done(): + rtn <- makeAIError(fmt.Errorf("request cancelled: %v", ctx.Err())) + break + default: + } + + resp, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + rtn <- makeAIError(fmt.Errorf("Google API error: %v", err)) + break + } + + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: wshrpc.WaveAIPacketType{Text: convertCandidatesToText(resp.Candidates)}} + } + }() + return rtn +} + +func extractHistory(history []wshrpc.WaveAIPromptMessageType) []*genai.Content { + var rtn []*genai.Content + for _, h := range history[:len(history)-1] { + if h.Role == "user" || h.Role == "model" { + rtn = append(rtn, &genai.Content{ + Role: h.Role, + Parts: []genai.Part{genai.Text(h.Content)}, + }) + } + } + return rtn +} + +func extractPrompt(prompt []wshrpc.WaveAIPromptMessageType) genai.Part { + p := prompt[len(prompt)-1] + return genai.Text(p.Content) +} + +func convertCandidatesToText(candidates []*genai.Candidate) string { + var rtn string + for _, c := range candidates { + for _, p := range c.Content.Parts { + rtn += fmt.Sprintf("%v", p) + } + } + return rtn +} diff --git a/pkg/waveai/openaibackend.go b/pkg/waveai/openaibackend.go index a334fb523..a33bf9f47 100644 --- a/pkg/waveai/openaibackend.go +++ b/pkg/waveai/openaibackend.go @@ -20,6 +20,8 @@ type OpenAIBackend struct{} var _ AIBackend = OpenAIBackend{} +const DefaultAzureAPIVersion = "2023-05-15" + // copied from go-openai/config.go func defaultAzureMapperFn(model string) string { return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go index 4ffa56f96..48847d3af 100644 --- a/pkg/waveai/waveai.go +++ b/pkg/waveai/waveai.go @@ -6,18 +6,16 @@ package waveai import ( "context" "log" - "time" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) -const OpenAIPacketStr = "openai" -const OpenAICloudReqStr = "openai-cloudreq" -const PacketEOFStr = "EOF" -const DefaultAzureAPIVersion = "2023-05-15" +const WaveAIPacketstr = "waveai" const ApiType_Anthropic = "anthropic" const ApiType_Perplexity = "perplexity" +const APIType_Google = "google" +const APIType_OpenAI = "openai" type WaveAICmdInfoPacketOutputType struct { Model string `json:"model,omitempty"` @@ -28,7 +26,7 @@ type WaveAICmdInfoPacketOutputType struct { } func MakeWaveAIPacket() *wshrpc.WaveAIPacketType { - return &wshrpc.WaveAIPacketType{Type: OpenAIPacketStr} + return &wshrpc.WaveAIPacketType{Type: WaveAIPacketstr} } type WaveAICmdInfoChatMessage struct { @@ -46,13 +44,6 @@ type AIBackend interface { ) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] } -const DefaultMaxTokens = 2048 -const DefaultModel = "gpt-4o-mini" -const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/" -const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT" - -const CloudWebsocketConnectTimeout = 1 * time.Minute - func IsCloudAIRequest(opts *wshrpc.WaveAIOptsType) bool { if opts == nil { return true @@ -66,31 +57,32 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { func RunAICommand(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{NumAIReqs: 1}, "RunAICommand") + + endpoint := request.Opts.BaseURL + if endpoint == "" { + endpoint = "default" + } + var backend AIBackend if request.Opts.APIType == ApiType_Anthropic { - endpoint := request.Opts.BaseURL - if endpoint == "" { - endpoint = "default" - } - log.Printf("sending ai chat message to anthropic endpoint %q using model %s\n", endpoint, request.Opts.Model) - anthropicBackend := AnthropicBackend{} - return anthropicBackend.StreamCompletion(ctx, request) - } - if request.Opts.APIType == ApiType_Perplexity { - endpoint := request.Opts.BaseURL - if endpoint == "" { - endpoint = "default" - } - log.Printf("sending ai chat message to perplexity endpoint %q using model %s\n", endpoint, request.Opts.Model) - perplexityBackend := PerplexityBackend{} - return perplexityBackend.StreamCompletion(ctx, request) - } - if IsCloudAIRequest(request.Opts) { - log.Print("sending ai chat message to default waveterm cloud endpoint\n") - cloudBackend := WaveAICloudBackend{} - return cloudBackend.StreamCompletion(ctx, request) + backend = AnthropicBackend{} + } else if request.Opts.APIType == ApiType_Perplexity { + backend = PerplexityBackend{} + } else if request.Opts.APIType == APIType_Google { + backend = GoogleBackend{} + } else if IsCloudAIRequest(request.Opts) { + endpoint = "waveterm cloud" + request.Opts.APIType = APIType_OpenAI + request.Opts.Model = "default" + backend = WaveAICloudBackend{} } else { - log.Printf("sending ai chat message to user-configured endpoint %s using model %s\n", request.Opts.BaseURL, request.Opts.Model) - openAIBackend := OpenAIBackend{} - return openAIBackend.StreamCompletion(ctx, request) + request.Opts.APIType = APIType_OpenAI + backend = OpenAIBackend{} } + if backend == nil { + log.Printf("no backend found for %s\n", request.Opts.APIType) + return nil + } + + log.Printf("sending ai chat message to %s endpoint %q using model %s\n", request.Opts.APIType, endpoint, request.Opts.Model) + return backend.StreamCompletion(ctx, request) }