Required linking: negation to additional requirements, state changes to trigger rechecks. Module manager logic changes

This commit is contained in:
Vankka 2024-06-29 02:52:26 +03:00
parent 6c88174e64
commit 5ee6a9365a
No known key found for this signature in database
GPG Key ID: 62E48025ED4E7EBB
23 changed files with 723 additions and 325 deletions

View File

@ -41,6 +41,14 @@ import java.util.function.Consumer;
public interface Module { public interface Module {
/**
* Determined if this {@link Module} can be enabled before {@link DiscordSRVApi#isReady()}.
* @return {@code true} to allow this {@link Module} to be enabled before DiscordSRV is ready
*/
default boolean canEnableBeforeReady() {
return false;
}
/** /**
* Determines if this {@link Module} should be enabled at the instant this method is called, this will be used * Determines if this {@link Module} should be enabled at the instant this method is called, this will be used
* to determine when modules should be enabled or disabled when DiscordSRV enabled, disables and reloads. * to determine when modules should be enabled or disabled when DiscordSRV enabled, disables and reloads.

View File

@ -21,6 +21,7 @@ package com.discordsrv.bukkit.requiredlinking;
import com.discordsrv.bukkit.BukkitDiscordSRV; import com.discordsrv.bukkit.BukkitDiscordSRV;
import com.discordsrv.bukkit.config.main.BukkitRequiredLinkingConfig; import com.discordsrv.bukkit.config.main.BukkitRequiredLinkingConfig;
import com.discordsrv.common.linking.requirelinking.ServerRequireLinkingModule; import com.discordsrv.common.linking.requirelinking.ServerRequireLinkingModule;
import com.discordsrv.common.player.IPlayer;
import org.bukkit.event.Listener; import org.bukkit.event.Listener;
public class BukkitRequiredLinkingModule extends ServerRequireLinkingModule<BukkitDiscordSRV> implements Listener { public class BukkitRequiredLinkingModule extends ServerRequireLinkingModule<BukkitDiscordSRV> implements Listener {
@ -33,4 +34,13 @@ public class BukkitRequiredLinkingModule extends ServerRequireLinkingModule<Bukk
public BukkitRequiredLinkingConfig config() { public BukkitRequiredLinkingConfig config() {
return discordSRV.config().requiredLinking; return discordSRV.config().requiredLinking;
} }
@Override
public void recheck(IPlayer player) {
getBlockReason(player.uniqueId(), player.username(), false).whenComplete((component, throwable) -> {
if (component != null) {
// TODO: handle
}
});
}
} }

View File

@ -667,6 +667,12 @@ public abstract class AbstractDiscordSRV<
} }
} }
List<ReloadResult> results = new ArrayList<>();
// Reload any modules that can be enabled before DiscordSRV is ready
if (initial) {
results.addAll(moduleManager().reload());
}
// Update check // Update check
UpdateConfig updateConfig = connectionConfig().update; UpdateConfig updateConfig = connectionConfig().update;
if (updateConfig.security.enabled) { if (updateConfig.security.enabled) {
@ -788,8 +794,8 @@ public abstract class AbstractDiscordSRV<
} }
} }
List<ReloadResult> results = new ArrayList<>(); // Modules are reloaded upon DiscordSRV being ready, thus not needed at initial
if (flags.contains(ReloadFlag.MODULES)) { if (!initial && flags.contains(ReloadFlag.MODULES)) {
results.addAll(moduleManager.reload()); results.addAll(moduleManager.reload());
} }

View File

@ -82,7 +82,7 @@ public abstract class ServerDiscordSRV<
@OverridingMethodsMustInvokeSuper @OverridingMethodsMustInvokeSuper
protected void serverStarted() { protected void serverStarted() {
serverStarted = true; serverStarted = true;
moduleManager().reload(); moduleManager().enableModules();
startedMessage(); startedMessage();
} }

View File

@ -18,6 +18,7 @@
package com.discordsrv.common.config.main.linking; package com.discordsrv.common.config.main.linking;
import com.discordsrv.common.config.configurate.annotation.Constants;
import com.discordsrv.common.config.configurate.annotation.DefaultOnly; import com.discordsrv.common.config.configurate.annotation.DefaultOnly;
import com.discordsrv.common.config.connection.ConnectionConfig; import com.discordsrv.common.config.connection.ConnectionConfig;
import org.spongepowered.configurate.objectmapping.ConfigSerializable; import org.spongepowered.configurate.objectmapping.ConfigSerializable;
@ -45,7 +46,7 @@ public class RequirementsConfig {
+ "DiscordBoosting(Server ID)\n" + "DiscordBoosting(Server ID)\n"
+ "DiscordRole(Role ID)\n" + "DiscordRole(Role ID)\n"
+ "\n" + "\n"
+ "The following are available if you're using MinecraftAuth.me for linked accounts and a MinecraftAuth.me token is specified in the " + ConnectionConfig.FILE_NAME + ":\n" + "The following are available if you're using MinecraftAuth.me for linked accounts and a MinecraftAuth.me token is specified in the %1:\n"
+ "PatreonSubscriber() or PatreonSubscriber(Tier Title)\n" + "PatreonSubscriber() or PatreonSubscriber(Tier Title)\n"
+ "GlimpseSubscriber() or GlimpseSubscriber(Level Name)\n" + "GlimpseSubscriber() or GlimpseSubscriber(Level Name)\n"
+ "TwitchFollower()\n" + "TwitchFollower()\n"
@ -58,5 +59,6 @@ public class RequirementsConfig {
+ "|| = or, for example \"DiscordBoosting(...) || YouTubeMember()\"\n" + "|| = or, for example \"DiscordBoosting(...) || YouTubeMember()\"\n"
+ "You can also use brackets () to clear ambiguity, for example: \"DiscordServer(...) && (TwitchSubscriber() || PatreonSubscriber())\"\n" + "You can also use brackets () to clear ambiguity, for example: \"DiscordServer(...) && (TwitchSubscriber() || PatreonSubscriber())\"\n"
+ "allows a member of the specified Discord server that is also a twitch or patreon subscriber to join the server") + "allows a member of the specified Discord server that is also a twitch or patreon subscriber to join the server")
public List<String> requirements = new ArrayList<>(); @Constants.Comment({ConnectionConfig.FILE_NAME})
public List<String> additionalRequirements = new ArrayList<>();
} }

View File

@ -26,7 +26,7 @@ import com.discordsrv.common.linking.LinkProvider;
import com.discordsrv.common.linking.LinkStore; import com.discordsrv.common.linking.LinkStore;
import com.discordsrv.common.linking.LinkingModule; import com.discordsrv.common.linking.LinkingModule;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule; import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.linking.requirelinking.requirement.MinecraftAuthRequirement; import com.discordsrv.common.linking.requirelinking.requirement.type.MinecraftAuthRequirementType;
import com.discordsrv.common.logging.Logger; import com.discordsrv.common.logging.Logger;
import com.discordsrv.common.logging.NamedLogger; import com.discordsrv.common.logging.NamedLogger;
import com.discordsrv.common.player.IPlayer; import com.discordsrv.common.player.IPlayer;
@ -115,8 +115,8 @@ public class MinecraftAuthenticationLinker extends CachedLinkProvider implements
StringBuilder additionalParam = new StringBuilder(); StringBuilder additionalParam = new StringBuilder();
RequiredLinkingModule<?> requiredLinkingModule = discordSRV.getModule(RequiredLinkingModule.class); RequiredLinkingModule<?> requiredLinkingModule = discordSRV.getModule(RequiredLinkingModule.class);
if (requiredLinkingModule != null && requiredLinkingModule.isEnabled()) { if (requiredLinkingModule != null && requiredLinkingModule.isEnabled()) {
for (MinecraftAuthRequirement.Type requirementType : requiredLinkingModule.getActiveRequirementTypes()) { for (MinecraftAuthRequirementType.Provider requirementProvider : requiredLinkingModule.getActiveMinecraftAuthProviders()) {
additionalParam.append(requirementType.character()); additionalParam.append(requirementProvider.character());
} }
} }
@ -146,7 +146,7 @@ public class MinecraftAuthenticationLinker extends CachedLinkProvider implements
private void unlinked(UUID playerUUID, long userId) { private void unlinked(UUID playerUUID, long userId) {
logger.debug("Unlink: " + playerUUID + " & " + Long.toUnsignedString(userId)); logger.debug("Unlink: " + playerUUID + " & " + Long.toUnsignedString(userId));
linkStore.createLink(playerUUID, userId).whenComplete((v, t) -> { linkStore.removeLink(playerUUID, userId).whenComplete((v, t) -> {
if (t != null) { if (t != null) {
logger.error("Failed to unlink player in persistent storage", t); logger.error("Failed to unlink player in persistent storage", t);
return; return;

View File

@ -19,41 +19,64 @@
package com.discordsrv.common.linking.requirelinking; package com.discordsrv.common.linking.requirelinking;
import com.discordsrv.api.DiscordSRVApi; import com.discordsrv.api.DiscordSRVApi;
import com.discordsrv.api.event.bus.Subscribe;
import com.discordsrv.api.event.events.linking.AccountUnlinkedEvent;
import com.discordsrv.common.DiscordSRV; import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.component.util.ComponentUtil;
import com.discordsrv.common.config.main.linking.RequiredLinkingConfig; import com.discordsrv.common.config.main.linking.RequiredLinkingConfig;
import com.discordsrv.common.config.main.linking.RequirementsConfig;
import com.discordsrv.common.future.util.CompletableFutureUtil;
import com.discordsrv.common.linking.LinkProvider;
import com.discordsrv.common.linking.impl.MinecraftAuthenticationLinker; import com.discordsrv.common.linking.impl.MinecraftAuthenticationLinker;
import com.discordsrv.common.linking.requirelinking.requirement.*; import com.discordsrv.common.linking.requirelinking.requirement.Requirement;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.parser.ParsedRequirements;
import com.discordsrv.common.linking.requirelinking.requirement.parser.RequirementParser; import com.discordsrv.common.linking.requirelinking.requirement.parser.RequirementParser;
import com.discordsrv.common.linking.requirelinking.requirement.type.DiscordBoostingRequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.type.DiscordRoleRequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.type.DiscordServerRequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.type.MinecraftAuthRequirementType;
import com.discordsrv.common.module.type.AbstractModule; import com.discordsrv.common.module.type.AbstractModule;
import com.discordsrv.common.player.IPlayer;
import com.discordsrv.common.scheduler.Scheduler; import com.discordsrv.common.scheduler.Scheduler;
import com.discordsrv.common.scheduler.executor.DynamicCachingThreadPoolExecutor; import com.discordsrv.common.scheduler.executor.DynamicCachingThreadPoolExecutor;
import com.discordsrv.common.scheduler.threadfactory.CountingThreadFactory; import com.discordsrv.common.scheduler.threadfactory.CountingThreadFactory;
import com.discordsrv.common.someone.Someone;
import net.kyori.adventure.text.Component;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.SynchronousQueue; import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Consumer; import java.util.function.Consumer;
public abstract class RequiredLinkingModule<T extends DiscordSRV> extends AbstractModule<T> { public abstract class RequiredLinkingModule<T extends DiscordSRV> extends AbstractModule<T> {
private final List<Requirement<?>> availableRequirements = new ArrayList<>(); private final List<RequirementType<?>> availableRequirementTypes = new ArrayList<>();
protected final List<MinecraftAuthRequirement.Type> activeRequirementTypes = new ArrayList<>();
private ThreadPoolExecutor executor; private ThreadPoolExecutor executor;
public RequiredLinkingModule(T discordSRV) { public RequiredLinkingModule(T discordSRV) {
super(discordSRV); super(discordSRV);
} }
public DiscordSRV discordSRV() {
return discordSRV;
}
public abstract RequiredLinkingConfig config(); public abstract RequiredLinkingConfig config();
@Override
public boolean canEnableBeforeReady() {
return true;
}
@Override @Override
public boolean isEnabled() { public boolean isEnabled() {
return config().enabled; return discordSRV.config() == null || config().enabled;
} }
@Override @Override
@ -78,54 +101,165 @@ public abstract class RequiredLinkingModule<T extends DiscordSRV> extends Abstra
} }
@Override @Override
public void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) { public final void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) {
List<Requirement<?>> requirements = new ArrayList<>(); List<RequirementType<?>> requirementTypes = new ArrayList<>();
requirements.add(new DiscordRoleRequirement(discordSRV)); requirementTypes.add(new DiscordRoleRequirementType(this));
requirements.add(new DiscordServerRequirement(discordSRV)); requirementTypes.add(new DiscordServerRequirementType(this));
requirements.add(new DiscordBoostingRequirement(discordSRV)); requirementTypes.add(new DiscordBoostingRequirementType(this));
if (discordSRV.linkProvider() instanceof MinecraftAuthenticationLinker) { if (discordSRV.linkProvider() instanceof MinecraftAuthenticationLinker) {
requirements.addAll(MinecraftAuthRequirement.createRequirements(discordSRV)); requirementTypes.addAll(MinecraftAuthRequirementType.createRequirements(this));
} }
synchronized (availableRequirements) { synchronized (availableRequirementTypes) {
availableRequirements.clear(); for (RequirementType<?> requirementType : availableRequirementTypes) {
availableRequirements.addAll(requirements); discordSRV.moduleManager().unregister(requirementType);
}
availableRequirementTypes.clear();
for (RequirementType<?> requirementType : requirementTypes) {
discordSRV.moduleManager().register(requirementType);
}
availableRequirementTypes.addAll(requirementTypes);
}
if (discordSRV.config() != null) {
reload();
} }
} }
public List<MinecraftAuthRequirement.Type> getActiveRequirementTypes() { public abstract void reload();
return activeRequirementTypes;
public abstract List<ParsedRequirements> getAllActiveRequirements();
public abstract void recheck(IPlayer player);
private void recheck(Someone someone) {
someone.withPlayerUUID(discordSRV).thenApply(uuid -> {
if (uuid == null) {
return null;
}
return discordSRV.playerProvider().player(uuid);
}).whenComplete((onlinePlayer, t) -> {
if (t != null) {
logger().error("Failed to get linked account for " + someone, t);
}
if (onlinePlayer != null) {
recheck(onlinePlayer);
}
});
} }
protected List<CompiledRequirement> compile(List<String> requirements) { public <RT> void stateChanged(Someone someone, RequirementType<RT> requirementType, RT value, boolean newState) {
List<CompiledRequirement> checks = new ArrayList<>(); for (ParsedRequirements activeRequirement : getAllActiveRequirements()) {
for (String requirement : requirements) { for (Requirement<?> requirement : activeRequirement.usedRequirements()) {
BiFunction<UUID, Long, CompletableFuture<Boolean>> function = RequirementParser.getInstance().parse(requirement, availableRequirements, if (requirement.type() != requirementType
activeRequirementTypes); || !Objects.equals(requirement.value(), value)
checks.add(new CompiledRequirement(requirement, function)); || newState == requirement.negated()) {
} continue;
return checks; }
}
public static class CompiledRequirement { // One of the checks now fails
recheck(someone);
private final String input; break;
private final BiFunction<UUID, Long, CompletableFuture<Boolean>> function; }
protected CompiledRequirement(String input, BiFunction<UUID, Long, CompletableFuture<Boolean>> function) {
this.input = input;
this.function = function;
}
public String input() {
return input;
}
public BiFunction<UUID, Long, CompletableFuture<Boolean>> function() {
return function;
} }
} }
@Subscribe
public void onAccountUnlinked(AccountUnlinkedEvent event) {
recheck(Someone.of(event.getPlayerUUID()));
}
protected List<ParsedRequirements> compile(List<String> additionalRequirements) {
List<ParsedRequirements> parsed = new ArrayList<>();
for (String input : additionalRequirements) {
ParsedRequirements parsedRequirement = RequirementParser.getInstance()
.parse(input, availableRequirementTypes);
parsed.add(parsedRequirement);
}
return parsed;
}
public List<MinecraftAuthRequirementType.Provider> getActiveMinecraftAuthProviders() {
List<MinecraftAuthRequirementType.Provider> providers = new ArrayList<>();
for (ParsedRequirements parsedRequirements : getAllActiveRequirements()) {
for (Requirement<?> requirement : parsedRequirements.usedRequirements()) {
RequirementType<?> requirementType = requirement.type();
if (requirementType instanceof MinecraftAuthRequirementType) {
providers.add(((MinecraftAuthRequirementType<?>) requirementType).getProvider());
}
}
}
return providers;
}
public CompletableFuture<Component> getBlockReason(
RequirementsConfig config,
List<ParsedRequirements> additionalRequirements,
UUID playerUUID,
String playerName,
boolean join
) {
if (config.bypassUUIDs.contains(playerUUID.toString())) {
// Bypasses: let them through
logger().debug("Player " + playerName + " is bypassing required linking requirements");
return CompletableFuture.completedFuture(null);
}
LinkProvider linkProvider = discordSRV.linkProvider();
if (linkProvider == null) {
// Link provider unavailable but required linking enabled: error message
Component message = ComponentUtil.fromAPI(
discordSRV.messagesConfig().minecraft.unableToCheckLinkingStatus.textBuilder().build()
);
return CompletableFuture.completedFuture(message);
}
return linkProvider.queryUserId(playerUUID, true).thenCompose(opt -> {
if (!opt.isPresent()) {
// User is not linked
return linkProvider.getLinkingInstructions(playerName, playerUUID, null, join ? "join" : "freeze")
.thenApply(ComponentUtil::fromAPI);
}
long userId = opt.get();
if (additionalRequirements.isEmpty()) {
// No additional requirements: let them through
return CompletableFuture.completedFuture(null);
}
CompletableFuture<Void> pass = new CompletableFuture<>();
List<CompletableFuture<Boolean>> all = new ArrayList<>();
for (ParsedRequirements requirement : additionalRequirements) {
CompletableFuture<Boolean> future = requirement.predicate().apply(Someone.of(playerUUID, userId));
all.add(future.thenApply(val -> {
if (val) {
pass.complete(null);
}
return val;
}).exceptionally(t -> {
logger().debug("Check \"" + requirement.input() + "\" failed for "
+ playerName + " / " + Long.toUnsignedString(userId), t);
return null;
}));
}
// Complete when at least one passes or all of them completed
return CompletableFuture.anyOf(pass, CompletableFutureUtil.combine(all)).thenApply(v -> {
if (pass.isDone()) {
// One of the futures passed: let them through
return null;
}
// None of the futures passed: additional requirements not met
return Component.text("You did not pass additionalRequirements");
});
});
}
} }

View File

@ -18,26 +18,19 @@
package com.discordsrv.common.linking.requirelinking; package com.discordsrv.common.linking.requirelinking;
import com.discordsrv.api.DiscordSRVApi;
import com.discordsrv.common.DiscordSRV; import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.component.util.ComponentUtil;
import com.discordsrv.common.config.main.linking.RequirementsConfig;
import com.discordsrv.common.config.main.linking.ServerRequiredLinkingConfig; import com.discordsrv.common.config.main.linking.ServerRequiredLinkingConfig;
import com.discordsrv.common.future.util.CompletableFutureUtil; import com.discordsrv.common.linking.requirelinking.requirement.parser.ParsedRequirements;
import com.discordsrv.common.linking.LinkProvider;
import com.discordsrv.common.linking.requirelinking.requirement.MinecraftAuthRequirement;
import net.kyori.adventure.text.Component; import net.kyori.adventure.text.Component;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
public abstract class ServerRequireLinkingModule<T extends DiscordSRV> extends RequiredLinkingModule<T> { public abstract class ServerRequireLinkingModule<T extends DiscordSRV> extends RequiredLinkingModule<T> {
private final List<CompiledRequirement> compiledRequirements = new CopyOnWriteArrayList<>(); private final List<ParsedRequirements> additionalRequirements = new CopyOnWriteArrayList<>();
public ServerRequireLinkingModule(T discordSRV) { public ServerRequireLinkingModule(T discordSRV) {
super(discordSRV); super(discordSRV);
@ -47,85 +40,24 @@ public abstract class ServerRequireLinkingModule<T extends DiscordSRV> extends R
public abstract ServerRequiredLinkingConfig config(); public abstract ServerRequiredLinkingConfig config();
@Override @Override
public void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) { public void reload() {
super.reload(resultConsumer); synchronized (additionalRequirements) {
additionalRequirements.clear();
synchronized (compiledRequirements) { additionalRequirements.addAll(compile(config().requirements.additionalRequirements));
activeRequirementTypes.clear();
compiledRequirements.clear();
compiledRequirements.addAll(compile(config().requirements.requirements));
} }
} }
public List<MinecraftAuthRequirement.Type> getRequirementTypes() { @Override
return activeRequirementTypes; public List<ParsedRequirements> getAllActiveRequirements() {
return additionalRequirements;
} }
public CompletableFuture<Component> getBlockReason(UUID playerUUID, String playerName, boolean join) { public CompletableFuture<Component> getBlockReason(UUID playerUUID, String playerName, boolean join) {
RequirementsConfig config = config().requirements; List<ParsedRequirements> additionalRequirements;
if (config.bypassUUIDs.contains(playerUUID.toString())) { synchronized (this.additionalRequirements) {
// Bypasses: let them through additionalRequirements = this.additionalRequirements;
logger().debug("Player " + playerName + " is bypassing required linking requirements");
return CompletableFuture.completedFuture(null);
} }
LinkProvider linkProvider = discordSRV.linkProvider(); return getBlockReason(config().requirements, additionalRequirements, playerUUID, playerName, join);
if (linkProvider == null) {
// Link provider unavailable but required linking enabled: error message
Component message = ComponentUtil.fromAPI(
discordSRV.messagesConfig().minecraft.unableToCheckLinkingStatus.textBuilder().build()
);
return CompletableFuture.completedFuture(message);
}
return linkProvider.queryUserId(playerUUID, true)
.thenCompose(opt -> {
if (!opt.isPresent()) {
// User is not linked
return linkProvider.getLinkingInstructions(playerName, playerUUID, null, join ? "join" : "freeze")
.thenApply(ComponentUtil::fromAPI);
}
List<CompiledRequirement> requirements;
synchronized (compiledRequirements) {
requirements = compiledRequirements;
}
if (requirements.isEmpty()) {
// No additional requirements: let them through
return CompletableFuture.completedFuture(null);
}
CompletableFuture<Void> pass = new CompletableFuture<>();
List<CompletableFuture<Boolean>> all = new ArrayList<>();
long userId = opt.get();
for (CompiledRequirement requirement : requirements) {
CompletableFuture<Boolean> future = requirement.function().apply(playerUUID, userId);
all.add(future);
future.whenComplete((val, t) -> {
if (val != null && val) {
pass.complete(null);
}
}).exceptionally(t -> {
logger().debug("Check \"" + requirement.input() + "\" failed for " + playerName + " / " + Long.toUnsignedString(userId), t);
return null;
});
}
// Complete when at least one passes or all of them completed
return CompletableFuture.anyOf(pass, CompletableFutureUtil.combine(all))
.thenApply(v -> {
if (pass.isDone()) {
// One of the futures passed: let them through
return null;
}
// None of the futures passed: requirements not met
return Component.text("You did not pass requirements");
});
});
} }
} }

View File

@ -18,15 +18,27 @@
package com.discordsrv.common.linking.requirelinking.requirement; package com.discordsrv.common.linking.requirelinking.requirement;
import java.util.UUID; public class Requirement<T> {
import java.util.concurrent.CompletableFuture;
public interface Requirement<T> { private final RequirementType<T> type;
private final T value;
private final boolean negated;
String name(); public Requirement(RequirementType<T> type, T value, boolean negated) {
this.type = type;
this.value = value;
this.negated = negated;
}
T parse(String input); public RequirementType<T> type() {
return type;
}
CompletableFuture<Boolean> isMet(T value, UUID player, long userId); public T value() {
return value;
}
public boolean negated() {
return negated;
}
} }

View File

@ -18,34 +18,28 @@
package com.discordsrv.common.linking.requirelinking.requirement; package com.discordsrv.common.linking.requirelinking.requirement;
import com.discordsrv.api.discord.entity.guild.DiscordRole;
import com.discordsrv.common.DiscordSRV; import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.module.type.AbstractModule;
import com.discordsrv.common.someone.Someone;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public class DiscordRoleRequirement extends LongRequirement { public abstract class RequirementType<T> extends AbstractModule<DiscordSRV> {
private final DiscordSRV discordSRV; protected final RequiredLinkingModule<? extends DiscordSRV> module;
public DiscordRoleRequirement(DiscordSRV discordSRV) { public RequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
this.discordSRV = discordSRV; super(module.discordSRV());
this.module = module;
} }
@Override public final void stateChanged(Someone someone, T value, boolean newState) {
public String name() { module.stateChanged(someone, this, value, newState);
return "DiscordRole";
} }
@Override public abstract String name();
public CompletableFuture<Boolean> isMet(Long value, UUID player, long userId) { public abstract T parse(String input);
DiscordRole role = discordSRV.discordAPI().getRoleById(value); public abstract CompletableFuture<Boolean> isMet(T value, Someone.Resolved someone);
if (role == null) {
return CompletableFuture.completedFuture(false);
}
return role.getGuild()
.retrieveMemberById(userId)
.thenApply(member -> member.getRoles().contains(role));
}
} }

View File

@ -0,0 +1,55 @@
/*
* This file is part of DiscordSRV, licensed under the GPLv3 License
* Copyright (c) 2016-2024 Austin "Scarsz" Shapiro, Henri "Vankka" Schubin and DiscordSRV contributors
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package com.discordsrv.common.linking.requirelinking.requirement.parser;
import com.discordsrv.common.linking.requirelinking.requirement.Requirement;
import com.discordsrv.common.someone.Someone;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
public class ParsedRequirements {
private final String input;
private final Function<Someone.Resolved, CompletableFuture<Boolean>> predicate;
private final List<Requirement<?>> usedRequirements;
public ParsedRequirements(
String input,
Function<Someone.Resolved, CompletableFuture<Boolean>> predicate,
List<Requirement<?>> usedRequirements
) {
this.input = input;
this.predicate = predicate;
this.usedRequirements = usedRequirements;
}
public String input() {
return input;
}
public Function<Someone.Resolved, CompletableFuture<Boolean>> predicate() {
return predicate;
}
public List<Requirement<?>> usedRequirements() {
return usedRequirements;
}
}

View File

@ -19,13 +19,13 @@
package com.discordsrv.common.linking.requirelinking.requirement.parser; package com.discordsrv.common.linking.requirelinking.requirement.parser;
import com.discordsrv.common.future.util.CompletableFutureUtil; import com.discordsrv.common.future.util.CompletableFutureUtil;
import com.discordsrv.common.linking.requirelinking.requirement.MinecraftAuthRequirement;
import com.discordsrv.common.linking.requirelinking.requirement.Requirement; import com.discordsrv.common.linking.requirelinking.requirement.Requirement;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
import com.discordsrv.common.someone.Someone;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction; import java.util.function.BiFunction;
@ -42,15 +42,24 @@ public class RequirementParser {
private RequirementParser() {} private RequirementParser() {}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> BiFunction<UUID, Long, CompletableFuture<Boolean>> parse(String input, List<Requirement<?>> requirements, List<MinecraftAuthRequirement.Type> types) { public <T> ParsedRequirements parse(
List<Requirement<T>> reqs = new ArrayList<>(requirements.size()); String input,
requirements.forEach(r -> reqs.add((Requirement<T>) r)); List<RequirementType<?>> availableRequirementTypes
) {
List<RequirementType<T>> reqs = new ArrayList<>(availableRequirementTypes.size());
availableRequirementTypes.forEach(r -> reqs.add((RequirementType<T>) r));
Func func = parse(input, new AtomicInteger(0), reqs, types); List<Requirement<?>> usedRequirements = new ArrayList<>();
return func::test; Func func = parse(input, new AtomicInteger(0), reqs, usedRequirements);
return new ParsedRequirements(input, func::test, usedRequirements);
} }
private <T> Func parse(String input, AtomicInteger iterator, List<Requirement<T>> requirements, List<MinecraftAuthRequirement.Type> types) { private <T> Func parse(
String input,
AtomicInteger iterator,
List<RequirementType<T>> availableRequirementTypes,
List<Requirement<?>> parsedRequirements
) {
StringBuilder functionNameBuffer = new StringBuilder(); StringBuilder functionNameBuffer = new StringBuilder();
StringBuilder functionValueBuffer = new StringBuilder(); StringBuilder functionValueBuffer = new StringBuilder();
boolean isFunctionValue = false; boolean isFunctionValue = false;
@ -58,6 +67,7 @@ public class RequirementParser {
Func func = null; Func func = null;
Operator operator = null; Operator operator = null;
boolean operatorSecond = false; boolean operatorSecond = false;
boolean negated = false;
Function<String, RuntimeException> error = text -> { Function<String, RuntimeException> error = text -> {
int i = iterator.get(); int i = iterator.get();
@ -70,7 +80,7 @@ public class RequirementParser {
char c = chars[i]; char c = chars[i];
if (c == '(' && functionNameBuffer.length() == 0) { if (c == '(' && functionNameBuffer.length() == 0) {
iterator.incrementAndGet(); iterator.incrementAndGet();
Func function = parse(input, iterator, requirements, types); Func function = parse(input, iterator, availableRequirementTypes, parsedRequirements);
if (function == null) { if (function == null) {
throw error.apply("Empty brackets"); throw error.apply("Empty brackets");
} }
@ -103,18 +113,20 @@ public class RequirementParser {
String functionName = functionNameBuffer.toString(); String functionName = functionNameBuffer.toString();
String value = functionValueBuffer.toString(); String value = functionValueBuffer.toString();
for (Requirement<T> requirement : requirements) { for (RequirementType<T> requirementType : availableRequirementTypes) {
if (requirement.name().equalsIgnoreCase(functionName)) { if (requirementType.name().equalsIgnoreCase(functionName)) {
if (requirement instanceof MinecraftAuthRequirement) { T requirementValue = requirementType.parse(value);
types.add(((MinecraftAuthRequirement<?>) requirement).getType());
}
T requirementValue = requirement.parse(value);
if (requirementValue == null) { if (requirementValue == null) {
throw error.apply("Unacceptable function value for " + functionName); throw error.apply("Unacceptable function value for " + functionName);
} }
Func function = (player, user) -> requirement.isMet(requirementValue, player, user); boolean isNegated = negated;
negated = false;
parsedRequirements.add(new Requirement<>(requirementType, requirementValue, isNegated));
Func function = someone -> requirementType.isMet(requirementValue, someone)
.thenApply(val -> isNegated != val);
if (func != null) { if (func != null) {
if (operator == null) { if (operator == null) {
throw error.apply("No operator"); throw error.apply("No operator");
@ -163,12 +175,23 @@ public class RequirementParser {
throw error.apply("Operators must be exactly two of the same character"); throw error.apply("Operators must be exactly two of the same character");
} }
if (!Character.isSpaceChar(c)) { if (Character.isSpaceChar(c)) {
if (isFunctionValue) { continue;
functionValueBuffer.append(c); }
} else {
functionNameBuffer.append(c); if (isFunctionValue) {
functionValueBuffer.append(c);
} else {
if (c == '!') {
if (functionNameBuffer.length() > 0) {
throw error.apply("Negation must be before function name");
}
negated = !negated;
continue;
} }
functionNameBuffer.append(c);
} }
} }
@ -180,7 +203,7 @@ public class RequirementParser {
@FunctionalInterface @FunctionalInterface
private interface Func { private interface Func {
CompletableFuture<Boolean> test(UUID player, long user); CompletableFuture<Boolean> test(Someone.Resolved someone);
} }
private enum Operator { private enum Operator {
@ -197,7 +220,7 @@ public class RequirementParser {
} }
private static Func apply(Func one, Func two, BiFunction<Boolean, Boolean, Boolean> function) { private static Func apply(Func one, Func two, BiFunction<Boolean, Boolean, Boolean> function) {
return (player, user) -> CompletableFutureUtil.combine(one.test(player, user), two.test(player, user)) return someone -> CompletableFutureUtil.combine(one.test(someone), two.test(someone))
.thenApply(bools -> function.apply(bools.get(0), bools.get(1))); .thenApply(bools -> function.apply(bools.get(0), bools.get(1)));
} }
} }

View File

@ -16,20 +16,21 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package com.discordsrv.common.linking.requirelinking.requirement; package com.discordsrv.common.linking.requirelinking.requirement.type;
import com.discordsrv.api.discord.entity.guild.DiscordGuild; import com.discordsrv.api.discord.entity.guild.DiscordGuild;
import com.discordsrv.api.event.bus.Subscribe;
import com.discordsrv.common.DiscordSRV; import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.someone.Someone;
import net.dv8tion.jda.api.events.guild.member.update.GuildMemberUpdateBoostTimeEvent;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public class DiscordBoostingRequirement extends LongRequirement { public class DiscordBoostingRequirementType extends LongRequirementType {
private final DiscordSRV discordSRV; public DiscordBoostingRequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module);
public DiscordBoostingRequirement(DiscordSRV discordSRV) {
this.discordSRV = discordSRV;
} }
@Override @Override
@ -38,13 +39,18 @@ public class DiscordBoostingRequirement extends LongRequirement {
} }
@Override @Override
public CompletableFuture<Boolean> isMet(Long value, UUID player, long userId) { public CompletableFuture<Boolean> isMet(Long value, Someone.Resolved someone) {
DiscordGuild guild = discordSRV.discordAPI().getGuildById(value); DiscordGuild guild = module.discordSRV().discordAPI().getGuildById(value);
if (guild == null) { if (guild == null) {
return CompletableFuture.completedFuture(false); return CompletableFuture.completedFuture(false);
} }
return guild.retrieveMemberById(userId) return guild.retrieveMemberById(someone.userId())
.thenApply(member -> member != null && member.isBoosting()); .thenApply(member -> member != null && member.isBoosting());
} }
@Subscribe
public void onGuildMemberUpdateBoostTime(GuildMemberUpdateBoostTimeEvent event) {
stateChanged(Someone.of(event.getMember().getIdLong()), null, event.getNewTimeBoosted() != null);
}
} }

View File

@ -0,0 +1,63 @@
/*
* This file is part of DiscordSRV, licensed under the GPLv3 License
* Copyright (c) 2016-2024 Austin "Scarsz" Shapiro, Henri "Vankka" Schubin and DiscordSRV contributors
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package com.discordsrv.common.linking.requirelinking.requirement.type;
import com.discordsrv.api.discord.entity.guild.DiscordRole;
import com.discordsrv.api.event.bus.Subscribe;
import com.discordsrv.api.event.events.discord.member.role.AbstractDiscordMemberRoleChangeEvent;
import com.discordsrv.api.event.events.discord.member.role.DiscordMemberRoleAddEvent;
import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.someone.Someone;
import java.util.concurrent.CompletableFuture;
public class DiscordRoleRequirementType extends LongRequirementType {
public DiscordRoleRequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module);
}
@Override
public String name() {
return "DiscordRole";
}
@Override
public CompletableFuture<Boolean> isMet(Long value, Someone.Resolved someone) {
DiscordRole role = module.discordSRV().discordAPI().getRoleById(value);
if (role == null) {
return CompletableFuture.completedFuture(false);
}
return role.getGuild()
.retrieveMemberById(someone.userId())
.thenApply(member -> member.getRoles().contains(role));
}
@Subscribe
public void onDiscordMemberRoleAdd(AbstractDiscordMemberRoleChangeEvent<?> event) {
boolean add = event instanceof DiscordMemberRoleAddEvent;
Someone someone = Someone.of(event.getMember().getUser().getId());
for (DiscordRole role : event.getRoles()) {
stateChanged(someone, role.getId(), add);
}
}
}

View File

@ -16,21 +16,23 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package com.discordsrv.common.linking.requirelinking.requirement; package com.discordsrv.common.linking.requirelinking.requirement.type;
import com.discordsrv.api.discord.entity.guild.DiscordGuild; import com.discordsrv.api.discord.entity.guild.DiscordGuild;
import com.discordsrv.api.event.bus.Subscribe;
import com.discordsrv.common.DiscordSRV; import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.someone.Someone;
import net.dv8tion.jda.api.events.guild.member.GuildMemberJoinEvent;
import net.dv8tion.jda.api.events.guild.member.GuildMemberRemoveEvent;
import java.util.Objects; import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public class DiscordServerRequirement extends LongRequirement { public class DiscordServerRequirementType extends LongRequirementType {
private final DiscordSRV discordSRV; public DiscordServerRequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module);
public DiscordServerRequirement(DiscordSRV discordSRV) {
this.discordSRV = discordSRV;
} }
@Override @Override
@ -39,13 +41,23 @@ public class DiscordServerRequirement extends LongRequirement {
} }
@Override @Override
public CompletableFuture<Boolean> isMet(Long value, UUID player, long userId) { public CompletableFuture<Boolean> isMet(Long value, Someone.Resolved someone) {
DiscordGuild guild = discordSRV.discordAPI().getGuildById(value); DiscordGuild guild = module.discordSRV().discordAPI().getGuildById(value);
if (guild == null) { if (guild == null) {
return CompletableFuture.completedFuture(false); return CompletableFuture.completedFuture(false);
} }
return guild.retrieveMemberById(userId) return guild.retrieveMemberById(someone.userId())
.thenApply(Objects::nonNull); .thenApply(Objects::nonNull);
} }
@Subscribe
public void onGuildMemberJoin(GuildMemberJoinEvent event) {
stateChanged(Someone.of(event.getUser().getIdLong()), event.getGuild().getIdLong(), true);
}
@Subscribe
public void onGuildMemberRemove(GuildMemberRemoveEvent event) {
stateChanged(Someone.of(event.getUser().getIdLong()), event.getGuild().getIdLong(), false);
}
} }

View File

@ -16,9 +16,17 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package com.discordsrv.common.linking.requirelinking.requirement; package com.discordsrv.common.linking.requirelinking.requirement.type;
public abstract class LongRequirement implements Requirement<Long> { import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
public abstract class LongRequirementType extends RequirementType<Long> {
public LongRequirementType(RequiredLinkingModule<? extends DiscordSRV> module) {
super(module);
}
@Override @Override
public Long parse(String input) { public Long parse(String input) {

View File

@ -16,9 +16,12 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package com.discordsrv.common.linking.requirelinking.requirement; package com.discordsrv.common.linking.requirelinking.requirement.type;
import com.discordsrv.common.DiscordSRV; import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
import com.discordsrv.common.someone.Someone;
import me.minecraftauth.lib.AuthService; import me.minecraftauth.lib.AuthService;
import me.minecraftauth.lib.account.platform.twitch.SubTier; import me.minecraftauth.lib.account.platform.twitch.SubTier;
import me.minecraftauth.lib.exception.LookupException; import me.minecraftauth.lib.exception.LookupException;
@ -30,41 +33,41 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.function.Function; import java.util.function.Function;
public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthRequirement.Reference<T>> { public class MinecraftAuthRequirementType<T> extends RequirementType<MinecraftAuthRequirementType.Reference<T>> {
private static final Reference<?> NULL_VALUE = new Reference<>(null); private static final Reference<?> NULL_VALUE = new Reference<>(null);
public static List<Requirement<?>> createRequirements(DiscordSRV discordSRV) { public static List<RequirementType<?>> createRequirements(RequiredLinkingModule<?> module) {
List<Requirement<?>> requirements = new ArrayList<>(); List<RequirementType<?>> requirementTypes = new ArrayList<>();
// Patreon // Patreon
requirements.add(new MinecraftAuthRequirement<>( requirementTypes.add(new MinecraftAuthRequirementType<>(
discordSRV, module,
Type.PATREON, Provider.PATREON,
"PatreonSubscriber", "PatreonSubscriber",
AuthService::isSubscribedPatreon, AuthService::isSubscribedPatreon,
AuthService::isSubscribedPatreon AuthService::isSubscribedPatreon
)); ));
// Glimpse // Glimpse
requirements.add(new MinecraftAuthRequirement<>( requirementTypes.add(new MinecraftAuthRequirementType<>(
discordSRV, module,
Type.GLIMPSE, Provider.GLIMPSE,
"GlimpseSubscriber", "GlimpseSubscriber",
AuthService::isSubscribedGlimpse, AuthService::isSubscribedGlimpse,
AuthService::isSubscribedGlimpse AuthService::isSubscribedGlimpse
)); ));
// Twitch // Twitch
requirements.add(new MinecraftAuthRequirement<>( requirementTypes.add(new MinecraftAuthRequirementType<>(
discordSRV, module,
Type.TWITCH, Provider.TWITCH,
"TwitchFollower", "TwitchFollower",
AuthService::isFollowingTwitch AuthService::isFollowingTwitch
)); ));
requirements.add(new MinecraftAuthRequirement<>( requirementTypes.add(new MinecraftAuthRequirementType<>(
discordSRV, module,
Type.TWITCH, Provider.TWITCH,
"TwitchSubscriber", "TwitchSubscriber",
AuthService::isSubscribedTwitch, AuthService::isSubscribedTwitch,
AuthService::isSubscribedTwitch, AuthService::isSubscribedTwitch,
@ -79,60 +82,59 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
)); ));
// YouTube // YouTube
requirements.add(new MinecraftAuthRequirement<>( requirementTypes.add(new MinecraftAuthRequirementType<>(
discordSRV, module,
Type.YOUTUBE, Provider.YOUTUBE,
"YouTubeSubscriber", "YouTubeSubscriber",
AuthService::isSubscribedYouTube AuthService::isSubscribedYouTube
)); ));
requirements.add(new MinecraftAuthRequirement<>( requirementTypes.add(new MinecraftAuthRequirementType<>(
discordSRV, module,
Type.YOUTUBE, Provider.YOUTUBE,
"YouTubeMember", "YouTubeMember",
AuthService::isMemberYouTube, AuthService::isMemberYouTube,
AuthService::isMemberYouTube AuthService::isMemberYouTube
)); ));
return requirements; return requirementTypes;
} }
private final DiscordSRV discordSRV; private final Provider provider;
private final Type type;
private final String name; private final String name;
private final Test test; private final Test test;
private final TestSpecific<T> testSpecific; private final TestSpecific<T> testSpecific;
private final Function<String, T> parse; private final Function<String, T> parse;
public MinecraftAuthRequirement( public MinecraftAuthRequirementType(
DiscordSRV discordSRV, RequiredLinkingModule<? extends DiscordSRV> module,
Type type, Provider provider,
String name, String name,
Test test Test test
) { ) {
this(discordSRV, type, name, test, null, null); this(module, provider, name, test, null, null);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public MinecraftAuthRequirement( public MinecraftAuthRequirementType(
DiscordSRV discordSRV, RequiredLinkingModule<? extends DiscordSRV> module,
Type type, Provider provider,
String name, String name,
Test test, Test test,
TestSpecific<String> testSpecific TestSpecific<String> testSpecific
) { ) {
this(discordSRV, type, name, test, (TestSpecific<T>) testSpecific, t -> (T) t); this(module, provider, name, test, (TestSpecific<T>) testSpecific, t -> (T) t);
} }
public MinecraftAuthRequirement( public MinecraftAuthRequirementType(
DiscordSRV discordSRV, RequiredLinkingModule<? extends DiscordSRV> module,
Type type, Provider provider,
String name, String name,
Test test, Test test,
TestSpecific<T> testSpecific, TestSpecific<T> testSpecific,
Function<String, T> parse Function<String, T> parse
) { ) {
this.discordSRV = discordSRV; super(module);
this.type = type; this.provider = provider;
this.name = name; this.name = name;
this.test = test; this.test = test;
this.testSpecific = testSpecific; this.testSpecific = testSpecific;
@ -144,8 +146,8 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
return name; return name;
} }
public Type getType() { public Provider getProvider() {
return type; return provider;
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -161,14 +163,14 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
} }
@Override @Override
public CompletableFuture<Boolean> isMet(Reference<T> atomicReference, UUID player, long userId) { public CompletableFuture<Boolean> isMet(Reference<T> atomicReference, Someone.Resolved someone) {
String token = discordSRV.connectionConfig().minecraftAuth.token; String token = module.discordSRV().connectionConfig().minecraftAuth.token;
T value = atomicReference.getValue(); T value = atomicReference.getValue();
return discordSRV.scheduler().supply(() -> { return module.discordSRV().scheduler().supply(() -> {
if (value == null) { if (value == null) {
return test.test(token, player); return test.test(token, someone.playerUUID());
} else { } else {
return testSpecific.test(token, player, value); return testSpecific.test(token, someone.playerUUID(), value);
} }
}); });
} }
@ -196,7 +198,7 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
} }
} }
public enum Type { public enum Provider {
PATREON('p'), PATREON('p'),
GLIMPSE('g'), GLIMPSE('g'),
TWITCH('t'), TWITCH('t'),
@ -204,7 +206,7 @@ public class MinecraftAuthRequirement<T> implements Requirement<MinecraftAuthReq
private final char character; private final char character;
Type(char character) { Provider(char character) {
this.character = character; this.character = character;
} }

View File

@ -118,6 +118,12 @@ public class ModuleManager {
: delegates.computeIfAbsent(module, mod -> new ModuleDelegate(discordSRV, mod)); : delegates.computeIfAbsent(module, mod -> new ModuleDelegate(discordSRV, mod));
} }
private String getName(AbstractModule<?> abstractModule) {
return abstractModule instanceof ModuleDelegate
? ((ModuleDelegate) abstractModule).getBase().getClass().getName()
: abstractModule.getClass().getSimpleName();
}
public <DT extends DiscordSRV> void registerModule(DT discordSRV, CheckedFunction<DT, AbstractModule<?>> function) { public <DT extends DiscordSRV> void registerModule(DT discordSRV, CheckedFunction<DT, AbstractModule<?>> function) {
try { try {
register(function.apply(discordSRV)); register(function.apply(discordSRV));
@ -134,12 +140,10 @@ public class ModuleManager {
this.modules.add(module); this.modules.add(module);
this.moduleLookupTable.put(module.getClass().getName(), module); this.moduleLookupTable.put(module.getClass().getName(), module);
logger.debug(module + " registered"); logger.debug(module.getClass().getName() + " registered");
if (discordSRV.isReady()) { // Enable the module if we're already ready
// Check if Discord connection is ready, if it is already we'll enable the module enableOrDisableAsNeeded(getAbstract(module), discordSRV.isReady(), true);
enable(getAbstract(module));
}
} }
public void unregister(Module module) { public void unregister(Module module) {
@ -154,17 +158,35 @@ public class ModuleManager {
this.moduleLookupTable.values().removeIf(mod -> mod == module); this.moduleLookupTable.values().removeIf(mod -> mod == module);
this.delegates.remove(module); this.delegates.remove(module);
logger.debug(module + " unregistered"); logger.debug(module.getClass().getName() + " unregistered");
} }
private void enable(AbstractModule<?> module) { private List<DiscordSRVApi.ReloadResult> enable(AbstractModule<?> module) {
try { try {
if (module.enableModule()) { if (module.enableModule()) {
logger.debug(module + " enabled"); logger.debug(module + " enabled");
return reload(module);
} }
} catch (Throwable t) { } catch (Throwable t) {
discordSRV.logger().error("Failed to enable " + module.getClass().getSimpleName(), t); discordSRV.logger().error("Failed to enable " + getName(module), t);
return Collections.emptyList();
} }
return null;
}
private List<DiscordSRVApi.ReloadResult> reload(AbstractModule<?> module) {
List<DiscordSRVApi.ReloadResult> reloadResults = new ArrayList<>();
try {
module.reload(result -> {
if (result == null) {
throw new NullPointerException("null result supplied to resultConsumer");
}
reloadResults.add(result);
});
} catch (Throwable t) {
discordSRV.logger().error("Failed to reload " + getName(module), t);
}
return reloadResults;
} }
private void disable(AbstractModule<?> module) { private void disable(AbstractModule<?> module) {
@ -173,7 +195,7 @@ public class ModuleManager {
logger.debug(module + " disabled"); logger.debug(module + " disabled");
} }
} catch (Throwable t) { } catch (Throwable t) {
discordSRV.logger().error("Failed to disable " + module.getClass().getSimpleName(), t); discordSRV.logger().error("Failed to disable " + getName(module), t);
} }
} }
@ -189,52 +211,21 @@ public class ModuleManager {
reload(); reload();
} }
public synchronized List<DiscordSRV.ReloadResult> reload() { public List<DiscordSRVApi.ReloadResult> reload() {
JDAConnectionManager connectionManager = discordSRV.discordConnectionManager(); return reloadAndEnableModules(true);
}
public void enableModules() {
reloadAndEnableModules(false);
}
private synchronized List<DiscordSRV.ReloadResult> reloadAndEnableModules(boolean reload) {
boolean isReady = discordSRV.isReady();
logger().debug((reload ? "Reloading" : "Enabling") + " modules (DiscordSRV ready = " + isReady + ")");
Set<DiscordSRVApi.ReloadResult> reloadResults = new HashSet<>(); Set<DiscordSRVApi.ReloadResult> reloadResults = new HashSet<>();
for (Module module : modules) { for (Module module : modules) {
AbstractModule<?> abstractModule = getAbstract(module); reloadResults.addAll(enableOrDisableAsNeeded(getAbstract(module), isReady, reload));
boolean fail = false;
if (abstractModule.isEnabled()) {
for (DiscordGatewayIntent requiredIntent : abstractModule.getRequestedIntents()) {
if (!connectionManager.getIntents().contains(requiredIntent)) {
fail = true;
logger().warning("Missing gateway intent " + requiredIntent.name() + " for module " + module.getClass().getSimpleName());
}
}
for (DiscordCacheFlag requiredCacheFlag : abstractModule.getRequestedCacheFlags()) {
if (!connectionManager.getCacheFlags().contains(requiredCacheFlag)) {
fail = true;
logger().warning("Missing cache flag " + requiredCacheFlag.name() + " for module " + module.getClass().getSimpleName());
}
}
}
if (fail) {
reloadResults.add(ReloadResults.DISCORD_CONNECTION_RELOAD_REQUIRED);
}
// Check if the module needs to be enabled or disabled
if (!fail) {
enable(abstractModule);
}
if (!abstractModule.isEnabled()) {
disable(abstractModule);
continue;
}
try {
abstractModule.reload(result -> {
if (result == null) {
throw new NullPointerException("null result supplied to resultConsumer");
}
reloadResults.add(result);
});
} catch (Throwable t) {
discordSRV.logger().error("Failed to reload " + module.getClass().getSimpleName(), t);
}
} }
List<DiscordSRVApi.ReloadResult> results = new ArrayList<>(); List<DiscordSRVApi.ReloadResult> results = new ArrayList<>();
@ -249,6 +240,52 @@ public class ModuleManager {
return results; return results;
} }
private List<DiscordSRVApi.ReloadResult> enableOrDisableAsNeeded(AbstractModule<?> module, boolean isReady, boolean reload) {
boolean canBeEnabled = isReady || module.canEnableBeforeReady();
if (!canBeEnabled) {
return Collections.emptyList();
}
boolean enabled = module.isEnabled();
if (!enabled) {
disable(module);
return Collections.emptyList();
}
JDAConnectionManager connectionManager = discordSRV.discordConnectionManager();
boolean fail = false;
for (DiscordGatewayIntent requiredIntent : module.getRequestedIntents()) {
if (!connectionManager.getIntents().contains(requiredIntent)) {
fail = true;
logger().warning("Missing gateway intent " + requiredIntent.name() + " for module " + getName(module));
}
}
for (DiscordCacheFlag requiredCacheFlag : module.getRequestedCacheFlags()) {
if (!connectionManager.getCacheFlags().contains(requiredCacheFlag)) {
fail = true;
logger().warning("Missing cache flag " + requiredCacheFlag.name() + " for module " + getName(module));
}
}
List<DiscordSRVApi.ReloadResult> reloadResults = new ArrayList<>();
if (fail) {
reloadResults.add(ReloadResults.DISCORD_CONNECTION_RELOAD_REQUIRED);
}
// Enable the module if reload passed
if (!fail) {
List<DiscordSRVApi.ReloadResult> results = enable(module);
if (results != null) {
reloadResults.addAll(results);
} else if (reload) {
reloadResults.addAll(reload(module));
}
}
return reloadResults;
}
@Subscribe @Subscribe
public void onDebugGenerate(DebugGenerateEvent event) { public void onDebugGenerate(DebugGenerateEvent event) {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();

View File

@ -35,7 +35,7 @@ public abstract class AbstractModule<DT extends DiscordSRV> implements Module {
protected final DT discordSRV; protected final DT discordSRV;
private final Logger logger; private final Logger logger;
private boolean hasBeenEnabled = false; private boolean isCurrentlyEnabled = false;
private final List<DiscordGatewayIntent> requestedIntents = new ArrayList<>(); private final List<DiscordGatewayIntent> requestedIntents = new ArrayList<>();
private final List<DiscordCacheFlag> requestedCacheFlags = new ArrayList<>(); private final List<DiscordCacheFlag> requestedCacheFlags = new ArrayList<>();
@ -72,11 +72,11 @@ public abstract class AbstractModule<DT extends DiscordSRV> implements Module {
// Internal // Internal
public final boolean enableModule() { public final boolean enableModule() {
if (hasBeenEnabled || !isEnabled()) { if (isCurrentlyEnabled) {
return false; return false;
} }
hasBeenEnabled = true; isCurrentlyEnabled = true;
enable(); enable();
try { try {
@ -87,12 +87,12 @@ public abstract class AbstractModule<DT extends DiscordSRV> implements Module {
} }
public final boolean disableModule() { public final boolean disableModule() {
if (!hasBeenEnabled) { if (!isCurrentlyEnabled) {
return false; return false;
} }
disable(); disable();
hasBeenEnabled = false; isCurrentlyEnabled = false;
try { try {
discordSRV.eventBus().unsubscribe(this); discordSRV.eventBus().unsubscribe(this);

View File

@ -39,6 +39,10 @@ public class ModuleDelegate extends AbstractModule<DiscordSRV> {
this.module = module; this.module = module;
} }
public Module getBase() {
return module;
}
@Override @Override
public boolean isEnabled() { public boolean isEnabled() {
return module.isEnabled(); return module.isEnabled();

View File

@ -27,6 +27,7 @@ import com.discordsrv.common.config.main.PresenceUpdaterConfig;
import com.discordsrv.common.logging.NamedLogger; import com.discordsrv.common.logging.NamedLogger;
import com.discordsrv.common.module.type.AbstractModule; import com.discordsrv.common.module.type.AbstractModule;
import net.dv8tion.jda.api.JDA; import net.dv8tion.jda.api.JDA;
import net.dv8tion.jda.api.events.StatusChangeEvent;
import java.time.Duration; import java.time.Duration;
import java.util.List; import java.util.List;
@ -45,6 +46,11 @@ public class PresenceUpdaterModule extends AbstractModule<DiscordSRV> {
super(discordSRV, new NamedLogger(discordSRV, "PRESENCE_UPDATER")); super(discordSRV, new NamedLogger(discordSRV, "PRESENCE_UPDATER"));
} }
@Override
public boolean canEnableBeforeReady() {
return true;
}
public void serverStarted() { public void serverStarted() {
serverState.set(ServerState.STARTED); serverState.set(ServerState.STARTED);
setPresenceOrSchedule(); setPresenceOrSchedule();
@ -56,8 +62,19 @@ public class PresenceUpdaterModule extends AbstractModule<DiscordSRV> {
setPresenceOrSchedule(); setPresenceOrSchedule();
} }
@Subscribe
public void onStatusChange(StatusChangeEvent event) {
if (event.getNewStatus() == JDA.Status.IDENTIFYING_SESSION) {
setPresenceOrSchedule();
}
}
@Override @Override
public void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) { public void reload(Consumer<DiscordSRVApi.ReloadResult> resultConsumer) {
if (discordSRV.jda() == null) {
return;
}
setPresenceOrSchedule(); setPresenceOrSchedule();
// Log problems with presences // Log problems with presences

View File

@ -63,6 +63,10 @@ public class Someone {
this.userId = userId; this.userId = userId;
} }
private <T> T throwIllegal() {
throw new IllegalStateException("Cannot have Someone instance without either a Player UUID or User Id");
}
@NotNull @NotNull
public CompletableFuture<@NotNull Profile> profile(DiscordSRV discordSRV) { public CompletableFuture<@NotNull Profile> profile(DiscordSRV discordSRV) {
if (playerUUID != null) { if (playerUUID != null) {
@ -70,7 +74,7 @@ public class Someone {
} else if (userId != null) { } else if (userId != null) {
return discordSRV.profileManager().lookupProfile(userId); return discordSRV.profileManager().lookupProfile(userId);
} else { } else {
throw new IllegalStateException("Cannot have Someone instance without either a Player UUID or User Id"); return throwIllegal();
} }
} }
@ -80,14 +84,31 @@ public class Someone {
return CompletableFuture.completedFuture(of(playerUUID, userId)); return CompletableFuture.completedFuture(of(playerUUID, userId));
} }
return profile(discordSRV).thenApply(profile -> { if (playerUUID != null) {
UUID playerUUID = profile.playerUUID(); return withUserId(discordSRV).thenApply(userId -> userId != null ? of(playerUUID, userId) : null);
Long userId = profile.userId(); } else if (userId != null) {
if (playerUUID == null || userId == null) { return withPlayerUUID(discordSRV).thenApply(playerUUID -> playerUUID != null ? of(playerUUID, userId) : null);
return null; } else {
} return throwIllegal();
return of(playerUUID, userId); }
}); }
public CompletableFuture<@Nullable Long> withUserId(DiscordSRV discordSRV) {
if (userId != null) {
return CompletableFuture.completedFuture(userId);
} else if (playerUUID == null) {
return throwIllegal();
}
return discordSRV.linkProvider().getUserId(playerUUID).thenApply(opt -> opt.orElse(null));
}
public CompletableFuture<@Nullable UUID> withPlayerUUID(DiscordSRV discordSRV) {
if (playerUUID != null) {
return CompletableFuture.completedFuture(playerUUID);
} else if (userId == null) {
return throwIllegal();
}
return discordSRV.linkProvider().getPlayerUUID(userId).thenApply(opt -> opt.orElse(null));
} }
@Nullable @Nullable
@ -102,7 +123,7 @@ public class Someone {
@Override @Override
public String toString() { public String toString() {
return playerUUID != null ? playerUUID.toString() : Objects.requireNonNull(userId).toString(); return playerUUID != null ? playerUUID.toString() : Long.toUnsignedString(Objects.requireNonNull(userId));
} }
@SuppressWarnings("DataFlowIssue") @SuppressWarnings("DataFlowIssue")

View File

@ -18,24 +18,48 @@
package com.discordsrv.common.linking.requirement.parser; package com.discordsrv.common.linking.requirement.parser;
import com.discordsrv.common.linking.requirelinking.requirement.Requirement; import com.discordsrv.common.DiscordSRV;
import com.discordsrv.common.MockDiscordSRV;
import com.discordsrv.common.config.main.linking.RequiredLinkingConfig;
import com.discordsrv.common.linking.requirelinking.RequiredLinkingModule;
import com.discordsrv.common.linking.requirelinking.requirement.RequirementType;
import com.discordsrv.common.linking.requirelinking.requirement.parser.ParsedRequirements;
import com.discordsrv.common.linking.requirelinking.requirement.parser.RequirementParser; import com.discordsrv.common.linking.requirelinking.requirement.parser.RequirementParser;
import com.discordsrv.common.player.IPlayer;
import com.discordsrv.common.someone.Someone;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.api.function.Executable;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class RequirementParserTest { public class RequirementTypeParserTest {
private final RequirementParser requirementParser = RequirementParser.getInstance(); private final RequirementParser requirementParser = RequirementParser.getInstance();
private final List<Requirement<?>> requirements = Arrays.asList( private final RequiredLinkingModule<?> module = new RequiredLinkingModule<DiscordSRV>(MockDiscordSRV.INSTANCE) {
new Requirement<Boolean>() { @Override
public RequiredLinkingConfig config() {
return null;
}
@Override
public void reload() {}
@Override
public List<ParsedRequirements> getAllActiveRequirements() {
return Collections.emptyList();
}
@Override
public void recheck(IPlayer player) {}
};
private final List<RequirementType<?>> requirementTypes = Arrays.asList(
new RequirementType<Boolean>(module) {
@Override @Override
public String name() { public String name() {
return "F"; return "F";
@ -47,11 +71,11 @@ public class RequirementParserTest {
} }
@Override @Override
public CompletableFuture<Boolean> isMet(Boolean value, UUID player, long userId) { public CompletableFuture<Boolean> isMet(Boolean value, Someone.Resolved someone) {
return CompletableFuture.completedFuture(value); return CompletableFuture.completedFuture(value);
} }
}, },
new Requirement<Object>() { new RequirementType<Object>(module) {
@Override @Override
public String name() { public String name() {
return "AlwaysError"; return "AlwaysError";
@ -63,14 +87,17 @@ public class RequirementParserTest {
} }
@Override @Override
public CompletableFuture<Boolean> isMet(Object value, UUID player, long userId) { public CompletableFuture<Boolean> isMet(Object value, Someone.Resolved someone) {
return null; return null;
} }
} }
); );
private boolean parse(String input) { private boolean parse(String input) {
return requirementParser.parse(input, requirements, new ArrayList<>()).apply(null, 0L).join(); return requirementParser.parse(input, requirementTypes)
.predicate()
.apply(Someone.of(UUID.randomUUID(), 0L))
.join();
} }
@Test @Test
@ -78,6 +105,21 @@ public class RequirementParserTest {
assertFalse(parse("f(false) || F(false)")); assertFalse(parse("f(false) || F(false)"));
} }
@Test
public void negate() {
assertTrue(parse("!F(false)"));
}
@Test
public void negateReverse() {
assertFalse(parse("!F(true)"));
}
@Test
public void doubleNegate() {
assertTrue(parse("!!F(true)"));
}
@Test @Test
public void orFail() { public void orFail() {
assertFalse(parse("F(false) || F(false)")); assertFalse(parse("F(false) || F(false)"));
@ -98,6 +140,11 @@ public class RequirementParserTest {
assertTrue(parse("F(true) && F(true)")); assertTrue(parse("F(true) && F(true)"));
} }
@Test
public void andNegate() {
assertTrue(parse("F(true) && !F(false)"));
}
@Test @Test
public void complexFail() { public void complexFail() {
assertFalse(parse("F(true) && (F(false) && F(true))")); assertFalse(parse("F(true) && (F(false) && F(true))"));
@ -143,4 +190,9 @@ public class RequirementParserTest {
assertExceptionMessageStartsWith("Unacceptable function value for", () -> parse("AlwaysError()")); assertExceptionMessageStartsWith("Unacceptable function value for", () -> parse("AlwaysError()"));
} }
@Test
public void negateBeforeFunctionNameError() {
assertExceptionMessageStartsWith("Negation must be before function name", () -> parse("F!(false)"));
}
} }