Java的并发API已经提供了大量的接口和类来帮助我们编写并发程序,但有时这些类仍无法满足我们的需求。这时我们就可以定制属于自己的并发类,通常来讲我们可以通过继承已有的并发类并对某些方法进行修改、拓展即可达到这一目的
我们可以通过继承 ThreadPoolExecutor
类并覆盖父类的某些方法来定制我们自己的执行器
MyExecutor(定制的执行器):
package day08.code_01; import java.util.Date; import java.util.List; import java.util.concurrent.*; public class MyExecutor extends ThreadPoolExecutor { //存储任务开始时间的map private ConcurrentHashMap<String, Date> startTime; //覆盖构造方法 public MyExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue); startTime = new ConcurrentHashMap<>(); } @Override public void shutdown() { //在shutdown方法中输出线程池相关信息 System.out.printf("MyExecutor: Going to shutdown/n"); //执行完毕的任务数量 System.out.printf("MyExecutor: Executed tasks: %d/n", getCompletedTaskCount()); //正在执行的任务数量 System.out.printf("MyExecutor: Running tasks: %d/n", getActiveCount()); //等待执行的任务数量 System.out.printf("MyExecutor: Pending tasks: %d/n", getQueue().size()); //调用父类的shutdown方法 super.shutdown(); } @Override public List<Runnable> shutdownNow() { //在shutdownNow方法中输出线程池相关信息 System.out.printf("MyExecutor: Going to immediately shutdown/n"); //执行完毕的任务数量 System.out.printf("MyExecutor: Executed tasks: %d/n", getCompletedTaskCount()); //正在执行的任务数量 System.out.printf("MyExecutor: Running tasks: %d/n", getActiveCount()); //等待执行的任务数量 System.out.printf("MyExecutor: Pending tasks: %d/n", getQueue().size()); //调用父类的shutdownNow方法 return super.shutdownNow(); } @Override protected void beforeExecute(Thread t, Runnable r) { //在任务开始执行前打印线程名称和任务的哈希码 System.out.printf("MyExecutor: A task is beginning: %s: %s/n", t.getName(), r.hashCode()); //以任务哈希码为键,日期为值,装入map中 startTime.put(String.valueOf(r.hashCode()), new Date()); //调用父类的beforeExecute方法 super.beforeExecute(t, r); } @Override protected void afterExecute(Runnable r, Throwable t) { //对任务进行类型强转 Future<?> result = (Future<?>) r; try { //打印任务结束提示语 System.out.println("*****************************"); System.out.println("MyExecutor: A task is finishing"); //打印结果 System.out.printf("MyExecutor: Result: %s/n", result.get()); //计算执行所花费的时间 Date startDate = startTime.remove(String.valueOf(r.hashCode())); Date finishDate = new Date(); long diff = finishDate.getTime() - startDate.getTime(); //打印执行所花费的时间 System.out.printf("MyExecutor: Duration: %d/n", diff); System.out.println("*****************************"); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } //调用父类方法 super.afterExecute(r, t); } } 复制代码
任务类:
package day08.code_01; import java.util.Date; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; public class SleepTwoSecondsTask implements Callable<String> { @Override public String call() throws Exception { //休眠2秒 TimeUnit.SECONDS.sleep(2); //返回时间字符串 return new Date().toString(); } } 复制代码
main方法:
package day08.code_01; import java.util.ArrayList; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; public class Main { public static void main(String[] args) { //创建自定义的ThreadPoolExecutor对象 MyExecutor myExecutor = new MyExecutor (2, 4, 1000, TimeUnit.MILLISECONDS, new LinkedBlockingDeque<>()); //创建装载Future对象的集合 ArrayList<Future<String>> results = new ArrayList<>(); //发送十个任务 for (int i = 0; i < 10; i++) { //创建任务 SleepTwoSecondsTask task = new SleepTwoSecondsTask(); //将任务发送给执行器 Future<String> result = myExecutor.submit(task); //将得到的Future对象装入集合 results.add(result); } //尝试获取前5个任务的结果 for (int i = 0; i < 5; i++) { try { //得到任务执行结束后返回的结果 String result = results.get(i).get(); //打印任务编号及结果 System.out.printf("Main: Result for Task %d : %s/n", i, result); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } } //关闭执行器 myExecutor.shutdown(); //尝试获取后5个任务的结果 for (int i = 5; i < 10; i++) { try { //得到任务执行结束后返回的结果 String result = results.get(i).get(); //打印任务编号及结果 System.out.printf("Main: Result for Task %d : %s/n", i, result); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } } //等待执行器关闭 try { myExecutor.awaitTermination(1, TimeUnit.DAYS); } catch (InterruptedException e) { e.printStackTrace(); } //打印程序结束提示语 System.out.printf("Main: End of the program/n"); } } 复制代码
执行器内部使用一个阻塞队列来装载等待执行的任务,我们可以通过 ThreadPoolExecutor
类的构造函数传入一个实现了 BlockingQueue<E>
接口的对象引用。Java为我们提供了具有不同特点的阻塞队列实现类,例如我们之前使用过的优先级队列 PriorityBlockingQueue
类。将此类的对象作为执行器中装载任务的阻塞队列可以实现按照优先级执行任务的效果,需要注意的是,在这种情况下我们的任务类不止要实现 Runnable
接口还需要实现 Comparable
接口,具体原因已经在day07中做过记录,就不在此赘述了
任务类:
package day08.code_02; import java.util.concurrent.TimeUnit; public class MyPriorityTask implements Runnable, Comparable<MyPriorityTask> { //优先级 private int priority; //任务名称 private String name; public MyPriorityTask(String name, int priority) { this.priority = priority; this.name = name; } public int getPriority() { return priority; } @Override public int compareTo(MyPriorityTask o) { //如果优先级较高,排在队列靠前位置 if (this.getPriority() > o.getPriority()) { return -1; //优先级较低排在靠后位置 } else if (this.getPriority() < o.getPriority()) { return 1; } //优先级相同则没有明确顺序 return 0; } @Override public void run() { //打印任务名称和优先级 System.out.printf("MyPriorityTask: %s Priority : %d/n", name, priority); //休眠两秒 try { TimeUnit.SECONDS.sleep(2); } catch (InterruptedException e) { e.printStackTrace(); } } } 复制代码
main方法:
package day08.code_02; import java.util.concurrent.PriorityBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; public class Main { public static void main(String[] args) { //创建执行器,任务队列使用优先级队列 ThreadPoolExecutor executor = new ThreadPoolExecutor( 2, 2, 1, TimeUnit.SECONDS, new PriorityBlockingQueue<Runnable>()); //创建四个任务 for (int i = 0; i < 4; i++) { //通过构造方法设置任务名称和优先级 MyPriorityTask task = new MyPriorityTask("Task" + i, i); //将任务发送到执行器 executor.execute(task); } //休眠1秒 try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } //再次创建四个任务 for (int i = 4; i < 8; i++) { //通过构造方法设置任务名称和优先级 MyPriorityTask task = new MyPriorityTask("Task" + i, i); //将任务发送到执行器 executor.execute(task); } //关闭执行器 executor.shutdown(); //等待执行器将所有任务执行完毕 try { executor.awaitTermination(1, TimeUnit.DAYS); } catch (InterruptedException e) { e.printStackTrace(); } //打印程序结束信息 System.out.printf("Main: End of the program/n"); } } 复制代码
我们可以通过实现 ThreadFactory
接口来定制特殊的线程工厂类,使用线程工厂创建线程对象较为简单,并且可以线程创建线程的数量。当然,如果我们仅仅是把定制的线程工厂类作为一个独立的类来使用,那么完全可以不实现 ThreadFactory
接口;但是如果打算与其他并发API组合使用,例如将线程工厂的引用作为一个参数传入其他方法中,那么就必须实现这一接口。另外,我们也可以使用 Executors
类的 defaultThreadFactory()
方法获取到一个最基本的线程工厂,这个工厂会生成同属于一个线程组的基本线程对象
在这个范例中,我们将使用定制线程工厂去创建定制线程
定制线程工厂类:
package day08.code_03; import java.util.concurrent.ThreadFactory; public class MyThreadFactory implements ThreadFactory { //计数器 private int counter; //名称前缀 private String prefix; public MyThreadFactory(String prefix) { this.prefix = prefix; counter = 1; } @Override public Thread newThread(Runnable r) { //创建线程,名字为前缀加计数器数字 MyThread myThread = new MyThread(r, prefix + "-" + counter); //计数器自增 counter++; //返回创建好的线程 return myThread; } } 复制代码
定制线程类:
package day08.code_03; import java.util.Date; public class MyThread extends Thread { //线程创建时间 private Date creationDate; //线程开始执行时间 private Date startDate; //线程执行结束时间 private Date finishDate; //重写构造函数 public MyThread(Runnable target, String name) { super(target, name); setCreationDate(); } @Override public void run() { //设置开始时间 setStartDate(); //执行任务 super.run(); //设置结束时间 setFinishDate(); } //设置线程被创建的时间 public void setCreationDate() { creationDate = new Date(); } //设置线程开始执行的时间 public void setStartDate() { startDate = new Date(); } //设置线程执行结束的时间 public void setFinishDate() { finishDate = new Date(); } //获取线程执行任务所消耗的时间 public long getExecutionTime() { return finishDate.getTime() - startDate.getTime(); } //重写toString方法 @Override public String toString() { StringBuilder builder = new StringBuilder(); //线程名称 builder.append(getName()); builder.append(" : "); //创建时间 builder.append("Creation Date: "); builder.append(creationDate); //运行时长 builder.append(" Running time: "); builder.append(getExecutionTime()); builder.append(" Milliseconds"); return builder.toString(); } } 复制代码
任务类:
package day08.code_03; import java.util.concurrent.TimeUnit; public class MyTask implements Runnable { @Override public void run() { //休眠2秒 try { TimeUnit.SECONDS.sleep(2); } catch (InterruptedException e) { e.printStackTrace(); } } } 复制代码
main方法:
package day08.code_03; public class Main { public static void main(String[] args) { //创建定制的线程工厂 MyThreadFactory myTactory = new MyThreadFactory("MyThreadFactory"); //创建任务 MyTask myTask = new MyTask(); //创建定制的线程对象 Thread thread = myTactory.newThread(myTask); //开启线程 thread.start(); //等待线程执行结束 try { thread.join(); } catch (InterruptedException e) { e.printStackTrace(); } //打印线程执行结束信息 System.out.printf("Main: Thread information/n"); System.out.printf("%s/n", thread); System.out.printf("Main: End of the example/n"); } } 复制代码
在第三小节中,我们编写了自己的线程工厂类和线程类。因为我们实现了 ThreadFactory
接口,因此在创建执行器时,可以将定制线程工厂对象作为参数传入。这样一来执行器在创建线程时便会使用我们自定义的线程工厂
在这个范例中,使用了第三小节的 MyThread
、 MyThreadFactory
、 MyTask
类,代码是完全一样的,这里就只给出main方法。
main方法:
package day08.code_04; import day08.code_03.MyTask; import day08.code_03.MyThreadFactory; import java.util.concurrent.*; public class Main { public static void main(String[] args) throws InterruptedException { //创建定制工厂对象 MyThreadFactory myTactory = new MyThreadFactory("MyThreadFactory"); //创建执行器并将定制工厂对象作为参数传入 ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newCachedThreadPool(myTactory); //创建任务 MyTask myTask = new MyTask(); //提交任务 executor.submit(myTask); //关闭执行器 executor.shutdown(); //等待执行器将所有任务执行完毕 executor.awaitTermination(1, TimeUnit.DAYS); //打印程序结束信息 System.out.printf("Main: End of the program/n"); } } 复制代码
定时线程池(Scheduled Thread Pool)可以执行延迟任务和周期性任务。其中延迟任务可以执行实现了 Callable
或 Runnable
接口的对象,但是周期性任务只能执行实现了 Runnable
接口的。另外,尽管我们向定时线程池发送的任务对象均是实现了 Callable
或 Runnable
接口的,但实际上任务想要在定时线程池中运行,就必须实现 RunnableScheduledFuture
接口,只不过这一工作由线程池内部的方法帮助我们完成了。在完成以下范例前,我们需要对定时线程池的机制有一定的了解,否则在定制类中重写方法时会无从下手
我们先来大概分析一下 ScheduledThreadPoolExecutor
类的部分源码以及运行流程:
ScheduledThreadPoolExecutor
类中有两个内部类,其中一个是 DelayedWorkQueue
类,它是一个装载任务的队列,每当有任务准备装入队列时,任务的 compareTo
方法会被调用以此来决定任务在队列中所处的位置;另一个是 ScheduledFutureTask
类,它实现了 RunnableScheduledFuture
接口并继承了 FutureTask
类。此内部类有如下几个变量和方法和此范例有较重要的关联:
private final long period;
:变量period用于保存当前任务的执行周期 RunnableScheduledFuture<V> outerTask = this;
:变量outerTask用于保存需要在下一次放入任务队列中的任务,默认指向当前对象 private long time;
:变量time表示周期任务下一次执行的时间(纳秒) getDelay(TimeUnit unit)
:此方法会根据time变量和当前时间的纳秒值来返回以给定时间参数为单位的距离任务下一次开始执行的时间,源码如下,其中 now()
方法返回了当前时间的纳秒值 public long getDelay(TimeUnit unit) { return unit.convert(time - now(), NANOSECONDS); } 复制代码
compareTo(Delayed other)
:此方法在任务被装入任务队列中会被调用。等待队列实际上是使用数组维护的最小堆,待进入队列的元素会和队列中的元素根据下一次任务开始执行的时间进行比较,时间较短的将排在队列较前方。通过源码我们可以看出,内部类的 compareTo
方法会先判断传入的元素是否为当前对象自身,如果是则不进行排序;再判断传入元素是否为内部类对象,如果是则根据time变量进行比较。如果前两个判断均不生效则意味着定时线程池中的任务类是用户的定制任务类,这种情况下调用传入参数自身的 compareTo
方法 public int compareTo(Delayed other) { //判断是否是自身 if (other == this) // compare zero if same object return 0; //判断是否为ScheduledFutureTask类 if (other instanceof ScheduledFutureTask) { ScheduledFutureTask<?> x = (ScheduledFutureTask<?>)other; //根据内部类的成员变量进行比较 long diff = time - x.time; if (diff < 0) return -1; else if (diff > 0) return 1; else if (sequenceNumber < x.sequenceNumber) return -1; else return 1; } long diff = getDelay(NANOSECONDS) - other.getDelay(NANOSECONDS); return (diff < 0) ? -1 : (diff > 0) ? 1 : 0; } 复制代码
run()
:在这个run方法中,值得注意的是在执行完FutureTask的run方法后,我们需要重新设置任务的执行时间,这一操作在这里等价于更新 ScheduledFutureTask
类的成员变量time。另外,由于是周期任务,我们必须将任务重新放入队列中 public void run() { //判断当前任务是否是周期任务 boolean periodic = isPeriodic(); //判断当前状态是否可以执行 if (!canRunInCurrentRunState(periodic)) cancel(false); //如果不是周期任务 else if (!periodic) //调用FutureTask的run方法执行Runnable或Callable实例的run方法 ScheduledFutureTask.super.run(); //否则调用runAndReset方法执行并初始化状态 else if (ScheduledFutureTask.super.runAndReset()) { //设置周期任务下一次的执行时间 setNextRunTime(); //将当前任务重新放入队列中并开启线程执行 reExecutePeriodic(outerTask); } } 复制代码
接下来是 ScheduledThreadPoolExecutor
类,其父类是 ThreadPoolExecutor
。通过查看源码我们可以发现 ScheduledThreadPoolExecutor
类的构造方法通过super调用 ThreadPoolExecutor
类的构造方法,传入一个内部类 DelayedWorkQueue
对象作为线程池的任务队列,原来定时线程池底层使用的仍然是 ThreadPoolExecutor
类。定时线程池主要有如下方法:
scheduleAtFixedRate(Runnable command,long initialDelay, long period,TimeUnit unit)
:此方法在这个范例中需要重写。通过源码我们可以看到,在进行了两次校验后,在方法中创建了一个 ScheduledFutureTask
对象,并将其作为 decorateTask
这一方法的参数。 decorateTask
方法是我们在线程池中使用定制任务类的关键,我们需要重写此方法使其返回一个定制任务类实例。 public ScheduledFuture<?> scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { if (command == null || unit == null) throw new NullPointerException(); if (period <= 0) throw new IllegalArgumentException(); //创建内部类对象 ScheduledFutureTask<Void> sft = new ScheduledFutureTask<Void>(command, null, triggerTime(initialDelay, unit), unit.toNanos(period)); //此方法需要重写,默认是返回sft RunnableScheduledFuture<Void> t = decorateTask(command, sft); //保存任务,方便下一次向队列中提交 sft.outerTask = t; //向队列中添加任务并执行 delayedExecute(t); return t; } 复制代码
decorateTask(Runnable runnable, RunnableScheduledFuture<V> task)
:此方法是留给用户进行拓展的,在这里只是返回了传入的内部类对象,并没有实现什么功能,我们可以重写这个方法,使其返回我们的定制任务类实例 protected <V> RunnableScheduledFuture<V> decorateTask( Runnable runnable, RunnableScheduledFuture<V> task) { return task; } 复制代码
delayedExecute(RunnableScheduledFuture<?> task)
:向队列中延迟提交任务, super.getQueue()
方法在这里得到的是上面记录过的内部队列类 private void delayedExecute(RunnableScheduledFuture<?> task) { //校验 if (isShutdown()) reject(task); else { //调用父类方法向队列中添加任务 super.getQueue().add(task); //校验 if (isShutdown() && !canRunInCurrentRunState(task.isPeriodic()) && remove(task)) task.cancel(false); else //启动一个线程去等待任务 ensurePrestart(); } } 复制代码
在这个范例中,我们将定制一个类并使其在定时线程池中运行。我们需要做一系列的工作,例如实现接口、继承现有类、重写父类方法。
首先我们需要定义一个类并实现 RunnableScheduledFuture
接口来作为在定时线程池中运行的定制任务类,先查看此接口的继承关系。
RunnableScheduledFuture
接口又继承了多个接口,如果直接实现接口就需要重写大量的方法。然而发现
FutureTask
类已经实现了
RunnableFuture
接口,我们只需要继承这个类就可以降低工作量。在定制任务类的构造方法中,我们将调用
FutureTask
类的构造方法传入实现了
Runnable
或
Callable
接口的对象和对应的返回值类型,并为当前类的成员变量赋值。定制类中有一个实现了
RunnableScheduledFuture
接口的成员变量。它将保存定时线程池所返回的内部类,这样做将进一步减轻工作量因为可以在后续的方法中直接调用此对象的方法如
isPeriodic()
、
getDelay()
。需要注意的是,原书在这里还直接调用了此对象的
compareTo()
方法,这是不正确的。因为在定时线程池中执行任务、刷新执行时间的是定制任务类而不是内部任务类,也就是说time变量在第一次被赋值后就不会再改变,然而内部类的
compareTo()
方法却不可避免的使用了time这一变量,这显然是错误的。当我们向定时线程池中添加一个以上的周期任务时就会出现难以预测的问题。
MyScheduledTask(自定义的运行在定时线程池中的任务类):
package day08.code_05; import java.util.Date; import java.util.concurrent.*; public class MyScheduledTask<V> extends FutureTask<V> implements RunnableScheduledFuture<V> { //保存ScheduledFutureTask对象 private RunnableScheduledFuture<V> task; //定时执行器 private ScheduledThreadPoolExecutor executor; //执行周期 private long period; //开始时间 private long startDate; public MyScheduledTask(Runnable runnable, V result, RunnableScheduledFuture<V> task, ScheduledThreadPoolExecutor executor) { //调用父类FutureTask的构造方法 super(runnable, result); this.task = task; this.executor = executor; } public void setPeriod(long period) { this.period = period; } @Override public boolean isPeriodic() { return task.isPeriodic(); } @Override public long getDelay(TimeUnit unit) { //非周期任务直接调用ScheduledFutureTask对象的方法 if (!isPeriodic()) { return task.getDelay(unit); } else { //周期性任务但是还未执行过,直接调用ScheduledFutureTask对象的方法 if (startDate == 0) { return task.getDelay(unit); } else { //周期性任务并且之前执行过 //根据自定义的属性计算出距离下一次运行的时间 Date now = new Date(); long delay = startDate - now.getTime(); return unit.convert(delay, TimeUnit.MILLISECONDS); } } } @Override public int compareTo(Delayed o) { //使用定制任务类自己的方法获取时间差 long diff = this.getDelay(TimeUnit.NANOSECONDS) - o.getDelay(TimeUnit.NANOSECONDS); //较早执行的任务排在队列前面 if (diff < 0) { return -1; //较晚执行的任务排在队列后面 } else if (diff > 0) { return 1; } return 0; } @Override public void run() { //判断任务是否是周期性执行且执行器未关闭 if (isPeriodic() && (!executor.isShutdown())) { //获取当前时间 Date now = new Date(); //计算出下一次任务的执行时间 startDate = now.getTime() + period; //将任务再次加入 executor.getQueue().add(this); } //打印任务开始执行时的日期 System.out.printf("Pre-MyScheduledTask: %s/n", new Date()); //打印任务执行的周期 System.out.printf("MyScheduledTask: Is Periodic: %s/n", isPeriodic()); //调用FutureTask的方法来执行传入的任务并重置 super.runAndReset(); //打印任务结束的时间 System.out.printf("Post-MyScheduledTask: %s/n", new Date()); } } 复制代码
想要使用定制任务类,我们需要与其配套的定制定时线程池,直接继承 ScheduledThreadPoolExecutor
类并重写其中的方法即可。这个范例中定制线程池只重写了2个方法,实际运用中可根据不同需求进行改变:
MyScheduledThreadPoolExecutor
(定制定时线程池类):
package day08.code_05; import java.util.concurrent.RunnableScheduledFuture; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; public class MyScheduledThreadPoolExecutor extends ScheduledThreadPoolExecutor { //指定核心线程数量的构造方法 public MyScheduledThreadPoolExecutor(int corePoolSize) { super(corePoolSize); } @Override public ScheduledFuture<?> scheduleAtFixedRate (Runnable command, long initialDelay, long period, TimeUnit unit) { //调用父类的方法执行传入的任务 ScheduledFuture<?> task = super.scheduleAtFixedRate (command, initialDelay, period, unit); //将返回值强转为定制的任务 MyScheduledTask myTask = (MyScheduledTask) task; //设置任务执行周期 myTask.setPeriod(TimeUnit.MILLISECONDS.convert(period, unit)); //返回任务 return myTask; } //装饰任务方法,此方法会在父类的scheduleAtFixedRate方法中被调用 @Override protected <V> RunnableScheduledFuture<V> decorateTask( Runnable runnable, RunnableScheduledFuture<V> task) { //创建我们自己的定制任务类 //第三个参数task在这里传入的是ScheduledThreadPoolExecutor的内部类ScheduledFutureTask MyScheduledTask<V> myTask = new MyScheduledTask<>(runnable, null, task, this); //返回定制任务类 return myTask; } } 复制代码
Task(向线程池提交的任务类,只进行简单的休眠工作):
package day08.code_05; import java.util.concurrent.TimeUnit; public class Task implements Runnable { @Override public void run() { //打印任务开始提示语 System.out.printf("Task: Begin/n"); //休眠2秒 try { TimeUnit.SECONDS.sleep(2); } catch (InterruptedException e) { e.printStackTrace(); } //打印任务结束提示语 System.out.printf("Task: End/n"); } } 复制代码
main方法:
package day08.code_05; import java.util.Date; import java.util.concurrent.TimeUnit; public class Main { public static void main(String[] args) throws InterruptedException { //创建我们自己的定时执行器 MyScheduledThreadPoolExecutor executor = new MyScheduledThreadPoolExecutor(2); //创建一个任务 Task task = new Task(); //打印程序开始的时间 System.out.printf("Main: %s/n", new Date()); //按照指定时间执行任务 executor.scheduleAtFixedRate(task, 1, 3, TimeUnit.SECONDS); //当前线程休眠10秒 TimeUnit.SECONDS.sleep(10); //关闭执行器 executor.shutdown(); //等待执行器关闭 executor.awaitTermination(1, TimeUnit.DAYS); //打印程序结束提示语 System.out.printf("Main: End of the program/n"); } } 复制代码
之前我们通过实现 ThreadFactory
接口创造了线程工厂以此来生成定制线程。我们同样可以通过实现 ForkJoinWorkerThreadFactory
接口来创造线程工厂以此来为Fork/Join框架生成定制线程。
创建Fork/Join框架中的定制线程,我们可以继承 ForkJoinWorkerThread
类并提供相应的构造方法。在这个范例中,定制线程类还重写了 onStart()
和 onTermination()
方法:
onStart()
:该方法会在线程被创建后,第一个任务开始执行前被自动执行,我们可以重写此方法去初始化线程内部状态或打印日志。根据源码中的注释,如果要重写该方法,需要将 super.onStart()
这条代码放在最开始 onTermination()
:该方法会在线程关闭前执行,我们可以重写此方法在线程关闭之前释放资源或打印日志。根据源码中的注释,如果要重写该方法,需要将 super.onTermination()
这条代码放在末尾 package day08.code_06; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinWorkerThread; public class MyWorkerThread extends ForkJoinWorkerThread { //线程级别的计数器 private static ThreadLocal<Integer> taskCounter = new ThreadLocal<>(); //构造方法 protected MyWorkerThread(ForkJoinPool pool) { super(pool); } @Override protected void onStart() { //必须先调用父类的onStart方法 super.onStart(); //打印线程信息 System.out.printf("MyWorkerThread %d: Initializing task counter/n", getId()); //初始化任务计数器 taskCounter.set(0); } @Override protected void onTermination(Throwable exception) { //打印线程信息和执行的任务数 System.out.printf("MyWorkerThread %d: %d/n", getId(), taskCounter.get()); //必须在最后调用父类的onTermination方法 super.onTermination(exception); } //调用此方法可以改变任务计数器的值 public void addTask() { //得到计数器的值 int counter = taskCounter.get().intValue(); //自增 counter++; //更新计数器的值 taskCounter.set(counter); } } 复制代码
MyWorkerThreadFactory(定制工厂类,使用工厂方法返回定制线程)
package day08.code_06; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinWorkerThread; public class MyWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory { @Override public ForkJoinWorkerThread newThread(ForkJoinPool pool) { return new MyWorkerThread(pool); } } 复制代码
MyRecursiveTask(带有返回值的任务类,在这个范例中将对超大数组求和,不是重点)
package day08.code_06; import java.util.concurrent.ExecutionException; import java.util.concurrent.RecursiveTask; import java.util.concurrent.TimeUnit; public class MyRecursiveTask extends RecursiveTask<Integer> { //超大数组 private int array[]; //任务起始、终止位置 private int start, end; //构造方法 public MyRecursiveTask(int[] array, int start, int end) { this.array = array; this.start = start; this.end = end; } @Override protected Integer compute() { //初始化结果 int ret = 0; //获取当前线程 MyWorkerThread thread = (MyWorkerThread) Thread.currentThread(); //调用线程的addTask方法增加任务计数器的值 thread.addTask(); //如果任务过大则分解 if (end - start > 10000) { int middle = (start + end) / 2; MyRecursiveTask task1 = new MyRecursiveTask(array, start, middle); MyRecursiveTask task2 = new MyRecursiveTask(array, middle, end); //异步执行任务 task1.fork(); task2.fork(); //合并结果 return addResults(task1, task2); } //求出范围内数组的和 for (int i = start; i < end; i++) { ret += array[i]; } //返回结果 return ret; } private Integer addResults(MyRecursiveTask task1, MyRecursiveTask task2) { int value; //尝试获取两个任务的返回值 try { value = task1.get().intValue() + task2.get().intValue(); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); value = 0; } //休眠1秒 try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } //返回结果值 return value; } } 复制代码
main方法:
package day08.code_06; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.TimeUnit; public class Main { public static void main(String[] args) throws ExecutionException, InterruptedException { //创建定制线程工厂 MyWorkerThreadFactory factory = new MyWorkerThreadFactory(); //创建线程池并将定制线程工厂作为参数传入 ForkJoinPool pool = new ForkJoinPool(4, factory, null, false); //创建超大数组并初始化 int[] array = new int[100000]; for (int i = 0; i < array.length; i++) { array[i] = i; } //创建任务 MyRecursiveTask task = new MyRecursiveTask(array, 0, array.length); //异步执行 pool.execute(task); //等待任务执行结束 task.join(); //关闭线程池 pool.shutdown(); //等待线程池中的任务执行结束 pool.awaitTermination(1, TimeUnit.DAYS); //打印任务返回的结果和程序执行结束提示语 System.out.printf("Main: Result: %d/n", task.get()); System.out.println("Main: End of the program"); } } 复制代码
之前我们创建可以使用Fork/Join框架执行的任务通常都是继承 RecursiveAciton
或 RecursiveTask
这两个抽象类并重写其中的方法。实际上,我们也可以根据这两个抽象类的构造去创建定制的任务抽象类。先查看 RecursiveAciton
、 RecursiveTask
这两个抽象类的源码
RecursiveAciton
类:该任务无返回值 public abstract class RecursiveAction extends ForkJoinTask<Void> { private static final long serialVersionUID = 5232453952276485070L; //抽象方法,主要用于重写任务逻辑 protected abstract void compute(); //获取任务结果,当前任务类无返回值,所以此方法必须返回null public final Void getRawResult() { return null; } //设置任务结果,当前任务类无返回值,所以此方法为空 protected final void setRawResult(Void mustBeNull) { } //线程池调用任务的此方法执行任务 //此方法又调用compute方法,我们可以重写此方法做拓展 protected final boolean exec() { compute(); return true; } } 复制代码
RecursiveTask
类:该方法有返回值 public abstract class RecursiveTask<V> extends ForkJoinTask<V> { private static final long serialVersionUID = 5232453952276485270L; //任务结果 V result; //抽象方法,主要用于重写任务逻辑 protected abstract V compute(); //获取任务结果,直接返回成员变量result public final V getRawResult() { return result; } //设置任务结果,为成员变量result赋值 protected final void setRawResult(V value) { result = value; } //线程池调用任务的此方法执行任务 //此方法又调用compute方法,我们可以重写此方法做拓展 protected final boolean exec() { result = compute(); return true; } } 复制代码
根据以上两个类的源码,我们在此范例中将创建自己的无返回值抽象任务类,并使任务类继承此定制抽象类而不是 RecursiveAciton
或 RecursiveTask
这两个类
MyWorkerTask
(定制抽象任务类):
package day08.code_07; import java.util.Date; import java.util.concurrent.ForkJoinTask; public abstract class MyWorkerTask extends ForkJoinTask<Void> { //任务名称 private String name; //构造方法 public MyWorkerTask(String name) { this.name = name; } public String getName() { return name; } @Override public Void getRawResult() { //因为当前任务无返回值,所以返回null return null; } @Override protected void setRawResult(Void value) { //因为无返回值,所以方法为空 } @Override protected boolean exec() { //创建任务开始时间 Date startDate = new Date(); //开始执行任务 compute(); //创建任务执行结束时间 Date finishDate = new Date(); //计算时间差 long diff = finishDate.getTime() - startDate.getTime(); //打印任务执行所花费的时间 System.out.printf("MyWorkerTask: %s : %d Milliseconds to complete/n", name, diff); return true; } protected abstract void compute(); } 复制代码
Task(任务类,继承了定制抽象任务类):
package day08.code_07; public class Task extends MyWorkerTask { //必备元素 private static final long serialVersionUID = 1L; //数组 private int array[]; //任务起始、终止位置 private int start, end; //构造方法 public Task(String name, int[] array, int start, int end) { super(name); this.array = array; this.start = start; this.end = end; } @Override protected void compute() { //如果任务过大,进行拆分 if (end - start > 100) { int mid = (start + end) / 2; Task task1 = new Task(this.getName() + "1", array, start, mid); Task task2 = new Task(this.getName() + "2", array, mid, end); //同步执行 invokeAll(task1, task2); } else { //将范围内的数组元素自增 for (int i = start; i < end; i++) { array[i]++; } } //修庙50毫秒 try { Thread.sleep(50); } catch (InterruptedException e) { e.printStackTrace(); } } } 复制代码
main方法:
package day08.code_07; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.TimeUnit; public class Main { public static void main(String[] args) throws InterruptedException { //创建数组,元素默认都为0 int[] array = new int[10000]; //创建线程池 ForkJoinPool pool = new ForkJoinPool(); //创建任务 Task task = new Task("Task", array, 0, array.length); //同步执行任务 pool.invoke(task); //关闭线程池 pool.shutdown(); //等待线程池执行完所有任务后关闭 pool.awaitTermination(1, TimeUnit.DAYS); //检查任务是否正常完成了 //因为最初元素都为0,正确自增之后应该为1 for (int i = 0; i < array.length; i++) { if (array[i] != 1) { System.out.println("Error!"); } } //打印程序结束提示语 System.out.printf("Main: End of the program/n"); } } 复制代码
之前使用过 ReentrantLock
类作为锁,在这一小节中我们会定制自己的Lock类。以 ReentrantLock
类为例,通过查看源码可以看到锁的底层是靠 AbstractOwnableSynchronizer
这一抽象类(之后简称为AQS类)的子类实现的。查看AQS类的源码,发现此类内部有一个计数器(state)和若干操作此计数器的方法,原来AQS类才是那个真正的‘锁’,之前使用过的Lock类只是在真正的锁上又进行了一层封装。当我们尝试获取锁时,其实是当前线程在尝试改变AQS类内部计数器的值,计数器的值将会以CAS操作来进行更新。如果更新失败则表示当前线程获取锁失败,这时线程会被装入CAS类内部维护的一个队列(链表实现)并不断尝试更改计数器的值,这便是我们在使用锁时看到的线程阻塞直到得到锁这一现象。另外,如果希望定制的锁具有可重入性,我们可以调用AQS类的父类 AbstractOwnableSynchronizer
中的 setExclusiveOwnerThread()
和 getExclusiveOwnerThread()
方法来设置和获取当前持有锁的线程,这样一来在线程尝试修改计数器值时,我们可以判断当前线程是否已经持有了锁并进行对应的操作。在继承AQS抽象类后,我们必须要重写 tryAcquire()
和 tryRelease()
这两个方法因为抽象类中并没有给出这两个方法的正确实现而是直接抛出了异常。
在这个范例中,我们会继承AQS类并重写其中的部分方法来实现定制的AQS类,并以此类为基础实现定制Lock类。最后我们将使用定制Lock类对象来同步代码
MyLock(定制Lock类,需要实现Lock接口):
package day08.code_08; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.AbstractQueuedSynchronizer; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; public class MyLock implements Lock { //定制AQS类对象 private AbstractQueuedSynchronizer sync; public MyLock() { //通过构造方法为AQS对象赋值 sync = new MyAbstractQueuedSynchronizer(); } @Override public void lock() { //调用AQS类的方法尝试修改计数器的值 //此方法内部会调用定制AQS类中的tryAcquire方法 sync.acquire(1); } @Override public void lockInterruptibly() throws InterruptedException { //调用AQS类的方法尝试修改计数器的值(可中断) //此方法内部会调用定制AQS类中的tryAcquire方法 sync.acquireInterruptibly(1); } @Override public boolean tryLock() { //尝试获取锁,如果失败直接返回不阻塞 try { return sync.tryAcquireNanos(1, 1000); } catch (InterruptedException e) { e.printStackTrace(); return false; } } @Override public boolean tryLock(long time, TimeUnit unit) throws InterruptedException { //尝试在指定时间内获取锁,如果失败直接返回不阻塞 return sync.tryAcquireNanos(1, TimeUnit.NANOSECONDS.convert(time, unit)); } @Override public void unlock() { //调用AQS类的方法尝试减少计数器的值 //此方法内部会调用定制AQS类的tryRelease方法 sync.release(1); } @Override public Condition newCondition() { //创建AQS内部类对象并返回 return sync.new ConditionObject(); } } 复制代码
MyAbstractQueuedSynchronizer(定制AQS类):
package day08.code_08; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.AbstractQueuedSynchronizer; public class MyAbstractQueuedSynchronizer extends AbstractQueuedSynchronizer { //采用原子变量作为内部计数器 private volatile AtomicInteger state; public MyAbstractQueuedSynchronizer() { //在构造方法中初始化计数器 state = new AtomicInteger(0); } @Override protected boolean tryAcquire(int arg) { //获得当前线程 Thread now = Thread.currentThread(); //判断当前线程是否为持锁线程 if (getExclusiveOwnerThread() == now) { //增加计数器的值 state.set(state.get() + arg); return true; //否则尝试增大计数器的值 } else if (state.compareAndSet(0, arg)) { //修改成功,设置当前线程为持锁线程 setExclusiveOwnerThread(now); return true; } //修改失败返回false return false; } @Override protected boolean tryRelease(int arg) { //获得当前线程 Thread now = Thread.currentThread(); //当前线程不是持锁线程就直接抛异常 if (now != getExclusiveOwnerThread()) { throw new RuntimeException("Error!"); } //得到计数器当前值 int number = state.get(); //判断减少指定参数后是否为0 if (number - arg == 0) { //为0则表示线程释放了锁,将持锁线程设置为null setExclusiveOwnerThread(null); } //减少计数器的值 return state.compareAndSet(number, number - arg); } } 复制代码
Task(任务类):
package day08.code_08; import java.util.concurrent.TimeUnit; public class Task implements Runnable { //定制锁 private MyLock lock; //任务名曾 private String name; public Task(MyLock lock, String name) { this.lock = lock; this.name = name; } @Override public void run() { //获取锁 lock.lock(); //打印获取锁的提示信息 System.out.printf("Task: %s: Take the lock/n", name); //调用hello方法,主要为了测试定制锁的可重入性 hello(); //休眠两秒 try { TimeUnit.SECONDS.sleep(2); //打印释放锁的提示信息 System.out.printf("Task: %s: Free the lock/n", name); } catch (InterruptedException e) { e.printStackTrace(); } finally { //释放锁 lock.unlock(); } } private void hello() { //获取锁 lock.lock(); //打印Hello System.out.println("Hello!"); //释放锁 lock.unlock(); } } 复制代码
main方法:
package day08.code_08; import java.util.concurrent.TimeUnit; public class Main { public static void main(String[] args) { //创建定制锁对象 MyLock lock = new MyLock(); //创建十个任务并分别开启线程执行 for (int i = 0; i < 10; i++) { Task task = new Task(lock, "Task-" + i); Thread thread = new Thread(task); thread.start(); } //主线程休眠两秒 try { TimeUnit.SECONDS.sleep(2); } catch (InterruptedException e) { e.printStackTrace(); } boolean value; //不断自旋尝试获取锁 do { try { value = lock.tryLock(1, TimeUnit.SECONDS); //获取锁失败打印相关信息 if (!value) { System.out.printf("Main: Trying to get the Lock/n"); } } catch (InterruptedException e) { e.printStackTrace(); value = false; } } while (!value); //打印成功获取锁信息 System.out.println("Main: Got the lock"); //释放锁 lock.unlock(); //打印程序结束信息 System.out.println("Main: End of the program"); } } 复制代码