我们平时开发中,应该遇到过这样的需求:一个功能需要几个线程一起合作完成,然后要等这些线程都处理完成了,才能继续后续的操作。这时我们就可以选择使用CountDownLatch这个并发工具包。
package com.demo; import java.text.SimpleDateFormat; import java.util.Date; import java.util.concurrent.CountDownLatch; public class CountDownLatchDemo { public static void main(String[] args) throws InterruptedException { // 参数5表示我们需要等待的线程数 CountDownLatch cdl = new CountDownLatch(5); // 启动5个子线程 for (int i = 0; i < 5; i++) { new Thread(() -> { try { // ... Thread.sleep(3000); }catch (Exception e){ e.printStackTrace(); }finally { // 子线程完成以后,调用CountDownLatch.countDown()方法 System.out.println(Thread.currentThread().getName() + "执行完成"); cdl.countDown(); } }, "线程-" + (i+1)).start(); } // 调用调用await方法后,主线程阻塞,并等待所有子线程执行完成 cdl.await(); System.out.println("子线程全部执行完成:" + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date())); } } 复制代码
执行结果如下:
线程-1执行完成 线程-4执行完成 线程-5执行完成 线程-2执行完成 线程-3执行完成 子线程全部执行完成:2020-03-12 12:42:41 复制代码
我们看到,主线程开启了5个子线程,调用了CountDownLatch.await()方法后主线程会进入阻塞状态,当所有子线程执行完成后,主线程才可以继续执行。
我们通过追踪源码来研究一下CountDownLatch的内部原理。首先我们先是new了一个CountDownLatch对象,进入它的构造方法:
public class CountDownLatch { public CountDownLatch(int count) { if (count < 0) throw new IllegalArgumentException("count < 0"); // 新建一个同步器 this.sync = new Sync(count); } } 复制代码
这边新建了一个Sync对象,Sync是它的内部类,我们进去看下:
private static final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981922014374L; Sync(int count) { // 将同步状态设置为指定的值 setState(count); } int getCount() { return getState(); } // 省略一些代码 ... } 复制代码
我们看到Sync继承了AbstractQueuedSynchronizer,说明CountDownLatch是基于AQS( AQS实现原理 )实现的,进入setState方法,这边只是把同步状态state赋值为我们传入的线程数。
public abstract class AbstractQueuedSynchronizer ... // 将同步状态设置为5 protected final void setState(int newState) { state = newState; } } 复制代码
上述就是新建一个CountDownLatch对象的逻辑,下面看下调用await方法的做了什么,进入await方法:
public class CountDownLatch { public void await() throws InterruptedException { // 父类AbstractQueuedSynchronizer中实现 sync.acquireSharedInterruptibly(1); } } 复制代码
acquireSharedInterruptibly这个方法在Sync的父类AbstractQueuedSynchronizer中实现:
public abstract class AbstractQueuedSynchronizer ... public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); // 获取共享锁 if (tryAcquireShared(arg) < 0) doAcquireSharedInterruptibly(arg); } // 这边空实现,其它是交给子类去实现,我们这边是Sync protected int tryAcquireShared(int arg) {throw new UnsupportedOperationException(); } } 复制代码
我们再回到Sync中,看下tryAcquireShared这个方法的实现:
protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; } 复制代码
这边很简单,比较同步状态state,等于0就返回1,反则返回-1,一开始我们初始化成了5,所以这边返回-1。返回至AbstractQueuedSynchronizer中:
public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); // 这边返回确实小于0 if (tryAcquireShared(arg) < 0) doAcquireSharedInterruptibly(arg); } 复制代码
因为前面tryAcquireShared返回的是-1,所以这边if条件成立,进入doAcquireSharedInterruptibly方法:
public abstract class AbstractQueuedSynchronizer ... private void doAcquireSharedInterruptibly(int arg) throws InterruptedException { // 当前线程包装成node放入阻塞队列 final Node node = addWaiter(Node.SHARED); boolean failed = true; try { for (;;) { final Node p = node.predecessor(); if (p == head) { int r = tryAcquireShared(arg); if (r >= 0) { setHeadAndPropagate(node, r); p.next = null; failed = false; return; } } if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) throw new InterruptedException(); } } finally { if (failed) cancelAcquire(node); } } } 复制代码
这边我们只关注下parkAndCheckInterrupt方法:
private final boolean parkAndCheckInterrupt() { // 阻塞当前线程 LockSupport.park(this); return Thread.interrupted(); } 复制代码
分析到目前,我们知道主线程调用完await方法后会一直阻塞在这里。 那么它是如何被唤醒的呢?接着我们继续追踪一下CountDownLatch.countDown的实现,我们进入这个方法:
public class CountDownLatch { ... public void countDown() { // 调用父类AbstractQueuedSynchronizer的方法 sync.releaseShared(1); } } 复制代码
同样releaseShared这个方法是在父类AbstractQueuedSynchronizer中实现的:
public abstract class AbstractQueuedSynchronizer public final boolean releaseShared(int arg) { // 模板方法,调用子类Sync的tryReleaseShared方法 if (tryReleaseShared(arg)) { doReleaseShared(); return true; } return false; } } 复制代码
tryReleaseShared是在Sync中实现的:
protected boolean tryReleaseShared(int releases) { // Decrement count; signal when transition to zero for (;;) { int c = getState(); // 同步状态已经是0了就直接返回false if (c == 0) return false; int nextc = c-1; // 使用CAS操作:同步状态减1 if (compareAndSetState(c, nextc)) return nextc == 0; } } 复制代码
这段代码的逻辑是,每当调用一次countDown方法,同步状态state就会减1,当同步状态state减至0后就会返回true,所以当我们的最后一个子线程执行完countDown方法后,就会返回true。我们回到releaseShared方法:
public abstract class AbstractQueuedSynchronizer public final boolean releaseShared(int arg) { // 当最后一个线程执行完,tryReleaseShared返回true if (tryReleaseShared(arg)) { doReleaseShared(); return true; } return false; } } 复制代码
我们知道最后一个线程执行完成后,tryReleaseShared返回true,所以会进入doReleaseShared方法:
private void doReleaseShared() { for (;;) { Node h = head; if (h != null && h != tail) { int ws = h.waitStatus; if (ws == Node.SIGNAL) { if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) continue; // 唤醒阻塞的线程 unparkSuccessor(h); } else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) continue; } if (h == head) break; } } 复制代码
这边我们只关注下unparkSuccessor方法:
private void unparkSuccessor(Node node) { ...// 省略一些代码 if (s != null) // 唤醒线程 LockSupport.unpark(s.thread); } 复制代码
最终我们看到当最后子线程执行完毕后,主线程会被唤醒。