Required linking: negation to additional requirements, state changes to trigger rechecks. Module manager logic changes

This commit is contained in:
Vankka 2024-06-29 02:52:26 +03:00
parent 6c88174e64
commit 5ee6a9365a
No known key found for this signature in database
GPG Key ID: 62E48025ED4E7EBB
23 changed files with 723 additions and 325 deletions

View File

@ -41,6 +41,14 @@ import java.util.function.Consumer;
public interface Module {
/**
* Determined if this {@link Module} can be enabled before {@link DiscordSRVApi#isReady()}.
* @return {@code true} to allow this {@link Module} to be enabled before DiscordSRV is ready
*/
default boolean canEnableBeforeReady() {
return false;
}
/**
* Determines if this {@link Module} should be enabled at the instant this method is called, this will be used
* to determine when modules should be enabled or disabled when DiscordSRV enabled, disables and reloads.

View File

@ -21,6 +21,7 @@ package com.discordsrv.bukkit.requiredlinking;
import com.discordsrv.bukkit.BukkitDiscordSRV;
import com.discordsrv.bukkit.config.main.BukkitRequiredLinkingConfig;
import com.discordsrv.common.linking.requirelinking.ServerRequireLinkingModule;
import com.discordsrv.common.player.IPlayer;
import org.bukkit.event.Listener;
public class BukkitRequiredLinkingModule extends ServerRequireLinkingModule<BukkitDiscordSRV> implements Listener {
@ -33,4 +34,13 @@ public class BukkitRequiredLinkingModule extends ServerRequireLinkingModule<Bukk
public BukkitRequiredLinkingConfig config() {
return discordSRV.config().requiredLinking;
}
@Override
public void recheck(IPlayer player) {
getBlockReason(player.uniqueId(), player.username(), false).whenComplete((component, throwable) -> {
if (component != null) {
// TODO: handle
}
});
}
}

View File

@ -667,6 +667,12 @@ public abstract class AbstractDiscordSRV<
}
}
List<ReloadResult> results = new ArrayList<>();
// Reload any modules that can be enabled before DiscordSRV is ready
if (initial) {
results.addAll(moduleManager().reload());
}
// Update check
UpdateConfig updateConfig = connectionConfig().update;
if (updateConfig.security.enabled) {
@ -788,8 +794,8 @@ public abstract class AbstractDiscordSRV<
}
}
List<ReloadResult> results = new ArrayList<>();
if (flags.contains(ReloadFlag.MODULES)) {
// Modules are reloaded upon DiscordSRV being ready, thus not needed at initial
if (!initial && flags.contains(ReloadFlag.MODULES)) {
results.addAll(moduleManager.reload());
}

View File

@ -82,7 +82,7 @@ public abstract class ServerDiscordSRV<
@OverridingMethodsMustInvokeSuper
protected void serverStarted() {
serverStarted = true;
moduleManager().reload();
moduleManager().enableModules();
startedMessage();
}

View File

@ -18,6 +18,7 @@
package com.discordsrv.common.config.main.linking;
import com.discordsrv.common.config.configurate.annotation.Constants;
import com.discordsrv.common.config.configurate.annotation.DefaultOnly;
import com.discordsrv.common.config.connection.ConnectionConfig;
import org.spongepowered.configurate.objectmapping.ConfigSerializable;
@ -45,7 +46,7 @@ public class RequirementsConfig {
+ "DiscordBoosting(Server ID)\n"
+ "DiscordRole(Role ID)\n"
+ "\n"
+ "The following are available if you're using MinecraftAuth.me for linked accounts and a MinecraftAuth.me token is specified in the " + ConnectionConfig.FILE_NAME + ":\n"
+ "The following are available if you're using MinecraftAuth.me for linked accounts and a MinecraftAuth.me token is specified in the %1:\n"
+ "PatreonSubscriber() or PatreonSubscriber(Tier Title)\n"
+ "GlimpseSubscriber() or GlimpseSubscriber(Level Name)\n"
+ "TwitchFollower()\n"
@ -58,5 +59,6 @@ public class RequirementsConfig {
+ "|| = or, for example \"DiscordBoosting(...) || YouTubeMember()\"\n"
+ "You can also use brackets () to clear ambiguity, for example: \"DiscordServer(...) && (TwitchSubscriber() || PatreonSubscriber())\"\n"
+ "allows a member of the specified Discord server that is also a twitch or patreon subscriber to join the server")
public List<String> requirements = new ArrayList<>();
@Constants.Comment({ConnectionConfig.FILE_NAME})
public List<String> additionalRequirements = new ArrayList<>();
}

View File

@ -26,7 +26,7 @@ import com.discordsrv.common.linking.LinkProvider;
import com.discordsrv.common.linking.LinkStore;
import com.discordsrv.common.linking.LinkingModule;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.linking.requirelinking.requirement.MinecraftAuthRequirement;
import com.discordsrv.common.linking.requirelinking.requirement.type.MinecraftAuthRequirementType;
import com.discordsrv.common.logging.Logger;
import com.discordsrv.common.logging.NamedLogger;
import com.discordsrv.common.player.IPlayer;
@ -115,8 +115,8 @@ public class MinecraftAuthenticationLinker extends CachedLinkProvider implements
StringBuilder additionalParam = new StringBuilder();
RequiredLinkingModule<?> requiredLinkingModule = discordSRV.getModule(RequiredLinkingModule.class);
if (requiredLinkingModule != null && requiredLinkingModule.isEnabled()) {
for (MinecraftAuthRequirement.Type requirementType : requiredLinkingModule.getActiveRequirementTypes()) {
additionalParam.append(requirementType.character());
for (MinecraftAuthRequirementType.Provider requirementProvider : requiredLinkingModule.getActiveMinecraftAuthProviders()) {
additionalParam.append(requirementProvider.character());
}
}
@ -146,7 +146,7 @@ public class MinecraftAuthenticationLinker extends CachedLinkProvider implements
private void unlinked(UUID playerUUID, long userId) {
logger.debug("Unlink: " + playerUUID + " & " + Long.toUnsignedString(userId));
linkStore.createLink(playerUUID, userId).whenComplete((v, t) -> {
linkStore.removeLink(playerUUID, userId).whenComplete((v, t) -> {
if (t != null) {
logger.error("Failed to unlink player in persistent storage", t);
return;

View File

@ -19,41 +19,64 @@
package com.discordsrv.common.linking.requirelinking;
import com.discordsrv.api.DiscordSRVApi;
import com.discordsrv.api.event.bus.Subscribe;
import com.discordsrv.api.event.events.linking.AccountUnlinkedEvent;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.component.util.ComponentUtil;
import com.discordsrv.common.config.main.linking.RequiredLinkingConfig;
import com.discordsrv.common.config.main.linking.RequirementsConfig;
import com.discordsrv.common.future.util.CompletableFutureUtil;
import com.discordsrv.common.linking.LinkProvider;
import com.discordsrv.common.linking.impl.MinecraftAuthenticationLinker;
import com.discordsrv.common.linking.requirelinking.requirement.*;
import com.discordsrv.common.linking.requirelinking.requirement.Requirement;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.parser.ParsedRequirements;
import com.discordsrv.common.linking.requirelinking.requirement.parser.RequirementParser;
import com.discordsrv.common.linking.requirelinking.requirement.type.DiscordBoostingRequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.type.DiscordRoleRequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.type.DiscordServerRequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.type.MinecraftAuthRequirementType;
import com.discordsrv.common.module.type.AbstractModule;
import com.discordsrv.common.player.IPlayer;
import com.discordsrv.common.scheduler.Scheduler;
import com.discordsrv.common.scheduler.executor.DynamicCachingThreadPoolExecutor;
import com.discordsrv.common.scheduler.threadfactory.CountingThreadFactory;
import com.discordsrv.common.someone.Someone;
import net.kyori.adventure.text.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Consumer;
public abstract class RequiredLinkingModule<T extends DiscordSRV> extends AbstractModule<T> {
private final List<Requirement<?>> availableRequirements = new ArrayList<>();
protected final List<MinecraftAuthRequirement.Type> activeRequirementTypes = new ArrayList<>();
private final List<RequirementType<?>> availableRequirementTypes = new ArrayList<>();
private ThreadPoolExecutor executor;
public RequiredLinkingModule(T discordSRV) {
super(discordSRV);
}
public DiscordSRV discordSRV() {
return discordSRV;
}
public abstract RequiredLinkingConfig config();
@Override
public boolean canEnableBeforeReady() {
return true;
}
@Override
public boolean isEnabled() {
return config().enabled;
return discordSRV.config() == null || config().enabled;
}
@Override
@ -78,54 +101,165 @@ public abstract class RequiredLinkingModule<T extends DiscordSRV> extends Abstra
}
@Override
public void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) {
List<Requirement<?>> requirements = new ArrayList<>();
public final void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) {
List<RequirementType<?>> requirementTypes = new ArrayList<>();
requirements.add(new DiscordRoleRequirement(discordSRV));
requirements.add(new DiscordServerRequirement(discordSRV));
requirements.add(new DiscordBoostingRequirement(discordSRV));
requirementTypes.add(new DiscordRoleRequirementType(this));
requirementTypes.add(new DiscordServerRequirementType(this));
requirementTypes.add(new DiscordBoostingRequirementType(this));
if (discordSRV.linkProvider() instanceof MinecraftAuthenticationLinker) {
requirements.addAll(MinecraftAuthRequirement.createRequirements(discordSRV));
requirementTypes.addAll(MinecraftAuthRequirementType.createRequirements(this));
}
synchronized (availableRequirements) {
availableRequirements.clear();
availableRequirements.addAll(requirements);
synchronized (availableRequirementTypes) {
for (RequirementType<?> requirementType : availableRequirementTypes) {
discordSRV.moduleManager().unregister(requirementType);
}
availableRequirementTypes.clear();
for (RequirementType<?> requirementType : requirementTypes) {
discordSRV.moduleManager().register(requirementType);
}
availableRequirementTypes.addAll(requirementTypes);
}
if (discordSRV.config() != null) {
reload();
}
}
public List<MinecraftAuthRequirement.Type> getActiveRequirementTypes() {
return activeRequirementTypes;
public abstract void reload();
public abstract List<ParsedRequirements> getAllActiveRequirements();
public abstract void recheck(IPlayer player);
private void recheck(Someone someone) {
someone.withPlayerUUID(discordSRV).thenApply(uuid -> {
if (uuid == null) {
return null;
}
return discordSRV.playerProvider().player(uuid);
}).whenComplete((onlinePlayer, t) -> {
if (t != null) {
logger().error("Failed to get linked account for " + someone, t);
}
if (onlinePlayer != null) {
recheck(onlinePlayer);
}
});
}
protected List<CompiledRequirement> compile(List<String> requirements) {
List<CompiledRequirement> checks = new ArrayList<>();
for (String requirement : requirements) {
BiFunction<UUID, Long, CompletableFuture<Boolean>> function = RequirementParser.getInstance().parse(requirement, availableRequirements,
activeRequirementTypes);
checks.add(new CompiledRequirement(requirement, function));
}
return checks;
}
public <RT> void stateChanged(Someone someone, RequirementType<RT> requirementType, RT value, boolean newState) {
for (ParsedRequirements activeRequirement : getAllActiveRequirements()) {
for (Requirement<?> requirement : activeRequirement.usedRequirements()) {
if (requirement.type() != requirementType
|| !Objects.equals(requirement.value(), value)
|| newState == requirement.negated()) {
continue;
}
public static class CompiledRequirement {
private final String input;
private final BiFunction<UUID, Long, CompletableFuture<Boolean>> function;
protected CompiledRequirement(String input, BiFunction<UUID, Long, CompletableFuture<Boolean>> function) {
this.input = input;
this.function = function;
}
public String input() {
return input;
}
public BiFunction<UUID, Long, CompletableFuture<Boolean>> function() {
return function;
// One of the checks now fails
recheck(someone);
break;
}
}
}
@Subscribe
public void onAccountUnlinked(AccountUnlinkedEvent event) {
recheck(Someone.of(event.getPlayerUUID()));
}
protected List<ParsedRequirements> compile(List<String> additionalRequirements) {
List<ParsedRequirements> parsed = new ArrayList<>();
for (String input : additionalRequirements) {
ParsedRequirements parsedRequirement = RequirementParser.getInstance()
.parse(input, availableRequirementTypes);
parsed.add(parsedRequirement);
}
return parsed;
}
public List<MinecraftAuthRequirementType.Provider> getActiveMinecraftAuthProviders() {
List<MinecraftAuthRequirementType.Provider> providers = new ArrayList<>();
for (ParsedRequirements parsedRequirements : getAllActiveRequirements()) {
for (Requirement<?> requirement : parsedRequirements.usedRequirements()) {
RequirementType<?> requirementType = requirement.type();
if (requirementType instanceof MinecraftAuthRequirementType) {
providers.add(((MinecraftAuthRequirementType<?>) requirementType).getProvider());
}
}
}
return providers;
}
public CompletableFuture<Component> getBlockReason(
RequirementsConfig config,
List<ParsedRequirements> additionalRequirements,
UUID playerUUID,
String playerName,
boolean join
) {
if (config.bypassUUIDs.contains(playerUUID.toString())) {
// Bypasses: let them through
logger().debug("Player " + playerName + " is bypassing required linking requirements");
return CompletableFuture.completedFuture(null);
}
LinkProvider linkProvider = discordSRV.linkProvider();
if (linkProvider == null) {
// Link provider unavailable but required linking enabled: error message
Component message = ComponentUtil.fromAPI(
discordSRV.messagesConfig().minecraft.unableToCheckLinkingStatus.textBuilder().build()
);
return CompletableFuture.completedFuture(message);
}
return linkProvider.queryUserId(playerUUID, true).thenCompose(opt -> {
if (!opt.isPresent()) {
// User is not linked
return linkProvider.getLinkingInstructions(playerName, playerUUID, null, join ? "join" : "freeze")
.thenApply(ComponentUtil::fromAPI);
}
long userId = opt.get();
if (additionalRequirements.isEmpty()) {
// No additional requirements: let them through
return CompletableFuture.completedFuture(null);
}
CompletableFuture<Void> pass = new CompletableFuture<>();
List<CompletableFuture<Boolean>> all = new ArrayList<>();
for (ParsedRequirements requirement : additionalRequirements) {
CompletableFuture<Boolean> future = requirement.predicate().apply(Someone.of(playerUUID, userId));
all.add(future.thenApply(val -> {
if (val) {
pass.complete(null);
}
return val;
}).exceptionally(t -> {
logger().debug("Check \"" + requirement.input() + "\" failed for "
+ playerName + " / " + Long.toUnsignedString(userId), t);
return null;
}));
}
// Complete when at least one passes or all of them completed
return CompletableFuture.anyOf(pass, CompletableFutureUtil.combine(all)).thenApply(v -> {
if (pass.isDone()) {
// One of the futures passed: let them through
return null;
}
// None of the futures passed: additional requirements not met
return Component.text("You did not pass additionalRequirements");
});
});
}
}

View File

@ -18,26 +18,19 @@
package com.discordsrv.common.linking.requirelinking;
import com.discordsrv.api.DiscordSRVApi;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.component.util.ComponentUtil;
import com.discordsrv.common.config.main.linking.RequirementsConfig;
import com.discordsrv.common.config.main.linking.ServerRequiredLinkingConfig;
import com.discordsrv.common.future.util.CompletableFutureUtil;
import com.discordsrv.common.linking.LinkProvider;
import com.discordsrv.common.linking.requirelinking.requirement.MinecraftAuthRequirement;
import com.discordsrv.common.linking.requirelinking.requirement.parser.ParsedRequirements;
import net.kyori.adventure.text.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
public abstract class ServerRequireLinkingModule<T extends DiscordSRV> extends RequiredLinkingModule<T> {
private final List<CompiledRequirement> compiledRequirements = new CopyOnWriteArrayList<>();
private final List<ParsedRequirements> additionalRequirements = new CopyOnWriteArrayList<>();
public ServerRequireLinkingModule(T discordSRV) {
super(discordSRV);
@ -47,85 +40,24 @@ public abstract class ServerRequireLinkingModule<T extends DiscordSRV> extends R
public abstract ServerRequiredLinkingConfig config();
@Override
public void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) {
super.reload(resultConsumer);
synchronized (compiledRequirements) {
activeRequirementTypes.clear();
compiledRequirements.clear();
compiledRequirements.addAll(compile(config().requirements.requirements));
public void reload() {
synchronized (additionalRequirements) {
additionalRequirements.clear();
additionalRequirements.addAll(compile(config().requirements.additionalRequirements));
}
}
public List<MinecraftAuthRequirement.Type> getRequirementTypes() {
return activeRequirementTypes;
@Override
public List<ParsedRequirements> getAllActiveRequirements() {
return additionalRequirements;
}
public CompletableFuture<Component> getBlockReason(UUID playerUUID, String playerName, boolean join) {
RequirementsConfig config = config().requirements;
if (config.bypassUUIDs.contains(playerUUID.toString())) {
// Bypasses: let them through
logger().debug("Player " + playerName + " is bypassing required linking requirements");
return CompletableFuture.completedFuture(null);
List<ParsedRequirements> additionalRequirements;
synchronized (this.additionalRequirements) {
additionalRequirements = this.additionalRequirements;
}
LinkProvider linkProvider = discordSRV.linkProvider();
if (linkProvider == null) {
// Link provider unavailable but required linking enabled: error message
Component message = ComponentUtil.fromAPI(
discordSRV.messagesConfig().minecraft.unableToCheckLinkingStatus.textBuilder().build()
);
return CompletableFuture.completedFuture(message);
}
return linkProvider.queryUserId(playerUUID, true)
.thenCompose(opt -> {
if (!opt.isPresent()) {
// User is not linked
return linkProvider.getLinkingInstructions(playerName, playerUUID, null, join ? "join" : "freeze")
.thenApply(ComponentUtil::fromAPI);
}
List<CompiledRequirement> requirements;
synchronized (compiledRequirements) {
requirements = compiledRequirements;
}
if (requirements.isEmpty()) {
// No additional requirements: let them through
return CompletableFuture.completedFuture(null);
}
CompletableFuture<Void> pass = new CompletableFuture<>();
List<CompletableFuture<Boolean>> all = new ArrayList<>();
long userId = opt.get();
for (CompiledRequirement requirement : requirements) {
CompletableFuture<Boolean> future = requirement.function().apply(playerUUID, userId);
all.add(future);
future.whenComplete((val, t) -> {
if (val != null && val) {
pass.complete(null);
}
}).exceptionally(t -> {
logger().debug("Check \"" + requirement.input() + "\" failed for " + playerName + " / " + Long.toUnsignedString(userId), t);
return null;
});
}
// Complete when at least one passes or all of them completed
return CompletableFuture.anyOf(pass, CompletableFutureUtil.combine(all))
.thenApply(v -> {
if (pass.isDone()) {
// One of the futures passed: let them through
return null;
}
// None of the futures passed: requirements not met
return Component.text("You did not pass requirements");
});
});
return getBlockReason(config().requirements, additionalRequirements, playerUUID, playerName, join);
}
}

View File

@ -18,15 +18,27 @@
package com.discordsrv.common.linking.requirelinking.requirement;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
public class Requirement<T> {
public interface Requirement<T> {
private final RequirementType<T> type;
private final T value;
private final boolean negated;
String name();
public Requirement(RequirementType<T> type, T value, boolean negated) {
this.type = type;
this.value = value;
this.negated = negated;
}
T parse(String input);
public RequirementType<T> type() {
return type;
}
CompletableFuture<Boolean> isMet(T value, UUID player, long userId);
public T value() {
return value;
}
public boolean negated() {
return negated;
}
}

View File

@ -18,34 +18,28 @@
package com.discordsrv.common.linking.requirelinking.requirement;
import com.discordsrv.api.discord.entity.guild.DiscordRole;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.module.type.AbstractModule;
import com.discordsrv.common.someone.Someone;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
public class DiscordRoleRequirement extends LongRequirement {
public abstract class RequirementType<T> extends AbstractModule<DiscordSRV> {
private final DiscordSRV discordSRV;
protected final RequiredLinkingModule<? extends DiscordSRV> module;
public DiscordRoleRequirement(DiscordSRV discordSRV) {
this.discordSRV = discordSRV;
public RequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module.discordSRV());
this.module = module;
}
@Override
public String name() {
return "DiscordRole";
public final void stateChanged(Someone someone, T value, boolean newState) {
module.stateChanged(someone, this, value, newState);
}
@Override
public CompletableFuture<Boolean> isMet(Long value, UUID player, long userId) {
DiscordRole role = discordSRV.discordAPI().getRoleById(value);
if (role == null) {
return CompletableFuture.completedFuture(false);
}
public abstract String name();
public abstract T parse(String input);
public abstract CompletableFuture<Boolean> isMet(T value, Someone.Resolved someone);
return role.getGuild()
.retrieveMemberById(userId)
.thenApply(member -> member.getRoles().contains(role));
}
}

View File

@ -0,0 +1,55 @@
/*
* This file is part of DiscordSRV, licensed under the GPLv3 License
* Copyright (c) 2016-2024 Austin "Scarsz" Shapiro, Henri "Vankka" Schubin and DiscordSRV contributors
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package com.discordsrv.common.linking.requirelinking.requirement.parser;
import com.discordsrv.common.linking.requirelinking.requirement.Requirement;
import com.discordsrv.common.someone.Someone;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
public class ParsedRequirements {
private final String input;
private final Function<Someone.Resolved, CompletableFuture<Boolean>> predicate;
private final List<Requirement<?>> usedRequirements;
public ParsedRequirements(
String input,
Function<Someone.Resolved, CompletableFuture<Boolean>> predicate,
List<Requirement<?>> usedRequirements
) {
this.input = input;
this.predicate = predicate;
this.usedRequirements = usedRequirements;
}
public String input() {
return input;
}
public Function<Someone.Resolved, CompletableFuture<Boolean>> predicate() {
return predicate;
}
public List<Requirement<?>> usedRequirements() {
return usedRequirements;
}
}

View File

@ -19,13 +19,13 @@
package com.discordsrv.common.linking.requirelinking.requirement.parser;
import com.discordsrv.common.future.util.CompletableFutureUtil;
import com.discordsrv.common.linking.requirelinking.requirement.MinecraftAuthRequirement;
import com.discordsrv.common.linking.requirelinking.requirement.Requirement;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
import com.discordsrv.common.someone.Someone;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
@ -42,15 +42,24 @@ public class RequirementParser {
private RequirementParser() {}
@SuppressWarnings("unchecked")
public <T> BiFunction<UUID, Long, CompletableFuture<Boolean>> parse(String input, List<Requirement<?>> requirements, List<MinecraftAuthRequirement.Type> types) {
List<Requirement<T>> reqs = new ArrayList<>(requirements.size());
requirements.forEach(r -> reqs.add((Requirement<T>) r));
public <T> ParsedRequirements parse(
String input,
List<RequirementType<?>> availableRequirementTypes
) {
List<RequirementType<T>> reqs = new ArrayList<>(availableRequirementTypes.size());
availableRequirementTypes.forEach(r -> reqs.add((RequirementType<T>) r));
Func func = parse(input, new AtomicInteger(0), reqs, types);
return func::test;
List<Requirement<?>> usedRequirements = new ArrayList<>();
Func func = parse(input, new AtomicInteger(0), reqs, usedRequirements);
return new ParsedRequirements(input, func::test, usedRequirements);
}
private <T> Func parse(String input, AtomicInteger iterator, List<Requirement<T>> requirements, List<MinecraftAuthRequirement.Type> types) {
private <T> Func parse(
String input,
AtomicInteger iterator,
List<RequirementType<T>> availableRequirementTypes,
List<Requirement<?>> parsedRequirements
) {
StringBuilder functionNameBuffer = new StringBuilder();
StringBuilder functionValueBuffer = new StringBuilder();
boolean isFunctionValue = false;
@ -58,6 +67,7 @@ public class RequirementParser {
Func func = null;
Operator operator = null;
boolean operatorSecond = false;
boolean negated = false;
Function<String, RuntimeException> error = text -> {
int i = iterator.get();
@ -70,7 +80,7 @@ public class RequirementParser {
char c = chars[i];
if (c == '(' && functionNameBuffer.length() == 0) {
iterator.incrementAndGet();
Func function = parse(input, iterator, requirements, types);
Func function = parse(input, iterator, availableRequirementTypes, parsedRequirements);
if (function == null) {
throw error.apply("Empty brackets");
}
@ -103,18 +113,20 @@ public class RequirementParser {
String functionName = functionNameBuffer.toString();
String value = functionValueBuffer.toString();
for (Requirement<T> requirement : requirements) {
if (requirement.name().equalsIgnoreCase(functionName)) {
if (requirement instanceof MinecraftAuthRequirement) {
types.add(((MinecraftAuthRequirement<?>) requirement).getType());
}
T requirementValue = requirement.parse(value);
for (RequirementType<T> requirementType : availableRequirementTypes) {
if (requirementType.name().equalsIgnoreCase(functionName)) {
T requirementValue = requirementType.parse(value);
if (requirementValue == null) {
throw error.apply("Unacceptable function value for " + functionName);
}
Func function = (player, user) -> requirement.isMet(requirementValue, player, user);
boolean isNegated = negated;
negated = false;
parsedRequirements.add(new Requirement<>(requirementType, requirementValue, isNegated));
Func function = someone -> requirementType.isMet(requirementValue, someone)
.thenApply(val -> isNegated != val);
if (func != null) {
if (operator == null) {
throw error.apply("No operator");
@ -163,12 +175,23 @@ public class RequirementParser {
throw error.apply("Operators must be exactly two of the same character");
}
if (!Character.isSpaceChar(c)) {
if (isFunctionValue) {
functionValueBuffer.append(c);
} else {
functionNameBuffer.append(c);
if (Character.isSpaceChar(c)) {
continue;
}
if (isFunctionValue) {
functionValueBuffer.append(c);
} else {
if (c == '!') {
if (functionNameBuffer.length() > 0) {
throw error.apply("Negation must be before function name");
}
negated = !negated;
continue;
}
functionNameBuffer.append(c);
}
}
@ -180,7 +203,7 @@ public class RequirementParser {
@FunctionalInterface
private interface Func {
CompletableFuture<Boolean> test(UUID player, long user);
CompletableFuture<Boolean> test(Someone.Resolved someone);
}
private enum Operator {
@ -197,7 +220,7 @@ public class RequirementParser {
}
private static Func apply(Func one, Func two, BiFunction<Boolean, Boolean, Boolean> function) {
return (player, user) -> CompletableFutureUtil.combine(one.test(player, user), two.test(player, user))
return someone -> CompletableFutureUtil.combine(one.test(someone), two.test(someone))
.thenApply(bools -> function.apply(bools.get(0), bools.get(1)));
}
}

View File

@ -16,20 +16,21 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package com.discordsrv.common.linking.requirelinking.requirement;
package com.discordsrv.common.linking.requirelinking.requirement.type;
import com.discordsrv.api.discord.entity.guild.DiscordGuild;
import com.discordsrv.api.event.bus.Subscribe;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.someone.Someone;
import net.dv8tion.jda.api.events.guild.member.update.GuildMemberUpdateBoostTimeEvent;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
public class DiscordBoostingRequirement extends LongRequirement {
public class DiscordBoostingRequirementType extends LongRequirementType {
private final DiscordSRV discordSRV;
public DiscordBoostingRequirement(DiscordSRV discordSRV) {
this.discordSRV = discordSRV;
public DiscordBoostingRequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module);
}
@Override
@ -38,13 +39,18 @@ public class DiscordBoostingRequirement extends LongRequirement {
}
@Override
public CompletableFuture<Boolean> isMet(Long value, UUID player, long userId) {
DiscordGuild guild = discordSRV.discordAPI().getGuildById(value);
public CompletableFuture<Boolean> isMet(Long value, Someone.Resolved someone) {
DiscordGuild guild = module.discordSRV().discordAPI().getGuildById(value);
if (guild == null) {
return CompletableFuture.completedFuture(false);
}
return guild.retrieveMemberById(userId)
return guild.retrieveMemberById(someone.userId())
.thenApply(member -> member != null && member.isBoosting());
}
@Subscribe
public void onGuildMemberUpdateBoostTime(GuildMemberUpdateBoostTimeEvent event) {
stateChanged(Someone.of(event.getMember().getIdLong()), null, event.getNewTimeBoosted() != null);
}
}

View File

@ -0,0 +1,63 @@
/*
* This file is part of DiscordSRV, licensed under the GPLv3 License
* Copyright (c) 2016-2024 Austin "Scarsz" Shapiro, Henri "Vankka" Schubin and DiscordSRV contributors
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package com.discordsrv.common.linking.requirelinking.requirement.type;
import com.discordsrv.api.discord.entity.guild.DiscordRole;
import com.discordsrv.api.event.bus.Subscribe;
import com.discordsrv.api.event.events.discord.member.role.AbstractDiscordMemberRoleChangeEvent;
import com.discordsrv.api.event.events.discord.member.role.DiscordMemberRoleAddEvent;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.someone.Someone;
import java.util.concurrent.CompletableFuture;
public class DiscordRoleRequirementType extends LongRequirementType {
public DiscordRoleRequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module);
}
@Override
public String name() {
return "DiscordRole";
}
@Override
public CompletableFuture<Boolean> isMet(Long value, Someone.Resolved someone) {
DiscordRole role = module.discordSRV().discordAPI().getRoleById(value);
if (role == null) {
return CompletableFuture.completedFuture(false);
}
return role.getGuild()
.retrieveMemberById(someone.userId())
.thenApply(member -> member.getRoles().contains(role));
}
@Subscribe
public void onDiscordMemberRoleAdd(AbstractDiscordMemberRoleChangeEvent<?> event) {
boolean add = event instanceof DiscordMemberRoleAddEvent;
Someone someone = Someone.of(event.getMember().getUser().getId());
for (DiscordRole role : event.getRoles()) {
stateChanged(someone, role.getId(), add);
}
}
}

View File

@ -16,21 +16,23 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package com.discordsrv.common.linking.requirelinking.requirement;
package com.discordsrv.common.linking.requirelinking.requirement.type;
import com.discordsrv.api.discord.entity.guild.DiscordGuild;
import com.discordsrv.api.event.bus.Subscribe;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.someone.Someone;
import net.dv8tion.jda.api.events.guild.member.GuildMemberJoinEvent;
import net.dv8tion.jda.api.events.guild.member.GuildMemberRemoveEvent;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
public class DiscordServerRequirement extends LongRequirement {
public class DiscordServerRequirementType extends LongRequirementType {
private final DiscordSRV discordSRV;
public DiscordServerRequirement(DiscordSRV discordSRV) {
this.discordSRV = discordSRV;
public DiscordServerRequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module);
}
@Override
@ -39,13 +41,23 @@ public class DiscordServerRequirement extends LongRequirement {
}
@Override
public CompletableFuture<Boolean> isMet(Long value, UUID player, long userId) {
DiscordGuild guild = discordSRV.discordAPI().getGuildById(value);
public CompletableFuture<Boolean> isMet(Long value, Someone.Resolved someone) {
DiscordGuild guild = module.discordSRV().discordAPI().getGuildById(value);
if (guild == null) {
return CompletableFuture.completedFuture(false);
}
return guild.retrieveMemberById(userId)
return guild.retrieveMemberById(someone.userId())
.thenApply(Objects::nonNull);
}
@Subscribe
public void onGuildMemberJoin(GuildMemberJoinEvent event) {
stateChanged(Someone.of(event.getUser().getIdLong()), event.getGuild().getIdLong(), true);
}
@Subscribe
public void onGuildMemberRemove(GuildMemberRemoveEvent event) {
stateChanged(Someone.of(event.getUser().getIdLong()), event.getGuild().getIdLong(), false);
}
}

View File

@ -16,9 +16,17 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package com.discordsrv.common.linking.requirelinking.requirement;
package com.discordsrv.common.linking.requirelinking.requirement.type;
public abstract class LongRequirement implements Requirement<Long> {
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
public abstract class LongRequirementType extends RequirementType<Long> {
public LongRequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module);
}
@Override
public Long parse(String input) {

View File

@ -16,9 +16,12 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package com.discordsrv.common.linking.requirelinking.requirement;
package com.discordsrv.common.linking.requirelinking.requirement.type;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
import com.discordsrv.common.someone.Someone;
import me.minecraftauth.lib.AuthService;
import me.minecraftauth.lib.account.platform.twitch.SubTier;
import me.minecraftauth.lib.exception.LookupException;
@ -30,41 +33,41 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthRequirement.Reference<T>> {
public class MinecraftAuthRequirementType<T> extends RequirementType<MinecraftAuthRequirementType.Reference<T>> {
private static final Reference<?> NULL_VALUE = new Reference<>(null);
public static List<Requirement<?>> createRequirements(DiscordSRV discordSRV) {
List<Requirement<?>> requirements = new ArrayList<>();
public static List<RequirementType<?>> createRequirements(RequiredLinkingModule<?> module) {
List<RequirementType<?>> requirementTypes = new ArrayList<>();
// Patreon
requirements.add(new MinecraftAuthRequirement<>(
discordSRV,
Type.PATREON,
requirementTypes.add(new MinecraftAuthRequirementType<>(
module,
Provider.PATREON,
"PatreonSubscriber",
AuthService::isSubscribedPatreon,
AuthService::isSubscribedPatreon
));
// Glimpse
requirements.add(new MinecraftAuthRequirement<>(
discordSRV,
Type.GLIMPSE,
requirementTypes.add(new MinecraftAuthRequirementType<>(
module,
Provider.GLIMPSE,
"GlimpseSubscriber",
AuthService::isSubscribedGlimpse,
AuthService::isSubscribedGlimpse
));
// Twitch
requirements.add(new MinecraftAuthRequirement<>(
discordSRV,
Type.TWITCH,
requirementTypes.add(new MinecraftAuthRequirementType<>(
module,
Provider.TWITCH,
"TwitchFollower",
AuthService::isFollowingTwitch
));
requirements.add(new MinecraftAuthRequirement<>(
discordSRV,
Type.TWITCH,
requirementTypes.add(new MinecraftAuthRequirementType<>(
module,
Provider.TWITCH,
"TwitchSubscriber",
AuthService::isSubscribedTwitch,
AuthService::isSubscribedTwitch,
@ -79,60 +82,59 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
));
// YouTube
requirements.add(new MinecraftAuthRequirement<>(
discordSRV,
Type.YOUTUBE,
requirementTypes.add(new MinecraftAuthRequirementType<>(
module,
Provider.YOUTUBE,
"YouTubeSubscriber",
AuthService::isSubscribedYouTube
));
requirements.add(new MinecraftAuthRequirement<>(
discordSRV,
Type.YOUTUBE,
requirementTypes.add(new MinecraftAuthRequirementType<>(
module,
Provider.YOUTUBE,
"YouTubeMember",
AuthService::isMemberYouTube,
AuthService::isMemberYouTube
));
return requirements;
return requirementTypes;
}
private final DiscordSRV discordSRV;
private final Type type;
private final Provider provider;
private final String name;
private final Test test;
private final TestSpecific<T> testSpecific;
private final Function<String, T> parse;
public MinecraftAuthRequirement(
DiscordSRV discordSRV,
Type type,
public MinecraftAuthRequirementType(
RequiredLinkingModule<? extends DiscordSRV> module,
Provider provider,
String name,
Test test
) {
this(discordSRV, type, name, test, null, null);
this(module, provider, name, test, null, null);
}
@SuppressWarnings("unchecked")
public MinecraftAuthRequirement(
DiscordSRV discordSRV,
Type type,
public MinecraftAuthRequirementType(
RequiredLinkingModule<? extends DiscordSRV> module,
Provider provider,
String name,
Test test,
TestSpecific<String> testSpecific
) {
this(discordSRV, type, name, test, (TestSpecific<T>) testSpecific, t -> (T) t);
this(module, provider, name, test, (TestSpecific<T>) testSpecific, t -> (T) t);
}
public MinecraftAuthRequirement(
DiscordSRV discordSRV,
Type type,
public MinecraftAuthRequirementType(
RequiredLinkingModule<? extends DiscordSRV> module,
Provider provider,
String name,
Test test,
TestSpecific<T> testSpecific,
Function<String, T> parse
) {
this.discordSRV = discordSRV;
this.type = type;
super(module);
this.provider = provider;
this.name = name;
this.test = test;
this.testSpecific = testSpecific;
@ -144,8 +146,8 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
return name;
}
public Type getType() {
return type;
public Provider getProvider() {
return provider;
}
@SuppressWarnings("unchecked")
@ -161,14 +163,14 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
}
@Override
public CompletableFuture<Boolean> isMet(Reference<T> atomicReference, UUID player, long userId) {
String token = discordSRV.connectionConfig().minecraftAuth.token;
public CompletableFuture<Boolean> isMet(Reference<T> atomicReference, Someone.Resolved someone) {
String token = module.discordSRV().connectionConfig().minecraftAuth.token;
T value = atomicReference.getValue();
return discordSRV.scheduler().supply(() -> {
return module.discordSRV().scheduler().supply(() -> {
if (value == null) {
return test.test(token, player);
return test.test(token, someone.playerUUID());
} else {
return testSpecific.test(token, player, value);
return testSpecific.test(token, someone.playerUUID(), value);
}
});
}
@ -196,7 +198,7 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
}
}
public enum Type {
public enum Provider {
PATREON('p'),
GLIMPSE('g'),
TWITCH('t'),
@ -204,7 +206,7 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
private final char character;
Type(char character) {
Provider(char character) {
this.character = character;
}

View File

@ -118,6 +118,12 @@ public class ModuleManager {
: delegates.computeIfAbsent(module, mod -> new ModuleDelegate(discordSRV, mod));
}
private String getName(AbstractModule<?> abstractModule) {
return abstractModule instanceof ModuleDelegate
? ((ModuleDelegate) abstractModule).getBase().getClass().getName()
: abstractModule.getClass().getSimpleName();
}
public <DT extends DiscordSRV> void registerModule(DT discordSRV, CheckedFunction<DT, AbstractModule<?>> function) {
try {
register(function.apply(discordSRV));
@ -134,12 +140,10 @@ public class ModuleManager {
this.modules.add(module);
this.moduleLookupTable.put(module.getClass().getName(), module);
logger.debug(module + " registered");
logger.debug(module.getClass().getName() + " registered");
if (discordSRV.isReady()) {
// Check if Discord connection is ready, if it is already we'll enable the module
enable(getAbstract(module));
}
// Enable the module if we're already ready
enableOrDisableAsNeeded(getAbstract(module), discordSRV.isReady(), true);
}
public void unregister(Module module) {
@ -154,17 +158,35 @@ public class ModuleManager {
this.moduleLookupTable.values().removeIf(mod -> mod == module);
this.delegates.remove(module);
logger.debug(module + " unregistered");
logger.debug(module.getClass().getName() + " unregistered");
}
private void enable(AbstractModule<?> module) {
private List<DiscordSRVApi.ReloadResult> enable(AbstractModule<?> module) {
try {
if (module.enableModule()) {
logger.debug(module + " enabled");
return reload(module);
}
} catch (Throwable t) {
discordSRV.logger().error("Failed to enable " + module.getClass().getSimpleName(), t);
discordSRV.logger().error("Failed to enable " + getName(module), t);
return Collections.emptyList();
}
return null;
}
private List<DiscordSRVApi.ReloadResult> reload(AbstractModule<?> module) {
List<DiscordSRVApi.ReloadResult> reloadResults = new ArrayList<>();
try {
module.reload(result -> {
if (result == null) {
throw new NullPointerException("null result supplied to resultConsumer");
}
reloadResults.add(result);
});
} catch (Throwable t) {
discordSRV.logger().error("Failed to reload " + getName(module), t);
}
return reloadResults;
}
private void disable(AbstractModule<?> module) {
@ -173,7 +195,7 @@ public class ModuleManager {
logger.debug(module + " disabled");
}
} catch (Throwable t) {
discordSRV.logger().error("Failed to disable " + module.getClass().getSimpleName(), t);
discordSRV.logger().error("Failed to disable " + getName(module), t);
}
}
@ -189,52 +211,21 @@ public class ModuleManager {
reload();
}
public synchronized List<DiscordSRV.ReloadResult> reload() {
JDAConnectionManager connectionManager = discordSRV.discordConnectionManager();
public List<DiscordSRVApi.ReloadResult> reload() {
return reloadAndEnableModules(true);
}
public void enableModules() {
reloadAndEnableModules(false);
}
private synchronized List<DiscordSRV.ReloadResult> reloadAndEnableModules(boolean reload) {
boolean isReady = discordSRV.isReady();
logger().debug((reload ? "Reloading" : "Enabling") + " modules (DiscordSRV ready = " + isReady + ")");
Set<DiscordSRVApi.ReloadResult> reloadResults = new HashSet<>();
for (Module module : modules) {
AbstractModule<?> abstractModule = getAbstract(module);
boolean fail = false;
if (abstractModule.isEnabled()) {
for (DiscordGatewayIntent requiredIntent : abstractModule.getRequestedIntents()) {
if (!connectionManager.getIntents().contains(requiredIntent)) {
fail = true;
logger().warning("Missing gateway intent " + requiredIntent.name() + " for module " + module.getClass().getSimpleName());
}
}
for (DiscordCacheFlag requiredCacheFlag : abstractModule.getRequestedCacheFlags()) {
if (!connectionManager.getCacheFlags().contains(requiredCacheFlag)) {
fail = true;
logger().warning("Missing cache flag " + requiredCacheFlag.name() + " for module " + module.getClass().getSimpleName());
}
}
}
if (fail) {
reloadResults.add(ReloadResults.DISCORD_CONNECTION_RELOAD_REQUIRED);
}
// Check if the module needs to be enabled or disabled
if (!fail) {
enable(abstractModule);
}
if (!abstractModule.isEnabled()) {
disable(abstractModule);
continue;
}
try {
abstractModule.reload(result -> {
if (result == null) {
throw new NullPointerException("null result supplied to resultConsumer");
}
reloadResults.add(result);
});
} catch (Throwable t) {
discordSRV.logger().error("Failed to reload " + module.getClass().getSimpleName(), t);
}
reloadResults.addAll(enableOrDisableAsNeeded(getAbstract(module), isReady, reload));
}
List<DiscordSRVApi.ReloadResult> results = new ArrayList<>();
@ -249,6 +240,52 @@ public class ModuleManager {
return results;
}
private List<DiscordSRVApi.ReloadResult> enableOrDisableAsNeeded(AbstractModule<?> module, boolean isReady, boolean reload) {
boolean canBeEnabled = isReady || module.canEnableBeforeReady();
if (!canBeEnabled) {
return Collections.emptyList();
}
boolean enabled = module.isEnabled();
if (!enabled) {
disable(module);
return Collections.emptyList();
}
JDAConnectionManager connectionManager = discordSRV.discordConnectionManager();
boolean fail = false;
for (DiscordGatewayIntent requiredIntent : module.getRequestedIntents()) {
if (!connectionManager.getIntents().contains(requiredIntent)) {
fail = true;
logger().warning("Missing gateway intent " + requiredIntent.name() + " for module " + getName(module));
}
}
for (DiscordCacheFlag requiredCacheFlag : module.getRequestedCacheFlags()) {
if (!connectionManager.getCacheFlags().contains(requiredCacheFlag)) {
fail = true;
logger().warning("Missing cache flag " + requiredCacheFlag.name() + " for module " + getName(module));
}
}
List<DiscordSRVApi.ReloadResult> reloadResults = new ArrayList<>();
if (fail) {
reloadResults.add(ReloadResults.DISCORD_CONNECTION_RELOAD_REQUIRED);
}
// Enable the module if reload passed
if (!fail) {
List<DiscordSRVApi.ReloadResult> results = enable(module);
if (results != null) {
reloadResults.addAll(results);
} else if (reload) {
reloadResults.addAll(reload(module));
}
}
return reloadResults;
}
@Subscribe
public void onDebugGenerate(DebugGenerateEvent event) {
StringBuilder builder = new StringBuilder();

View File

@ -35,7 +35,7 @@ public abstract class AbstractModule<DT extends DiscordSRV> implements Module {
protected final DT discordSRV;
private final Logger logger;
private boolean hasBeenEnabled = false;
private boolean isCurrentlyEnabled = false;
private final List<DiscordGatewayIntent> requestedIntents = new ArrayList<>();
private final List<DiscordCacheFlag> requestedCacheFlags = new ArrayList<>();
@ -72,11 +72,11 @@ public abstract class AbstractModule<DT extends DiscordSRV> implements Module {
// Internal
public final boolean enableModule() {
if (hasBeenEnabled || !isEnabled()) {
if (isCurrentlyEnabled) {
return false;
}
hasBeenEnabled = true;
isCurrentlyEnabled = true;
enable();
try {
@ -87,12 +87,12 @@ public abstract class AbstractModule<DT extends DiscordSRV> implements Module {
}
public final boolean disableModule() {
if (!hasBeenEnabled) {
if (!isCurrentlyEnabled) {
return false;
}
disable();
hasBeenEnabled = false;
isCurrentlyEnabled = false;
try {
discordSRV.eventBus().unsubscribe(this);

View File

@ -39,6 +39,10 @@ public class ModuleDelegate extends AbstractModule<DiscordSRV> {
this.module = module;
}
public Module getBase() {
return module;
}
@Override
public boolean isEnabled() {
return module.isEnabled();

View File

@ -27,6 +27,7 @@ import com.discordsrv.common.config.main.PresenceUpdaterConfig;
import com.discordsrv.common.logging.NamedLogger;
import com.discordsrv.common.module.type.AbstractModule;
import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.events.StatusChangeEvent;
import java.time.Duration;
import java.util.List;
@ -45,6 +46,11 @@ public class PresenceUpdaterModule extends AbstractModule<DiscordSRV> {
super(discordSRV, new NamedLogger(discordSRV, "PRESENCE_UPDATER"));
}
@Override
public boolean canEnableBeforeReady() {
return true;
}
public void serverStarted() {
serverState.set(ServerState.STARTED);
setPresenceOrSchedule();
@ -56,8 +62,19 @@ public class PresenceUpdaterModule extends AbstractModule<DiscordSRV> {
setPresenceOrSchedule();
}
@Subscribe
public void onStatusChange(StatusChangeEvent event) {
if (event.getNewStatus() == JDA.Status.IDENTIFYING_SESSION) {
setPresenceOrSchedule();
}
}
@Override
public void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) {
if (discordSRV.jda() == null) {
return;
}
setPresenceOrSchedule();
// Log problems with presences

View File

@ -63,6 +63,10 @@ public class Someone {
this.userId = userId;
}
private <T> T throwIllegal() {
throw new IllegalStateException("Cannot have Someone instance without either a Player UUID or User Id");
}
@NotNull
public CompletableFuture<@NotNull Profile> profile(DiscordSRV discordSRV) {
if (playerUUID != null) {
@ -70,7 +74,7 @@ public class Someone {
} else if (userId != null) {
return discordSRV.profileManager().lookupProfile(userId);
} else {
throw new IllegalStateException("Cannot have Someone instance without either a Player UUID or User Id");
return throwIllegal();
}
}
@ -80,14 +84,31 @@ public class Someone {
return CompletableFuture.completedFuture(of(playerUUID, userId));
}
return profile(discordSRV).thenApply(profile -> {
UUID playerUUID = profile.playerUUID();
Long userId = profile.userId();
if (playerUUID == null || userId == null) {
return null;
}
return of(playerUUID, userId);
});
if (playerUUID != null) {
return withUserId(discordSRV).thenApply(userId -> userId != null ? of(playerUUID, userId) : null);
} else if (userId != null) {
return withPlayerUUID(discordSRV).thenApply(playerUUID -> playerUUID != null ? of(playerUUID, userId) : null);
} else {
return throwIllegal();
}
}
public CompletableFuture<@Nullable Long> withUserId(DiscordSRV discordSRV) {
if (userId != null) {
return CompletableFuture.completedFuture(userId);
} else if (playerUUID == null) {
return throwIllegal();
}
return discordSRV.linkProvider().getUserId(playerUUID).thenApply(opt -> opt.orElse(null));
}
public CompletableFuture<@Nullable UUID> withPlayerUUID(DiscordSRV discordSRV) {
if (playerUUID != null) {
return CompletableFuture.completedFuture(playerUUID);
} else if (userId == null) {
return throwIllegal();
}
return discordSRV.linkProvider().getPlayerUUID(userId).thenApply(opt -> opt.orElse(null));
}
@Nullable
@ -102,7 +123,7 @@ public class Someone {
@Override
public String toString() {
return playerUUID != null ? playerUUID.toString() : Objects.requireNonNull(userId).toString();
return playerUUID != null ? playerUUID.toString() : Long.toUnsignedString(Objects.requireNonNull(userId));
}
@SuppressWarnings("DataFlowIssue")

View File

@ -18,24 +18,48 @@
package com.discordsrv.common.linking.requirement.parser;
import com.discordsrv.common.linking.requirelinking.requirement.Requirement;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.MockDiscordSRV;
import com.discordsrv.common.config.main.linking.RequiredLinkingConfig;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.parser.ParsedRequirements;
import com.discordsrv.common.linking.requirelinking.requirement.parser.RequirementParser;
import com.discordsrv.common.player.IPlayer;
import com.discordsrv.common.someone.Someone;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import static org.junit.jupiter.api.Assertions.*;
public class RequirementParserTest {
public class RequirementTypeParserTest {
private final RequirementParser requirementParser = RequirementParser.getInstance();
private final List<Requirement<?>> requirements = Arrays.asList(
new Requirement<Boolean>() {
private final RequiredLinkingModule<?> module = new RequiredLinkingModule<DiscordSRV>(MockDiscordSRV.INSTANCE) {
@Override
public RequiredLinkingConfig config() {
return null;
}
@Override
public void reload() {}
@Override
public List<ParsedRequirements> getAllActiveRequirements() {
return Collections.emptyList();
}
@Override
public void recheck(IPlayer player) {}
};
private final List<RequirementType<?>> requirementTypes = Arrays.asList(
new RequirementType<Boolean>(module) {
@Override
public String name() {
return "F";
@ -47,11 +71,11 @@ public class RequirementParserTest {
}
@Override
public CompletableFuture<Boolean> isMet(Boolean value, UUID player, long userId) {
public CompletableFuture<Boolean> isMet(Boolean value, Someone.Resolved someone) {
return CompletableFuture.completedFuture(value);
}
},
new Requirement<Object>() {
new RequirementType<Object>(module) {
@Override
public String name() {
return "AlwaysError";
@ -63,14 +87,17 @@ public class RequirementParserTest {
}
@Override
public CompletableFuture<Boolean> isMet(Object value, UUID player, long userId) {
public CompletableFuture<Boolean> isMet(Object value, Someone.Resolved someone) {
return null;
}
}
);
private boolean parse(String input) {
return requirementParser.parse(input, requirements, new ArrayList<>()).apply(null, 0L).join();
return requirementParser.parse(input, requirementTypes)
.predicate()
.apply(Someone.of(UUID.randomUUID(), 0L))
.join();
}
@Test
@ -78,6 +105,21 @@ public class RequirementParserTest {
assertFalse(parse("f(false) || F(false)"));
}
@Test
public void negate() {
assertTrue(parse("!F(false)"));
}
@Test
public void negateReverse() {
assertFalse(parse("!F(true)"));
}
@Test
public void doubleNegate() {
assertTrue(parse("!!F(true)"));
}
@Test
public void orFail() {
assertFalse(parse("F(false) || F(false)"));
@ -98,6 +140,11 @@ public class RequirementParserTest {
assertTrue(parse("F(true) && F(true)"));
}
@Test
public void andNegate() {
assertTrue(parse("F(true) && !F(false)"));
}
@Test
public void complexFail() {
assertFalse(parse("F(true) && (F(false) && F(true))"));
@ -143,4 +190,9 @@ public class RequirementParserTest {
assertExceptionMessageStartsWith("Unacceptable function value for", () -> parse("AlwaysError()"));
}
@Test
public void negateBeforeFunctionNameError() {
assertExceptionMessageStartsWith("Negation must be before function name", () -> parse("F!(false)"));
}
}