handle simple authkey authentication for local-server

This commit is contained in:
sawka 2022-12-20 16:16:46 -08:00
parent 21bbab88c8
commit 1697010d55
3 changed files with 82 additions and 59 deletions

View File

@ -28,6 +28,8 @@ import (
"github.com/scripthaus-dev/sh2-server/pkg/wsshell" "github.com/scripthaus-dev/sh2-server/pkg/wsshell"
) )
type WebFnType = func(http.ResponseWriter, *http.Request)
const HttpReadTimeout = 5 * time.Second const HttpReadTimeout = 5 * time.Second
const HttpWriteTimeout = 21 * time.Second const HttpWriteTimeout = 21 * time.Second
const HttpMaxHeaderBytes = 60000 const HttpMaxHeaderBytes = 60000
@ -40,6 +42,7 @@ const WSStatePacketChSize = 20
var GlobalLock = &sync.Mutex{} var GlobalLock = &sync.Mutex{}
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
var GlobalAuthKey string
func setWSState(state *scws.WSState) { func setWSState(state *scws.WSState) {
GlobalLock.Lock() GlobalLock.Lock()
@ -81,7 +84,7 @@ func HandleWs(w http.ResponseWriter, r *http.Request) {
} }
state := getWSState(clientId) state := getWSState(clientId)
if state == nil { if state == nil {
state = scws.MakeWSState(clientId) state = scws.MakeWSState(clientId, GlobalAuthKey)
state.ReplaceShell(shell) state.ReplaceShell(shell)
setWSState(state) setWSState(state)
} else { } else {
@ -119,10 +122,6 @@ func writeToFifo(fifoName string, data []byte) error {
} }
func HandleGetClientData(w http.ResponseWriter, r *http.Request) { func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
cdata, err := sstore.EnsureClientData(r.Context()) cdata, err := sstore.EnsureClientData(r.Context())
if err != nil { if err != nil {
WriteJsonError(w, err) WriteJsonError(w, err)
@ -133,15 +132,6 @@ func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
} }
func HandleSetWinSize(w http.ResponseWriter, r *http.Request) { func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
if r.Method == "GET" || r.Method == "OPTIONS" {
w.WriteHeader(200)
return
}
decoder := json.NewDecoder(r.Body) decoder := json.NewDecoder(r.Body)
var winSize sstore.ClientWinSizeType var winSize sstore.ClientWinSizeType
err := decoder.Decode(&winSize) err := decoder.Decode(&winSize)
@ -160,10 +150,6 @@ func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
// params: sessionid, windowid // params: sessionid, windowid
func HandleGetWindow(w http.ResponseWriter, r *http.Request) { func HandleGetWindow(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
qvals := r.URL.Query() qvals := r.URL.Query()
sessionId := qvals.Get("sessionid") sessionId := qvals.Get("sessionid")
windowId := qvals.Get("windowid") windowId := qvals.Get("windowid")
@ -226,10 +212,6 @@ func HandleRtnState(w http.ResponseWriter, r *http.Request) {
} }
func HandleRemotePty(w http.ResponseWriter, r *http.Request) { func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
qvals := r.URL.Query() qvals := r.URL.Query()
remoteId := qvals.Get("remoteid") remoteId := qvals.Get("remoteid")
if remoteId == "" { if remoteId == "" {
@ -255,10 +237,6 @@ func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
} }
func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) { func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
qvals := r.URL.Query() qvals := r.URL.Query()
sessionId := qvals.Get("sessionid") sessionId := qvals.Get("sessionid")
cmdId := qvals.Get("cmdid") cmdId := qvals.Get("cmdid")
@ -330,15 +308,7 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
WriteJsonError(w, fmt.Errorf("panic: %v", r)) WriteJsonError(w, fmt.Errorf("panic: %v", r))
return return
}() }()
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
if r.Method == "GET" || r.Method == "OPTIONS" {
w.WriteHeader(200)
return
}
decoder := json.NewDecoder(r.Body) decoder := json.NewDecoder(r.Body)
var commandPk scpacket.FeCommandPacketType var commandPk scpacket.FeCommandPacketType
err := decoder.Decode(&commandPk) err := decoder.Decode(&commandPk)
@ -355,6 +325,24 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
return return
} }
func AuthKeyWrap(fn WebFnType) WebFnType {
return func(w http.ResponseWriter, r *http.Request) {
reqAuthKey := r.Header.Get("X-AuthKey")
if reqAuthKey == "" {
w.WriteHeader(500)
w.Write([]byte("no x-authkey header"))
return
}
if reqAuthKey != GlobalAuthKey {
w.WriteHeader(500)
w.Write([]byte("x-authkey header is invalid"))
return
}
w.Header().Set("Cache-Control", "no-cache")
fn(w, r)
}
}
func runWebSocketServer() { func runWebSocketServer() {
gr := mux.NewRouter() gr := mux.NewRouter()
gr.HandleFunc("/ws", HandleWs) gr.HandleFunc("/ws", HandleWs)
@ -405,10 +393,9 @@ func main() {
scLock, err := scbase.AcquirePromptLock() scLock, err := scbase.AcquirePromptLock()
if err != nil || scLock == nil { if err != nil || scLock == nil {
log.Printf("[error] cannot acquire sh2 lock: %v\n", err) log.Printf("[error] cannot acquire prompt lock: %v\n", err)
return return
} }
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") { if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
err := sstore.MigrateCommandOpts(os.Args[1:]) err := sstore.MigrateCommandOpts(os.Args[1:])
if err != nil { if err != nil {
@ -416,6 +403,12 @@ func main() {
} }
return return
} }
authKey, err := scbase.ReadPromptAuthKey()
if err != nil {
log.Printf("[error] %v\n", err)
return
}
GlobalAuthKey = authKey
err = sstore.TryMigrateUp() err = sstore.TryMigrateUp()
if err != nil { if err != nil {
log.Printf("[error] migrate up: %v\n", err) log.Printf("[error] migrate up: %v\n", err)
@ -451,13 +444,13 @@ func main() {
go stdinReadWatch() go stdinReadWatch()
go runWebSocketServer() go runWebSocketServer()
gr := mux.NewRouter() gr := mux.NewRouter()
gr.HandleFunc("/api/ptyout", HandleGetPtyOut) gr.HandleFunc("/api/ptyout", AuthKeyWrap(HandleGetPtyOut))
gr.HandleFunc("/api/remote-pty", HandleRemotePty) gr.HandleFunc("/api/remote-pty", AuthKeyWrap(HandleRemotePty))
gr.HandleFunc("/api/rtnstate", HandleRtnState) gr.HandleFunc("/api/rtnstate", AuthKeyWrap(HandleRtnState))
gr.HandleFunc("/api/get-window", HandleGetWindow) gr.HandleFunc("/api/get-window", AuthKeyWrap(HandleGetWindow))
gr.HandleFunc("/api/run-command", HandleRunCommand).Methods("GET", "POST", "OPTIONS") gr.HandleFunc("/api/run-command", AuthKeyWrap(HandleRunCommand)).Methods("POST")
gr.HandleFunc("/api/get-client-data", HandleGetClientData) gr.HandleFunc("/api/get-client-data", AuthKeyWrap(HandleGetClientData))
gr.HandleFunc("/api/set-winsize", HandleSetWinSize) gr.HandleFunc("/api/set-winsize", AuthKeyWrap(HandleSetWinSize))
server := &http.Server{ server := &http.Server{
Addr: MainServerAddr, Addr: MainServerAddr,
ReadTimeout: HttpReadTimeout, ReadTimeout: HttpReadTimeout,

View File

@ -77,6 +77,7 @@ type WatchScreenPacketType struct {
SessionId string `json:"sessionid"` SessionId string `json:"sessionid"`
ScreenId string `json:"screenid"` ScreenId string `json:"screenid"`
Connect bool `json:"connect"` Connect bool `json:"connect"`
AuthKey string `json:"authkey"`
} }
func init() { func init() {

View File

@ -19,25 +19,40 @@ const WSStatePacketChSize = 20
const MaxInputDataSize = 1000 const MaxInputDataSize = 1000
type WSState struct { type WSState struct {
Lock *sync.Mutex Lock *sync.Mutex
ClientId string ClientId string
ConnectTime time.Time ConnectTime time.Time
Shell *wsshell.WSShell Shell *wsshell.WSShell
UpdateCh chan interface{} UpdateCh chan interface{}
UpdateQueue []interface{} UpdateQueue []interface{}
Authenticated bool
AuthKey string
SessionId string SessionId string
ScreenId string ScreenId string
} }
func MakeWSState(clientId string) *WSState { func MakeWSState(clientId string, authKey string) *WSState {
rtn := &WSState{} rtn := &WSState{}
rtn.Lock = &sync.Mutex{} rtn.Lock = &sync.Mutex{}
rtn.ClientId = clientId rtn.ClientId = clientId
rtn.ConnectTime = time.Now() rtn.ConnectTime = time.Now()
rtn.AuthKey = authKey
return rtn return rtn
} }
func (ws *WSState) SetAuthenticated(authVal bool) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.Authenticated = authVal
}
func (ws *WSState) IsAuthenticated() bool {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Authenticated
}
func (ws *WSState) GetShell() *wsshell.WSShell { func (ws *WSState) GetShell() *wsshell.WSShell {
ws.Lock.Lock() ws.Lock.Lock()
defer ws.Lock.Unlock() defer ws.Lock.Unlock()
@ -151,6 +166,15 @@ func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error
return fmt.Errorf("invalid watchscreen screenid: %w", err) return fmt.Errorf("invalid watchscreen screenid: %w", err)
} }
} }
if wsPk.AuthKey == "" {
ws.SetAuthenticated(false)
return fmt.Errorf("invalid watchscreen, no authkey")
}
if wsPk.AuthKey != ws.AuthKey {
ws.SetAuthenticated(false)
return fmt.Errorf("invalid watchscreen, invalid authkey")
}
ws.SetAuthenticated(true)
if wsPk.SessionId == "" || wsPk.ScreenId == "" { if wsPk.SessionId == "" || wsPk.ScreenId == "" {
ws.UnWatchScreen() ws.UnWatchScreen()
} else { } else {
@ -179,6 +203,20 @@ func (ws *WSState) RunWSRead() {
log.Printf("error unmarshalling ws message: %v\n", err) log.Printf("error unmarshalling ws message: %v\n", err)
continue continue
} }
if pk.GetType() == scpacket.WatchScreenPacketStr {
wsPk := pk.(*scpacket.WatchScreenPacketType)
err := ws.handleWatchScreen(wsPk)
if err != nil {
// TODO send errors back to client, likely unrecoverable
log.Printf("[ws %s] error %v\n", err)
}
continue
}
isAuth := ws.IsAuthenticated()
if !isAuth {
log.Printf("[error] cannot process ws-packet[%s], not authenticated\n", pk.GetType())
continue
}
if pk.GetType() == scpacket.FeInputPacketStr { if pk.GetType() == scpacket.FeInputPacketStr {
feInputPk := pk.(*scpacket.FeInputPacketType) feInputPk := pk.(*scpacket.FeInputPacketType)
if feInputPk.Remote.OwnerId != "" { if feInputPk.Remote.OwnerId != "" {
@ -198,15 +236,6 @@ func (ws *WSState) RunWSRead() {
}() }()
continue continue
} }
if pk.GetType() == scpacket.WatchScreenPacketStr {
wsPk := pk.(*scpacket.WatchScreenPacketType)
err := ws.handleWatchScreen(wsPk)
if err != nil {
// TODO send errors back to client, likely unrecoverable
log.Printf("[ws %s] error %v\n", err)
}
continue
}
if pk.GetType() == scpacket.RemoteInputPacketStr { if pk.GetType() == scpacket.RemoteInputPacketStr {
inputPk := pk.(*scpacket.RemoteInputPacketType) inputPk := pk.(*scpacket.RemoteInputPacketType)
if inputPk.RemoteId == "" { if inputPk.RemoteId == "" {