上一篇文章笔者解读了 HashMap 的源码,正好趁热打铁,今天笔者抽了些时间通过 TDD 实现了一个精简版的 HashMap,经笔者测试,正常情况下效率略微逊于 HashMap。
public class SimpleHashMap<K, V> { public V put(K key, V value); public V get(K key); public V remove(K key); public boolean containsKey(K key); public int size(); public Iterator<V> values(); public void forEach(Consumer<? super K> action); } 复制代码
** * @author lyning */ public class SimpleHashMapTest { private SimpleHashMap<Integer, Integer> map; @BeforeEach public void setUp() throws Exception { // given this.map = new SimpleHashMap<>(); } /************ size test start **********/ @Test @DisplayName("given empty entries" + "when call size() " + "then return 0") public void size1() { // when int size = map.size(); // then assertThat(size).isZero(); } @Test @DisplayName("given multiple entries(contains duplicate key) " + "when call size() " + "then return correct size") public void size2() { // given SimpleHashMap<Integer, Integer> map = new SimpleHashMap<>(); map.put(1, 1); map.put(2, 2); map.put(3, 3); map.put(3, 4); map.put(3, 5); map.put(4, 4); map.put(5, 5); map.remove(1); map.remove(2); // when int size = map.size(); // then assertThat(size).isEqualTo(3); } @Test @DisplayName("given multiple entries(hash conflict) " + "when call size() " + "then return correct size") public void size3() { // given SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>(); map.put(new HashConflict(1), 1); map.put(new HashConflict(2), 2); map.put(new HashConflict(3), 3); map.put(new HashConflict(4), 4); map.put(new HashConflict(5), 5); map.remove(new HashConflict(5)); map.remove(new HashConflict(3)); // when int size = map.size(); // then assertThat(size).isEqualTo(3); } /************ size test end **********/ /************ put test start **********/ @Test @DisplayName("given empty entries " + "when put one entry " + "then return size 1") public void put1() { // when map.put(1, 1); // then assertThat(map.size()).isOne(); } @Test @DisplayName("given empty entries " + "when put two entries(duplicate key) " + "then return size 1") public void put2() { // when map.put(1, 1); map.put(1, 2); // then assertThat(map.size()).isEqualTo(1); } @Test @DisplayName("given empty entries " + "when put three entries " + "then return size 3") public void put3() { // when map.put(1, 1); map.put(2, 2); map.put(3, 3); // then assertThat(map.size()).isEqualTo(3); } @Test @DisplayName("should return value " + "when call put") public void put4() { // when Integer value = map.put(1, 1); // then assertThat(value).isEqualTo(1); } @Test @DisplayName("given empty entries " + "when put multiples entries(hash conflict) " + "then") public void put5() { // given SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>(); // when map.put(new HashConflict(1), 1); map.put(new HashConflict(2), 2); map.put(new HashConflict(3), 3); map.put(new HashConflict(3), 4); map.put(new HashConflict(3), 5); map.put(new HashConflict(4), 4); map.put(new HashConflict(5), 5); // then assertThat(Lists.newArrayList(map.values())).isEqualTo(Lists.list(1, 2, 5, 4, 5)); } @Test @DisplayName("should auto grow " + "when capacity exceed threshold") public void put6() { // given default threshold = 8 // when for (int i = 1; i <= 20; i++) { map.put(i, i); } // then assertThat(map.size()).isEqualTo(20); assertThat(map.get(20)).isEqualTo(20); } /************ put test end **********/ /************ get test start **********/ @Test @DisplayName("given empty entries" + "when get by null key" + "then return null") public void get1() { // when Integer value = map.get(null); // then assertThat(value).isNull(); } @Test @DisplayName("given empty entries" + "when get value by not exist key" + "then return null") public void get2() { // when Integer value = map.get(2); // then assertThat(value).isNull(); } @Test @DisplayName("given entry" + "when get value by not exist key" + "then return null") public void get3() { // given map.put(1, 1); // when Integer value = map.get(2); // then assertThat(value).isNull(); } @Test @DisplayName("given entry" + "when get value" + "then return value") public void get4() { // given map.put(1, 1); // when Integer value = map.get(1); // then assertThat(value).isEqualTo(1); } @Test @DisplayName("given multiple entries(hash conflict)" + "when get value by hash conflict key" + "then return value") public void get5() { // given SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>(); map.put(new HashConflict(1), 1); map.put(new HashConflict(2), 2); map.put(new HashConflict(3), 3); map.put(new HashConflict(3), 4); map.put(new HashConflict(3), 5); map.put(new HashConflict(4), 4); map.put(new HashConflict(5), 5); // when Integer value = map.get(new HashConflict(3)); // then assertThat(value).isEqualTo(5); } @Test @DisplayName("given multiple entries(hash conflict)" + "when get value by not exist hash conflict key" + "then return null") public void get6() { // given SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>(); map.put(new HashConflict(1), 1); map.put(new HashConflict(2), 2); map.put(new HashConflict(3), 3); map.put(new HashConflict(4), 4); map.put(new HashConflict(5), 5); // when Integer value = map.get(new HashConflict(6)); // then assertThat(value).isNull(); } /************ get test end **********/ /************ remove test start **********/ @Test @DisplayName("given empty entries" + "when remove by null key" + "then return null") public void remove1() { // when Integer value = map.remove(null); // then assertThat(value).isNull(); } @Test @DisplayName("given entry" + "when remove by null key" + "then return null") public void remove2() { // given map.put(1, 1); // when Integer value = map.remove(null); // then assertThat(value).isNull(); } @Test @DisplayName("given entry" + "when remove by key" + "then return value") public void remove3() { // given map.put(1, 1); // when int value = map.remove(1); // then assertThat(value).isEqualTo(1); } @Test @DisplayName("given entry" + "when remove by not exist key" + "then return null") public void remove4() { // given map.put(1, 1); // when Integer value = map.remove(2); // then assertThat(value).isNull(); } @Test @DisplayName("given multiple entries(hash conflict)" + "when remove by hash conflict key" + "then return value") public void remove5() { // given SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>(); map.put(new HashConflict(1), 1); map.put(new HashConflict(2), 2); map.put(new HashConflict(3), 3); map.put(new HashConflict(4), 4); map.put(new HashConflict(5), 5); // when Integer value = map.remove(new HashConflict(3)); // then assertThat(value).isEqualTo(3); assertThat(Lists.newArrayList(map.values())).isEqualTo(Lists.list(1, 2, 4, 5)); } /************ remove test end **********/ /************ values test start **********/ @Test @DisplayName("given empty entries" + "when call values" + "then return empty values") public void values1() { // when Iterable<Integer> values = map.values(); // then assertThat(values).isEmpty(); } @Test @DisplayName("given multiple entries" + "when call values" + "then return all values") public void values2() { // given map.put(1, 1); map.put(2, 2); map.put(3, 3); map.put(3, 4); map.put(4, 4); map.remove(4); // when Iterable<Integer> values = map.values(); // then assertThat(values.spliterator().estimateSize()).isEqualTo(3); assertThat(Lists.newArrayList(values)).isEqualTo(Lists.list(1, 2, 4)); } /************ values test end **********/ /************ containsKey test start **********/ @Test @DisplayName("given entry" + "when key exist" + "then return true") public void contains_key1() { // given map.put(1, 1); // when boolean result = map.containsKey(1); // then assertThat(result).isTrue(); } @Test @DisplayName("given entry" + "when key not exist" + "then return false") public void containsKey2() { // given map.put(1, 1); // when boolean result = map.containsKey(2); // then assertThat(result).isFalse(); } @Test @DisplayName("given multiple entries(hash conflict)" + "when call containsKey" + "then return correct result") public void containsKey3() { // given SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>(); map.put(new HashConflict(1), 1); map.put(new HashConflict(2), 2); map.put(new HashConflict(3), 3); map.put(new HashConflict(4), 4); map.put(new HashConflict(5), 5); // then assertThat(map.containsKey(new HashConflict(3))).isTrue(); assertThat(map.containsKey(new HashConflict(5))).isTrue(); assertThat(map.containsKey(new HashConflict(6))).isFalse(); } /************ containsKey test end **********/ /************ forEach test start **********/ @Test @DisplayName("given multiple entries" + "when call forEach" + "then pass") public void forEach1() { // given map.put(1, 1); map.put(2, 2); map.put(3, 3); map.put(4, 4); // when List<Integer> results = new ArrayList<>(); map.forEach((key) -> results.add(map.get(key))); // then assertThat(results).isEqualTo(Lists.list(1, 2, 3, 4)); } @Test @DisplayName("given multiple entries(hash conflict)" + "when call forEach" + "then pass") public void forEach2() { // given SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>(); map.put(new HashConflict(1), 1); map.put(new HashConflict(2), 2); map.put(new HashConflict(3), 3); map.put(new HashConflict(4), 4); map.put(new HashConflict(5), 5); // when List<Integer> results = new ArrayList<>(); map.forEach((key) -> results.add(map.get(key))); // then assertThat(results).isEqualTo(Lists.list(1, 2, 3, 4, 5)); } /************ forEach test end **********/ class HashConflict { private int field; HashConflict(int field) { this.field = field; } @Override public int hashCode() { return this.field <= 8 ? 1 : this.field; } @Override public boolean equals(Object obj) { return ((HashConflict) obj).field == this.field; } } } 复制代码
/** * @author lyning */ public class SimpleHashMap<K, V> { private static final int DEFAULT_INITIAL_CAPACITY = 16; private static final float DEFAULT_LOAD_FACTOR = 0.75f; private int size; private Bucket<K, V>[] table; private int threshold; public boolean containsKey(K key) { int hash = this.hash(key); int index = this.index(hash); Bucket<K, V> bucket = this.table[index]; return bucket != null && bucket.lookup(key) != null; } public void forEach(Consumer<K> action) { for (Bucket<K, V> bucket : this.table) { while (bucket != null) { action.accept(bucket.key); bucket = bucket.next; } } } public V get(K key) { if (this.tableEmpty()) { return null; } int hash = this.hash(key); int index = this.index(hash); return this.getVal(index, key); } public V put(K key, V value) { if (this.tableEmpty() || this.nearByThreshold()) { this.resize(); } int hash = this.hash(key); return this.putVal(key, value, hash); } public V remove(K key) { if (this.tableEmpty()) { return null; } int hash = this.hash(key); int index = this.index(hash); return this.removeVal(index, key); } public int size() { return this.size; } public Iterable<V> values() { if (this.tableEmpty()) { return new ArrayList<>(); } List<V> collections = new ArrayList<>(); this.collectValues(collections); return collections; } private void collectValues(List<V> collections) { for (Bucket<K, V> bucket : this.table) { while (bucket != null) { collections.add(bucket.value); bucket = bucket.next; } } } private Bucket<K, V> findBucket(int index) { return this.table[index]; } private V getVal(int index, K key) { Bucket<K, V> bucket = this.findBucket(index); if (Objects.isNull(bucket) || Objects.isNull(bucket = bucket.lookup(key))) { return null; } return bucket.value; } private void grow(int newCap) { if (this.tableEmpty()) { this.initTable(newCap); return; } this.table = this.rebuildTable(newCap); } private int hash(K key) { int hashcode; return key == null ? 0 : (hashcode = key.hashCode()) ^ (hashcode >>> 16); } private int index(int hash) { return hash & (this.table.length - 1); } private void initTable(int newCap) { this.table = new Bucket[newCap]; } private boolean nearByThreshold() { return this.size + 1 >= this.threshold; } private V putVal(K key, V value, int hash) { int index = this.index(hash); Bucket<K, V> bucket = this.table[index]; if (Objects.isNull(bucket)) { this.table[index] = new Bucket<>(hash, key, value); } else { Bucket<K, V> indexBucket = bucket.lookup(key); if (indexBucket != null) { indexBucket.value = value; return value; } bucket.putLast(new Bucket<>(hash, key, value)); } this.size += 1; return value; } private Bucket<K, V>[] rebuildTable(int newCap) { Bucket<K, V>[] oldTable = this.table; Bucket<K, V>[] newTable = new Bucket[newCap]; for (Bucket<K, V> bucket : oldTable) { if (bucket != null) { int index = this.index(bucket.hash); newTable[index] = bucket; } } return newTable; } private V removeVal(int index, K key) { Bucket<K, V> bucket = this.findBucket(index); Bucket<K, V> prev = null; while (bucket != null) { if (bucket.matchKey(key)) { if (Objects.isNull(prev)) { this.table[index] = null; } else { prev.next = bucket.next; } this.size -= 1; return bucket.value; } prev = bucket; bucket = bucket.next; } return null; } private void resize() { int oldCap = this.tableCapacity(); int newCap = 0; if (oldCap == 0) { oldCap = DEFAULT_INITIAL_CAPACITY; this.threshold = (int) (DEFAULT_INITIAL_CAPACITY * DEFAULT_LOAD_FACTOR); } else { newCap = oldCap << 1; this.threshold = this.threshold << 1; } if (newCap == 0) { newCap = oldCap; } this.grow(newCap); } private int tableCapacity() { return Objects.isNull(this.table) ? 0 : this.table.length; } private boolean tableEmpty() { return Objects.isNull(this.table); } static class Bucket<K, V> { Bucket<K, V> next; int hash; K key; V value; public Bucket(int hash, K key, V value) { this.hash = hash; this.key = key; this.value = value; } public Bucket<K, V> lookup(K key) { Bucket<K, V> bucket = this; while (bucket != null) { if (bucket.matchKey(key)) { return bucket; } bucket = bucket.next; } return null; } public boolean matchKey(K key) { return this.key == key || this.key.equals(key); } public void putLast(Bucket<K, V> bucket) { this.last().next = bucket; } private Bucket last() { Bucket<K, V> bucket = this; while (true) { if (Objects.isNull(bucket.next)) { return bucket; } bucket = bucket.next; } } } } 复制代码
其中最难的应属红黑树,真的是极其复杂,笔者用了一个小时还没能理解其中要领,索性使用链表替代了,等有时间再静下心来把未完成的任务消灭掉。
理解问题,Tasking,TDD(包含重构),这是笔者最近一直在遵守的规则,希望可以给您给来一点感悟。