Refactor Log4JFilter and improve branch coverage

This commit is contained in:
ljacqu 2015-11-20 23:17:50 +01:00
parent 019390dfe0
commit 84de22c9c0
3 changed files with 112 additions and 51 deletions

View File

@ -6,80 +6,50 @@ import org.apache.logging.log4j.core.LogEvent;
import org.apache.logging.log4j.core.Logger; import org.apache.logging.log4j.core.Logger;
import org.apache.logging.log4j.message.Message; import org.apache.logging.log4j.message.Message;
import fr.xephi.authme.util.StringUtils;
/** /**
* * Implements a filter for Log4j to skip sensitive AuthMe commands.
* @author Xephi59 * @author Xephi59
*/ */
public class Log4JFilter implements org.apache.logging.log4j.core.Filter { public class Log4JFilter implements org.apache.logging.log4j.core.Filter {
/** List of commands (lower-case) to skip. */
private static final String[] COMMANDS_TO_SKIP = { "/login ", "/l ", "/reg ", "/changepassword ",
"/unregister ", "/authme register ", "/authme changepassword ", "/authme reg ", "/authme cp ",
"/register " };
/** Constructor. */
public Log4JFilter() { public Log4JFilter() {
} }
@Override @Override
public Result filter(LogEvent record) { public Result filter(LogEvent record) {
try { if (record == null) {
if (record == null || record.getMessage() == null)
return Result.NEUTRAL;
String logM = record.getMessage().getFormattedMessage().toLowerCase();
if (!logM.contains("issued server command:"))
return Result.NEUTRAL;
if (!logM.contains("/login ") && !logM.contains("/l ") && !logM.contains("/reg ") && !logM.contains("/changepassword ") && !logM.contains("/unregister ") && !logM.contains("/authme register ") && !logM.contains("/authme changepassword ") && !logM.contains("/authme reg ") && !logM.contains("/authme cp ") && !logM.contains("/register "))
return Result.NEUTRAL;
return Result.DENY;
} catch (NullPointerException npe) {
return Result.NEUTRAL; return Result.NEUTRAL;
} }
return validateMessage(record.getMessage());
} }
@Override @Override
public Result filter(Logger arg0, Level arg1, Marker arg2, String message, public Result filter(Logger arg0, Level arg1, Marker arg2, String message,
Object... arg4) { Object... arg4) {
try { return validateMessage(message);
if (message == null)
return Result.NEUTRAL;
String logM = message.toLowerCase();
if (!logM.contains("issued server command:"))
return Result.NEUTRAL;
if (!logM.contains("/login ") && !logM.contains("/l ") && !logM.contains("/reg ") && !logM.contains("/changepassword ") && !logM.contains("/unregister ") && !logM.contains("/authme register ") && !logM.contains("/authme changepassword ") && !logM.contains("/authme reg ") && !logM.contains("/authme cp ") && !logM.contains("/register "))
return Result.NEUTRAL;
return Result.DENY;
} catch (NullPointerException npe) {
return Result.NEUTRAL;
}
} }
@Override @Override
public Result filter(Logger arg0, Level arg1, Marker arg2, Object message, public Result filter(Logger arg0, Level arg1, Marker arg2, Object message,
Throwable arg4) { Throwable arg4) {
try { if (message == null) {
if (message == null)
return Result.NEUTRAL;
String logM = message.toString().toLowerCase();
if (!logM.contains("issued server command:"))
return Result.NEUTRAL;
if (!logM.contains("/login ") && !logM.contains("/l ") && !logM.contains("/reg ") && !logM.contains("/changepassword ") && !logM.contains("/unregister ") && !logM.contains("/authme register ") && !logM.contains("/authme changepassword ") && !logM.contains("/authme reg ") && !logM.contains("/authme cp ") && !logM.contains("/register "))
return Result.NEUTRAL;
return Result.DENY;
} catch (NullPointerException npe) {
return Result.NEUTRAL; return Result.NEUTRAL;
} }
return validateMessage(message.toString());
} }
@Override @Override
public Result filter(Logger arg0, Level arg1, Marker arg2, Message message, public Result filter(Logger arg0, Level arg1, Marker arg2, Message message,
Throwable arg4) { Throwable arg4) {
try { return validateMessage(message);
if (message == null)
return Result.NEUTRAL;
String logM = message.getFormattedMessage().toLowerCase();
if (!logM.contains("issued server command:"))
return Result.NEUTRAL;
if (!logM.contains("/login ") && !logM.contains("/l ") && !logM.contains("/reg ") && !logM.contains("/changepassword ") && !logM.contains("/unregister ") && !logM.contains("/authme register ") && !logM.contains("/authme changepassword ") && !logM.contains("/authme reg ") && !logM.contains("/authme cp ") && !logM.contains("/register "))
return Result.NEUTRAL;
return Result.DENY;
} catch (NullPointerException npe) {
return Result.NEUTRAL;
}
} }
@Override @Override
@ -92,4 +62,39 @@ public class Log4JFilter implements org.apache.logging.log4j.core.Filter {
return Result.NEUTRAL; return Result.NEUTRAL;
} }
/**
* Validates a Message instance and returns the {@link Result} value
* depending depending on whether the message contains sensitive AuthMe
* data.
*
* @param message the Message object to verify
* @return the Result value
*/
private static Result validateMessage(Message message) {
if (message == null) {
return Result.NEUTRAL;
}
return validateMessage(message.getFormattedMessage());
}
/**
* Validates a message and returns the {@link Result} value depending
* depending on whether the message contains sensitive AuthMe data.
*
* @param message the message to verify
* @return the Result value
*/
private static Result validateMessage(String message) {
if (message == null) {
return Result.NEUTRAL;
}
String lowerMessage = message.toLowerCase();
if (lowerMessage.contains("issued server command:")
&& StringUtils.containsAny(lowerMessage, COMMANDS_TO_SKIP)) {
return Result.DENY;
}
return Result.NEUTRAL;
}
} }

View File

@ -25,4 +25,16 @@ public class StringUtils {
// Determine the difference value, return the result // Determine the difference value, return the result
return Math.abs(service.score(first, second) - 1.0); return Math.abs(service.score(first, second) - 1.0);
} }
public static boolean containsAny(String str, String... pieces) {
if (str == null) {
return false;
}
for (String piece : pieces) {
if (str.contains(piece)) {
return true;
}
}
return false;
}
} }

View File

@ -52,6 +52,20 @@ public class Log4JFilterTest {
assertThat(result, equalTo(Result.NEUTRAL)); assertThat(result, equalTo(Result.NEUTRAL));
} }
@Test
public void shouldNotFilterNonCommandLogEvent() {
// given
Message message = mockMessage(OTHER_COMMAND);
LogEvent event = Mockito.mock(LogEvent.class);
when(event.getMessage()).thenReturn(message);
// when
Result result = log4JFilter.filter(event);
// then
assertThat(result, equalTo(Result.NEUTRAL));
}
@Test @Test
public void shouldNotFilterLogEventWithNullMessage() { public void shouldNotFilterLogEventWithNullMessage() {
// given // given
@ -96,6 +110,15 @@ public class Log4JFilterTest {
assertThat(result, equalTo(Result.NEUTRAL)); assertThat(result, equalTo(Result.NEUTRAL));
} }
@Test
public void shouldNotFilterNonCommandStringMessage() {
// given / when
Result result = log4JFilter.filter(null, null, null, OTHER_COMMAND, new Object[0]);
// then
assertThat(result, equalTo(Result.NEUTRAL));
}
@Test @Test
public void shouldReturnNeutralForNullMessage() { public void shouldReturnNeutralForNullMessage() {
// given / when // given / when
@ -120,7 +143,7 @@ public class Log4JFilterTest {
@Test @Test
public void shouldNotFilterNullObjectParam() { public void shouldNotFilterNullObjectParam() {
// given / when // given / when
Result result = log4JFilter.filter(null, null, null, null, new Exception()); Result result = log4JFilter.filter(null, null, null, (Object) null, new Exception());
// then // then
assertThat(result, equalTo(Result.NEUTRAL)); assertThat(result, equalTo(Result.NEUTRAL));
@ -135,6 +158,15 @@ public class Log4JFilterTest {
assertThat(result, equalTo(Result.NEUTRAL)); assertThat(result, equalTo(Result.NEUTRAL));
} }
@Test
public void shouldNotFilterNonSensitiveCommand() {
// given / when
Result result = log4JFilter.filter(null, null, null, NORMAL_COMMAND, new Exception());
// then
assertThat(result, equalTo(Result.NEUTRAL));
}
// -------- // --------
// Test filter(Logger, Level, Marker, Message, Throwable) // Test filter(Logger, Level, Marker, Message, Throwable)
// -------- // --------
@ -162,10 +194,22 @@ public class Log4JFilterTest {
assertThat(result, equalTo(Result.NEUTRAL)); assertThat(result, equalTo(Result.NEUTRAL));
} }
@Test
public void shouldNotFilterNonCommandMessage() {
// given
Message message = mockMessage(OTHER_COMMAND);
// when
Result result = log4JFilter.filter(null, null, null, message, new Exception());
// then
assertThat(result, equalTo(Result.NEUTRAL));
}
@Test @Test
public void shouldNotFilterNullMessage() { public void shouldNotFilterNullMessage() {
// given / when // given / when
Result result = log4JFilter.filter(null, null, null, null, new Exception()); Result result = log4JFilter.filter(null, null, null, (Message) null, new Exception());
// then // then
assertThat(result, equalTo(Result.NEUTRAL)); assertThat(result, equalTo(Result.NEUTRAL));