Add permission checks to destination lookup

This commit is contained in:
Vankka 2024-06-21 13:50:41 +03:00
parent 6e7552b347
commit 73bd92fc86
No known key found for this signature in database
GPG Key ID: 62E48025ED4E7EBB
4 changed files with 53 additions and 8 deletions

View File

@ -23,6 +23,8 @@
package com.discordsrv.api.discord.entity.channel; package com.discordsrv.api.discord.entity.channel;
import com.discordsrv.api.DiscordSRVApi;
import net.dv8tion.jda.api.entities.channel.attribute.IThreadContainer;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import java.util.List; import java.util.List;
@ -42,4 +44,11 @@ public interface DiscordThreadContainer extends DiscordGuildChannel {
CompletableFuture<DiscordThreadChannel> createThread(String name, boolean privateThread); CompletableFuture<DiscordThreadChannel> createThread(String name, boolean privateThread);
CompletableFuture<DiscordThreadChannel> createThread(String name, long messageId); CompletableFuture<DiscordThreadChannel> createThread(String name, long messageId);
/**
* Returns the JDA representation of this object. This should not be used if it can be avoided.
* @return the JDA representation of this object
* @see DiscordSRVApi#jda()
*/
IThreadContainer getAsJDAThreadContainer();
} }

View File

@ -7,6 +7,9 @@ import com.discordsrv.common.config.main.generic.DestinationConfig;
import com.discordsrv.common.config.main.generic.ThreadConfig; import com.discordsrv.common.config.main.generic.ThreadConfig;
import com.discordsrv.common.logging.Logger; import com.discordsrv.common.logging.Logger;
import com.discordsrv.common.logging.NamedLogger; import com.discordsrv.common.logging.NamedLogger;
import net.dv8tion.jda.api.Permission;
import net.dv8tion.jda.api.entities.channel.attribute.IThreadContainer;
import net.dv8tion.jda.api.entities.channel.concrete.ThreadChannel;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.*; import java.util.*;
@ -91,7 +94,7 @@ public class DestinationLookupHelper {
future = createThread(threadContainer, threadConfig, logFailures); future = createThread(threadContainer, threadConfig, logFailures);
} else if (existingThread != null) { } else if (existingThread != null) {
// Unarchive existing thread // Unarchive existing thread
future = unarchiveThread(existingThread); future = unarchiveThread(existingThread, logFailures);
} else { } else {
// Lookup threads // Lookup threads
CompletableFuture<List<DiscordThreadChannel>> threads = CompletableFuture<List<DiscordThreadChannel>> threads =
@ -103,7 +106,7 @@ public class DestinationLookupHelper {
DiscordThreadChannel archivedThread = findThread(archivedThreads, threadConfig); DiscordThreadChannel archivedThread = findThread(archivedThreads, threadConfig);
if (archivedThread != null) { if (archivedThread != null) {
// Unarchive existing thread // Unarchive existing thread
return unarchiveThread(archivedThread); return unarchiveThread(archivedThread, logFailures);
} }
// Create thread // Create thread
@ -157,14 +160,27 @@ public class DestinationLookupHelper {
ThreadConfig threadConfig, ThreadConfig threadConfig,
boolean logFailures boolean logFailures
) { ) {
boolean forum = threadContainer instanceof DiscordForumChannel;
boolean privateThread = !forum && threadConfig.privateThread;
IThreadContainer container = threadContainer.getAsJDAThreadContainer();
if (!container.getGuild().getSelfMember().hasPermission(container, privateThread ? Permission.CREATE_PRIVATE_THREADS : Permission.CREATE_PUBLIC_THREADS)) {
if (logFailures) {
logger.error("Failed to create thread \"" + threadConfig.threadName + "\" "
+ "in channel ID " + Long.toUnsignedString(threadContainer.getId())
+ ": lacking \"Create " + (privateThread ? "Private" : "Public") + " Threads\" permission");
}
return CompletableFuture.completedFuture(null);
}
CompletableFuture<DiscordThreadChannel> future; CompletableFuture<DiscordThreadChannel> future;
if (threadContainer instanceof DiscordForumChannel) { if (forum) {
future = ((DiscordForumChannel) threadContainer).createPost( future = ((DiscordForumChannel) threadContainer).createPost(
threadConfig.threadName, threadConfig.threadName,
SendableDiscordMessage.builder().setContent("\u200B").build() // zero-width-space SendableDiscordMessage.builder().setContent("\u200B").build() // zero-width-space
); );
} else { } else {
future = threadContainer.createThread(threadConfig.threadName, threadConfig.privateThread); future = threadContainer.createThread(threadConfig.threadName, privateThread);
} }
return future.exceptionally(t -> { return future.exceptionally(t -> {
if (logFailures) { if (logFailures) {
@ -175,15 +191,27 @@ public class DestinationLookupHelper {
}); });
} }
private CompletableFuture<DiscordThreadChannel> unarchiveThread(DiscordThreadChannel channel) { private CompletableFuture<DiscordThreadChannel> unarchiveThread(DiscordThreadChannel channel, boolean logFailures) {
ThreadChannel jdaChannel = channel.asJDA();
if ((jdaChannel.isLocked() || !jdaChannel.isOwner()) && !jdaChannel.getGuild().getSelfMember().hasPermission(jdaChannel, Permission.MANAGE_THREADS)) {
if (logFailures) {
logger.error("Cannot unarchive thread \"" + channel.getName() + "\" "
+ "in channel ID " + Long.toUnsignedString(channel.getParentChannel().getId())
+ ": lacking \"Manage Threads\" permission");
}
return CompletableFuture.completedFuture(null);
}
return discordSRV.discordAPI().mapExceptions( return discordSRV.discordAPI().mapExceptions(
channel.asJDA().getManager() channel.asJDA().getManager()
.setArchived(false) .setArchived(false)
.reason("DiscordSRV destination lookup") .reason("DiscordSRV destination lookup")
.submit() .submit()
).thenApply(v -> channel).exceptionally(t -> { ).thenApply(v -> channel).exceptionally(t -> {
logger.error("Failed to unarchive thread \"" + channel.getName() + "\" " if (logFailures) {
+ "in channel ID " + Long.toUnsignedString(channel.getParentChannel().getId()), t); logger.error("Failed to unarchive thread \"" + channel.getName() + "\" "
+ "in channel ID " + Long.toUnsignedString(channel.getParentChannel().getId()), t);
}
return null; return null;
}); });
} }

View File

@ -100,4 +100,8 @@ public abstract class AbstractDiscordThreadedGuildMessageChannel<T extends Guild
}); });
} }
@Override
public IThreadContainer getAsJDAThreadContainer() {
return channel;
}
} }

View File

@ -12,7 +12,6 @@ import net.dv8tion.jda.api.entities.channel.concrete.ForumChannel;
import net.dv8tion.jda.api.entities.channel.concrete.ThreadChannel; import net.dv8tion.jda.api.entities.channel.concrete.ThreadChannel;
import net.dv8tion.jda.api.entities.channel.forums.ForumPost; import net.dv8tion.jda.api.entities.channel.forums.ForumPost;
import net.dv8tion.jda.api.requests.restaction.AbstractThreadCreateAction; import net.dv8tion.jda.api.requests.restaction.AbstractThreadCreateAction;
import net.dv8tion.jda.api.requests.restaction.ThreadChannelAction;
import net.dv8tion.jda.api.requests.restaction.pagination.ThreadChannelPaginationAction; import net.dv8tion.jda.api.requests.restaction.pagination.ThreadChannelPaginationAction;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
@ -107,6 +106,11 @@ public class DiscordForumChannelImpl implements DiscordForumChannel {
return thread(channel -> channel.createThreadChannel(name, messageId), result -> result); return thread(channel -> channel.createThreadChannel(name, messageId), result -> result);
} }
@Override
public IThreadContainer getAsJDAThreadContainer() {
return channel;
}
@Override @Override
public CompletableFuture<DiscordThreadChannel> createPost(String name, SendableDiscordMessage message) { public CompletableFuture<DiscordThreadChannel> createPost(String name, SendableDiscordMessage message) {
return thread( return thread(