From ed82f642f19e9f079285a31f5acea6b9b7dd8c06 Mon Sep 17 00:00:00 2001 From: Vankka Date: Sat, 30 Nov 2024 15:21:44 +0200 Subject: [PATCH] Fix for non-text channels in ChannelConfigHelper, resolution by full channel atom. Fix forwarding messages from non-text, non-thread channels to Minecraft. Add test for voice channel messages --- .../DiscordChatMessageProcessEvent.java | 5 ++ .../DiscordChatMessageReceiveEvent.java | 19 ++---- .../commands/subcommand/BroadcastCommand.java | 2 +- .../discord/DiscordChatMessageModule.java | 15 ++++- .../common/helper/ChannelConfigHelper.java | 60 ++++++++++++------- .../discordsrv/common/FullBootExtension.java | 8 ++- .../com/discordsrv/common/MockDiscordSRV.java | 10 ++-- .../MinecraftToDiscordChatMessageTest.java | 6 +- 8 files changed, 78 insertions(+), 47 deletions(-) diff --git a/api/src/main/java/com/discordsrv/api/events/message/process/discord/DiscordChatMessageProcessEvent.java b/api/src/main/java/com/discordsrv/api/events/message/process/discord/DiscordChatMessageProcessEvent.java index cee04d4a..ff5a8cf2 100644 --- a/api/src/main/java/com/discordsrv/api/events/message/process/discord/DiscordChatMessageProcessEvent.java +++ b/api/src/main/java/com/discordsrv/api/events/message/process/discord/DiscordChatMessageProcessEvent.java @@ -24,6 +24,7 @@ package com.discordsrv.api.events.message.process.discord; import com.discordsrv.api.channel.GameChannel; +import com.discordsrv.api.discord.entity.channel.DiscordMessageChannel; import com.discordsrv.api.discord.entity.message.ReceivedDiscordMessage; import com.discordsrv.api.events.Cancellable; import com.discordsrv.api.events.Processable; @@ -51,6 +52,10 @@ public class DiscordChatMessageProcessEvent implements Cancellable, Processable. this.destinationChannel = destinationChannel; } + public DiscordMessageChannel getChannel() { + return message.getChannel(); + } + public ReceivedDiscordMessage getMessage() { return message; } diff --git a/api/src/main/java/com/discordsrv/api/events/message/receive/discord/DiscordChatMessageReceiveEvent.java b/api/src/main/java/com/discordsrv/api/events/message/receive/discord/DiscordChatMessageReceiveEvent.java index 3a2c3e9c..cf355a61 100644 --- a/api/src/main/java/com/discordsrv/api/events/message/receive/discord/DiscordChatMessageReceiveEvent.java +++ b/api/src/main/java/com/discordsrv/api/events/message/receive/discord/DiscordChatMessageReceiveEvent.java @@ -23,9 +23,7 @@ package com.discordsrv.api.events.message.receive.discord; -import com.discordsrv.api.discord.entity.channel.DiscordMessageChannel; -import com.discordsrv.api.discord.entity.channel.DiscordTextChannel; -import com.discordsrv.api.discord.entity.channel.DiscordThreadChannel; +import com.discordsrv.api.discord.entity.channel.*; import com.discordsrv.api.discord.entity.guild.DiscordGuild; import com.discordsrv.api.discord.entity.message.ReceivedDiscordMessage; import com.discordsrv.api.events.Cancellable; @@ -39,15 +37,12 @@ import org.jetbrains.annotations.NotNull; public class DiscordChatMessageReceiveEvent implements Cancellable { private final ReceivedDiscordMessage message; - private final DiscordMessageChannel channel; + private final DiscordGuildMessageChannel channel; private boolean cancelled; - public DiscordChatMessageReceiveEvent(@NotNull ReceivedDiscordMessage discordMessage, @NotNull DiscordMessageChannel channel) { + public DiscordChatMessageReceiveEvent(@NotNull ReceivedDiscordMessage discordMessage, @NotNull DiscordGuildMessageChannel channel) { this.message = discordMessage; this.channel = channel; - if (!(channel instanceof DiscordTextChannel) && !(channel instanceof DiscordThreadChannel)) { - throw new IllegalStateException("Cannot process messages that aren't from a text channel or thread"); - } } public ReceivedDiscordMessage getMessage() { @@ -59,13 +54,7 @@ public class DiscordChatMessageReceiveEvent implements Cancellable { } public DiscordGuild getGuild() { - if (channel instanceof DiscordTextChannel) { - return ((DiscordTextChannel) channel).getGuild(); - } else if (channel instanceof DiscordThreadChannel) { - return ((DiscordThreadChannel) channel).getParentChannel().getGuild(); - } else { - throw new IllegalStateException("Message isn't from a text channel or thread"); - } + return channel.getGuild(); } @Override diff --git a/common/src/main/java/com/discordsrv/common/command/game/commands/subcommand/BroadcastCommand.java b/common/src/main/java/com/discordsrv/common/command/game/commands/subcommand/BroadcastCommand.java index b6dca127..15d8dc39 100644 --- a/common/src/main/java/com/discordsrv/common/command/game/commands/subcommand/BroadcastCommand.java +++ b/common/src/main/java/com/discordsrv/common/command/game/commands/subcommand/BroadcastCommand.java @@ -112,7 +112,7 @@ public abstract class BroadcastCommand implements GameCommandExecutor, GameComma channels.add(messageChannel); } } catch (IllegalArgumentException ignored) { - BaseChannelConfig channelConfig = discordSRV.channelConfig().resolve(null, channel); + BaseChannelConfig channelConfig = discordSRV.channelConfig().resolve(channel); CC config = channelConfig != null ? (CC) channelConfig : null; if (config != null) { diff --git a/common/src/main/java/com/discordsrv/common/feature/messageforwarding/discord/DiscordChatMessageModule.java b/common/src/main/java/com/discordsrv/common/feature/messageforwarding/discord/DiscordChatMessageModule.java index cf4bb33c..01ebaf76 100644 --- a/common/src/main/java/com/discordsrv/common/feature/messageforwarding/discord/DiscordChatMessageModule.java +++ b/common/src/main/java/com/discordsrv/common/feature/messageforwarding/discord/DiscordChatMessageModule.java @@ -22,6 +22,8 @@ import com.discordsrv.api.channel.GameChannel; import com.discordsrv.api.component.MinecraftComponent; import com.discordsrv.api.discord.connection.details.DiscordGatewayIntent; import com.discordsrv.api.discord.entity.DiscordUser; +import com.discordsrv.api.discord.entity.channel.DiscordChannel; +import com.discordsrv.api.discord.entity.channel.DiscordGuildMessageChannel; import com.discordsrv.api.discord.entity.channel.DiscordMessageChannel; import com.discordsrv.api.discord.entity.guild.DiscordGuild; import com.discordsrv.api.discord.entity.guild.DiscordGuildMember; @@ -95,12 +97,19 @@ public class DiscordChatMessageModule extends AbstractModule { @Subscribe public void onDiscordMessageReceived(DiscordMessageReceiveEvent event) { - if (!discordSRV.isReady() || event.getMessage().isFromSelf() - || !(event.getTextChannel() != null || event.getThreadChannel() != null)) { + if (!discordSRV.isReady() || event.getMessage().isFromSelf()) { return; } - discordSRV.eventBus().publish(new DiscordChatMessageReceiveEvent(event.getMessage(), event.getChannel())); + DiscordChannel channel = event.getChannel(); + if (!(channel instanceof DiscordGuildMessageChannel)) { + return; + } + + discordSRV.eventBus().publish(new DiscordChatMessageReceiveEvent( + event.getMessage(), + (DiscordGuildMessageChannel) channel + )); } @Subscribe diff --git a/common/src/main/java/com/discordsrv/common/helper/ChannelConfigHelper.java b/common/src/main/java/com/discordsrv/common/helper/ChannelConfigHelper.java index caa27604..6e56d5ff 100644 --- a/common/src/main/java/com/discordsrv/common/helper/ChannelConfigHelper.java +++ b/common/src/main/java/com/discordsrv/common/helper/ChannelConfigHelper.java @@ -20,7 +20,6 @@ package com.discordsrv.common.helper; import com.discordsrv.api.channel.GameChannel; import com.discordsrv.api.discord.entity.channel.DiscordMessageChannel; -import com.discordsrv.api.discord.entity.channel.DiscordTextChannel; import com.discordsrv.api.discord.entity.channel.DiscordThreadChannel; import com.discordsrv.api.events.channel.GameChannelLookupEvent; import com.discordsrv.common.DiscordSRV; @@ -30,6 +29,8 @@ import com.discordsrv.common.config.main.channels.base.ChannelConfig; import com.discordsrv.common.config.main.channels.base.IChannelConfig; import com.discordsrv.common.config.main.generic.DestinationConfig; import com.discordsrv.common.config.main.generic.ThreadConfig; +import com.discordsrv.common.core.logging.Logger; +import com.discordsrv.common.core.logging.NamedLogger; import com.github.benmanes.caffeine.cache.CacheLoader; import com.github.benmanes.caffeine.cache.LoadingCache; import org.apache.commons.lang3.tuple.Pair; @@ -45,6 +46,7 @@ import java.util.concurrent.TimeUnit; public class ChannelConfigHelper { private final DiscordSRV discordSRV; + private final Logger logger; // game channel name eg. "global" -> game channel ("discordsrv:global") private final LoadingCache nameToChannelCache; @@ -53,11 +55,13 @@ public class ChannelConfigHelper { private final Map configs; // caches for Discord channel -> config - private final Map> textChannelToConfigMap; + private final Map> messageChannelToConfigMap; private final Map, Map> threadToConfigMap; public ChannelConfigHelper(DiscordSRV discordSRV) { this.discordSRV = discordSRV; + this.logger = new NamedLogger(discordSRV, "CHANNEL_CONFIG_HELPER"); + this.nameToChannelCache = discordSRV.caffeineBuilder() .expireAfterWrite(60, TimeUnit.SECONDS) .expireAfterAccess(30, TimeUnit.SECONDS) @@ -65,21 +69,33 @@ public class ChannelConfigHelper { .build(new CacheLoader() { @Override - public @Nullable GameChannel load(@NotNull String channelName) { - GameChannelLookupEvent event = new GameChannelLookupEvent(null, channelName); + public @Nullable GameChannel load(@NotNull String channelAtom) { + Pair channelPair = parseOwnerAndChannel(channelAtom); + + GameChannelLookupEvent event = new GameChannelLookupEvent(channelPair.getKey(), channelPair.getValue()); discordSRV.eventBus().publish(event); if (!event.isProcessed()) { return null; } - return event.getChannelFromProcessing(); + GameChannel channel = event.getChannelFromProcessing(); + logger.trace(channelAtom + " looked up to " + GameChannel.toString(channel)); + return channel; } }); this.configs = new HashMap<>(); - this.textChannelToConfigMap = new HashMap<>(); + this.messageChannelToConfigMap = new HashMap<>(); this.threadToConfigMap = new LinkedHashMap<>(); } + private Pair parseOwnerAndChannel(String channelAtom) { + String[] split = channelAtom.split(":", 2); + String channelName = split[split.length - 1]; + String ownerName = split.length == 2 ? split[0] : null; + + return Pair.of(ownerName, channelName); + } + @SuppressWarnings("unchecked") private BaseChannelConfig map(BaseChannelConfig defaultConfig, BaseChannelConfig config) throws SerializationException { @@ -120,7 +136,7 @@ public class ChannelConfigHelper { this.configs.putAll(configs); } - Map> text = new HashMap<>(); + Map> messageChannel = new HashMap<>(); Map, Map> thread = new HashMap<>(); for (Map.Entry entry : channels().entrySet()) { @@ -132,7 +148,7 @@ public class ChannelConfigHelper { List channelIds = destination.channelIds; if (channelIds != null) { for (long channelId : channelIds) { - text.computeIfAbsent(channelId, key -> new LinkedHashMap<>()) + messageChannel.computeIfAbsent(channelId, key -> new LinkedHashMap<>()) .put(channelName, value); } } @@ -151,9 +167,9 @@ public class ChannelConfigHelper { } } - synchronized (textChannelToConfigMap) { - textChannelToConfigMap.clear(); - textChannelToConfigMap.putAll(text); + synchronized (messageChannelToConfigMap) { + messageChannelToConfigMap.clear(); + messageChannelToConfigMap.putAll(messageChannel); } synchronized (threadToConfigMap) { threadToConfigMap.clear(); @@ -204,6 +220,12 @@ public class ChannelConfigHelper { return resolve(gameChannel.getOwnerName(), gameChannel.getChannelName()); } + @Nullable + public BaseChannelConfig resolve(@NotNull String channelAtom) { + Pair channelPair = parseOwnerAndChannel(channelAtom); + return resolve(channelPair.getKey(), channelPair.getValue()); + } + @Nullable public BaseChannelConfig resolve(@Nullable String ownerName, @NotNull String channelName) { if (ownerName != null) { @@ -216,6 +238,7 @@ public class ChannelConfigHelper { } // Check if this owner has the highest priority for this channel name + // in case they are, we can also use "channel" config directly GameChannel gameChannel = nameToChannelCache.get(channelName); if (gameChannel != null && gameChannel.getOwnerName().equalsIgnoreCase(ownerName)) { config = findChannel(channelName); @@ -249,19 +272,16 @@ public class ChannelConfigHelper { } private Map get(DiscordMessageChannel channel) { - Map pairs = null; - if (channel instanceof DiscordTextChannel) { - pairs = getByTextChannel((DiscordTextChannel) channel); - } else if (channel instanceof DiscordThreadChannel) { - pairs = getByThreadChannel((DiscordThreadChannel) channel); + if (channel instanceof DiscordThreadChannel) { + return getByThreadChannel((DiscordThreadChannel) channel); } - return pairs; + return getByMessageChannel(channel); } - private Map getByTextChannel(DiscordTextChannel channel) { - synchronized (textChannelToConfigMap) { - return textChannelToConfigMap.get(channel.getId()); + private Map getByMessageChannel(DiscordMessageChannel channel) { + synchronized (messageChannelToConfigMap) { + return messageChannelToConfigMap.get(channel.getId()); } } diff --git a/common/src/test/java/com/discordsrv/common/FullBootExtension.java b/common/src/test/java/com/discordsrv/common/FullBootExtension.java index 09a9e979..84063f03 100644 --- a/common/src/test/java/com/discordsrv/common/FullBootExtension.java +++ b/common/src/test/java/com/discordsrv/common/FullBootExtension.java @@ -26,16 +26,18 @@ import org.junit.jupiter.api.extension.ExtensionContext; public class FullBootExtension implements BeforeAllCallback, ExtensionContext.Store.CloseableResource { public static String BOT_TOKEN = System.getenv("DISCORDSRV_AUTOTEST_BOT_TOKEN"); - public static String TEST_CHANNEL_ID = System.getenv("DISCORDSRV_AUTOTEST_CHANNEL_ID"); + public static String TEXT_CHANNEL_ID = System.getenv("DISCORDSRV_AUTOTEST_CHANNEL_ID"); public static String FORUM_CHANNEL_ID = System.getenv("DISCORDSRV_AUTOTEST_FORUM_ID"); + public static String VOICE_CHANNEL_ID = System.getenv("DISCORDSRV_AUTOTEST_VOICE_ID"); public boolean started = false; @Override public void beforeAll(ExtensionContext context) { Assumptions.assumeTrue(BOT_TOKEN != null, "Automated testing bot token"); - Assumptions.assumeTrue(TEST_CHANNEL_ID != null, "Automated testing channel id"); - Assumptions.assumeTrue(FORUM_CHANNEL_ID != null, "Automated testing forum id"); + Assumptions.assumeTrue(TEXT_CHANNEL_ID != null, "Automated testing text channel id"); + Assumptions.assumeTrue(FORUM_CHANNEL_ID != null, "Automated testing forum channel id"); + Assumptions.assumeTrue(VOICE_CHANNEL_ID != null, "Automated testing voice channel id"); if (started) return; started = true; diff --git a/common/src/test/java/com/discordsrv/common/MockDiscordSRV.java b/common/src/test/java/com/discordsrv/common/MockDiscordSRV.java index 9cc0ff48..e93e7cd2 100644 --- a/common/src/test/java/com/discordsrv/common/MockDiscordSRV.java +++ b/common/src/test/java/com/discordsrv/common/MockDiscordSRV.java @@ -241,22 +241,24 @@ public class MockDiscordSRV extends AbstractDiscordSRV channelIds = destination.channelIds; channelIds.clear(); - channelIds.add(channelId); + channelIds.add(textChannelId); + channelIds.add(voiceChannelId); List threadConfigs = destination.threads; threadConfigs.clear(); ThreadConfig thread = new ThreadConfig(); - thread.channelId = channelId; + thread.channelId = textChannelId; threadConfigs.add(thread); ThreadConfig forumThread = new ThreadConfig(); diff --git a/common/src/test/java/com/discordsrv/common/messageforwarding/game/MinecraftToDiscordChatMessageTest.java b/common/src/test/java/com/discordsrv/common/messageforwarding/game/MinecraftToDiscordChatMessageTest.java index 099fa156..99a45efe 100644 --- a/common/src/test/java/com/discordsrv/common/messageforwarding/game/MinecraftToDiscordChatMessageTest.java +++ b/common/src/test/java/com/discordsrv/common/messageforwarding/game/MinecraftToDiscordChatMessageTest.java @@ -21,6 +21,7 @@ package com.discordsrv.common.messageforwarding.game; import com.discordsrv.api.discord.entity.channel.DiscordMessageChannel; import com.discordsrv.api.discord.entity.channel.DiscordTextChannel; import com.discordsrv.api.discord.entity.channel.DiscordThreadChannel; +import com.discordsrv.api.discord.entity.channel.DiscordVoiceChannel; import com.discordsrv.api.discord.entity.message.ReceivedDiscordMessage; import com.discordsrv.api.eventbus.EventBus; import com.discordsrv.api.eventbus.Subscribe; @@ -163,6 +164,7 @@ public class MinecraftToDiscordChatMessageTest { @Subscribe public void onForwarded(GameChatMessageForwardedEvent event) { int text = 0; + int voice = 0; int thread = 0; for (ReceivedDiscordMessage message : event.getDiscordMessage().getMessages()) { String content = message.getContent(); @@ -170,13 +172,15 @@ public class MinecraftToDiscordChatMessageTest { DiscordMessageChannel channel = message.getChannel(); if (channel instanceof DiscordTextChannel) { text++; + } else if (channel instanceof DiscordVoiceChannel) { + voice++; } else if (channel instanceof DiscordThreadChannel) { thread++; } } } - success.complete(text == 1 && thread == 2); + success.complete(text == 1 && voice == 1 && thread == 2); } } }