diff --git a/go.mod b/go.mod index 8ae012fe5..e0c9d07f3 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/alexflint/go-filemutex v1.3.0 github.com/creack/pty v1.1.21 github.com/fsnotify/fsnotify v1.7.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-migrate/migrate/v4 v4.17.1 github.com/google/uuid v1.6.0 github.com/gorilla/handlers v1.5.2 diff --git a/go.sum b/go.sum index e4a82752c..fffe8b001 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4= github.com/golang-migrate/migrate/v4 v4.17.1/go.mod h1:m8hinFyWBn0SA4QKHuKh175Pm9wjmxj3S2Mia7dbXzM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index f2e779e1b..58e8e386d 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -27,6 +27,7 @@ const WaveHomeVarName = "WAVETERM_HOME" const WaveDevVarName = "WAVETERM_DEV" const WaveLockFile = "waveterm.lock" const DomainSocketBaseName = "wave.sock" +const JwtSecret = "waveterm" // TODO generate and store this var baseLock = &sync.Mutex{} var ensureDirCache = map[string]bool{} @@ -175,5 +176,4 @@ func AcquireWaveLock() (*filemutex.FileMutex, error) { err = m.TryLock() return m, err - } diff --git a/pkg/wshrpc/wshserver/wshserverutil.go b/pkg/wshrpc/wshserver/wshserverutil.go index ece0324ef..461c1a33b 100644 --- a/pkg/wshrpc/wshserver/wshserverutil.go +++ b/pkg/wshrpc/wshserver/wshserverutil.go @@ -10,7 +10,9 @@ import ( "net" "os" "reflect" + "time" + "github.com/golang-jwt/jwt/v5" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wshrpc" @@ -177,6 +179,82 @@ func RunWshRpcOverListener(listener net.Listener) { }() } +func MakeClientJWTToken(rpcCtx wshutil.RpcContext, sockName string) (string, error) { + claims := jwt.MapClaims{} + claims["iat"] = time.Now().Unix() + claims["iss"] = "waveterm" + claims["sock"] = sockName + claims["exp"] = time.Now().Add(time.Hour * 24 * 365).Unix() + if rpcCtx.BlockId != "" { + claims["blockid"] = rpcCtx.BlockId + } + if rpcCtx.TabId != "" { + claims["tabid"] = rpcCtx.TabId + } + if rpcCtx.WindowId != "" { + claims["windowid"] = rpcCtx.WindowId + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret)) + if err != nil { + return "", fmt.Errorf("error signing token: %w", err) + } + return tokenStr, nil +} + +func ValidateAndExtractRpcContextFromToken(tokenStr string) (wshutil.RpcContext, error) { + parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + token, err := parser.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { + return []byte(wavebase.JwtSecret), nil + }) + if err != nil { + return wshutil.RpcContext{}, fmt.Errorf("error parsing token: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return wshutil.RpcContext{}, fmt.Errorf("error getting claims from token") + } + // validate "exp" claim + if exp, ok := claims["exp"].(float64); ok { + if int64(exp) < time.Now().Unix() { + return wshutil.RpcContext{}, fmt.Errorf("token has expired") + } + } else { + return wshutil.RpcContext{}, fmt.Errorf("exp claim is missing or invalid") + } + // validate "iss" claim + if iss, ok := claims["iss"].(string); ok { + if iss != "waveterm" { + return wshutil.RpcContext{}, fmt.Errorf("unexpected issuer: %s", iss) + } + } else { + return wshutil.RpcContext{}, fmt.Errorf("iss claim is missing or invalid") + } + rpcCtx := wshutil.RpcContext{} + rpcCtx.BlockId = claims["blockid"].(string) + rpcCtx.TabId = claims["tabid"].(string) + rpcCtx.WindowId = claims["windowid"].(string) + return rpcCtx, nil +} + +func ExtractUnverifiedSocketName(tokenStr string) (string, error) { + // this happens on the client who does not have access to the secret key + // we want to read the claims without validating the signature + token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return "", fmt.Errorf("error parsing token: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return "", fmt.Errorf("error getting claims from token") + } + sockName, ok := claims["sock"].(string) + if !ok { + return "", fmt.Errorf("sock claim is missing or invalid") + } + return sockName, nil +} + func RunDomainSocketWshServer() error { sockName := wavebase.GetDomainSocketName() listener, err := MakeUnixListener(sockName)