mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-01 18:28:59 +01:00
Add support for Google AI models (#1612)
This commit is contained in:
parent
8dbb4623f9
commit
93b9432f99
@ -127,6 +127,23 @@ To use Perplexity's models:
|
||||
}
|
||||
```
|
||||
|
||||
### Google (Gemini)
|
||||
|
||||
To use Google's Gemini models from [Google AI Studio](https://aistudio.google.com):
|
||||
|
||||
```json
|
||||
{
|
||||
"ai@gemini-2.0": {
|
||||
"display:name": "Gemini 2.0",
|
||||
"display:order": 5,
|
||||
"ai:*": true,
|
||||
"ai:apitype": "google",
|
||||
"ai:model": "gemini-2.0-flash-exp",
|
||||
"ai:apitoken": "<your Google AI API key>"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Multiple Presets Example
|
||||
|
||||
You can define multiple presets in your `ai.json` file:
|
||||
|
@ -347,7 +347,7 @@ export class WaveAiModel implements ViewModel {
|
||||
// Add a typing indicator
|
||||
globalStore.set(this.addMessageAtom, typingMessage);
|
||||
const history = await this.fetchAiData();
|
||||
const beMsg: OpenAiStreamRequest = {
|
||||
const beMsg: WaveAIStreamRequest = {
|
||||
clientid: clientId,
|
||||
opts: opts,
|
||||
prompt: [...history, newPrompt],
|
||||
|
26
go.mod
26
go.mod
@ -8,6 +8,7 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.8.0
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/golang-migrate/migrate/v4 v4.18.1
|
||||
github.com/google/generative-ai-go v0.19.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/handlers v1.5.2
|
||||
github.com/gorilla/mux v1.8.1
|
||||
@ -27,12 +28,24 @@ require (
|
||||
golang.org/x/crypto v0.31.0
|
||||
golang.org/x/sys v0.28.0
|
||||
golang.org/x/term v0.27.0
|
||||
google.golang.org/api v0.214.0
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.115.0 // indirect
|
||||
cloud.google.com/go/ai v0.8.0 // indirect
|
||||
cloud.google.com/go/auth v0.13.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
cloud.google.com/go/longrunning v0.5.7 // indirect
|
||||
github.com/ebitengine/purego v0.8.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
@ -44,8 +57,21 @@ require (
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
|
||||
go.opentelemetry.io/otel v1.29.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.29.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.29.0 // indirect
|
||||
go.uber.org/atomic v1.7.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/oauth2 v0.24.0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/time v0.8.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 // indirect
|
||||
google.golang.org/grpc v1.67.1 // indirect
|
||||
google.golang.org/protobuf v1.35.2 // indirect
|
||||
)
|
||||
|
||||
replace github.com/kevinburke/ssh_config => github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34
|
||||
|
53
go.sum
53
go.sum
@ -1,3 +1,15 @@
|
||||
cloud.google.com/go v0.115.0 h1:CnFSK6Xo3lDYRoBKEcAtia6VSC837/ZkJuRduSFnr14=
|
||||
cloud.google.com/go v0.115.0/go.mod h1:8jIM5vVgoAEoiVxQ/O4BFTfHqulPZgs/ufEzMcFMdWU=
|
||||
cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w=
|
||||
cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE=
|
||||
cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs=
|
||||
cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
|
||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||
cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU=
|
||||
cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng=
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg=
|
||||
@ -14,6 +26,11 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
@ -22,11 +39,19 @@ github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17w
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y=
|
||||
github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks=
|
||||
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
|
||||
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
|
||||
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
||||
github.com/googleapis/gax-go/v2 v2.14.0 h1:f+jMrjBPl+DL9nI4IQzLUxMq7XrAqFYB7hBPqMNIe8o=
|
||||
github.com/googleapis/gax-go/v2 v2.14.0/go.mod h1:lhBCnjdLrWRaPvLWhmc8IS24m9mr07qSYnHncrgo+zk=
|
||||
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
|
||||
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
@ -94,12 +119,26 @@ github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34 h1:I8VZVTZE
|
||||
github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
|
||||
go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
|
||||
go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
|
||||
go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
|
||||
go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
|
||||
go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
|
||||
go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
|
||||
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
|
||||
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@ -111,7 +150,21 @@ golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
|
||||
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.214.0 h1:h2Gkq07OYi6kusGOaT/9rnNljuXmqPnaig7WGPmKbwA=
|
||||
google.golang.org/api v0.214.0/go.mod h1:bYPpLG8AyeMWwDU6NXoB00xC0DFkikVvd5MfwoxjLqE=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 h1:M0KvPgPmDZHPlbRbaNU1APr28TvwvvdUPlSv7PUvy8g=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 h1:8ZmaLZE4XWrtU3MyClkYqqtl6Oegr3235h7jxsDyqCY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU=
|
||||
google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
|
||||
google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
|
||||
google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io=
|
||||
google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
@ -20,6 +21,10 @@ type WaveAICloudBackend struct{}
|
||||
|
||||
var _ AIBackend = WaveAICloudBackend{}
|
||||
|
||||
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
||||
const OpenAICloudReqStr = "openai-cloudreq"
|
||||
const PacketEOFStr = "EOF"
|
||||
|
||||
type WaveAICloudReqPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ClientId string `json:"clientid"`
|
||||
|
91
pkg/waveai/googlebackend.go
Normal file
91
pkg/waveai/googlebackend.go
Normal file
@ -0,0 +1,91 @@
|
||||
package waveai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
type GoogleBackend struct{}
|
||||
|
||||
var _ AIBackend = GoogleBackend{}
|
||||
|
||||
func (GoogleBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
|
||||
client, err := genai.NewClient(ctx, option.WithAPIKey(request.Opts.APIToken))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create client: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
model := client.GenerativeModel(request.Opts.Model)
|
||||
if model == nil {
|
||||
log.Fatal("model not found")
|
||||
client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
cs := model.StartChat()
|
||||
cs.History = extractHistory(request.Prompt)
|
||||
iter := cs.SendMessageStream(ctx, extractPrompt(request.Prompt))
|
||||
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
|
||||
|
||||
go func() {
|
||||
defer client.Close()
|
||||
defer close(rtn)
|
||||
for {
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rtn <- makeAIError(fmt.Errorf("request cancelled: %v", ctx.Err()))
|
||||
break
|
||||
default:
|
||||
}
|
||||
|
||||
resp, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
rtn <- makeAIError(fmt.Errorf("Google API error: %v", err))
|
||||
break
|
||||
}
|
||||
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: wshrpc.WaveAIPacketType{Text: convertCandidatesToText(resp.Candidates)}}
|
||||
}
|
||||
}()
|
||||
return rtn
|
||||
}
|
||||
|
||||
func extractHistory(history []wshrpc.WaveAIPromptMessageType) []*genai.Content {
|
||||
var rtn []*genai.Content
|
||||
for _, h := range history[:len(history)-1] {
|
||||
if h.Role == "user" || h.Role == "model" {
|
||||
rtn = append(rtn, &genai.Content{
|
||||
Role: h.Role,
|
||||
Parts: []genai.Part{genai.Text(h.Content)},
|
||||
})
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func extractPrompt(prompt []wshrpc.WaveAIPromptMessageType) genai.Part {
|
||||
p := prompt[len(prompt)-1]
|
||||
return genai.Text(p.Content)
|
||||
}
|
||||
|
||||
func convertCandidatesToText(candidates []*genai.Candidate) string {
|
||||
var rtn string
|
||||
for _, c := range candidates {
|
||||
for _, p := range c.Content.Parts {
|
||||
rtn += fmt.Sprintf("%v", p)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
@ -20,6 +20,8 @@ type OpenAIBackend struct{}
|
||||
|
||||
var _ AIBackend = OpenAIBackend{}
|
||||
|
||||
const DefaultAzureAPIVersion = "2023-05-15"
|
||||
|
||||
// copied from go-openai/config.go
|
||||
func defaultAzureMapperFn(model string) string {
|
||||
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
||||
|
@ -6,18 +6,16 @@ package waveai
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/telemetry"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
const OpenAIPacketStr = "openai"
|
||||
const OpenAICloudReqStr = "openai-cloudreq"
|
||||
const PacketEOFStr = "EOF"
|
||||
const DefaultAzureAPIVersion = "2023-05-15"
|
||||
const WaveAIPacketstr = "waveai"
|
||||
const ApiType_Anthropic = "anthropic"
|
||||
const ApiType_Perplexity = "perplexity"
|
||||
const APIType_Google = "google"
|
||||
const APIType_OpenAI = "openai"
|
||||
|
||||
type WaveAICmdInfoPacketOutputType struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
@ -28,7 +26,7 @@ type WaveAICmdInfoPacketOutputType struct {
|
||||
}
|
||||
|
||||
func MakeWaveAIPacket() *wshrpc.WaveAIPacketType {
|
||||
return &wshrpc.WaveAIPacketType{Type: OpenAIPacketStr}
|
||||
return &wshrpc.WaveAIPacketType{Type: WaveAIPacketstr}
|
||||
}
|
||||
|
||||
type WaveAICmdInfoChatMessage struct {
|
||||
@ -46,13 +44,6 @@ type AIBackend interface {
|
||||
) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]
|
||||
}
|
||||
|
||||
const DefaultMaxTokens = 2048
|
||||
const DefaultModel = "gpt-4o-mini"
|
||||
const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
||||
const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"
|
||||
|
||||
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
||||
|
||||
func IsCloudAIRequest(opts *wshrpc.WaveAIOptsType) bool {
|
||||
if opts == nil {
|
||||
return true
|
||||
@ -66,31 +57,32 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
|
||||
|
||||
func RunAICommand(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
|
||||
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{NumAIReqs: 1}, "RunAICommand")
|
||||
|
||||
endpoint := request.Opts.BaseURL
|
||||
if endpoint == "" {
|
||||
endpoint = "default"
|
||||
}
|
||||
var backend AIBackend
|
||||
if request.Opts.APIType == ApiType_Anthropic {
|
||||
endpoint := request.Opts.BaseURL
|
||||
if endpoint == "" {
|
||||
endpoint = "default"
|
||||
}
|
||||
log.Printf("sending ai chat message to anthropic endpoint %q using model %s\n", endpoint, request.Opts.Model)
|
||||
anthropicBackend := AnthropicBackend{}
|
||||
return anthropicBackend.StreamCompletion(ctx, request)
|
||||
}
|
||||
if request.Opts.APIType == ApiType_Perplexity {
|
||||
endpoint := request.Opts.BaseURL
|
||||
if endpoint == "" {
|
||||
endpoint = "default"
|
||||
}
|
||||
log.Printf("sending ai chat message to perplexity endpoint %q using model %s\n", endpoint, request.Opts.Model)
|
||||
perplexityBackend := PerplexityBackend{}
|
||||
return perplexityBackend.StreamCompletion(ctx, request)
|
||||
}
|
||||
if IsCloudAIRequest(request.Opts) {
|
||||
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
|
||||
cloudBackend := WaveAICloudBackend{}
|
||||
return cloudBackend.StreamCompletion(ctx, request)
|
||||
backend = AnthropicBackend{}
|
||||
} else if request.Opts.APIType == ApiType_Perplexity {
|
||||
backend = PerplexityBackend{}
|
||||
} else if request.Opts.APIType == APIType_Google {
|
||||
backend = GoogleBackend{}
|
||||
} else if IsCloudAIRequest(request.Opts) {
|
||||
endpoint = "waveterm cloud"
|
||||
request.Opts.APIType = APIType_OpenAI
|
||||
request.Opts.Model = "default"
|
||||
backend = WaveAICloudBackend{}
|
||||
} else {
|
||||
log.Printf("sending ai chat message to user-configured endpoint %s using model %s\n", request.Opts.BaseURL, request.Opts.Model)
|
||||
openAIBackend := OpenAIBackend{}
|
||||
return openAIBackend.StreamCompletion(ctx, request)
|
||||
request.Opts.APIType = APIType_OpenAI
|
||||
backend = OpenAIBackend{}
|
||||
}
|
||||
if backend == nil {
|
||||
log.Printf("no backend found for %s\n", request.Opts.APIType)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("sending ai chat message to %s endpoint %q using model %s\n", request.Opts.APIType, endpoint, request.Opts.Model)
|
||||
return backend.StreamCompletion(ctx, request)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user