mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-04-01 17:57:31 +02:00
handle simple authkey authentication for local-server
This commit is contained in:
parent
21bbab88c8
commit
1697010d55
@ -28,6 +28,8 @@ import (
|
||||
"github.com/scripthaus-dev/sh2-server/pkg/wsshell"
|
||||
)
|
||||
|
||||
type WebFnType = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
const HttpReadTimeout = 5 * time.Second
|
||||
const HttpWriteTimeout = 21 * time.Second
|
||||
const HttpMaxHeaderBytes = 60000
|
||||
@ -40,6 +42,7 @@ const WSStatePacketChSize = 20
|
||||
|
||||
var GlobalLock = &sync.Mutex{}
|
||||
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
|
||||
var GlobalAuthKey string
|
||||
|
||||
func setWSState(state *scws.WSState) {
|
||||
GlobalLock.Lock()
|
||||
@ -81,7 +84,7 @@ func HandleWs(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
state := getWSState(clientId)
|
||||
if state == nil {
|
||||
state = scws.MakeWSState(clientId)
|
||||
state = scws.MakeWSState(clientId, GlobalAuthKey)
|
||||
state.ReplaceShell(shell)
|
||||
setWSState(state)
|
||||
} else {
|
||||
@ -119,10 +122,6 @@ func writeToFifo(fifoName string, data []byte) error {
|
||||
}
|
||||
|
||||
func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
cdata, err := sstore.EnsureClientData(r.Context())
|
||||
if err != nil {
|
||||
WriteJsonError(w, err)
|
||||
@ -133,15 +132,6 @@ func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
if r.Method == "GET" || r.Method == "OPTIONS" {
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var winSize sstore.ClientWinSizeType
|
||||
err := decoder.Decode(&winSize)
|
||||
@ -160,10 +150,6 @@ func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// params: sessionid, windowid
|
||||
func HandleGetWindow(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
qvals := r.URL.Query()
|
||||
sessionId := qvals.Get("sessionid")
|
||||
windowId := qvals.Get("windowid")
|
||||
@ -226,10 +212,6 @@ func HandleRtnState(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
qvals := r.URL.Query()
|
||||
remoteId := qvals.Get("remoteid")
|
||||
if remoteId == "" {
|
||||
@ -255,10 +237,6 @@ func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
qvals := r.URL.Query()
|
||||
sessionId := qvals.Get("sessionid")
|
||||
cmdId := qvals.Get("cmdid")
|
||||
@ -330,15 +308,7 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
|
||||
WriteJsonError(w, fmt.Errorf("panic: %v", r))
|
||||
return
|
||||
}()
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
if r.Method == "GET" || r.Method == "OPTIONS" {
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var commandPk scpacket.FeCommandPacketType
|
||||
err := decoder.Decode(&commandPk)
|
||||
@ -355,6 +325,24 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
func AuthKeyWrap(fn WebFnType) WebFnType {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reqAuthKey := r.Header.Get("X-AuthKey")
|
||||
if reqAuthKey == "" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte("no x-authkey header"))
|
||||
return
|
||||
}
|
||||
if reqAuthKey != GlobalAuthKey {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte("x-authkey header is invalid"))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
fn(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func runWebSocketServer() {
|
||||
gr := mux.NewRouter()
|
||||
gr.HandleFunc("/ws", HandleWs)
|
||||
@ -405,10 +393,9 @@ func main() {
|
||||
|
||||
scLock, err := scbase.AcquirePromptLock()
|
||||
if err != nil || scLock == nil {
|
||||
log.Printf("[error] cannot acquire sh2 lock: %v\n", err)
|
||||
log.Printf("[error] cannot acquire prompt lock: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
|
||||
err := sstore.MigrateCommandOpts(os.Args[1:])
|
||||
if err != nil {
|
||||
@ -416,6 +403,12 @@ func main() {
|
||||
}
|
||||
return
|
||||
}
|
||||
authKey, err := scbase.ReadPromptAuthKey()
|
||||
if err != nil {
|
||||
log.Printf("[error] %v\n", err)
|
||||
return
|
||||
}
|
||||
GlobalAuthKey = authKey
|
||||
err = sstore.TryMigrateUp()
|
||||
if err != nil {
|
||||
log.Printf("[error] migrate up: %v\n", err)
|
||||
@ -451,13 +444,13 @@ func main() {
|
||||
go stdinReadWatch()
|
||||
go runWebSocketServer()
|
||||
gr := mux.NewRouter()
|
||||
gr.HandleFunc("/api/ptyout", HandleGetPtyOut)
|
||||
gr.HandleFunc("/api/remote-pty", HandleRemotePty)
|
||||
gr.HandleFunc("/api/rtnstate", HandleRtnState)
|
||||
gr.HandleFunc("/api/get-window", HandleGetWindow)
|
||||
gr.HandleFunc("/api/run-command", HandleRunCommand).Methods("GET", "POST", "OPTIONS")
|
||||
gr.HandleFunc("/api/get-client-data", HandleGetClientData)
|
||||
gr.HandleFunc("/api/set-winsize", HandleSetWinSize)
|
||||
gr.HandleFunc("/api/ptyout", AuthKeyWrap(HandleGetPtyOut))
|
||||
gr.HandleFunc("/api/remote-pty", AuthKeyWrap(HandleRemotePty))
|
||||
gr.HandleFunc("/api/rtnstate", AuthKeyWrap(HandleRtnState))
|
||||
gr.HandleFunc("/api/get-window", AuthKeyWrap(HandleGetWindow))
|
||||
gr.HandleFunc("/api/run-command", AuthKeyWrap(HandleRunCommand)).Methods("POST")
|
||||
gr.HandleFunc("/api/get-client-data", AuthKeyWrap(HandleGetClientData))
|
||||
gr.HandleFunc("/api/set-winsize", AuthKeyWrap(HandleSetWinSize))
|
||||
server := &http.Server{
|
||||
Addr: MainServerAddr,
|
||||
ReadTimeout: HttpReadTimeout,
|
||||
|
@ -77,6 +77,7 @@ type WatchScreenPacketType struct {
|
||||
SessionId string `json:"sessionid"`
|
||||
ScreenId string `json:"screenid"`
|
||||
Connect bool `json:"connect"`
|
||||
AuthKey string `json:"authkey"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -19,25 +19,40 @@ const WSStatePacketChSize = 20
|
||||
const MaxInputDataSize = 1000
|
||||
|
||||
type WSState struct {
|
||||
Lock *sync.Mutex
|
||||
ClientId string
|
||||
ConnectTime time.Time
|
||||
Shell *wsshell.WSShell
|
||||
UpdateCh chan interface{}
|
||||
UpdateQueue []interface{}
|
||||
Lock *sync.Mutex
|
||||
ClientId string
|
||||
ConnectTime time.Time
|
||||
Shell *wsshell.WSShell
|
||||
UpdateCh chan interface{}
|
||||
UpdateQueue []interface{}
|
||||
Authenticated bool
|
||||
AuthKey string
|
||||
|
||||
SessionId string
|
||||
ScreenId string
|
||||
}
|
||||
|
||||
func MakeWSState(clientId string) *WSState {
|
||||
func MakeWSState(clientId string, authKey string) *WSState {
|
||||
rtn := &WSState{}
|
||||
rtn.Lock = &sync.Mutex{}
|
||||
rtn.ClientId = clientId
|
||||
rtn.ConnectTime = time.Now()
|
||||
rtn.AuthKey = authKey
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (ws *WSState) SetAuthenticated(authVal bool) {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
ws.Authenticated = authVal
|
||||
}
|
||||
|
||||
func (ws *WSState) IsAuthenticated() bool {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
return ws.Authenticated
|
||||
}
|
||||
|
||||
func (ws *WSState) GetShell() *wsshell.WSShell {
|
||||
ws.Lock.Lock()
|
||||
defer ws.Lock.Unlock()
|
||||
@ -151,6 +166,15 @@ func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error
|
||||
return fmt.Errorf("invalid watchscreen screenid: %w", err)
|
||||
}
|
||||
}
|
||||
if wsPk.AuthKey == "" {
|
||||
ws.SetAuthenticated(false)
|
||||
return fmt.Errorf("invalid watchscreen, no authkey")
|
||||
}
|
||||
if wsPk.AuthKey != ws.AuthKey {
|
||||
ws.SetAuthenticated(false)
|
||||
return fmt.Errorf("invalid watchscreen, invalid authkey")
|
||||
}
|
||||
ws.SetAuthenticated(true)
|
||||
if wsPk.SessionId == "" || wsPk.ScreenId == "" {
|
||||
ws.UnWatchScreen()
|
||||
} else {
|
||||
@ -179,6 +203,20 @@ func (ws *WSState) RunWSRead() {
|
||||
log.Printf("error unmarshalling ws message: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == scpacket.WatchScreenPacketStr {
|
||||
wsPk := pk.(*scpacket.WatchScreenPacketType)
|
||||
err := ws.handleWatchScreen(wsPk)
|
||||
if err != nil {
|
||||
// TODO send errors back to client, likely unrecoverable
|
||||
log.Printf("[ws %s] error %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
isAuth := ws.IsAuthenticated()
|
||||
if !isAuth {
|
||||
log.Printf("[error] cannot process ws-packet[%s], not authenticated\n", pk.GetType())
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == scpacket.FeInputPacketStr {
|
||||
feInputPk := pk.(*scpacket.FeInputPacketType)
|
||||
if feInputPk.Remote.OwnerId != "" {
|
||||
@ -198,15 +236,6 @@ func (ws *WSState) RunWSRead() {
|
||||
}()
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == scpacket.WatchScreenPacketStr {
|
||||
wsPk := pk.(*scpacket.WatchScreenPacketType)
|
||||
err := ws.handleWatchScreen(wsPk)
|
||||
if err != nil {
|
||||
// TODO send errors back to client, likely unrecoverable
|
||||
log.Printf("[ws %s] error %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == scpacket.RemoteInputPacketStr {
|
||||
inputPk := pk.(*scpacket.RemoteInputPacketType)
|
||||
if inputPk.RemoteId == "" {
|
||||
|
Loading…
Reference in New Issue
Block a user