我们Threadlocal类的作用是提供一个线程间隔离,线程内部共享的数据。今天我们一起看看TreadLocal是怎么做到线程隔离的。
例子同样可以在 github 中找到
public static void testThreadLocal() { ThreadLocal<Integer> threadLocal = new ThreadLocal<>(); System.out.println(Thread.currentThread().getName() + ".set: " + -1); threadLocal.set(-1); ExecutorService executorService = Executors.newCachedThreadPool(); for (int i=1; i< 5; i++) { final Integer setValue = i; executorService.submit(() -> { System.out.println(Thread.currentThread().getName() + ".set: " + setValue); threadLocal.set(setValue); System.out.println(Thread.currentThread().getName() + ".get: " + threadLocal.get()); threadLocal.remove(); }); } System.out.println(Thread.currentThread().getName() + ".get: " + threadLocal.get()); threadLocal.remove(); } 复制代码
运行结果:
main.set: -1 pool-1-thread-1.set: 1 pool-1-thread-2.set: 2 pool-1-thread-2.get: 2 pool-1-thread-3.set: 3 pool-1-thread-1.get: 1 pool-1-thread-3.get: 3 pool-1-thread-4.set: 4 main.get: -1 pool-1-thread-4.get: 4 复制代码
代码中threadLocal对象看着也是被多线程竞争写入的,多个线程同时对他进行写入,但每个线程get到的都是正确的结果,为什么可以做到线程隔离呢?
我们先大致看看set方法
public void set(T value) { //得到当前线程 Thread t = Thread.currentThread(); //获取线程的ThreadLocalMap属性 ThreadLocalMap map = getMap(t); //map不为空时,set threadlocal 和value if (map != null) map.set(this, value); else createMap(t, value); //为空时创建一个map并将threadlocal 和value放入 } ThreadLocalMap getMap(Thread t) { return t.threadLocals; } 复制代码
原来threadLocal set value的时候,首先获得当前的线程对象,然后得到线程对象的ThreadLocalMap属性,然后将 threadlocal自身作为key , set到map中。图解一下Thread类和ThreadLocal类的关系。
原来Thread对象中有个ThreadLocalMap属性,ThreadLocalMap顾名思义就是存放ThreadLocal的map。所以虽然例子中看着threadLocal是竞争的写入,其实不是,都是在自己的线程对象中维护了一个threadLocal。
get方法也清晰了,就是从Thread对象里拿key为这个threadLocal对象的 value值呗!
public T get() { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) { //map中当前threadlocal作为key,拿到value的值,并返回 ThreadLocalMap.Entry e = map.getEntry(this); if (e != null) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } return setInitialValue(); } 复制代码
看到这ThreadLocal类的原理就说完了。
但Doug Lea 哪是一般人,源码中围绕着减少内存泄漏做的很多努力。下面我们就看看为什么会发生内存泄漏,以及怎么防止内存泄漏。
名词解释:
什么是内存泄漏?
为什么会发生内存泄漏?
怎么解决?
强软弱虚四种引用的定义及使用场景
为什么弱引用可以帮我们解决key上的内存泄漏呢?
static class ThreadLocalMap { //这里的源码可以看到map中Entry类继承了WeakReference类,key弱弱的引用ThreadLocal对象 static class Entry extends WeakReference<ThreadLocal<?>> { /** The value associated with this ThreadLocal. */ Object value; Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } } 。。。 } 复制代码
上面说了防止value对象的内存泄漏,使用过期检查和清理,以及提供remove方法,这里是ThreadLocal最复杂的一部分,我们详细看看吧。再看set方法。
public void set(T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); else createMap(t, value); } 复制代码
调用map的set方法,并不是常用的put方法,看来有不是简单的存值啊
private void set(ThreadLocal<?> key, Object value) { Entry[] tab = table; int len = tab.length; //根据hashcode和Entry数组的长度,计算下标值 int i = key.threadLocalHashCode & (len-1); //根据得到的下标值找,遇到hash冲突就向后移动一个,直到找到entry是空的节点 for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get(); //遇到key相等,说明key之前存过,替换value值就行了 if (k == key) { e.value = value; return; } //如果k是空,说明这里存的是一个过期数据,进行替换 //这里会进行过期数据的清理 if (k == null) { replaceStaleEntry(key, value, i); return; } } //前面的位置都被占用着,新建一个Entry放在i上 tab[i] = new Entry(key, value); //将map中的size加1 int sz = ++size; //扫描清理一次过期数据,如果还是达到扩容的阈值了,进行扩容 //这里也会进行过期数据的清理 if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash();//先进行一次全量的扫描清理过期数据,还是快接近阈值就扩容 } 复制代码
replaceStaleEntry方法
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) { Entry[] tab = table; int len = tab.length; Entry e; // 向前扫描第一个过期的节点 int slotToExpunge = staleSlot; for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len)) if (e.get() == null) slotToExpunge = i; //标识第一个需要清除的位置 // 向后遍历 for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); // 向后找到了key,把value进行替换 if (k == key) { e.value = value; //i节点设为过期数据 tab[i] = tab[staleSlot]; //之前的过期节点赋值为key的Entry数据 tab[staleSlot] = e; // 如果staleSlot就是第一个过期数据(上面的for进行了一次向前扫描),把过期下标设为i if (slotToExpunge == staleSlot) slotToExpunge = i; //expungeStaleEntry方法清理过期节点,并进行整理(因为存在hash冲突后移,可能某些节点的hash位置空出来了,放入对应的自己的位置,后面会有图解说明) //cleanSomeSlots会清理Log n次,为了效率不能每次都全量扫描 cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); return; } // staleSlot是第一个过期数据,把slotToExpunge标记为i 说明有其他过期节点 if (k == null && slotToExpunge == staleSlot) slotToExpunge = i; } // 过期位置赋值为用key value构建的新Entry tab[staleSlot].value = null; tab[staleSlot] = new Entry(key, value); // 如果slotToExpunge != staleSlot说明有其他节点也过期了,继续清理一些其他过期节点 //和for循环中slotToExpunge = i 呼应 if (slotToExpunge != staleSlot) cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); } 复制代码
replaceStaleEntry方法顾名思义用当前的key value构造一个entry替换这个过期的Entry节点。但因为存在hash冲突后移,并不能单纯的直接替换,所以做了上面的这么多事情
//清理下标为staleSlot的过期节点 private int expungeStaleEntry(int staleSlot) { Entry[] tab = table; int len = tab.length; // 过期节点设为空 tab[staleSlot].value = null; //help gc tab[staleSlot] = null; size--; // 清理的过程中可能之前因为存在hash冲突后移的节点,位置恰好是staleSlot,staleSlot空出来了,节点应该放在正确的位置。 Entry e; int i; //向后扫描 for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); //节点key为空,说明已过期,直接干掉 if (k == null) { e.value = null; tab[i] = null; size--; } else { //计算节点的hash值,确定在数组中的位置 int h = k.threadLocalHashCode & (len - 1); //如果节点不应该放在i位置上,则可能放在h到i中间的位置上 if (h != i) { tab[i] = null; // 从h位置一直后移,找到第一个为空的位置,放在正确的位置上(hash冲突后移的逻辑) while (tab[h] != null) h = nextIndex(h, len); tab[h] = e; } } } return i; } 复制代码
cleanSomeSlots 清理部分过期Entry
//进行log n次扫描 //如果没有发现过期节点返回false(没有节点移动) //如果发现了过期节点,清理过期节点,n重置为table数组的length,再次扫描log n次 private boolean cleanSomeSlots(int i, int n) { boolean removed = false; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null) { n = len; removed = true; i = expungeStaleEntry(i); } } while ( (n >>>= 1) != 0); return removed; } 复制代码
上面是set方法中对防止内存泄漏的一些努力,每次set都会对一些过期节点进行清除整理,这一部分也是较难理解的。我们放一张图,方便大家理解。
我们看看get方法,会发现也对防止内存泄漏做了一些努力
public T get() { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) { ThreadLocalMap.Entry e = map.getEntry(this); if (e != null) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } //当map为空时,创建一个map并存入key:this value:null 返回null return setInitialValue(); } private Entry getEntry(ThreadLocal<?> key) { int i = key.threadLocalHashCode & (table.length - 1); Entry e = table[i]; if (e != null && e.get() == key) return e; else //这里是重点,当hash的位置被其他节点占用了,可能是冲突后移了,可能就是没有 return getEntryAfterMiss(key, i, e); } 复制代码
getEntryAfterMiss 方法
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) { Entry[] tab = table; int len = tab.length; //向后找,直到找到entry节点是空时,返回null while (e != null) { ThreadLocal<?> k = e.get(); //k正好是我们要找的数据,返回节点entry if (k == key) return e; //如果k是空,说明是过期节点,清除该过期节点 if (k == null) expungeStaleEntry(i); else i = nextIndex(i, len); e = tab[i]; } return null; } 复制代码
remove方法
//手动清理threadLocal private void remove(ThreadLocal<?> key) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { //得到key对应的entry if (e.get() == key) { e.clear(); //将referent赋值为null expungeStaleEntry(i); //清理该节点 return; } } } 复制代码
你可能有疑问,既然set和get方法都会移除过期节点,还要我们remove吗?
强烈建议大家使用完threadlocal后一定要调用remove方法。
填坑记
我们曾经一个项目中使用了threadlocal,业务上是这样的