ThreadLocal提供了线程安全的数据存储和访问方式,利用不带key的get和set方法,居然能做到线程之间隔离,非常神奇。
比如
ThreadLocal<String> threadLocal = new ThreadLocal<>();
in thread 1
//in thread1 treadLocal.set("value1"); ..... //value的值是value1 String value = threadLocal.get();
in thread 2
//in thread2 treadLocal.set("value2"); ..... //value的值是value2 String value = threadLocal.get();
不论thread1和thread2是不是同时执行,都不会有线程安全问题,我们来测试一下。
开10个线程,每个线程内都对同一个ThreadLocal对象set不同的值,会发现ThreadLocal在每个线程内部get出来的值,只会是自己线程内set进去的值,不会被别的线程影响。
static void testUsage() throws InterruptedException { Utils.println("-------------testUsage-------------------"); ThreadLocal<Long> threadLocal = new ThreadLocal<>(); AtomicBoolean threadSafe = new AtomicBoolean(true); int count = 10; CountDownLatch countDownLatch = new CountDownLatch(count); Random random = new Random(736832); for (int i = 0; i < count; i ++){ new Thread(() -> { try { //生成一个随机数 Long value = System.nanoTime() + random.nextInt(); threadLocal.set(value); Thread.sleep(1000); Long value2 = threadLocal.get(); if (!value.equals(value2)) { //get和set的value不一致,说明被别的线程修改了,但这是不可能出现的 threadSafe.set(false); Utils.println("thread unsafe, this could not be happen!"); } } catch (InterruptedException e) { }finally { countDownLatch.countDown(); } }).start(); } countDownLatch.await(); Utils.println("all thread done, and threadSafe is " + threadSafe.get()); Utils.println("------------------------------------------"); }
输出:
-------------testUsage------------------ all thread done, and threadSafe is true -----------------------------------------
翻开ThreadLocal的源码,会发现ThreadLocal只是一个空壳子,它并不存储具体的value,而是利用当前线程(Thread.currentThread())的threadLocalMap来存储value,key就是这个threadLocal对象本身。
public void set(T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); else createMap(t, value); } ThreadLocalMap getMap(Thread t) { return t.threadLocals; }
Thread的threadLocals字段是ThreadLocalMap类型(你可以简单理解为一个key value的Map),key是ThreadLocal对象,value是我们在外层设置的值
这就相当于:
Thread.currentThread().threadLocals.set(threadLocal1, "value1"); ..... //value的值是value1 String value = Thread.currentThread().threadLocals.get(threadLocal1);
因为每个Thread都是不同的对象,所以他们的threadLocals也是不同的map,threadLocal在不同的线程里工作时,实际上是从不同的map里get/set,这也就是线程安全的原因了,了解到这一点就差不多了。
如果继续翻ThreadLocalMap的源码,会发现它有个字段table,是Entry类型的数组。
我们不妨写段代码,把ThreadLocalMap的结构输出出来。
由于Thread.threadLocals和ThreadLocalMap类不是public的,我们只有通过反射来获取它的值。反射的代码如下(如果嫌长可以不看,直接看输出):
static Object getThreadLocalMap(Thread thread) throws NoSuchFieldException, IllegalAccessException { //get thread.threadLocals Field threadLocals = Thread.class.getDeclaredField("threadLocals"); threadLocals.setAccessible(true); return threadLocals.get(thread); } static void printThreadLocalMap(Object threadLocalMap) throws NoSuchFieldException, IllegalAccessException { String threadName = Thread.currentThread().getName(); if(threadLocalMap == null){ Utils.println("threadMap is null, threadName:" + threadName); return; } Utils.println(threadName); //get threadLocalMap.table Field tableField = threadLocalMap.getClass().getDeclaredField("table"); tableField.setAccessible(true); Object[] table = (Object[])tableField.get(threadLocalMap); Utils.println("----threadLocals (ThreadLocalMap), table.length = " + table.length); for (int i = 0; i < table.length; i ++){ WeakReference<ThreadLocal<?>> entry = (WeakReference<ThreadLocal<?>>)table[i]; printEntry(entry, i); } } static void printEntry(WeakReference<ThreadLocal<?>> entry, int i) throws NoSuchFieldException, IllegalAccessException { if(entry == null){ Utils.println("--------table[" + i + "] -> null"); return; } ThreadLocal key = entry.get(); //get entry.value Field valueField = entry.getClass().getDeclaredField("value"); valueField.setAccessible(true); Object value = valueField.get(entry); Utils.println("--------table[" + i + "] -> entry key = " + key + ", value = " + value); }
测试代码:
static void testStructure() throws InterruptedException { Utils.println("-------------testStructure----------------"); ThreadLocal<String> threadLocal1 = new ThreadLocal<>(); ThreadLocal<String> threadLocal2 = new ThreadLocal<>(); Thread thread1 = new Thread(() -> { threadLocal1.set("threadLocal1-value"); threadLocal2.set("threadLocal2-value"); try { Object threadLocalMap = getThreadLocalMap(Thread.currentThread()); printThreadLocalMap(threadLocalMap); } catch (NoSuchFieldException | IllegalAccessException e) { e.printStackTrace(); } }, "thread1"); thread1.start(); //wait thread1 done thread1.join(); Thread thread2 = new Thread(() -> { threadLocal1.set("threadLocal1-value"); try { Object threadLocalMap = getThreadLocalMap(Thread.currentThread()); printThreadLocalMap(threadLocalMap); } catch (NoSuchFieldException | IllegalAccessException e) { e.printStackTrace(); } }, "thread2"); thread2.start(); thread2.join(); Utils.println("------------------------------------------"); }
我们在创建了两个ThreadLocal的对象threadLocal1和threadLocal2,在线程1里为这两个对象设置值,在线程2里只为threadLocal1设置值。然后分别打印出这两个线程的threadLocalMap。
输出结果为:
-------------testStructure---------------- thread1 ----threadLocals (ThreadLocalMap), table.length = 16 --------table[0] -> null --------table[1] -> entry key = java.lang.ThreadLocal@33baa315, value = threadLocal2-value --------table[2] -> null --------table[3] -> null --------table[4] -> null --------table[5] -> null --------table[6] -> null --------table[7] -> null --------table[8] -> null --------table[9] -> null --------table[10] -> entry key = java.lang.ThreadLocal@4d42db5c, value = threadLocal1-value --------table[11] -> null --------table[12] -> null --------table[13] -> null --------table[14] -> null --------table[15] -> null thread2 ----threadLocals (ThreadLocalMap), table.length = 16 --------table[0] -> null --------table[1] -> null --------table[2] -> null --------table[3] -> null --------table[4] -> null --------table[5] -> null --------table[6] -> null --------table[7] -> null --------table[8] -> null --------table[9] -> null --------table[10] -> entry key = java.lang.ThreadLocal@4d42db5c, value = threadLocal1-value --------table[11] -> null --------table[12] -> null --------table[13] -> null --------table[14] -> null --------table[15] -> null ------------------------------------------
从结果上可以看出:
查看Entry的源码,会发现Entry继承自WeakReference:
static class Entry extends WeakReference<ThreadLocal<?>> { /** The value associated with this ThreadLocal. */ Object value; Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } }
构造函数里把key传给了super,也就是说,ThreadLocalMap中对key的引用,是WeakReference的。
Weak reference objects, which do not prevent their referents from being
made finalizable, finalized, and then reclaimed. Weak references are most
often used to implement canonicalizing mappings.
通俗点解释:
当一个对象仅仅被weak reference(弱引用), 而没有任何其他strong reference(强引用)的时候, 不论当前的内存空间是否足够,当GC运行的时候, 这个对象就会被回收。
看不明白没关系,还是写代码测试一下什么是WeakReference吧...
static void testWeakReference(){ Object obj1 = new Object(); Object obj2 = new Object(); WeakReference<Object> obj1WeakRef = new WeakReference<>(obj1); WeakReference<Object> obj2WeakRf = new WeakReference<>(obj2); //obj32StrongRef是强引用 Object obj2StrongRef = obj2; Utils.println("before gc: obj1WeakRef = " + obj1WeakRef.get() + ", obj2WeakRef = " + obj2WeakRf.get() + ", obj2StrongRef = " + obj2StrongRef); //把obj1和obj2设为null obj1 = null; obj2 = null; //强制gc forceGC(); Utils.println("after gc: obj1WeakRef = " + obj1WeakRef.get() + ", obj2WeakRef = " + obj2WeakRf.get() + ", obj2StrongRef = " + obj2StrongRef); }
结果输出:
before gc: obj1WeakRef = java.lang.Object@4554617c, obj2WeakRef = java.lang.Object@74a14482, obj2StrongRef = java.lang.Object@74a14482 after gc: obj1WeakRef = null, obj2WeakRef = java.lang.Object@74a14482, obj2StrongRef = java.lang.Object@74a14482
从结果上可以看出:
那么,ThreadLocalMap中对key的引用,为什么是WeakReference的呢?
大部分情况下,线程不会频繁的创建和销毁,一般都会用线程池。所以线程对象一般不会被清除,线程的threadLocalMap就一直存在。
如果key对ThreadLocal是强引用,那么key永远不会被回收,即使我们程序里再也不用它了。
但是key是弱引用的话,情况就会得到改善:只要没有指向threadLocal的强引用了,这个ThreadLocal对象就会被清理。
我们还是写代码测试一下吧。
/** * 测试ThreadLocal对象什么时候被回收 * @throws InterruptedException */ static void testGC() throws InterruptedException { Utils.println("-----------------testGC-------------------"); Thread thread1 = new Thread(() -> { ThreadLocal<String> threadLocal1 = new ThreadLocal<>(); ThreadLocal<String> threadLocal2 = new ThreadLocal<>(); threadLocal1.set("threadLocal1-value"); threadLocal2.set("threadLocal2-value"); try { Object threadLocalMap = getThreadLocalMap(Thread.currentThread()); Utils.println("print threadLocalMap before gc"); printThreadLocalMap(threadLocalMap); //set threadLocal1 unreachable threadLocal1 = null; forceGC(); Utils.println("print threadLocalMap after gc"); printThreadLocalMap(threadLocalMap); } catch (NoSuchFieldException | IllegalAccessException e) { e.printStackTrace(); } }, "thread1"); thread1.start(); thread1.join(); Utils.println("------------------------------------------"); }
我们在一个线程里为两个ThreadLocal对象赋值,最后把其中一个对象的强引用移除,gc后打印当前线程的threadLocalMap。
输出结果如下:
-----------------testGC------------------- print threadLocalMap before gc thread1 ----threadLocals (ThreadLocalMap), table.length = 16 --------table[0] -> null --------table[1] -> entry key = java.lang.ThreadLocal@7bf9cebf, value = threadLocal2-value --------table[2] -> null --------table[3] -> null --------table[4] -> null --------table[5] -> null --------table[6] -> null --------table[7] -> null --------table[8] -> null --------table[9] -> null --------table[10] -> entry key = java.lang.ThreadLocal@56342d38, value = threadLocal1-value --------table[11] -> null --------table[12] -> null --------table[13] -> null --------table[14] -> null --------table[15] -> null print threadLocalMap after gc thread1 ----threadLocals (ThreadLocalMap), table.length = 16 --------table[0] -> null --------table[1] -> entry key = java.lang.ThreadLocal@7bf9cebf, value = threadLocal2-value --------table[2] -> null --------table[3] -> null --------table[4] -> null --------table[5] -> null --------table[6] -> null --------table[7] -> null --------table[8] -> null --------table[9] -> null --------table[10] -> entry key = null, value = threadLocal1-value --------table[11] -> null --------table[12] -> null --------table[13] -> null --------table[14] -> null --------table[15] -> null ------------------------------------------
从输出结果可以看到,当我们把threadLocal1的强引用移除并gc之后,table[10]的key变成了null,说明threadLocal1这个对象被回收了;threadLocal2的强引用还在,所以table[1]的key不是null,没有被回收。
但是你发现没有,table[10]的key虽然是null了,但value还活着! table[10]这个entry对象,也活着!
是的,因为只有key是WeakReference....
通过查看ThreadLocal的源码,发现在ThreadLocal对象的get/set/remove方法执行时,都有机会清除掉map中已经无用的entry。
最容易验证清除无用entry的场景分别是:
set: 当一个新的threadLocal对象(没有set过value)发生set调用时,会在map中加入新的entry,此时有机会清除掉无用的entry,清除的逻辑是:
还有其他场景,但不好验证,这里就不提了。
ThreadLocal源码就不贴了,贴了也讲不明白,相关逻辑在setInitialValue、cleanSomeSlots、expungeStaleEntries、rehash、resize等方法里。
在我们写代码验证entry回收逻辑之前,还需要简单的提一下ThreadLocalMap的hash算法。
每个ThreadLocal对象,都有一个threadLocalHashCode变量,在加入ThreadLocalMap的时候,根据这个threadLocalHashCode的值,对entry数组的长度取余(hash & (len - 1)),余数作为下标。
那么threadLocalHashCode是怎么计算的呢?看源码:
public class ThreadLocal<T>{ private final int threadLocalHashCode = nextHashCode(); private static AtomicInteger nextHashCode = new AtomicInteger(); private static final int HASH_INCREMENT = 0x61c88647; private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT); } ... }
ThreadLocal类维护了一个全局静态字段nextHashCode,每new一个ThreadLocal对象,nextHashCode都会递增0x61c88647,作为下一个ThreadLocal对象的threadLocalHashCode。
这个0x61c88647,是个神奇的数字,只要以它为递增值,那么和2的N次方取余时,在有限的次数内不会发生重复。
比如和16取余,那么在16次递增内,不会发生重复。还是写代码验证一下吧。
int hashCode = 0; int HASH_INCREMENT = 0x61c88647; int length = 16; for(int i = 0; i < length ; i ++){ int h = hashCode & (length - 1); hashCode += HASH_INCREMENT; System.out.println("h = " + h + ", i = " + i); }
输出结果为:
h = 0, i = 0 h = 7, i = 1 h = 14, i = 2 h = 5, i = 3 h = 12, i = 4 h = 3, i = 5 h = 10, i = 6 h = 1, i = 7 h = 8, i = 8 h = 15, i = 9 h = 6, i = 10 h = 13, i = 11 h = 4, i = 12 h = 11, i = 13 h = 2, i = 14 h = 9, i = 15
你看,h的值在16次递增内,没有发生重复。 但是要记住,2的N次方作为长度才会有这个效果,这也解释了为什么ThreadLocalMap的entry数组初始长度是16,每次都是2倍的扩容。
为了验证出结果,我们需要先给ThreadLocal的nextHashCode重置一个初始值,这样在测试的时候,每个threadLocal的数组下标才会按照我们设计的思路走。
static void resetNextHashCode() throws NoSuchFieldException, IllegalAccessException { Field nextHashCodeField = ThreadLocal.class.getDeclaredField("nextHashCode"); nextHashCodeField.setAccessible(true); nextHashCodeField.set(null, new AtomicInteger(1253254570)); }
然后在测试代码里,我们先调用resetNextHashCode方法,然后加两个ThreadLocal对象并set值,gc前把强引用去除,gc后再new两个新的theadLocal对象,分别调用他们的get和set方法。
在每个关键点打印出threadLocalMap做比较。
static void testExpungeSomeEntriesWhenGetOrSet() throws InterruptedException { Utils.println("----------testExpungeStaleEntries----------"); Thread thread1 = new Thread(() -> { try { resetNextHashCode(); //注意,这里必须有两个ThreadLocal,才能验证出threadLocal1被清理 ThreadLocal<String> threadLocal1 = new ThreadLocal<>(); ThreadLocal<String> threadLocal2 = new ThreadLocal<>(); threadLocal1.set("threadLocal1-value"); threadLocal2.set("threadLocal2-value"); Object threadLocalMap = getThreadLocalMap(Thread.currentThread()); //set threadLocal1 unreachable threadLocal1 = null; threadLocal2 = null; forceGC(); Utils.println("print threadLocalMap after gc"); printThreadLocalMap(threadLocalMap); ThreadLocal<String> newThreadLocal1 = new ThreadLocal<>(); newThreadLocal1.get(); Utils.println("print threadLocalMap after call a new newThreadLocal1.get"); printThreadLocalMap(threadLocalMap); ThreadLocal<String> newThreadLocal2 = new ThreadLocal<>(); newThreadLocal2.set("newThreadLocal2-value"); Utils.println("print threadLocalMap after call a new newThreadLocal2.set"); printThreadLocalMap(threadLocalMap); } catch (NoSuchFieldException | IllegalAccessException e) { e.printStackTrace(); } }, "thread1"); thread1.start(); thread1.join(); Utils.println("------------------------------------------"); }
程序输出结果为:
----------testExpungeStaleEntries---------- print threadLocalMap after gc thread1 ----threadLocals (ThreadLocalMap), table.length = 16 --------table[0] -> null --------table[1] -> entry key = null, value = threadLocal2-value --------table[2] -> null --------table[3] -> null --------table[4] -> null --------table[5] -> null --------table[6] -> null --------table[7] -> null --------table[8] -> null --------table[9] -> null --------table[10] -> entry key = null, value = threadLocal1-value --------table[11] -> null --------table[12] -> null --------table[13] -> null --------table[14] -> null --------table[15] -> null print threadLocalMap after call a new newThreadLocal1.get thread1 ----threadLocals (ThreadLocalMap), table.length = 16 --------table[0] -> null --------table[1] -> entry key = null, value = threadLocal2-value --------table[2] -> null --------table[3] -> null --------table[4] -> null --------table[5] -> null --------table[6] -> null --------table[7] -> null --------table[8] -> entry key = java.lang.ThreadLocal@2b63dc81, value = null --------table[9] -> null --------table[10] -> null --------table[11] -> null --------table[12] -> null --------table[13] -> null --------table[14] -> null --------table[15] -> null print threadLocalMap after call a new newThreadLocal2.set thread1 ----threadLocals (ThreadLocalMap), table.length = 16 --------table[0] -> null --------table[1] -> null --------table[2] -> null --------table[3] -> null --------table[4] -> null --------table[5] -> null --------table[6] -> null --------table[7] -> null --------table[8] -> entry key = java.lang.ThreadLocal@2b63dc81, value = null --------table[9] -> null --------table[10] -> null --------table[11] -> null --------table[12] -> null --------table[13] -> null --------table[14] -> null --------table[15] -> entry key = java.lang.ThreadLocal@2e93c547, value = newThreadLocal2-value ------------------------------------------
从结果上来看,
static void testExpungeAllEntries() throws InterruptedException { Utils.println("----------testExpungeStaleEntries----------"); Thread thread1 = new Thread(() -> { try { resetNextHashCode(); int threshold = 16 * 2 / 3; ThreadLocal[] threadLocals = new ThreadLocal[threshold - 1]; for(int i = 0; i < threshold - 1; i ++){ threadLocals[i] = new ThreadLocal<String>(); threadLocals[i].set("threadLocal" + i + "-value"); } Object threadLocalMap = getThreadLocalMap(Thread.currentThread()); threadLocals[1] = null; threadLocals[8] = null; //threadLocals[6] = null; //threadLocals[4] = null; //threadLocals[2] = null; forceGC(); Utils.println("print threadLocalMap after gc"); printThreadLocalMap(threadLocalMap); ThreadLocal<String> newThreadLocal1 = new ThreadLocal<>(); newThreadLocal1.set("newThreadLocal1-value"); Utils.println("print threadLocalMap after call a new newThreadLocal1.get"); printThreadLocalMap(threadLocalMap); } catch (NoSuchFieldException | IllegalAccessException e) { e.printStackTrace(); } }, "thread1"); thread1.start(); thread1.join(); Utils.println("------------------------------------------"); }
我们先创建了9个threadLocal对象并设置了值,然后去掉了其中2个的强引用(注意这2个可不是随意挑选的)。
gc后再添加一个新的threadLocal,最后打印出最新的map。输出为:
----------testExpungeStaleEntries---------- print threadLocalMap after gc thread1 ----threadLocals (ThreadLocalMap), table.length = 16 --------table[0] -> null --------table[1] -> entry key = null, value = threadLocal1-value --------table[2] -> entry key = null, value = threadLocal8-value --------table[3] -> null --------table[4] -> entry key = java.lang.ThreadLocal@60523912, value = threadLocal6-value --------table[5] -> null --------table[6] -> entry key = java.lang.ThreadLocal@48fccd7a, value = threadLocal4-value --------table[7] -> null --------table[8] -> entry key = java.lang.ThreadLocal@188bbe72, value = threadLocal2-value --------table[9] -> null --------table[10] -> entry key = java.lang.ThreadLocal@19e0ebe8, value = threadLocal0-value --------table[11] -> entry key = java.lang.ThreadLocal@688bcb6f, value = threadLocal7-value --------table[12] -> null --------table[13] -> entry key = java.lang.ThreadLocal@46324c19, value = threadLocal5-value --------table[14] -> null --------table[15] -> entry key = java.lang.ThreadLocal@38f1283, value = threadLocal3-value print threadLocalMap after call a new newThreadLocal1.get thread1 ----threadLocals (ThreadLocalMap), table.length = 32 --------table[0] -> null --------table[1] -> null --------table[2] -> null --------table[3] -> null --------table[4] -> null --------table[5] -> null --------table[6] -> entry key = java.lang.ThreadLocal@48fccd7a, value = threadLocal4-value --------table[7] -> null --------table[8] -> null --------table[9] -> entry key = java.lang.ThreadLocal@1dae16b1, value = newThreadLocal1-value --------table[10] -> entry key = java.lang.ThreadLocal@19e0ebe8, value = threadLocal0-value --------table[11] -> null --------table[12] -> null --------table[13] -> entry key = java.lang.ThreadLocal@46324c19, value = threadLocal5-value --------table[14] -> null --------table[15] -> null --------table[16] -> null --------table[17] -> null --------table[18] -> null --------table[19] -> null --------table[20] -> entry key = java.lang.ThreadLocal@60523912, value = threadLocal6-value --------table[21] -> null --------table[22] -> null --------table[23] -> null --------table[24] -> entry key = java.lang.ThreadLocal@188bbe72, value = threadLocal2-value --------table[25] -> null --------table[26] -> null --------table[27] -> entry key = java.lang.ThreadLocal@688bcb6f, value = threadLocal7-value --------table[28] -> null --------table[29] -> null --------table[30] -> null --------table[31] -> entry key = java.lang.ThreadLocal@38f1283, value = threadLocal3-value ------------------------------------------
从结果上看:
如果在gc前,我们把threadLocals[1、8、6、4、2]都去掉强引用,加入新threadLocal后会发现1、8、6、4、2被清除了,但没有扩容,因为此时size是5,小于10-10/4。这个逻辑就不贴测试结果了,你可以取消注释上面代码中相关的逻辑试试。
回到现实中。
我们用ThreadLocal的目的,无非是在跨方法调用时更方便的线程安全地存储和使用变量。这就意味着ThreadLocal的生命周期很长,甚至和app是一起存活的,强引用一直在。
既然强引用一直存在,那么弱引用就形同虚设了。
所以在确定不再需要ThreadLocal中的值的情况下, 还是老老实实的调用remove方法吧!
https://github.com/kongxiangxin/pine/tree/master/threadlocal