mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-04-10 19:26:45 +02:00
handle simple authkey authentication for local-server
This commit is contained in:
parent
21bbab88c8
commit
1697010d55
@ -28,6 +28,8 @@ import (
|
|||||||
"github.com/scripthaus-dev/sh2-server/pkg/wsshell"
|
"github.com/scripthaus-dev/sh2-server/pkg/wsshell"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WebFnType = func(http.ResponseWriter, *http.Request)
|
||||||
|
|
||||||
const HttpReadTimeout = 5 * time.Second
|
const HttpReadTimeout = 5 * time.Second
|
||||||
const HttpWriteTimeout = 21 * time.Second
|
const HttpWriteTimeout = 21 * time.Second
|
||||||
const HttpMaxHeaderBytes = 60000
|
const HttpMaxHeaderBytes = 60000
|
||||||
@ -40,6 +42,7 @@ const WSStatePacketChSize = 20
|
|||||||
|
|
||||||
var GlobalLock = &sync.Mutex{}
|
var GlobalLock = &sync.Mutex{}
|
||||||
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
|
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
|
||||||
|
var GlobalAuthKey string
|
||||||
|
|
||||||
func setWSState(state *scws.WSState) {
|
func setWSState(state *scws.WSState) {
|
||||||
GlobalLock.Lock()
|
GlobalLock.Lock()
|
||||||
@ -81,7 +84,7 @@ func HandleWs(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
state := getWSState(clientId)
|
state := getWSState(clientId)
|
||||||
if state == nil {
|
if state == nil {
|
||||||
state = scws.MakeWSState(clientId)
|
state = scws.MakeWSState(clientId, GlobalAuthKey)
|
||||||
state.ReplaceShell(shell)
|
state.ReplaceShell(shell)
|
||||||
setWSState(state)
|
setWSState(state)
|
||||||
} else {
|
} else {
|
||||||
@ -119,10 +122,6 @@ func writeToFifo(fifoName string, data []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
|
func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
w.Header().Set("Vary", "Origin")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
cdata, err := sstore.EnsureClientData(r.Context())
|
cdata, err := sstore.EnsureClientData(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteJsonError(w, err)
|
WriteJsonError(w, err)
|
||||||
@ -133,15 +132,6 @@ func HandleGetClientData(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
|
func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
||||||
w.Header().Set("Vary", "Origin")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
if r.Method == "GET" || r.Method == "OPTIONS" {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
decoder := json.NewDecoder(r.Body)
|
decoder := json.NewDecoder(r.Body)
|
||||||
var winSize sstore.ClientWinSizeType
|
var winSize sstore.ClientWinSizeType
|
||||||
err := decoder.Decode(&winSize)
|
err := decoder.Decode(&winSize)
|
||||||
@ -160,10 +150,6 @@ func HandleSetWinSize(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// params: sessionid, windowid
|
// params: sessionid, windowid
|
||||||
func HandleGetWindow(w http.ResponseWriter, r *http.Request) {
|
func HandleGetWindow(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
w.Header().Set("Vary", "Origin")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
qvals := r.URL.Query()
|
qvals := r.URL.Query()
|
||||||
sessionId := qvals.Get("sessionid")
|
sessionId := qvals.Get("sessionid")
|
||||||
windowId := qvals.Get("windowid")
|
windowId := qvals.Get("windowid")
|
||||||
@ -226,10 +212,6 @@ func HandleRtnState(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
|
func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
w.Header().Set("Vary", "Origin")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
qvals := r.URL.Query()
|
qvals := r.URL.Query()
|
||||||
remoteId := qvals.Get("remoteid")
|
remoteId := qvals.Get("remoteid")
|
||||||
if remoteId == "" {
|
if remoteId == "" {
|
||||||
@ -255,10 +237,6 @@ func HandleRemotePty(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
|
func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
w.Header().Set("Vary", "Origin")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
qvals := r.URL.Query()
|
qvals := r.URL.Query()
|
||||||
sessionId := qvals.Get("sessionid")
|
sessionId := qvals.Get("sessionid")
|
||||||
cmdId := qvals.Get("cmdid")
|
cmdId := qvals.Get("cmdid")
|
||||||
@ -330,15 +308,7 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
|
|||||||
WriteJsonError(w, fmt.Errorf("panic: %v", r))
|
WriteJsonError(w, fmt.Errorf("panic: %v", r))
|
||||||
return
|
return
|
||||||
}()
|
}()
|
||||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
||||||
w.Header().Set("Vary", "Origin")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
if r.Method == "GET" || r.Method == "OPTIONS" {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
decoder := json.NewDecoder(r.Body)
|
decoder := json.NewDecoder(r.Body)
|
||||||
var commandPk scpacket.FeCommandPacketType
|
var commandPk scpacket.FeCommandPacketType
|
||||||
err := decoder.Decode(&commandPk)
|
err := decoder.Decode(&commandPk)
|
||||||
@ -355,6 +325,24 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AuthKeyWrap(fn WebFnType) WebFnType {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
reqAuthKey := r.Header.Get("X-AuthKey")
|
||||||
|
if reqAuthKey == "" {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
w.Write([]byte("no x-authkey header"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqAuthKey != GlobalAuthKey {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
w.Write([]byte("x-authkey header is invalid"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
fn(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func runWebSocketServer() {
|
func runWebSocketServer() {
|
||||||
gr := mux.NewRouter()
|
gr := mux.NewRouter()
|
||||||
gr.HandleFunc("/ws", HandleWs)
|
gr.HandleFunc("/ws", HandleWs)
|
||||||
@ -405,10 +393,9 @@ func main() {
|
|||||||
|
|
||||||
scLock, err := scbase.AcquirePromptLock()
|
scLock, err := scbase.AcquirePromptLock()
|
||||||
if err != nil || scLock == nil {
|
if err != nil || scLock == nil {
|
||||||
log.Printf("[error] cannot acquire sh2 lock: %v\n", err)
|
log.Printf("[error] cannot acquire prompt lock: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
|
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
|
||||||
err := sstore.MigrateCommandOpts(os.Args[1:])
|
err := sstore.MigrateCommandOpts(os.Args[1:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -416,6 +403,12 @@ func main() {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
authKey, err := scbase.ReadPromptAuthKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[error] %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
GlobalAuthKey = authKey
|
||||||
err = sstore.TryMigrateUp()
|
err = sstore.TryMigrateUp()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[error] migrate up: %v\n", err)
|
log.Printf("[error] migrate up: %v\n", err)
|
||||||
@ -451,13 +444,13 @@ func main() {
|
|||||||
go stdinReadWatch()
|
go stdinReadWatch()
|
||||||
go runWebSocketServer()
|
go runWebSocketServer()
|
||||||
gr := mux.NewRouter()
|
gr := mux.NewRouter()
|
||||||
gr.HandleFunc("/api/ptyout", HandleGetPtyOut)
|
gr.HandleFunc("/api/ptyout", AuthKeyWrap(HandleGetPtyOut))
|
||||||
gr.HandleFunc("/api/remote-pty", HandleRemotePty)
|
gr.HandleFunc("/api/remote-pty", AuthKeyWrap(HandleRemotePty))
|
||||||
gr.HandleFunc("/api/rtnstate", HandleRtnState)
|
gr.HandleFunc("/api/rtnstate", AuthKeyWrap(HandleRtnState))
|
||||||
gr.HandleFunc("/api/get-window", HandleGetWindow)
|
gr.HandleFunc("/api/get-window", AuthKeyWrap(HandleGetWindow))
|
||||||
gr.HandleFunc("/api/run-command", HandleRunCommand).Methods("GET", "POST", "OPTIONS")
|
gr.HandleFunc("/api/run-command", AuthKeyWrap(HandleRunCommand)).Methods("POST")
|
||||||
gr.HandleFunc("/api/get-client-data", HandleGetClientData)
|
gr.HandleFunc("/api/get-client-data", AuthKeyWrap(HandleGetClientData))
|
||||||
gr.HandleFunc("/api/set-winsize", HandleSetWinSize)
|
gr.HandleFunc("/api/set-winsize", AuthKeyWrap(HandleSetWinSize))
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: MainServerAddr,
|
Addr: MainServerAddr,
|
||||||
ReadTimeout: HttpReadTimeout,
|
ReadTimeout: HttpReadTimeout,
|
||||||
|
@ -77,6 +77,7 @@ type WatchScreenPacketType struct {
|
|||||||
SessionId string `json:"sessionid"`
|
SessionId string `json:"sessionid"`
|
||||||
ScreenId string `json:"screenid"`
|
ScreenId string `json:"screenid"`
|
||||||
Connect bool `json:"connect"`
|
Connect bool `json:"connect"`
|
||||||
|
AuthKey string `json:"authkey"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -19,25 +19,40 @@ const WSStatePacketChSize = 20
|
|||||||
const MaxInputDataSize = 1000
|
const MaxInputDataSize = 1000
|
||||||
|
|
||||||
type WSState struct {
|
type WSState struct {
|
||||||
Lock *sync.Mutex
|
Lock *sync.Mutex
|
||||||
ClientId string
|
ClientId string
|
||||||
ConnectTime time.Time
|
ConnectTime time.Time
|
||||||
Shell *wsshell.WSShell
|
Shell *wsshell.WSShell
|
||||||
UpdateCh chan interface{}
|
UpdateCh chan interface{}
|
||||||
UpdateQueue []interface{}
|
UpdateQueue []interface{}
|
||||||
|
Authenticated bool
|
||||||
|
AuthKey string
|
||||||
|
|
||||||
SessionId string
|
SessionId string
|
||||||
ScreenId string
|
ScreenId string
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeWSState(clientId string) *WSState {
|
func MakeWSState(clientId string, authKey string) *WSState {
|
||||||
rtn := &WSState{}
|
rtn := &WSState{}
|
||||||
rtn.Lock = &sync.Mutex{}
|
rtn.Lock = &sync.Mutex{}
|
||||||
rtn.ClientId = clientId
|
rtn.ClientId = clientId
|
||||||
rtn.ConnectTime = time.Now()
|
rtn.ConnectTime = time.Now()
|
||||||
|
rtn.AuthKey = authKey
|
||||||
return rtn
|
return rtn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ws *WSState) SetAuthenticated(authVal bool) {
|
||||||
|
ws.Lock.Lock()
|
||||||
|
defer ws.Lock.Unlock()
|
||||||
|
ws.Authenticated = authVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WSState) IsAuthenticated() bool {
|
||||||
|
ws.Lock.Lock()
|
||||||
|
defer ws.Lock.Unlock()
|
||||||
|
return ws.Authenticated
|
||||||
|
}
|
||||||
|
|
||||||
func (ws *WSState) GetShell() *wsshell.WSShell {
|
func (ws *WSState) GetShell() *wsshell.WSShell {
|
||||||
ws.Lock.Lock()
|
ws.Lock.Lock()
|
||||||
defer ws.Lock.Unlock()
|
defer ws.Lock.Unlock()
|
||||||
@ -151,6 +166,15 @@ func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error
|
|||||||
return fmt.Errorf("invalid watchscreen screenid: %w", err)
|
return fmt.Errorf("invalid watchscreen screenid: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if wsPk.AuthKey == "" {
|
||||||
|
ws.SetAuthenticated(false)
|
||||||
|
return fmt.Errorf("invalid watchscreen, no authkey")
|
||||||
|
}
|
||||||
|
if wsPk.AuthKey != ws.AuthKey {
|
||||||
|
ws.SetAuthenticated(false)
|
||||||
|
return fmt.Errorf("invalid watchscreen, invalid authkey")
|
||||||
|
}
|
||||||
|
ws.SetAuthenticated(true)
|
||||||
if wsPk.SessionId == "" || wsPk.ScreenId == "" {
|
if wsPk.SessionId == "" || wsPk.ScreenId == "" {
|
||||||
ws.UnWatchScreen()
|
ws.UnWatchScreen()
|
||||||
} else {
|
} else {
|
||||||
@ -179,6 +203,20 @@ func (ws *WSState) RunWSRead() {
|
|||||||
log.Printf("error unmarshalling ws message: %v\n", err)
|
log.Printf("error unmarshalling ws message: %v\n", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if pk.GetType() == scpacket.WatchScreenPacketStr {
|
||||||
|
wsPk := pk.(*scpacket.WatchScreenPacketType)
|
||||||
|
err := ws.handleWatchScreen(wsPk)
|
||||||
|
if err != nil {
|
||||||
|
// TODO send errors back to client, likely unrecoverable
|
||||||
|
log.Printf("[ws %s] error %v\n", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
isAuth := ws.IsAuthenticated()
|
||||||
|
if !isAuth {
|
||||||
|
log.Printf("[error] cannot process ws-packet[%s], not authenticated\n", pk.GetType())
|
||||||
|
continue
|
||||||
|
}
|
||||||
if pk.GetType() == scpacket.FeInputPacketStr {
|
if pk.GetType() == scpacket.FeInputPacketStr {
|
||||||
feInputPk := pk.(*scpacket.FeInputPacketType)
|
feInputPk := pk.(*scpacket.FeInputPacketType)
|
||||||
if feInputPk.Remote.OwnerId != "" {
|
if feInputPk.Remote.OwnerId != "" {
|
||||||
@ -198,15 +236,6 @@ func (ws *WSState) RunWSRead() {
|
|||||||
}()
|
}()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if pk.GetType() == scpacket.WatchScreenPacketStr {
|
|
||||||
wsPk := pk.(*scpacket.WatchScreenPacketType)
|
|
||||||
err := ws.handleWatchScreen(wsPk)
|
|
||||||
if err != nil {
|
|
||||||
// TODO send errors back to client, likely unrecoverable
|
|
||||||
log.Printf("[ws %s] error %v\n", err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if pk.GetType() == scpacket.RemoteInputPacketStr {
|
if pk.GetType() == scpacket.RemoteInputPacketStr {
|
||||||
inputPk := pk.(*scpacket.RemoteInputPacketType)
|
inputPk := pk.(*scpacket.RemoteInputPacketType)
|
||||||
if inputPk.RemoteId == "" {
|
if inputPk.RemoteId == "" {
|
||||||
|
Loading…
Reference in New Issue
Block a user