diff --git a/cmd/main-server.go b/cmd/main-server.go index 79309fc54..b6c27b0e1 100644 --- a/cmd/main-server.go +++ b/cmd/main-server.go @@ -28,6 +28,8 @@ import ( "github.com/scripthaus-dev/sh2-server/pkg/wsshell" ) +type WebFnType = func(http.ResponseWriter, *http.Request) + const HttpReadTimeout = 5 * time.Second const HttpWriteTimeout = 21 * time.Second const HttpMaxHeaderBytes = 60000 @@ -40,6 +42,7 @@ const WSStatePacketChSize = 20 var GlobalLock = &sync.Mutex{} var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState +var GlobalAuthKey string func setWSState(state *scws.WSState) { GlobalLock.Lock() @@ -81,7 +84,7 @@ func HandleWs(w http.ResponseWriter, r *http.Request) { } state := getWSState(clientId) if state == nil { - state = scws.MakeWSState(clientId) + state = scws.MakeWSState(clientId, GlobalAuthKey) state.ReplaceShell(shell) setWSState(state) } else { @@ -119,10 +122,6 @@ func writeToFifo(fifoName string, data []byte) error { } func HandleGetClientData(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Vary", "Origin") - w.Header().Set("Cache-Control", "no-cache") cdata, err := sstore.EnsureClientData(r.Context()) if err != nil { WriteJsonError(w, err) @@ -133,15 +132,6 @@ func HandleGetClientData(w http.ResponseWriter, r *http.Request) { } func HandleSetWinSize(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") - w.Header().Set("Vary", "Origin") - w.Header().Set("Cache-Control", "no-cache") - if r.Method == "GET" || r.Method == "OPTIONS" { - w.WriteHeader(200) - return - } decoder := json.NewDecoder(r.Body) var winSize sstore.ClientWinSizeType err := decoder.Decode(&winSize) @@ -160,10 +150,6 @@ func HandleSetWinSize(w http.ResponseWriter, r *http.Request) { // params: sessionid, windowid func HandleGetWindow(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Vary", "Origin") - w.Header().Set("Cache-Control", "no-cache") qvals := r.URL.Query() sessionId := qvals.Get("sessionid") windowId := qvals.Get("windowid") @@ -226,10 +212,6 @@ func HandleRtnState(w http.ResponseWriter, r *http.Request) { } func HandleRemotePty(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Vary", "Origin") - w.Header().Set("Cache-Control", "no-cache") qvals := r.URL.Query() remoteId := qvals.Get("remoteid") if remoteId == "" { @@ -255,10 +237,6 @@ func HandleRemotePty(w http.ResponseWriter, r *http.Request) { } func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Vary", "Origin") - w.Header().Set("Cache-Control", "no-cache") qvals := r.URL.Query() sessionId := qvals.Get("sessionid") cmdId := qvals.Get("cmdid") @@ -330,15 +308,7 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) { WriteJsonError(w, fmt.Errorf("panic: %v", r)) return }() - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") - w.Header().Set("Vary", "Origin") w.Header().Set("Cache-Control", "no-cache") - if r.Method == "GET" || r.Method == "OPTIONS" { - w.WriteHeader(200) - return - } decoder := json.NewDecoder(r.Body) var commandPk scpacket.FeCommandPacketType err := decoder.Decode(&commandPk) @@ -355,6 +325,24 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) { return } +func AuthKeyWrap(fn WebFnType) WebFnType { + return func(w http.ResponseWriter, r *http.Request) { + reqAuthKey := r.Header.Get("X-AuthKey") + if reqAuthKey == "" { + w.WriteHeader(500) + w.Write([]byte("no x-authkey header")) + return + } + if reqAuthKey != GlobalAuthKey { + w.WriteHeader(500) + w.Write([]byte("x-authkey header is invalid")) + return + } + w.Header().Set("Cache-Control", "no-cache") + fn(w, r) + } +} + func runWebSocketServer() { gr := mux.NewRouter() gr.HandleFunc("/ws", HandleWs) @@ -405,10 +393,9 @@ func main() { scLock, err := scbase.AcquirePromptLock() if err != nil || scLock == nil { - log.Printf("[error] cannot acquire sh2 lock: %v\n", err) + log.Printf("[error] cannot acquire prompt lock: %v\n", err) return } - if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") { err := sstore.MigrateCommandOpts(os.Args[1:]) if err != nil { @@ -416,6 +403,12 @@ func main() { } return } + authKey, err := scbase.ReadPromptAuthKey() + if err != nil { + log.Printf("[error] %v\n", err) + return + } + GlobalAuthKey = authKey err = sstore.TryMigrateUp() if err != nil { log.Printf("[error] migrate up: %v\n", err) @@ -451,13 +444,13 @@ func main() { go stdinReadWatch() go runWebSocketServer() gr := mux.NewRouter() - gr.HandleFunc("/api/ptyout", HandleGetPtyOut) - gr.HandleFunc("/api/remote-pty", HandleRemotePty) - gr.HandleFunc("/api/rtnstate", HandleRtnState) - gr.HandleFunc("/api/get-window", HandleGetWindow) - gr.HandleFunc("/api/run-command", HandleRunCommand).Methods("GET", "POST", "OPTIONS") - gr.HandleFunc("/api/get-client-data", HandleGetClientData) - gr.HandleFunc("/api/set-winsize", HandleSetWinSize) + gr.HandleFunc("/api/ptyout", AuthKeyWrap(HandleGetPtyOut)) + gr.HandleFunc("/api/remote-pty", AuthKeyWrap(HandleRemotePty)) + gr.HandleFunc("/api/rtnstate", AuthKeyWrap(HandleRtnState)) + gr.HandleFunc("/api/get-window", AuthKeyWrap(HandleGetWindow)) + gr.HandleFunc("/api/run-command", AuthKeyWrap(HandleRunCommand)).Methods("POST") + gr.HandleFunc("/api/get-client-data", AuthKeyWrap(HandleGetClientData)) + gr.HandleFunc("/api/set-winsize", AuthKeyWrap(HandleSetWinSize)) server := &http.Server{ Addr: MainServerAddr, ReadTimeout: HttpReadTimeout, diff --git a/pkg/scpacket/scpacket.go b/pkg/scpacket/scpacket.go index 7ea1bb396..bb88338d1 100644 --- a/pkg/scpacket/scpacket.go +++ b/pkg/scpacket/scpacket.go @@ -77,6 +77,7 @@ type WatchScreenPacketType struct { SessionId string `json:"sessionid"` ScreenId string `json:"screenid"` Connect bool `json:"connect"` + AuthKey string `json:"authkey"` } func init() { diff --git a/pkg/scws/scws.go b/pkg/scws/scws.go index 045e39b2e..34a1cd5a8 100644 --- a/pkg/scws/scws.go +++ b/pkg/scws/scws.go @@ -19,25 +19,40 @@ const WSStatePacketChSize = 20 const MaxInputDataSize = 1000 type WSState struct { - Lock *sync.Mutex - ClientId string - ConnectTime time.Time - Shell *wsshell.WSShell - UpdateCh chan interface{} - UpdateQueue []interface{} + Lock *sync.Mutex + ClientId string + ConnectTime time.Time + Shell *wsshell.WSShell + UpdateCh chan interface{} + UpdateQueue []interface{} + Authenticated bool + AuthKey string SessionId string ScreenId string } -func MakeWSState(clientId string) *WSState { +func MakeWSState(clientId string, authKey string) *WSState { rtn := &WSState{} rtn.Lock = &sync.Mutex{} rtn.ClientId = clientId rtn.ConnectTime = time.Now() + rtn.AuthKey = authKey return rtn } +func (ws *WSState) SetAuthenticated(authVal bool) { + ws.Lock.Lock() + defer ws.Lock.Unlock() + ws.Authenticated = authVal +} + +func (ws *WSState) IsAuthenticated() bool { + ws.Lock.Lock() + defer ws.Lock.Unlock() + return ws.Authenticated +} + func (ws *WSState) GetShell() *wsshell.WSShell { ws.Lock.Lock() defer ws.Lock.Unlock() @@ -151,6 +166,15 @@ func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error return fmt.Errorf("invalid watchscreen screenid: %w", err) } } + if wsPk.AuthKey == "" { + ws.SetAuthenticated(false) + return fmt.Errorf("invalid watchscreen, no authkey") + } + if wsPk.AuthKey != ws.AuthKey { + ws.SetAuthenticated(false) + return fmt.Errorf("invalid watchscreen, invalid authkey") + } + ws.SetAuthenticated(true) if wsPk.SessionId == "" || wsPk.ScreenId == "" { ws.UnWatchScreen() } else { @@ -179,6 +203,20 @@ func (ws *WSState) RunWSRead() { log.Printf("error unmarshalling ws message: %v\n", err) continue } + if pk.GetType() == scpacket.WatchScreenPacketStr { + wsPk := pk.(*scpacket.WatchScreenPacketType) + err := ws.handleWatchScreen(wsPk) + if err != nil { + // TODO send errors back to client, likely unrecoverable + log.Printf("[ws %s] error %v\n", err) + } + continue + } + isAuth := ws.IsAuthenticated() + if !isAuth { + log.Printf("[error] cannot process ws-packet[%s], not authenticated\n", pk.GetType()) + continue + } if pk.GetType() == scpacket.FeInputPacketStr { feInputPk := pk.(*scpacket.FeInputPacketType) if feInputPk.Remote.OwnerId != "" { @@ -198,15 +236,6 @@ func (ws *WSState) RunWSRead() { }() continue } - if pk.GetType() == scpacket.WatchScreenPacketStr { - wsPk := pk.(*scpacket.WatchScreenPacketType) - err := ws.handleWatchScreen(wsPk) - if err != nil { - // TODO send errors back to client, likely unrecoverable - log.Printf("[ws %s] error %v\n", err) - } - continue - } if pk.GetType() == scpacket.RemoteInputPacketStr { inputPk := pk.(*scpacket.RemoteInputPacketType) if inputPk.RemoteId == "" {