Fix removing intervals in the interval tree.

This commit is contained in:
Kristian S. Stangeland 2012-11-21 05:50:23 +01:00
parent 4b19f8498b
commit 4b2f69c3c8

View File

@ -6,6 +6,7 @@ import java.util.NavigableMap;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
import com.google.common.base.Objects;
import com.google.common.collect.Range; import com.google.common.collect.Range;
import com.google.common.collect.Ranges; import com.google.common.collect.Ranges;
@ -32,24 +33,25 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
* Represents a range and a value in this interval tree. * Represents a range and a value in this interval tree.
*/ */
public class Entry implements Map.Entry<Range<TKey>, TValue> { public class Entry implements Map.Entry<Range<TKey>, TValue> {
private final Range<TKey> key;
private EndPoint left; private EndPoint left;
private EndPoint right; private EndPoint right;
Entry(Range<TKey> key, EndPoint left, EndPoint right) { Entry(EndPoint left, EndPoint right) {
if (left == null) if (left == null)
throw new IllegalAccessError("left cannot be NUll"); throw new IllegalAccessError("left cannot be NUll");
if (right == null) if (right == null)
throw new IllegalAccessError("right cannot be NUll"); throw new IllegalAccessError("right cannot be NUll");
if (left.key.compareTo(right.key) > 0)
throw new IllegalArgumentException(
"Left key (" + left.key + ") cannot be greater than the right key (" + right.key + ")");
this.key = key;
this.left = left; this.left = left;
this.right = right; this.right = right;
} }
@Override @Override
public Range<TKey> getKey() { public Range<TKey> getKey() {
return key; return Ranges.closed(left.key, right.key);
} }
@Override @Override
@ -66,6 +68,31 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
right.value = value; right.value = value;
return old; return old;
} }
@SuppressWarnings("rawtypes")
@Override
public boolean equals(Object obj) {
// Quick equality check
if (obj == this) {
return true;
} else if (obj instanceof AbstractIntervalTree.Entry) {
return Objects.equal(left.key, ((AbstractIntervalTree.Entry) obj).left.key) &&
Objects.equal(right.key, ((AbstractIntervalTree.Entry) obj).right.key) &&
Objects.equal(left.value, ((AbstractIntervalTree.Entry) obj).left.value);
} else {
return false;
}
}
@Override
public int hashCode() {
return Objects.hashCode(left.key, right.key, left.value);
}
@Override
public String toString() {
return String.format("Value %s at [%s, %s]", left.value, left.key, right.key);
}
} }
/** /**
@ -79,8 +106,12 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
// The value this range contains // The value this range contains
public TValue value; public TValue value;
public EndPoint(State state, TValue value) { // The key of this end point
public TKey key;
public EndPoint(State state, TKey key, TValue value) {
this.state = state; this.state = state;
this.key = key;
this.value = value; this.value = value;
} }
} }
@ -107,31 +138,46 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
checkBounds(lowerBound, upperBound); checkBounds(lowerBound, upperBound);
NavigableMap<TKey, EndPoint> range = bounds.subMap(lowerBound, true, upperBound, true); NavigableMap<TKey, EndPoint> range = bounds.subMap(lowerBound, true, upperBound, true);
boolean emptyRange = range.isEmpty(); EndPoint first = getNextEndPoint(lowerBound, true);
TKey first = !emptyRange ? range.firstKey() : null; EndPoint last = getPreviousEndPoint(upperBound, true);
TKey last = !emptyRange ? range.lastKey() : null;
// Used while resizing intervals
EndPoint previous = null;
EndPoint next = null;
Set<Entry> resized = new HashSet<Entry>(); Set<Entry> resized = new HashSet<Entry>();
Set<Entry> removed = new HashSet<Entry>(); Set<Entry> removed = new HashSet<Entry>();
// Remove the previous element too. A close end-point must be preceded by an OPEN end-point. // Remove the previous element too. A close end-point must be preceded by an OPEN end-point.
if (first != null && range.get(first).state == State.CLOSE) { if (first != null && first.state == State.CLOSE) {
TKey key = bounds.floorKey(first); previous = getPreviousEndPoint(first.key, false);
EndPoint removedPoint = removeIfNonNull(key);
// Add the interval back // Add the interval back
if (removedPoint != null && preserveDifference) { if (previous != null) {
resized.add(putUnsafe(key, decrementKey(lowerBound), removedPoint.value)); removed.add(getEntry(previous, first));
} }
} }
// Get the closing element too. // Get the closing element too.
if (last != null && range.get(last).state == State.OPEN) { if (last != null && last.state == State.OPEN) {
TKey key = bounds.ceilingKey(last); next = getNextEndPoint(last.key, false);
EndPoint removedPoint = removeIfNonNull(key);
if (removedPoint != null && preserveDifference) { if (next != null) {
resized.add(putUnsafe(incrementKey(upperBound), key, removedPoint.value)); removed.add(getEntry(last, next));
}
}
// Now remove both ranges
removeEntrySafely(previous, first);
removeEntrySafely(last, next);
// Add new resized intervals
if (preserveDifference) {
if (previous != null) {
resized.add(putUnsafe(previous.key, decrementKey(lowerBound), previous.value));
}
if (next != null) {
resized.add(putUnsafe(incrementKey(upperBound), next.key, next.value));
} }
} }
@ -140,7 +186,6 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
invokeEntryRemoved(removed); invokeEntryRemoved(removed);
if (preserveDifference) { if (preserveDifference) {
invokeEntryRemoved(resized);
invokeEntryAdded(resized); invokeEntryAdded(resized);
} }
@ -149,12 +194,30 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
return removed; return removed;
} }
// Helper /**
private EndPoint removeIfNonNull(TKey key) { * Retrieve the entry from a given set of end points.
if (key != null) { * @param left - leftmost end point.
return bounds.remove(key); * @param right - rightmost end point.
* @return The associated entry.
*/
protected Entry getEntry(EndPoint left, EndPoint right) {
if (left == null)
throw new IllegalArgumentException("left endpoint cannot be NULL.");
if (right == null)
throw new IllegalArgumentException("right endpoint cannot be NULL.");
// Make sure the order is correct
if (right.key.compareTo(left.key) < 0) {
return getEntry(right, left);
} else { } else {
return null; return new Entry(left, right);
}
}
private void removeEntrySafely(EndPoint left, EndPoint right) {
if (left != null && right != null) {
bounds.remove(left.key);
bounds.remove(right.key);
} }
} }
@ -165,7 +228,7 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
if (endPoint != null) { if (endPoint != null) {
endPoint.state = State.BOTH; endPoint.state = State.BOTH;
} else { } else {
endPoint = new EndPoint(state, value); endPoint = new EndPoint(state, key, value);
bounds.put(key, endPoint); bounds.put(key, endPoint);
} }
return endPoint; return endPoint;
@ -199,8 +262,7 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
EndPoint left = addEndPoint(lowerBound, value, State.OPEN); EndPoint left = addEndPoint(lowerBound, value, State.OPEN);
EndPoint right = addEndPoint(upperBound, value, State.CLOSE); EndPoint right = addEndPoint(upperBound, value, State.CLOSE);
Range<TKey> range = Ranges.closed(lowerBound, upperBound); return new Entry(left, right);
return new Entry(range, left, right);
} else { } else {
return null; return null;
} }
@ -261,11 +323,12 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
switch (entry.getValue().state) { switch (entry.getValue().state) {
case BOTH: case BOTH:
EndPoint point = entry.getValue(); EndPoint point = entry.getValue();
destination.add(new Entry(Ranges.singleton(entry.getKey()), point, point)); destination.add(new Entry(point, point));
break; break;
case CLOSE: case CLOSE:
Range<TKey> range = Ranges.closed(last.getKey(), entry.getKey()); if (last != null) {
destination.add(new Entry(range, last.getValue(), entry.getValue())); destination.add(new Entry(last.getValue(), entry.getValue()));
}
break; break;
case OPEN: case OPEN:
// We don't know the full range yet // We don't know the full range yet
@ -284,7 +347,7 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
public void putAll(AbstractIntervalTree<TKey, TValue> other) { public void putAll(AbstractIntervalTree<TKey, TValue> other) {
// Naively copy every range. // Naively copy every range.
for (Entry entry : other.entrySet()) { for (Entry entry : other.entrySet()) {
put(entry.key.lowerEndpoint(), entry.key.upperEndpoint(), entry.getValue()); put(entry.left.key, entry.right.key, entry.getValue());
} }
} }
@ -303,7 +366,7 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
} }
/** /**
* Get the end-point composite associated with this key. * Get the left-most end-point associated with this key.
* @param key - key to search for. * @param key - key to search for.
* @return The end point found, or NULL. * @return The end point found, or NULL.
*/ */
@ -311,22 +374,60 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
EndPoint ends = bounds.get(key); EndPoint ends = bounds.get(key);
if (ends != null) { if (ends != null) {
// This is a piece of cake // Always return the end point to the left
return ends; if (ends.state == State.CLOSE) {
} else { Map.Entry<TKey, EndPoint> left = bounds.floorEntry(decrementKey(key));
return left != null ? left.getValue() : null;
} else {
return ends;
}
} else {
// We need to determine if the point intersects with a range // We need to determine if the point intersects with a range
TKey left = bounds.floorKey(key); Map.Entry<TKey, EndPoint> left = bounds.floorEntry(key);
// We only need to check to the left // We only need to check to the left
if (left != null && bounds.get(left).state == State.OPEN) { if (left != null && left.getValue().state == State.OPEN) {
return bounds.get(left); return left.getValue();
} else { } else {
return null; return null;
} }
} }
} }
/**
* Get the previous end point of a given key.
* @param point - the point to search with.
* @param inclusive - whether or not to include the current point in the search.
* @return The previous end point of a given given key, or NULL if not found.
*/
protected EndPoint getPreviousEndPoint(TKey point, boolean inclusive) {
if (point != null) {
Map.Entry<TKey, EndPoint> previous = bounds.floorEntry(inclusive ? point : decrementKey(point));
if (previous != null)
return previous.getValue();
}
return null;
}
/**
* Get the next end point of a given key.
* @param point - the point to search with.
* @param inclusive - whether or not to include the current point in the search.
* @return The next end point of a given given key, or NULL if not found.
*/
protected EndPoint getNextEndPoint(TKey point, boolean inclusive) {
if (point != null) {
Map.Entry<TKey, EndPoint> next = bounds.ceilingEntry(inclusive ? point : incrementKey(point));
if (next != null)
return next.getValue();
}
return null;
}
private void invokeEntryAdded(Entry added) { private void invokeEntryAdded(Entry added) {
if (added != null) { if (added != null) {
onEntryAdded(added); onEntryAdded(added);