转载

ThreadLocal的原理及源码解读

本文来自我的个人网站 欢迎大家来看

需要提前搞懂的一些名词

开放寻址: 可以参考这个地址

关于强引用、弱引用、软引用、虚引用的含义, 请参考这个地址

什么是ThreadLocal

摘自 百度百科

JDK 1.2的版本中就提供java.lang.ThreadLocal,ThreadLocal为解决多线程程序的并发问题提供了一种新的思路。使用这个工具类可以很简洁地编写出优美的多线程程序,ThreadLocal并不是一个 Thread ,而是 Thread局部变量

为什么要使用ThreadLocal

多线程环境下有时候需要调用一些公共资源,例如我们想使用 SimpleDateFormat 这个来格式化我们的 Date 日期,通常的思路是我们在工具类定义一个格式化时间的静态方法:

public class Util{
    static SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
    public static String format(Date date){
        return sdf.format(date); 
    }
}

这样使用时不安全的。当我们使用 Util.format() 方法时会出现线程安全问题,因为 SimpleDateFormat 在多线程环境下是不安全的。

那么如何解决呢?下面有两种能用但是不推荐的解决思路

  1. 每次使用新的对象,将方法修改为:

    public class Util{ 
        public static String format(Date date){
            //每次使用新的对象
            return new SimpleDateFormat("yyyy-MM-dd").format(date); 
        }
    }

    缺点:每次都是重新 new 一个对象出来,浪费内存,失去了抽离公共方法的意义。

  2. 使用锁,将格式化的方法使用锁包起来:

    public class Util{
        static SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
        public static String format(Date date){
            //将线程不安全的代码锁起来
            synchronized (sdf) {
                return sdf.format(date);
            }  
        }
    }

    缺点:降低了并发性,多线程环境下只能串行使用该方法

如何使用ThreadLocal

由于上面两种思路弊端太多于是我们使用 ThreadLocal 来完善我们的格式化日期的方法:

public class Util{
    static final ThreadLocal<SimpleDateFormat> dateFormatThreadLocal = new ThreadLocal(){
        /** 设置线程本地变量的初始化方法 **/
        @Override
        protected Object initialValue() {
            return new SimpleDateFormat("yyyy-MM-dd");
        }
    };
    public static String format(Date date){
        //获取当前线程的 dateFormatThreadLocal 对象。
        SimpleDateFormat simpleDateFormat = dateFormatThreadLocal.get();
        //格式化日期
        return simpleDateFormat.format(date); 
    }
     
    public static void main(String[] args) {
        //使用工具的格式化日期的方法
        String dateStr = Util.format(new Date());
        System.out.println(dateStr);//输出2020-04-10
    }
    
}

ThreadLocal原理

整体思路

开头我们就说到 ThreadLocal 并不是一个 Thread ,而是 Thread局部变量 ,其中的重点是 ThreadLocal线程Thread 的局部变量,也就是说每一个线程中都有 ThreadLocal 这个局部变量。

我们可以看下 Thread 的源码

public class Thread implements Runnable {
    //...省略其他代码
    
    //Thread中维护了一个ThreadLocal.ThreadLocalMap变量
    ThreadLocal.ThreadLocalMap threadLocals = null;
    
    //...省略其他代码
    
}

看到这里也许有个疑问,不是说好的线程中维护的的局部变量是 ThreadLocal 嘛?怎么是一个 ThreadLocal.ThreadLocalMap 呢?,这个 ThreadLocal.ThreadLocalMap 其实是 ThreadLocal 类中的一个静态内部类,

并且 ThreadLocal.ThreadLocalMap 中还有一个静态内部类 Entry 和一个 Entry数组 , Entry数组 是真正存放线程本地变量的地方。

public class ThreadLocal<T> { 
    //静态内部类
    static class ThreadLocalMap { 
        //静态内部类并且继承弱引用,内存不够可以直接回收防止一个线程内的ThreadLocal对象过多。
        static class Entry extends WeakReference<ThreadLocal<?>> {
            //value用于线程的本地变量
            Object value; 
            //Entry的构造方法 K:ThreadLocal对象  v:存储在线程的本地变量
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        } 
        //Entry[] 存放多个该线程的本地变量
        private Entry[] table;
         
    } 
}

看起来有点乱,我来画一张图帮助理解,以上面的我们使用 ThreadLocal 来包装的 SimpleDateFormat 的案例来讲如下图所示:

ThreadLocal的原理及源码解读

我们要理清的思路就是,线程中的 ThreadLocalMap 对象中维护了一个 Entry数组Entry 对象有两个属性比较重要,一个是 referent 属性 这个属性是继承父类的用于存放 ThreadLocal 的对象弱引用,一个是 value 属性 这个属性用于存放线程的本地变量。我们使用线程本地变量时需要通过传入 ThreadLocal 对象来找到 value 的值的。

源码讲解

当我们使用 dateFormatThreadLocal.get() 方法获取当前线程的本地变量时,内部都做了那些呢?这个get()方法是怎么做到不管哪个线程过来都能获取到自己线程内部的本地变量呢?下面是get()的源码:,大家可以结合我写的步骤一步一步走,

public class ThreadLocal<T> { 
    //...
    public T get() {
        //(1) 获取当前调用该方法的线程
        Thread t = Thread.currentThread();
        //(2) 获取当前线程的threadLocals局部变量,也就是 ThreadLocal.ThreadLocalMap 
        ThreadLocalMap map = getMap(t);
        //判断是否为null
        if (map != null) { 
            //(4) 获取Entry数组中referent属性是dateFormatThreadLocal对象的Entry对象。  
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                //Entry对象不为空,则返回Entry对象中的value值
                //也就是我们初始化时定义的SimpleDateFormat对象
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        //(25) map为null则说明你是首次调用,并且没有进行过初始化操作,于是进行初始化操作
        return setInitialValue();
    }
    
    ThreadLocalMap getMap(Thread t) {
        //(3) 返回当前线程的 ThreadLocalMap
        return t.threadLocals;
    }
    //静态内部类
    static class ThreadLocalMap {  
        //静态内部类
        static class Entry extends WeakReference<ThreadLocal<?>> { 
        }   
        
        private int threshold; //扩容的阈值默认为2/3
        private void setThreshold(int len) { threshold = len * 2 / 3;}
        //存放线程本地变量的数组,数组长度一定是2的幂
        private Entry[] table; 
        private Entry getEntry(ThreadLocal<?> key) {
            //(5) ThreadLocal自定义的 hash值 和 (Entry的长度-1) 做与运算获得存放的下标位置 
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            /**(8) Entry对象不为null,并且Entry对象的referent引用是我们的dateFormatThreadLocal对象的引用时返回Entry对象。e.get()==key使用的是==比较,比较的是内存地址,因为我们已经知道了,所有线程共同持有dateFormatThreadLocal的弱引用,所以内存地址是相同的 **/
            if (e != null && e.get() == key)
                return e;
            else
                //(9) Entry对象为null或者该位置存放的不是我们dateFormatThreadLocal对象,则需要向后面的位置找,直到遇到下标位置上是null则结束。
                return getEntryAfterMiss(key, i, e);
        }
        //采用线性探测找到对应元素
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length; 
            //(10)Entry对象不为null
            while (e != null) {
                //获取Entry的referent(ThreadLocal对象)
                ThreadLocal<?> k = e.get();
                //(11)判断当前的referent是不是我们dateFormatThreadLocal的引用地址
                if (k == key)
                    return e;
                if (k == null)
                    //(12)Entry对象不为null但是referent==null说明referent这个弱引用已经被内存回收
                    //必须要要执行清理工作,否则会造成内存泄漏
                    expungeStaleEntry(i);
                else
                    //(23)referent既不是我们要找的dateFormatThreadLocal又不等于null,循环下一个下标。
                    //下标+1,如果超过length则为0
                    i = nextIndex(i, len);
                e = tab[i];
            }
            //(24)如果找了一圈都没有找到,则返回null,
            return null;
        }
        
        //清理当前槽位及后面槽位的无效值,并rehash
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // (13)清理value这个强引用。
            tab[staleSlot].value = null;
            // (14)清理Entry对象。
            tab[staleSlot] = null;
            size--;

            Entry e;
            int i;
            //(15)清理后面槽位的ThreadLocal对象,因为他们的referent有可能也已经内存回收了,并且执行rehash操作。  
            for (i = nextIndex(staleSlot, len); 
                     (e = tab[i]) != null; 
                         i = nextIndex(i, len)) {
                //(16)获取下一个槽位的referent引用
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    //(17)referent引用为空则清理,避免内存泄漏
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else { 
                   //(18)referent不为null则重新查看ThreadLocal的下标位置是否是对象申请时定义的位置
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        //(19)两个位置不相等说明将ThreadLocal对象实例化放入到线程的Entry数组时,对应位置已经有另外的ThreadLocal对象了,那么只能使用开放寻址方式向后顺延找一个空的位置存放对象,现在我们在(14)步清理了前面位置的那个另外的ThreadLocal对象了,现在应该将我向前挪,放在我起初应该在的位置。要理解这个下面的步骤就能搞懂了。
                        //(20)当前i位置不要了,因为我要向前挪了
                        tab[i] = null;  
 
                        //(21)判断我原来的位置h的位置如果不为null,说明我起初存放位置时顺延了多次,那么第14步删除的应该是我顺延当中的某一个位置,我要从h位置一个一个往后找直到找到第一个空位置。
                        while (tab[h] != null)
                            h = nextIndex(h, len);//下标+1,如果超过length则为0
                        //(22)rehash成功
                        tab[h] = e;
                    }
                }
            }
            //返回staleSlot下标向后第一个Entry为null的下标
            return i;
        }
        
        
        private void set(ThreadLocal<?> key, Object value) {
 
            Entry[] tab = table;
            int len = tab.length;
            //(29)获取ThreadLocal对象对应的下标
            int i = key.threadLocalHashCode & (len-1);
            //(30)如果e!=null说明当前有对象,要判断是否对当前对象弱引用已经无效,然后进行清理
            for (Entry e = tab[i];
                     e != null;
                         e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                
                //(31)当前的ThreadLocal已经存放过了,则直接覆盖旧值
                if (k == key) {
                    e.value = value;
                    return;
                } 
                //(32)e!=null,但是referent为null,则表明弱引用已经GC回收,直接将value值放在这个已经废弃的槽位上就行了,并把旧值删除
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            //(51)如果初始位置i为null,并且经过第30步循环搜索后面也没有找到,说明该ThreadLoc对象是首次存储,直接放到i位置上
            tab[i] = new Entry(key, value);
            //(52)已存放的ThreadLocal数量+1
            int sz = ++size;
            //(53)如果i下标往后连续 log2(sz)个都是有效数据,并且sz已经达到rehash的阈值则进行rehash
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
        
        //替换已经GC回收的槽位数据
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;
             //(33)使用变量slotToExpunge保存清理的起始下标
            int slotToExpunge = staleSlot;
            //(34)为了更多的清理,我们从staleSlot槽位向前搜索如果前面某个槽位的Entry对象不为null并且referent也被GC回收,则将slotToExpunge的清理下标向前挪动一次。
            for (int i = prevIndex(staleSlot, len);
                     (e = tab[i]) != null;
                         i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            //(35)从staleSlot下一个位置开始,连续搜索一段Entry不为null的槽位,这个循环的目的是为了查看之前是否存放过这个ThreadLocal对象,因为存放时有可能进行了位置的顺延,所以要向后搜索看看该对象是不是放在后面的位置了
            for (int i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                         i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                //(36)如果找到了,说明之前存入的时候肯定位置顺延了。
                if (k == key) {
                    //(37) 将Entry对象的旧的value替换
                    e.value = value;
                    //(38)将staleSlot的Entry放到 i 的位置,等着一会被清理。
                    tab[i] = tab[staleSlot];
                    //(39)将替换了新value的Entry对象放到staleSlot位置,因为存放的时候staleSlot位置有数据了所以位置顺延了,现在staleSlot位置没有数据了,那我可以回到staleSlot位置了。
                    tab[staleSlot] = e;

                    // (40)如果第 34 步清理的下标没有向前挪动,说明从staleSlot到i的位置数据都是有效数据,则把清理的下标放到i的位置,从i位置开始清理
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    //(41)清理数据 expungeStaleEntry(slotToExpunge)方法可以看第 13步
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                //(48)如果i下标的referent引用被GC回收,并且清除下标等于staleSlot,那么将清除下标移动到i的位置
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            //(49)如果staleSlot的下一个下标上的Entry对象为null,则第35步骤的循环没有走,则将value填充到当前staleSlot位置上
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // (50)如果staleSlot的前面或者后面有垃圾则进行清理 
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
        
        /**启发式地扫描一些无效Entry
            参数i位置的Entry为null,并且从i开始扫描
            n控制扫描,正常情况是log2(n),除非找到新的无效Entry
            返回下标i后面是否清除过无效数据
        **/
        private boolean cleanSomeSlots(int i, int n) {
            //(42)清除标识
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                //(43)从i的下一个位置开始
                i = nextIndex(i, len);
                Entry e = tab[i];
                //(44)如果Entry不为null,但是referent为null,说明弱引用已经回收
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    //(45)移除i及i后面的无效Entry,并对扫描的Entry进行rehash,返回第一个Entry为null的i下标
                    i = expungeStaleEntry(i);
                }
                //(46)无符号右移1位,如果n=16则循环log2(n)=4次后结束
            } while ( (n >>>= 1) != 0);
            //(47)返回清除标识
            return removed;
        }
       
         
    } 
    //(6) ThreadLocal自定义的hash值,改值在对象初始化时通过nextHashCode()得到一个固定值。
    private final int threadLocalHashCode = nextHashCode();  
    private static AtomicInteger nextHashCode = new AtomicInteger(); //初始值为0
    private static final int HASH_INCREMENT = 0x61c88647; //自增步长
    private static int nextHashCode() {
        //(7) 通过AtomicInteger 每次增加 0x61c88647来获取下一次的hash值。它和Entry数组长度做与运算得到存储ThreadLocal对象的下标。这是一个静态方法,那么同一个ThreadLocal对象在不同线程的Entry数组中的下标是一样的。
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
    //初始化方法
    private T setInitialValue() {
        //(26)通过初始化方法获得dateFormatThreadLocal对象
        T value = initialValue();
        Thread t = Thread.currentThread();
        //(27)获得当前线程的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null)
            //(28)设置dateFormatThreadLocal对象到线程的ThreadLocalMap中
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
    
    
 
}

难理解部分解读

Entry数组的长度为啥一定要是2的幂

以下是源码

/**
 * The table, resized as necessary.
 * table.length MUST always be a power of two. 长度必须是2的幂
 */
private Entry[] table;

这就不得不说我们计算下标的方式了,通常我们求下标是用 % 取模方式,为啥这里能用 & 运算来取下标呢?有一下几点原因:

  1. & 要比 % 运算快
  2. 使用 & 运算可以达到和 % 一样的效果,他们之间存在一定的联系。假如我们使用 hash 来求下标,那么使用 % 计算方式为: hash % lentgh ,使用 & 的计算方式为: hash & (length-1) 如果我们想让这两个公式相等,那么 length 必须是 2的幂。这种前提下这两个公式的效果一样,也就是 hash % lentgh == hash & (length-1)
  3. rehash 时数据迁移少,如果 length 是2的幂,在 rehash 以后数据迁移时,移动的数据更少,例如 length = 8 ,假如 A.hash = 4 B.hash = 15 C.hash = 33 那么 A ` B C`存放的下标是 `4%8=0 15%8=7 33%8=1`,如果长度扩大一倍变成 `16`,则 `A B C`的下标变成 `4 %16=0 15%16=15 33%16=1 ,那么我们只需要移动 B 到新的位置就行了,其他的不需要移动。

不光是这里,在 JAVAHashMap 中也有这种体现,你可以看下源码, HashMap 在初始化时的长度也一定是2的幂,以后的扩容也是 length*2 来扩容的

关于ThreadLocal的threadLocalHashCode成员变量

我们需要知道该成员变量的2点内容:

  1. threadLocalHashCode 该成员变量在 ThreadLocal 对象初始化时通过静态方法获得,它的作用是计算 ThreadLocal 对象存放在线程 ThreadLocal.ThreadLocalMap 对象中 Entry数组 的下标。假如计算出的下标为5,那么==所有线程==存放该 ThreadLocal 对象时都会放在下标为5的位置上。
  2. 初始化 threadLocalHashCode 属性的静态方法是通过 AtomicInteger 增加 0x61c88647(魔数--与斐波那契散列有关) 来得到当前 hashCode 的值,在上面的==第(7)步==我们可以看到,然后再通过这个 hashCode 来计算下标。假如同时有 16个 ThreadLocal 对象 放到的 Entry数组 中,他们的下标我在下面都计算了出来

    0x61c88647 & 15    
     1100001110010001000011001000111 & 1111     = 0111 = 7  rehash2倍后下标= 00111 = 7
     0x61c88647*2    & 15   
     11000011100100010000110010001110 &1111     = 1110 = 14 rehash2倍后下标= 01110 = 14
     0x61c88647*3    & 15                  
     100100101010110011001001011010101 & 1111     = 0101 = 5  rehash2倍后下标= 10101 = 21
     0x61c88647*4    & 15                    
     110000111001000100001100100011100 & 1111     = 1100 = 12 rehash2倍后下标= 11100 = 28
     0x61c88647*5    & 15                    
     111101000111010101001111101100011 & 1111     = 0011 = 3  rehash2倍后下标= 00011 = 3
     0x61c88647*6    & 15                    
     1001001010101100110010010110101010 & 1111     = 1010 = 10 rehash2倍后下标= 01010 = 10
     0x61c88647*7    & 15                    
     1010101100011110111010101111110001 & 1111     = 0001 = 1  rehash2倍后下标= 10001 = 17
     0x61c88647*8    & 15                    
     1100001110010001000011001000111000 & 1111     = 1000 = 8  rehash2倍后下标= 11000 = 24
     0x61c88647*9    & 15                    
     1101110000000011001011100001111111 & 1111     = 1111 = 15 rehash2倍后下标= 11111 = 31
     0x61c88647*10 & 15                        
     1111010001110101010011111011000110 & 1111 = 0110 = 6  rehash2倍后下标= 00110 = 6 
      //默认扩容阈值为2/3从这里就开始扩容了
     0x61c88647*11 & 15                     
     10000110011100111011100010100001101 & 1111 = 1101 = 13  rehash2倍后下标= 01101 = 13
     0x61c88647*12 & 15                 
     10010010101011001100100101101010100 & 1111 = 0100 = 4  rehash2倍后下标= 10100 = 20
     0x61c88647*13 & 15             
     10011110111001011101101000110011011 & 1111 = 1011 = 11  rehash2倍后下标= 11011 = 27
     0x61c88647*14 & 15                         
     10101011000111101110101011111100010 & 1111 = 0010 = 2  rehash2倍后下标= 00010 = 2
     0x61c88647*15 & 15                     
     10110111010101111111101111000101001 & 1111= 1001 = 9  rehash2倍后下标= 01001 =  9
     0x61c88647*16 & 15                         
     11000011100100010000110010001110000 & 1111= 0000 = 0 rehash2倍后下标= 10000 = 16

    通过这个魔数累加后和 length-1 做计算所得下标的散列值非常好,基本不会出现冲突,再结合一定的阈值配置,那么就可以快速定位到当前对象的下标了 。但是扩容的阈值控制的再好也会出现下标冲突的时候,如果出现了则依据开放寻址方法寻找一个合适的位置。

关于ThreadLocal对象存储在Entry数组中为位置顺延(开放地址法) 问题

当多个ThreadLocal对象存储到线程的Entry数组时,下标有可能会出现冲突,那么就会依据开放寻址法,找到后面一个空的位置来存放。画个图帮助理解:

ThreadLocal的原理及源码解读

关于弱引用

原文  https://segmentfault.com/a/1190000022347362
正文到此结束
Loading...