diff --git a/server/routes.go b/server/routes.go index 4f64db2..2cac66b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -5,6 +5,7 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" "net/http" + "strings" "sync" ) @@ -95,6 +96,14 @@ type IRoutes interface { var Routes IRoutes = &routesImpl{} +func NewRoutes() IRoutes { + r := &routesImpl{ + mappings: make(map[string]string), + } + + return r +} + func (r *routesImpl) RegisterAll(mappings map[string]string) { r.Lock() defer r.Unlock() @@ -120,11 +129,13 @@ func (r *routesImpl) FindBackendForServerAddress(serverAddress string) string { r.RLock() defer r.RUnlock() + addressParts := strings.Split(serverAddress, `\x00`) + if r.mappings == nil { return r.defaultRoute } else { - if route, exists := r.mappings[serverAddress]; exists { + if route, exists := r.mappings[addressParts[0]]; exists { return route } else { return r.defaultRoute diff --git a/server/routes_test.go b/server/routes_test.go new file mode 100644 index 0000000..8b7fae9 --- /dev/null +++ b/server/routes_test.go @@ -0,0 +1,53 @@ +package server + +import ( + "testing" +) + +func Test_routesImpl_FindBackendForServerAddress(t *testing.T) { + type args struct { + serverAddress string + } + type mapping struct { + serverAddress string + backend string + } + tests := []struct { + name string + mapping mapping + args args + want string + }{ + { + name: "typical", + mapping: mapping{ + serverAddress: "typical.my.domain", backend: "backend:25565", + }, + args: args{ + serverAddress: `typical.my.domain`, + }, + want: "backend:25565", + }, + { + name: "forge", + mapping: mapping{ + serverAddress: "forge.my.domain", backend: "backend:25566", + }, + args: args{ + serverAddress: `forge.my.domain\x00FML2\x00`, + }, + want: "backend:25566", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewRoutes() + + r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend) + + if got := r.FindBackendForServerAddress(tt.args.serverAddress); got != tt.want { + t.Errorf("routesImpl.FindBackendForServerAddress() = %v, want %v", got, tt.want) + } + }) + } +}