diff --git a/common/src/main/java/com/discordsrv/common/messageforwarding/discord/DiscordMessageMirroringModule.java b/common/src/main/java/com/discordsrv/common/messageforwarding/discord/DiscordMessageMirroringModule.java index 722b600e..3505994f 100644 --- a/common/src/main/java/com/discordsrv/common/messageforwarding/discord/DiscordMessageMirroringModule.java +++ b/common/src/main/java/com/discordsrv/common/messageforwarding/discord/DiscordMessageMirroringModule.java @@ -49,7 +49,7 @@ import java.util.concurrent.TimeUnit; public class DiscordMessageMirroringModule extends AbstractModule { - private final Cache> mapping; + private final Cache> mapping; public DiscordMessageMirroringModule(DiscordSRV discordSRV) { super(discordSRV, new NamedLogger(discordSRV, "DISCORD_MIRRORING")); @@ -133,7 +133,7 @@ public class DiscordMessageMirroringModule extends AbstractModule { for (Pair> pair : messages) { references.add(getReference(pair.getKey(), pair.getValue())); } - mapping.put(getReference(message, null), references); + mapping.put(getCacheKey(message), references); }); }); } @@ -141,7 +141,7 @@ public class DiscordMessageMirroringModule extends AbstractModule { @Subscribe public void onDiscordMessageUpdate(DiscordMessageUpdateEvent event) { ReceivedDiscordMessage message = event.getMessage(); - Set references = mapping.get(getReference(message, null), k -> null); + Set references = mapping.get(getCacheKey(message), k -> null); if (references == null) { return; } @@ -163,7 +163,7 @@ public class DiscordMessageMirroringModule extends AbstractModule { @Subscribe public void onDiscordMessageDelete(DiscordMessageDeleteEvent event) { - Set references = mapping.get(getReference(event.getChannel(), event.getMessageId(), false, null), k -> null); + Set references = mapping.get(getCacheKey(event.getChannel(), event.getMessageId()), k -> null); if (references == null) { return; } @@ -230,6 +230,26 @@ public class DiscordMessageMirroringModule extends AbstractModule { throw new IllegalStateException("Unexpected channel type: " + channel.getClass().getName()); } + private static String getCacheKey(ReceivedDiscordMessage message) { + return getCacheKey(message.getChannel(), message.getId()); + } + + private static String getCacheKey(DiscordMessageChannel channel, long messageId) { + if (channel instanceof DiscordTextChannel) { + return getCacheKey(channel.getId(), 0L, messageId); + } else if (channel instanceof DiscordThreadChannel) { + long parentId = ((DiscordThreadChannel) channel).getParentChannel().getId(); + return getCacheKey(channel.getId(), parentId, messageId); + } + throw new IllegalStateException("Unexpected channel type: " + channel.getClass().getName()); + } + + private static String getCacheKey(long channelId, long threadId, long messageId) { + return Long.toUnsignedString(channelId) + + (threadId > 0 ? ":" + Long.toUnsignedString(threadId) : "") + + ":" + Long.toUnsignedString(messageId); + } + public static class MessageReference { private final long channelId; @@ -285,20 +305,5 @@ public class DiscordMessageMirroringModule extends AbstractModule { } return null; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - MessageReference that = (MessageReference) o; - // Intentionally ignores webhookMessage - return channelId == that.channelId && threadId == that.threadId && messageId == that.messageId; - } - - @Override - public int hashCode() { - // Intentionally ignores webhookMessage - return Objects.hash(channelId, threadId, messageId); - } } }