public class Semaphore implements java.io.Serializable 复制代码
//序列化版本号 private static final long serialVersionUID = -3222578661600680210L; //同步器,AbstractQueuedSynchronizer的子类 private final Sync sync; 复制代码
从字段属性中可以看出
//传入信号数 public Semaphore(int permits) { //默认使用非公平锁 sync = new NonfairSync(permits); } //传入信号数和锁的类型 public Semaphore(int permits, boolean fair) { sync = fair ? new FairSync(permits) : new NonfairSync(permits); } 复制代码
从构造方法中可以看出
//可中断的获取信号量 public void acquire() throws InterruptedException { //调用sync的acquireSharedInterruptibly方法 sync.acquireSharedInterruptibly(1); } //获取指定数量的信号量 public void acquire(int permits) throws InterruptedException { if (permits < 0) throw new IllegalArgumentException(); //调用sync的acquireSharedInterruptibly方法 sync.acquireSharedInterruptibly(permits); } 复制代码
//不可中断的获取信号量 public void acquireUninterruptibly() { //调用sync的acquireShared方法 sync.acquireShared(1); } //不可中断的获取指定数量的信号量 public void acquireUninterruptibly(int permits) { if (permits < 0) throw new IllegalArgumentException(); sync.acquireShared(permits); } 复制代码
//尝试获取信号量 public boolean tryAcquire() { //调用sync的nonfairTryAcquireShared方法 return sync.nonfairTryAcquireShared(1) >= 0; } //尝试获取指定数量的信号量 public boolean tryAcquire(int permits) { if (permits < 0) throw new IllegalArgumentException(); //尝试获取指定数量的信号量 return sync.nonfairTryAcquireShared(permits) >= 0; } //设置超时时间的尝试获取信号量 public boolean tryAcquire(long timeout, TimeUnit unit) throws InterruptedException { //调用sync的tryAcquireSharedNanos方法 return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout)); } //设置超时时间的尝试获取指定数量的信号量 public boolean tryAcquire(int permits, long timeout, TimeUnit unit) throws InterruptedException { if (permits < 0) throw new IllegalArgumentException(); //调用sync的tryAcquireSharedNanos方法 return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout)); } 复制代码
//释放一个信号量 public void release() { //调用sync的releaseShared方法 sync.releaseShared(1); } //释放指定数量的信号量 public void release(int permits) { if (permits < 0) throw new IllegalArgumentException(); sync.releaseShared(permits); } 复制代码
//获取当前可用的通道数 public int availablePermits() { return sync.getPermits(); } 复制代码
//获取立即可用的通道数 public int drainPermits() { return sync.drainPermits(); } 复制代码
//减少信号数 protected void reducePermits(int reduction) { if (reduction < 0) throw new IllegalArgumentException(); sync.reducePermits(reduction); } 复制代码
//获取锁的类型, true 公平锁, false 非公平锁 public boolean isFair() { return sync instanceof FairSync; } 复制代码
//队列中是否有正在等信号的线程 public final boolean hasQueuedThreads() { return sync.hasQueuedThreads(); } 复制代码
//获取队列中等待信号的线程数 public final int getQueueLength() { return sync.getQueueLength(); } 复制代码
//获取队列中的线程,以集合的方式返回 protected Collection<Thread> getQueuedThreads() { return sync.getQueuedThreads(); } 复制代码
abstract static class Sync extends AbstractQueuedSynchronizer 复制代码
从类的定义中可以看出
//序列化版本号 private static final long serialVersionUID = 1192457210091910933L; 复制代码
//传入的信号数就是AQS中的state Sync(int permits) { setState(permits); } 复制代码
//获取信号数 final int getPermits() { //实质上就是获取state的值 return getState(); } 复制代码
//非公平方式获取共享锁,返回剩余可用信号数 final int nonfairTryAcquireShared(int acquires) { //for无限循环,自旋CAS for (;;) { //获取当前的state int available = getState(); //可用信号-获取数量=剩余可用数量 int remaining = available - acquires; if (remaining < 0 || compareAndSetState(available, remaining)) return remaining; } } 复制代码
//尝试释放共享锁 protected final boolean tryReleaseShared(int releases) { //for无限循环,自旋CAS for (;;) { //获取当前的state int current = getState(); //可用信号+释放数量=新的可用数量 int next = current + releases; if (next < current) // overflow //releases小于0 抛出异常 throw new Error("Maximum permit count exceeded"); //CAS,设置新值 if (compareAndSetState(current, next)) return true; } } 复制代码
//减少信号数量 final void reducePermits(int reductions) { //for无限循环,自旋CAS for (;;) { //获取当前的state int current = getState(); //可用信号-减少的数量=新的可用数量 int next = current - reductions; if (next > current) // underflow //reductions小于0 抛出异常 throw new Error("Permit count underflow"); //CAS,设置新值 if (compareAndSetState(current, next)) return; } } 复制代码
//清空信号 final int drainPermits() { //CAS 自旋把state置为0 for (;;) { int current = getState(); if (current == 0 || compareAndSetState(current, 0)) return current; } } 复制代码
static final class NonfairSync extends Sync 复制代码
从类的定义可以看出
//序列化版本号 private static final long serialVersionUID = -2694183684443567898L; 复制代码
//设置state NonfairSync(int permits) { super(permits); } 复制代码
//尝试获取共享锁 protected int tryAcquireShared(int acquires) { //调用父类的方法nonfairTryAcquireShared,直接抢锁 return nonfairTryAcquireShared(acquires); } 复制代码
static final class FairSync extends Sync 复制代码
从类的定义中可以看出
//序列化版本号 private static final long serialVersionUID = 2014338818796000944L; 复制代码
//设置state FairSync(int permits) { super(permits); } 复制代码
//尝试获取共享锁 protected int tryAcquireShared(int acquires) { for (;;) { //先看队列中是否有线程排队 if (hasQueuedPredecessors()) return -1; //获取state的值 int available = getState(); //可用信号-获取数量=剩余可用数量 int remaining = available - acquires; if (remaining < 0 || compareAndSetState(available, remaining)) return remaining; } } 复制代码