From d60ab3e953c004f528061befbe994e7fbd6fec02 Mon Sep 17 00:00:00 2001 From: "Kristian S. Stangeland" Date: Thu, 28 Feb 2013 01:39:49 +0100 Subject: [PATCH] Identify player connections by socket address. It's the only thing that will not not be removed when a network manager closes, making it relatively safe to block on. --- .../protocol/concurrency/BlockingHashMap.java | 76 ++++++++--- .../injector/player/NetLoginInjector.java | 80 +++--------- .../player/NetworkServerInjector.java | 5 +- .../injector/player/PlayerInjector.java | 31 +++-- .../player/ProxyPlayerInjectionHandler.java | 45 +++---- .../server/AbstractInputStreamLookup.java | 32 +++-- .../server/InputStreamReflectLookup.java | 119 ++++++++++-------- 7 files changed, 202 insertions(+), 186 deletions(-) diff --git a/ProtocolLib/src/main/java/com/comphenix/protocol/concurrency/BlockingHashMap.java b/ProtocolLib/src/main/java/com/comphenix/protocol/concurrency/BlockingHashMap.java index 54295cdb..19d35b51 100644 --- a/ProtocolLib/src/main/java/com/comphenix/protocol/concurrency/BlockingHashMap.java +++ b/ProtocolLib/src/main/java/com/comphenix/protocol/concurrency/BlockingHashMap.java @@ -19,15 +19,20 @@ package com.comphenix.protocol.concurrency; import java.util.Collection; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; -import com.google.common.collect.MapMaker; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.RemovalListener; +import com.google.common.cache.RemovalNotification; /** * A map that supports blocking on read operations. Null keys are not supported. *

- * Keys are stored as weak references, and will be automatically removed once they've all been dereferenced. + * Values are stored as weak references, and will be automatically removed once they've all been dereferenced. *

* @author Kristian * @@ -35,10 +40,10 @@ import com.google.common.collect.MapMaker; * @param - type of the value. */ public class BlockingHashMap { - // Map of values + private final Cache backingCache; private final ConcurrentMap backingMap; - + // Map of locked objects private final ConcurrentMap locks; @@ -46,8 +51,24 @@ public class BlockingHashMap { * Initialize a new map. */ public BlockingHashMap() { - backingMap = new MapMaker().weakKeys().makeMap(); - locks = new MapMaker().weakKeys().makeMap(); + backingCache = CacheBuilder.newBuilder().weakValues().removalListener( + new RemovalListener() { + @Override + public void onRemoval(RemovalNotification entry) { + // Clean up locks too + locks.remove(entry.getKey()); + } + }).build( + new CacheLoader() { + @Override + public TValue load(TKey key) throws Exception { + throw new IllegalStateException("Illegal use. Access the map directly instead."); + } + }); + backingMap = backingCache.asMap(); + + // Normal concurrent hash map + locks = new ConcurrentHashMap(); } /** @@ -94,34 +115,57 @@ public class BlockingHashMap { * @throws InterruptedException If the current thread got interrupted while waiting. */ public TValue get(TKey key, long timeout, TimeUnit unit) throws InterruptedException { + return get(key, timeout, unit, false); + } + + /** + * Waits until a value has been associated with the given key, and then retrieves that value. + *

+ * If timeout is zero, this method will return immediately if it can't find an socket injector. + * + * @param key - the key whose associated value is to be returned + * @param timeout - the amount of time to wait until an association has been made. + * @param unit - unit of timeout. + * @param ignoreInterrupted - TRUE if we should ignore the thread being interrupted, FALSE otherwise. + * @return The value to which the specified key is mapped, or NULL if the timeout elapsed. + * @throws InterruptedException If the current thread got interrupted while waiting. + */ + public TValue get(TKey key, long timeout, TimeUnit unit, boolean ignoreInterrupted) throws InterruptedException { if (key == null) throw new IllegalArgumentException("key cannot be NULL."); if (unit == null) throw new IllegalArgumentException("Unit cannot be NULL."); + if (timeout < 0) + throw new IllegalArgumentException("Timeout cannot be less than zero."); TValue value = backingMap.get(key); // Only lock if no value is available - if (value == null) { + if (value == null && timeout > 0) { final Object lock = getLock(key); final long stopTimeNS = System.nanoTime() + unit.toNanos(timeout); // Don't exceed the timeout synchronized (lock) { while (value == null) { - long remainingTime = stopTimeNS - System.nanoTime(); - - if (remainingTime > 0) { - TimeUnit.NANOSECONDS.timedWait(lock, remainingTime); - value = backingMap.get(key); - } else { - // Timeout elapsed - break; + try { + long remainingTime = stopTimeNS - System.nanoTime(); + + if (remainingTime > 0) { + TimeUnit.NANOSECONDS.timedWait(lock, remainingTime); + value = backingMap.get(key); + } else { + // Timeout elapsed + break; + } + } catch (InterruptedException e) { + // This is fairly dangerous - but we might HAVE to block the thread + if (!ignoreInterrupted) + throw e; } } } } - return value; } diff --git a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/NetLoginInjector.java b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/NetLoginInjector.java index f7f669be..fc5052d6 100644 --- a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/NetLoginInjector.java +++ b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/NetLoginInjector.java @@ -17,18 +17,14 @@ package com.comphenix.protocol.injector.player; -import java.lang.reflect.Field; -import java.net.SocketAddress; import java.util.concurrent.ConcurrentMap; +import org.bukkit.Server; import org.bukkit.entity.Player; import com.comphenix.protocol.error.ErrorReporter; import com.comphenix.protocol.injector.GamePhase; -import com.comphenix.protocol.injector.server.AbstractInputStreamLookup; -import com.comphenix.protocol.injector.server.SocketInjector; -import com.comphenix.protocol.reflect.FieldUtils; -import com.comphenix.protocol.reflect.FuzzyReflection; +import com.comphenix.protocol.injector.server.TemporaryPlayerFactory; import com.comphenix.protocol.utility.MinecraftReflection; import com.google.common.collect.Maps; @@ -39,23 +35,21 @@ import com.google.common.collect.Maps; */ class NetLoginInjector { private ConcurrentMap injectedLogins = Maps.newConcurrentMap(); - - private static Field networkManagerField; - private static Field socketAddressField; - + // Handles every hook private ProxyPlayerInjectionHandler injectionHandler; + + // Create temporary players + private TemporaryPlayerFactory playerFactory = new TemporaryPlayerFactory(); - // Associate input streams and injectors - private AbstractInputStreamLookup inputStreamLookup; - - // The current error rerporter + // The current error reporter private ErrorReporter reporter; + private Server server; - public NetLoginInjector(ErrorReporter reporter, ProxyPlayerInjectionHandler injectionHandler, AbstractInputStreamLookup inputStreamLookup) { + public NetLoginInjector(ErrorReporter reporter, Server server, ProxyPlayerInjectionHandler injectionHandler) { this.reporter = reporter; + this.server = server; this.injectionHandler = injectionHandler; - this.inputStreamLookup = inputStreamLookup; } /** @@ -69,22 +63,16 @@ class NetLoginInjector { if (!injectionHandler.isInjectionNecessary(GamePhase.LOGIN)) return inserting; - Object networkManager = getNetworkManager(inserting); - SocketAddress address = getAddress(networkManager); + Player temporary = playerFactory.createTemporaryPlayer(server); + PlayerInjector injector = injectionHandler.injectPlayer(temporary, inserting, GamePhase.LOGIN); - // Get the underlying socket - SocketInjector socketInjector = inputStreamLookup.getSocketInjector(address); - - // This is the case if we're dealing with a connection initiated by the injected server socket - if (socketInjector != null) { - PlayerInjector injector = injectionHandler.injectPlayer(socketInjector.getPlayer(), inserting, GamePhase.LOGIN); - - if (injector != null) { - injector.updateOnLogin = true; - - // Save the login - injectedLogins.putIfAbsent(inserting, injector); - } + if (injector != null) { + // Update injector as well + TemporaryPlayerFactory.setInjectorInPlayer(temporary, injector); + injector.updateOnLogin = true; + + // Save the login + injectedLogins.putIfAbsent(inserting, injector); } // NetServerInjector can never work (currently), so we don't need to replace the NetLoginHandler @@ -98,36 +86,6 @@ class NetLoginInjector { } } - /** - * Retrieve the network manager from a given pending connection. - * @param inserting - the pending connection. - * @return The referenced network manager. - * @throws IllegalAccessException If we are unable to read the network manager. - */ - private Object getNetworkManager(Object inserting) throws IllegalAccessException { - if (networkManagerField == null) { - networkManagerField = FuzzyReflection.fromObject(inserting, true). - getFieldByType("networkManager", MinecraftReflection.getNetworkManagerClass()); - } - - return FieldUtils.readField(networkManagerField, inserting, true); - } - - /** - * Retrieve the socket address stored in a network manager. - * @param networkManager - the network manager. - * @return The associated socket address. - * @throws IllegalAccessException If we are unable to read the address. - */ - private SocketAddress getAddress(Object networkManager) throws IllegalAccessException { - if (socketAddressField == null) { - socketAddressField = FuzzyReflection.fromObject(networkManager, true). - getFieldByType("socketAddress", SocketAddress.class); - } - - return (SocketAddress) FieldUtils.readField(socketAddressField, networkManager, true); - } - /** * Invoked when a NetLoginHandler should be reverted. * @param inserting - the original NetLoginHandler. diff --git a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/NetworkServerInjector.java b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/NetworkServerInjector.java index f81be360..d9a7f447 100644 --- a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/NetworkServerInjector.java +++ b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/NetworkServerInjector.java @@ -180,6 +180,7 @@ class NetworkServerInjector extends PlayerInjector { return; if (!tryInjectManager()) { + Class serverHandlerClass = MinecraftReflection.getNetServerHandlerClass(); // Try to override the proxied object if (proxyServerField != null) { @@ -188,6 +189,8 @@ class NetworkServerInjector extends PlayerInjector { if (serverHandler == null) throw new RuntimeException("Cannot hook player: Inner proxy object is NULL."); + else + serverHandlerClass = serverHandler.getClass(); // Try again if (tryInjectManager()) { @@ -198,7 +201,7 @@ class NetworkServerInjector extends PlayerInjector { throw new RuntimeException( "Cannot hook player: Unable to find a valid constructor for the " - + MinecraftReflection.getNetServerHandlerClass().getName() + " object."); + + serverHandlerClass.getName() + " object."); } } diff --git a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/PlayerInjector.java b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/PlayerInjector.java index 5764ac9a..524bd40a 100644 --- a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/PlayerInjector.java +++ b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/PlayerInjector.java @@ -60,6 +60,7 @@ abstract class PlayerInjector implements SocketInjector { protected static Field networkManagerField; protected static Field netHandlerField; protected static Field socketField; + protected static Field socketAddressField; private static Field inputField; private static Field entityPlayerField; @@ -87,8 +88,9 @@ abstract class PlayerInjector implements SocketInjector { protected Object serverHandler; protected Object netHandler; - // Current socket + // Current socket and address protected Socket socket; + protected SocketAddress socketAddress; // The packet manager and filters protected ListenerInvoker invoker; @@ -250,7 +252,8 @@ abstract class PlayerInjector implements SocketInjector { public Socket getSocket() throws IllegalAccessException { try { if (socketField == null) - socketField = FuzzyReflection.fromObject(networkManager, true).getFieldListByType(Socket.class).get(0); + socketField = FuzzyReflection.fromObject(networkManager, true). + getFieldListByType(Socket.class).get(0); if (socket == null) socket = (Socket) FieldUtils.readField(socketField, networkManager, true); return socket; @@ -261,19 +264,23 @@ abstract class PlayerInjector implements SocketInjector { } /** - * Retrieve the associated address of this player. - * @return The associated address. - * @throws IllegalAccessException If we're unable to read the socket field. + * Retrieve the associated remote address of a player. + * @return The associated remote address.. + * @throws IllegalAccessException If we're unable to read the socket address field. */ @Override public SocketAddress getAddress() throws IllegalAccessException { - Socket socket = getSocket(); - - // Guard against NULL - if (socket != null) - return socket.getRemoteSocketAddress(); - else - return null; + try { + if (socketAddressField == null) + socketAddressField = FuzzyReflection.fromObject(networkManager, true). + getFieldListByType(SocketAddress.class).get(0); + if (socketAddress == null) + socketAddress = (SocketAddress) FieldUtils.readField(socketAddressField, networkManager, true); + return socketAddress; + + } catch (IndexOutOfBoundsException e) { + throw new IllegalAccessException("Unable to read the socket address field."); + } } /** diff --git a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/ProxyPlayerInjectionHandler.java b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/ProxyPlayerInjectionHandler.java index 74bcaafd..14091c80 100644 --- a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/ProxyPlayerInjectionHandler.java +++ b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/player/ProxyPlayerInjectionHandler.java @@ -20,9 +20,12 @@ package com.comphenix.protocol.injector.player; import java.io.DataInputStream; import java.lang.reflect.InvocationTargetException; import java.net.InetSocketAddress; -import java.net.Socket; +import java.net.SocketAddress; import java.util.Map; import java.util.Set; + +import net.sf.cglib.proxy.Factory; + import org.bukkit.Server; import org.bukkit.entity.Player; @@ -39,7 +42,7 @@ import com.comphenix.protocol.injector.PacketFilterManager.PlayerInjectHooks; import com.comphenix.protocol.injector.server.AbstractInputStreamLookup; import com.comphenix.protocol.injector.server.InputStreamLookupBuilder; import com.comphenix.protocol.injector.server.SocketInjector; -import com.comphenix.protocol.injector.server.TemporaryPlayerFactory; + import com.google.common.base.Predicate; import com.google.common.collect.Maps; @@ -105,7 +108,7 @@ class ProxyPlayerInjectionHandler implements PlayerInjectionHandler { build(); // Create net login injectors and the server connection injector - this.netLoginInjector = new NetLoginInjector(reporter, this, inputStreamLookup); + this.netLoginInjector = new NetLoginInjector(reporter, server, this); this.serverInjection = new InjectedServerConnection(reporter, inputStreamLookup, server, netLoginInjector); serverInjection.injectList(); } @@ -216,7 +219,7 @@ class ProxyPlayerInjectionHandler implements PlayerInjectionHandler { @Override public Player getPlayerByConnection(DataInputStream inputStream) { // Wait until the connection owner has been established - SocketInjector injector = inputStreamLookup.getSocketInjector(inputStream); + SocketInjector injector = inputStreamLookup.waitSocketInjector(inputStream); if (injector != null) { return injector.getPlayer(); @@ -309,24 +312,18 @@ class ProxyPlayerInjectionHandler implements PlayerInjectionHandler { injector.initialize(injectionPoint); // Get socket and socket injector - Socket socket = injector.getSocket(); - SocketInjector previous = null; - - // Due to a race condition, the main server "accept connections" thread may - // get a closed network manager with a NULL input stream, - if (socket == null) { - - } - + SocketAddress address = injector.getAddress(); + SocketInjector previous = inputStreamLookup.peekSocketInjector(address); + // Close any previously associated hooks before we proceed - if (previous != null && previous instanceof PlayerInjector) { + if (previous != null && !(player instanceof Factory)) { uninjectPlayer(previous.getPlayer(), true); } injector.injectManager(); // Save injector - inputStreamLookup.setSocketInjector(socket, injector); + inputStreamLookup.setSocketInjector(address, injector); break; } @@ -453,7 +450,7 @@ class ProxyPlayerInjectionHandler implements PlayerInjectionHandler { @Override public boolean uninjectPlayer(InetSocketAddress address) { if (!hasClosed && address != null) { - SocketInjector injector = inputStreamLookup.getSocketInjector(address); + SocketInjector injector = inputStreamLookup.peekSocketInjector(address); // Clean up if (injector != null) @@ -495,7 +492,6 @@ class ProxyPlayerInjectionHandler implements PlayerInjectionHandler { */ @Override public void processPacket(Player player, Object mcPacket) throws IllegalAccessException, InvocationTargetException { - PlayerInjector injector = getInjector(player); // Process the given packet, or simply give up @@ -518,16 +514,13 @@ class ProxyPlayerInjectionHandler implements PlayerInjectionHandler { if (injector == null) { // Try getting it from the player itself - SocketInjector socket = TemporaryPlayerFactory.getInjectorFromPlayer(player); - - // Only accept it if it's a player injector - if (!(socket instanceof PlayerInjector)) { - socket = inputStreamLookup.getSocketInjector(player.getAddress()); - } + SocketAddress address = player.getAddress(); + // Look that up without blocking + SocketInjector result = inputStreamLookup.peekSocketInjector(address); - // Ensure that it is a player injector - if (socket instanceof PlayerInjector) - return (PlayerInjector) socket; + // Ensure that it is non-null and a player injector + if (result instanceof PlayerInjector) + return (PlayerInjector) result; else return null; } else { diff --git a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/server/AbstractInputStreamLookup.java b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/server/AbstractInputStreamLookup.java index eef3275d..86b13605 100644 --- a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/server/AbstractInputStreamLookup.java +++ b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/server/AbstractInputStreamLookup.java @@ -63,43 +63,40 @@ public abstract class AbstractInputStreamLookup { */ public abstract void postWorldLoaded(); - /** - * Retrieve the associated socket injector for a player. - * @param filtered - the indentifying filtered input stream. - * @return The socket injector we have associated with this player. - * @throws FieldAccessException Unable to access input stream. - */ - public SocketInjector getSocketInjector(FilterInputStream filtered) { - return getSocketInjector(getInputStream(filtered)); - } - /** * Retrieve the associated socket injector for a player. * @param input - the indentifying filtered input stream. * @return The socket injector we have associated with this player. */ - public abstract SocketInjector getSocketInjector(InputStream input); + public abstract SocketInjector waitSocketInjector(InputStream input); /** * Retrieve an injector by its socket. * @param socket - the socket. * @return The socket injector. */ - public abstract SocketInjector getSocketInjector(Socket socket); + public abstract SocketInjector waitSocketInjector(Socket socket); /** * Retrieve a injector by its address. * @param address - the address of the socket. * @return The socket injector, or NULL if not found. */ - public abstract SocketInjector getSocketInjector(SocketAddress address); + public abstract SocketInjector waitSocketInjector(SocketAddress address); /** - * Associate a given socket the provided socket injector. - * @param input - the socket to associate. + * Attempt to get a socket injector without blocking the thread. + * @param address - the address to lookup. + * @return The socket injector, or NULL if not found. + */ + public abstract SocketInjector peekSocketInjector(SocketAddress address); + + /** + * Associate a given socket address to the provided socket injector. + * @param input - the socket address to associate. * @param injector - the injector. */ - public abstract void setSocketInjector(Socket socket, SocketInjector injector); + public abstract void setSocketInjector(SocketAddress address, SocketInjector injector); /** * If a player can hold a reference to its parent injector, this method will update that reference. @@ -111,8 +108,7 @@ public abstract class AbstractInputStreamLookup { // Default implementation if (player instanceof InjectorContainer) { - InjectorContainer container = (InjectorContainer) player; - container.setInjector(current); + TemporaryPlayerFactory.setInjectorInPlayer(player, current); } } diff --git a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/server/InputStreamReflectLookup.java b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/server/InputStreamReflectLookup.java index 7aeae97a..58305d49 100644 --- a/ProtocolLib/src/main/java/com/comphenix/protocol/injector/server/InputStreamReflectLookup.java +++ b/ProtocolLib/src/main/java/com/comphenix/protocol/injector/server/InputStreamReflectLookup.java @@ -6,10 +6,11 @@ import java.lang.reflect.Field; import java.net.Socket; import java.net.SocketAddress; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; import org.bukkit.Server; -import org.bukkit.entity.Player; +import com.comphenix.protocol.concurrency.BlockingHashMap; import com.comphenix.protocol.error.ErrorReporter; import com.comphenix.protocol.reflect.FieldAccessException; import com.comphenix.protocol.reflect.FieldUtils; @@ -17,18 +18,33 @@ import com.comphenix.protocol.reflect.FuzzyReflection; import com.google.common.collect.MapMaker; class InputStreamReflectLookup extends AbstractInputStreamLookup { - // Using weak keys and values ensures that we will not hold up garbage collection - protected ConcurrentMap ownerSocket = new MapMaker().weakKeys().makeMap(); - protected ConcurrentMap addressLookup = new MapMaker().weakValues().makeMap(); - protected ConcurrentMap inputLookup = new MapMaker().weakValues().makeMap(); + // The default lookup timeout + private static final long DEFAULT_TIMEOUT = 2000; // ms + + // Using weak keys and values ensures that we will not hold up garbage collection + protected BlockingHashMap addressLookup = new BlockingHashMap(); + protected ConcurrentMap inputLookup = new MapMaker().weakValues().makeMap(); + + // The timeout + private final long injectorTimeout; - // Used to create fake players - private TemporaryPlayerFactory tempPlayerFactory = new TemporaryPlayerFactory(); - public InputStreamReflectLookup(ErrorReporter reporter, Server server) { - super(reporter, server); + this(reporter, server, DEFAULT_TIMEOUT); } + /** + * Initialize a reflect lookup with a given default injector timeout. + *

+ * This timeout defines the maximum amount of time to wait until an injector has been discovered. + * @param reporter - the error reporter. + * @param server - the current Bukkit server. + * @param injectorTimeout - the injector timeout. + */ + public InputStreamReflectLookup(ErrorReporter reporter, Server server, long injectorTimeout) { + super(reporter, server); + this.injectorTimeout = injectorTimeout; + } + @Override public void inject(Object container) { // Do nothing @@ -38,34 +54,45 @@ class InputStreamReflectLookup extends AbstractInputStreamLookup { public void postWorldLoaded() { // Nothing again } - + @Override - public SocketInjector getSocketInjector(Socket socket) { - SocketInjector result = ownerSocket.get(socket); - - if (result == null) { - Player player = tempPlayerFactory.createTemporaryPlayer(server); - SocketInjector created = new TemporarySocketInjector(player, socket); - - result = ownerSocket.putIfAbsent(socket, created); - - if (result == null) { - // We won - use our created injector - TemporaryPlayerFactory.setInjectorInPlayer(player, created); - result = created; - } + public SocketInjector peekSocketInjector(SocketAddress address) { + try { + return addressLookup.get(address, 0, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + // Whatever + return null; } - return result; } @Override - public SocketInjector getSocketInjector(InputStream input) { + public SocketInjector waitSocketInjector(SocketAddress address) { try { - Socket socket = getSocket(input); + // Note that we actually SWALLOW interrupts here - this is because Minecraft uses interrupts to + // periodically wake up waiting readers and writers. We have to wait for the dedicated server thread + // to catch up, so we'll swallow these interrupts. + // + // TODO: Consider if we should raise the thread priority of the dedicated server listener thread. + return addressLookup.get(address, injectorTimeout, TimeUnit.MILLISECONDS, true); + } catch (InterruptedException e) { + // This cannot be! + throw new IllegalStateException("Impossible exception occured!", e); + } + } + + @Override + public SocketInjector waitSocketInjector(Socket socket) { + return waitSocketInjector(socket.getRemoteSocketAddress()); + } + + @Override + public SocketInjector waitSocketInjector(InputStream input) { + try { + SocketAddress address = getSocketAddress(input); // Guard against NPE - if (socket != null) - return getSocketInjector(socket); + if (address != null) + return waitSocketInjector(address); else return null; } catch (IllegalAccessException e) { @@ -74,38 +101,36 @@ class InputStreamReflectLookup extends AbstractInputStreamLookup { } /** - * Use reflection to get the underlying socket from an input stream. + * Use reflection to get the underlying socket address from an input stream. * @param stream - the socket stream to lookup. - * @return The underlying socket, or NULL if not found. + * @return The underlying socket address, or NULL if not found. * @throws IllegalAccessException Unable to access socket field. */ - private Socket getSocket(InputStream stream) throws IllegalAccessException { + private SocketAddress getSocketAddress(InputStream stream) throws IllegalAccessException { // Extra check, just in case if (stream instanceof FilterInputStream) - return getSocket(getInputStream((FilterInputStream) stream)); + return getSocketAddress(getInputStream((FilterInputStream) stream)); - Socket result = inputLookup.get(stream); + SocketAddress result = inputLookup.get(stream); if (result == null) { - result = lookupSocket(stream); + Socket socket = lookupSocket(stream); // Save it + result = socket.getRemoteSocketAddress(); inputLookup.put(stream, result); } return result; } @Override - public void setSocketInjector(Socket socket, SocketInjector injector) { - if (socket == null) - throw new IllegalArgumentException("socket cannot be NULL"); + public void setSocketInjector(SocketAddress address, SocketInjector injector) { + if (address == null) + throw new IllegalArgumentException("address cannot be NULL"); if (injector == null) throw new IllegalArgumentException("injector cannot be NULL."); - SocketInjector previous = ownerSocket.put(socket, injector); - - // Save the address lookup too - addressLookup.put(socket.getRemoteSocketAddress(), socket); + SocketInjector previous = addressLookup.put(address, injector); // Any previous temporary players will also be associated if (previous != null) { @@ -113,16 +138,6 @@ class InputStreamReflectLookup extends AbstractInputStreamLookup { onPreviousSocketOverwritten(previous, injector); } } - - @Override - public SocketInjector getSocketInjector(SocketAddress address) { - Socket socket = addressLookup.get(address); - - if (socket != null) - return getSocketInjector(socket); - else - return null; - } @Override public void cleanupAll() {