Fix removing intervals in the interval tree.

This commit is contained in:
Kristian S. Stangeland 2012-11-21 05:50:23 +01:00
parent 4b19f8498b
commit 4b2f69c3c8

View File

@ -6,6 +6,7 @@ import java.util.NavigableMap;
import java.util.Set;
import java.util.TreeMap;
import com.google.common.base.Objects;
import com.google.common.collect.Range;
import com.google.common.collect.Ranges;
@ -32,24 +33,25 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
* Represents a range and a value in this interval tree.
*/
public class Entry implements Map.Entry<Range<TKey>, TValue> {
private final Range<TKey> key;
private EndPoint left;
private EndPoint right;
Entry(Range<TKey> key, EndPoint left, EndPoint right) {
Entry(EndPoint left, EndPoint right) {
if (left == null)
throw new IllegalAccessError("left cannot be NUll");
if (right == null)
throw new IllegalAccessError("right cannot be NUll");
if (left.key.compareTo(right.key) > 0)
throw new IllegalArgumentException(
"Left key (" + left.key + ") cannot be greater than the right key (" + right.key + ")");
this.key = key;
this.left = left;
this.right = right;
}
@Override
public Range<TKey> getKey() {
return key;
return Ranges.closed(left.key, right.key);
}
@Override
@ -66,6 +68,31 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
right.value = value;
return old;
}
@SuppressWarnings("rawtypes")
@Override
public boolean equals(Object obj) {
// Quick equality check
if (obj == this) {
return true;
} else if (obj instanceof AbstractIntervalTree.Entry) {
return Objects.equal(left.key, ((AbstractIntervalTree.Entry) obj).left.key) &&
Objects.equal(right.key, ((AbstractIntervalTree.Entry) obj).right.key) &&
Objects.equal(left.value, ((AbstractIntervalTree.Entry) obj).left.value);
} else {
return false;
}
}
@Override
public int hashCode() {
return Objects.hashCode(left.key, right.key, left.value);
}
@Override
public String toString() {
return String.format("Value %s at [%s, %s]", left.value, left.key, right.key);
}
}
/**
@ -79,8 +106,12 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
// The value this range contains
public TValue value;
public EndPoint(State state, TValue value) {
// The key of this end point
public TKey key;
public EndPoint(State state, TKey key, TValue value) {
this.state = state;
this.key = key;
this.value = value;
}
}
@ -107,31 +138,46 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
checkBounds(lowerBound, upperBound);
NavigableMap<TKey, EndPoint> range = bounds.subMap(lowerBound, true, upperBound, true);
boolean emptyRange = range.isEmpty();
TKey first = !emptyRange ? range.firstKey() : null;
TKey last = !emptyRange ? range.lastKey() : null;
EndPoint first = getNextEndPoint(lowerBound, true);
EndPoint last = getPreviousEndPoint(upperBound, true);
// Used while resizing intervals
EndPoint previous = null;
EndPoint next = null;
Set<Entry> resized = new HashSet<Entry>();
Set<Entry> removed = new HashSet<Entry>();
// Remove the previous element too. A close end-point must be preceded by an OPEN end-point.
if (first != null && range.get(first).state == State.CLOSE) {
TKey key = bounds.floorKey(first);
EndPoint removedPoint = removeIfNonNull(key);
if (first != null && first.state == State.CLOSE) {
previous = getPreviousEndPoint(first.key, false);
// Add the interval back
if (removedPoint != null && preserveDifference) {
resized.add(putUnsafe(key, decrementKey(lowerBound), removedPoint.value));
if (previous != null) {
removed.add(getEntry(previous, first));
}
}
// Get the closing element too.
if (last != null && range.get(last).state == State.OPEN) {
TKey key = bounds.ceilingKey(last);
EndPoint removedPoint = removeIfNonNull(key);
if (last != null && last.state == State.OPEN) {
next = getNextEndPoint(last.key, false);
if (removedPoint != null && preserveDifference) {
resized.add(putUnsafe(incrementKey(upperBound), key, removedPoint.value));
if (next != null) {
removed.add(getEntry(last, next));
}
}
// Now remove both ranges
removeEntrySafely(previous, first);
removeEntrySafely(last, next);
// Add new resized intervals
if (preserveDifference) {
if (previous != null) {
resized.add(putUnsafe(previous.key, decrementKey(lowerBound), previous.value));
}
if (next != null) {
resized.add(putUnsafe(incrementKey(upperBound), next.key, next.value));
}
}
@ -140,7 +186,6 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
invokeEntryRemoved(removed);
if (preserveDifference) {
invokeEntryRemoved(resized);
invokeEntryAdded(resized);
}
@ -149,12 +194,30 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
return removed;
}
// Helper
private EndPoint removeIfNonNull(TKey key) {
if (key != null) {
return bounds.remove(key);
/**
* Retrieve the entry from a given set of end points.
* @param left - leftmost end point.
* @param right - rightmost end point.
* @return The associated entry.
*/
protected Entry getEntry(EndPoint left, EndPoint right) {
if (left == null)
throw new IllegalArgumentException("left endpoint cannot be NULL.");
if (right == null)
throw new IllegalArgumentException("right endpoint cannot be NULL.");
// Make sure the order is correct
if (right.key.compareTo(left.key) < 0) {
return getEntry(right, left);
} else {
return null;
return new Entry(left, right);
}
}
private void removeEntrySafely(EndPoint left, EndPoint right) {
if (left != null && right != null) {
bounds.remove(left.key);
bounds.remove(right.key);
}
}
@ -165,7 +228,7 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
if (endPoint != null) {
endPoint.state = State.BOTH;
} else {
endPoint = new EndPoint(state, value);
endPoint = new EndPoint(state, key, value);
bounds.put(key, endPoint);
}
return endPoint;
@ -199,8 +262,7 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
EndPoint left = addEndPoint(lowerBound, value, State.OPEN);
EndPoint right = addEndPoint(upperBound, value, State.CLOSE);
Range<TKey> range = Ranges.closed(lowerBound, upperBound);
return new Entry(range, left, right);
return new Entry(left, right);
} else {
return null;
}
@ -261,11 +323,12 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
switch (entry.getValue().state) {
case BOTH:
EndPoint point = entry.getValue();
destination.add(new Entry(Ranges.singleton(entry.getKey()), point, point));
destination.add(new Entry(point, point));
break;
case CLOSE:
Range<TKey> range = Ranges.closed(last.getKey(), entry.getKey());
destination.add(new Entry(range, last.getValue(), entry.getValue()));
if (last != null) {
destination.add(new Entry(last.getValue(), entry.getValue()));
}
break;
case OPEN:
// We don't know the full range yet
@ -284,7 +347,7 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
public void putAll(AbstractIntervalTree<TKey, TValue> other) {
// Naively copy every range.
for (Entry entry : other.entrySet()) {
put(entry.key.lowerEndpoint(), entry.key.upperEndpoint(), entry.getValue());
put(entry.left.key, entry.right.key, entry.getValue());
}
}
@ -303,7 +366,7 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
}
/**
* Get the end-point composite associated with this key.
* Get the left-most end-point associated with this key.
* @param key - key to search for.
* @return The end point found, or NULL.
*/
@ -311,22 +374,60 @@ public abstract class AbstractIntervalTree<TKey extends Comparable<TKey>, TValue
EndPoint ends = bounds.get(key);
if (ends != null) {
// This is a piece of cake
return ends;
} else {
// Always return the end point to the left
if (ends.state == State.CLOSE) {
Map.Entry<TKey, EndPoint> left = bounds.floorEntry(decrementKey(key));
return left != null ? left.getValue() : null;
} else {
return ends;
}
} else {
// We need to determine if the point intersects with a range
TKey left = bounds.floorKey(key);
Map.Entry<TKey, EndPoint> left = bounds.floorEntry(key);
// We only need to check to the left
if (left != null && bounds.get(left).state == State.OPEN) {
return bounds.get(left);
if (left != null && left.getValue().state == State.OPEN) {
return left.getValue();
} else {
return null;
}
}
}
/**
* Get the previous end point of a given key.
* @param point - the point to search with.
* @param inclusive - whether or not to include the current point in the search.
* @return The previous end point of a given given key, or NULL if not found.
*/
protected EndPoint getPreviousEndPoint(TKey point, boolean inclusive) {
if (point != null) {
Map.Entry<TKey, EndPoint> previous = bounds.floorEntry(inclusive ? point : decrementKey(point));
if (previous != null)
return previous.getValue();
}
return null;
}
/**
* Get the next end point of a given key.
* @param point - the point to search with.
* @param inclusive - whether or not to include the current point in the search.
* @return The next end point of a given given key, or NULL if not found.
*/
protected EndPoint getNextEndPoint(TKey point, boolean inclusive) {
if (point != null) {
Map.Entry<TKey, EndPoint> next = bounds.ceilingEntry(inclusive ? point : incrementKey(point));
if (next != null)
return next.getValue();
}
return null;
}
private void invokeEntryAdded(Entry added) {
if (added != null) {
onEntryAdded(added);