handle simple authkey authentication for local-server

This commit is contained in:
sawka 2022-12-20 16:16:46 -08:00
parent 21bbab88c8
commit 1697010d55
3 changed files with 82 additions and 59 deletions

View File

@ -28,6 +28,8 @@ import (
"github.com/scripthaus-dev/sh2-server/pkg/wsshell"
)
type WebFnType = func(http.ResponseWriter, *http.Request)
const HttpReadTimeout = 5 * time.Second
const HttpWriteTimeout = 21 * time.Second
const HttpMaxHeaderBytes = 60000
@ -40,6 +42,7 @@ const WSStatePacketChSize = 20
var GlobalLock = &sync.Mutex{}
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
var GlobalAuthKey string
func setWSState(state *scws.WSState) {
GlobalLock.Lock()
@ -81,7 +84,7 @@ func HandleWs(w http.ResponseWriter, r *http.Request) {
}
state := getWSState(clientId)
if state == nil {
state = scws.MakeWSState(clientId)
state = scws.MakeWSState(clientId, GlobalAuthKey)
state.ReplaceShell(shell)
setWSState(state)
} else {
@ -119,10 +122,6 @@ func writeToFifo(fifoName string, data []byte) error {
}
func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
cdata, err := sstore.EnsureClientData(r.Context())
if err != nil {
WriteJsonError(w, err)
@ -133,15 +132,6 @@ func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
}
func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
if r.Method == "GET" || r.Method == "OPTIONS" {
w.WriteHeader(200)
return
}
decoder := json.NewDecoder(r.Body)
var winSize sstore.ClientWinSizeType
err := decoder.Decode(&winSize)
@ -160,10 +150,6 @@ func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
// params: sessionid, windowid
func HandleGetWindow(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
qvals := r.URL.Query()
sessionId := qvals.Get("sessionid")
windowId := qvals.Get("windowid")
@ -226,10 +212,6 @@ func HandleRtnState(w http.ResponseWriter, r *http.Request) {
}
func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
qvals := r.URL.Query()
remoteId := qvals.Get("remoteid")
if remoteId == "" {
@ -255,10 +237,6 @@ func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
}
func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
qvals := r.URL.Query()
sessionId := qvals.Get("sessionid")
cmdId := qvals.Get("cmdid")
@ -330,15 +308,7 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
WriteJsonError(w, fmt.Errorf("panic: %v", r))
return
}()
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Vary", "Origin")
w.Header().Set("Cache-Control", "no-cache")
if r.Method == "GET" || r.Method == "OPTIONS" {
w.WriteHeader(200)
return
}
decoder := json.NewDecoder(r.Body)
var commandPk scpacket.FeCommandPacketType
err := decoder.Decode(&commandPk)
@ -355,6 +325,24 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
return
}
func AuthKeyWrap(fn WebFnType) WebFnType {
return func(w http.ResponseWriter, r *http.Request) {
reqAuthKey := r.Header.Get("X-AuthKey")
if reqAuthKey == "" {
w.WriteHeader(500)
w.Write([]byte("no x-authkey header"))
return
}
if reqAuthKey != GlobalAuthKey {
w.WriteHeader(500)
w.Write([]byte("x-authkey header is invalid"))
return
}
w.Header().Set("Cache-Control", "no-cache")
fn(w, r)
}
}
func runWebSocketServer() {
gr := mux.NewRouter()
gr.HandleFunc("/ws", HandleWs)
@ -405,10 +393,9 @@ func main() {
scLock, err := scbase.AcquirePromptLock()
if err != nil || scLock == nil {
log.Printf("[error] cannot acquire sh2 lock: %v\n", err)
log.Printf("[error] cannot acquire prompt lock: %v\n", err)
return
}
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
err := sstore.MigrateCommandOpts(os.Args[1:])
if err != nil {
@ -416,6 +403,12 @@ func main() {
}
return
}
authKey, err := scbase.ReadPromptAuthKey()
if err != nil {
log.Printf("[error] %v\n", err)
return
}
GlobalAuthKey = authKey
err = sstore.TryMigrateUp()
if err != nil {
log.Printf("[error] migrate up: %v\n", err)
@ -451,13 +444,13 @@ func main() {
go stdinReadWatch()
go runWebSocketServer()
gr := mux.NewRouter()
gr.HandleFunc("/api/ptyout", HandleGetPtyOut)
gr.HandleFunc("/api/remote-pty", HandleRemotePty)
gr.HandleFunc("/api/rtnstate", HandleRtnState)
gr.HandleFunc("/api/get-window", HandleGetWindow)
gr.HandleFunc("/api/run-command", HandleRunCommand).Methods("GET", "POST", "OPTIONS")
gr.HandleFunc("/api/get-client-data", HandleGetClientData)
gr.HandleFunc("/api/set-winsize", HandleSetWinSize)
gr.HandleFunc("/api/ptyout", AuthKeyWrap(HandleGetPtyOut))
gr.HandleFunc("/api/remote-pty", AuthKeyWrap(HandleRemotePty))
gr.HandleFunc("/api/rtnstate", AuthKeyWrap(HandleRtnState))
gr.HandleFunc("/api/get-window", AuthKeyWrap(HandleGetWindow))
gr.HandleFunc("/api/run-command", AuthKeyWrap(HandleRunCommand)).Methods("POST")
gr.HandleFunc("/api/get-client-data", AuthKeyWrap(HandleGetClientData))
gr.HandleFunc("/api/set-winsize", AuthKeyWrap(HandleSetWinSize))
server := &http.Server{
Addr: MainServerAddr,
ReadTimeout: HttpReadTimeout,

View File

@ -77,6 +77,7 @@ type WatchScreenPacketType struct {
SessionId string `json:"sessionid"`
ScreenId string `json:"screenid"`
Connect bool `json:"connect"`
AuthKey string `json:"authkey"`
}
func init() {

View File

@ -19,25 +19,40 @@ const WSStatePacketChSize = 20
const MaxInputDataSize = 1000
type WSState struct {
Lock *sync.Mutex
ClientId string
ConnectTime time.Time
Shell *wsshell.WSShell
UpdateCh chan interface{}
UpdateQueue []interface{}
Lock *sync.Mutex
ClientId string
ConnectTime time.Time
Shell *wsshell.WSShell
UpdateCh chan interface{}
UpdateQueue []interface{}
Authenticated bool
AuthKey string
SessionId string
ScreenId string
}
func MakeWSState(clientId string) *WSState {
func MakeWSState(clientId string, authKey string) *WSState {
rtn := &WSState{}
rtn.Lock = &sync.Mutex{}
rtn.ClientId = clientId
rtn.ConnectTime = time.Now()
rtn.AuthKey = authKey
return rtn
}
func (ws *WSState) SetAuthenticated(authVal bool) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.Authenticated = authVal
}
func (ws *WSState) IsAuthenticated() bool {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Authenticated
}
func (ws *WSState) GetShell() *wsshell.WSShell {
ws.Lock.Lock()
defer ws.Lock.Unlock()
@ -151,6 +166,15 @@ func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error
return fmt.Errorf("invalid watchscreen screenid: %w", err)
}
}
if wsPk.AuthKey == "" {
ws.SetAuthenticated(false)
return fmt.Errorf("invalid watchscreen, no authkey")
}
if wsPk.AuthKey != ws.AuthKey {
ws.SetAuthenticated(false)
return fmt.Errorf("invalid watchscreen, invalid authkey")
}
ws.SetAuthenticated(true)
if wsPk.SessionId == "" || wsPk.ScreenId == "" {
ws.UnWatchScreen()
} else {
@ -179,6 +203,20 @@ func (ws *WSState) RunWSRead() {
log.Printf("error unmarshalling ws message: %v\n", err)
continue
}
if pk.GetType() == scpacket.WatchScreenPacketStr {
wsPk := pk.(*scpacket.WatchScreenPacketType)
err := ws.handleWatchScreen(wsPk)
if err != nil {
// TODO send errors back to client, likely unrecoverable
log.Printf("[ws %s] error %v\n", err)
}
continue
}
isAuth := ws.IsAuthenticated()
if !isAuth {
log.Printf("[error] cannot process ws-packet[%s], not authenticated\n", pk.GetType())
continue
}
if pk.GetType() == scpacket.FeInputPacketStr {
feInputPk := pk.(*scpacket.FeInputPacketType)
if feInputPk.Remote.OwnerId != "" {
@ -198,15 +236,6 @@ func (ws *WSState) RunWSRead() {
}()
continue
}
if pk.GetType() == scpacket.WatchScreenPacketStr {
wsPk := pk.(*scpacket.WatchScreenPacketType)
err := ws.handleWatchScreen(wsPk)
if err != nil {
// TODO send errors back to client, likely unrecoverable
log.Printf("[ws %s] error %v\n", err)
}
continue
}
if pk.GetType() == scpacket.RemoteInputPacketStr {
inputPk := pk.(*scpacket.RemoteInputPacketType)
if inputPk.RemoteId == "" {