【java】Java 1.8 fork / join

描述 java 1.8 的 fork / join 功能

概述

Fork/Join 框架是一个实现了 ExecutorService 接口的多线程处理器。它可以把一个大的任务划分为若干个小的任务并发执行,小任务执行完成后,再合并成最终结果。框架使用了工作窃取(work-stealing)算法,空闲的线程可以从满负荷的线程中窃取任务来帮忙执行。

在 JDK 1.8 中,实现 Fork/Join 的关键类:

  • ForkJoinPool:用于执行任务的线程池
  • WorkQueue:ForkJoinPool 的内部类,实现一个双向队列,用于支持工作窃取算法(内部使用)
  • ForkJoinWorkThread:执行任务的线程
  • ForkJoinTask:表示 Fork/Join 任务,实现了 Future 接口。主要使用两个子类 RecursiveTask 和 RecursiveAction

示例

要计算一个数组的所有元素的数,可以使用 fork/join。 如果数组的长度小于阈值,则直接进行计算;否则将数组分成两部分,分别计算两部分的结果,然后合并。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
public class FJTask {

public static void main(String[] args) {
long [] array = new long[400];
for (int i = 0 ; i < array.length ; ++i) {
array[i] = i+1;
}

// 创建 fork/join 线程池
ForkJoinPool fjp = new ForkJoinPool(2); // 最大并发数2
ForkJoinTask<Long> task = new SumTask(array, 0, array.length);

// 计算并记录时间
long startTime = System.currentTimeMillis();
Long result = fjp.invoke(task);
long endTime = System.currentTimeMillis();

System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
}

// 求和任务
static class SumTask extends RecursiveTask<Long> {
private static final long serialVersionUID = 1L;

static final int THRESHOLD = 50;

long[] array;
int start;
int end;

SumTask(long[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}

@Override
protected Long compute() {
// 如果任务足够小,直接计算
if (end - start <= THRESHOLD) {
return sum(start, end);
} else {
// 将任务分解成多个小任务
int middle = (end + start) / 2;
SumTask subtask1 = new SumTask(array, start, middle);
SumTask subtask2 = new SumTask(array, middle, end);

// 计算
invokeAll(subtask1, subtask2);

// 合并结果
Long subresult1 = subtask1.join();
Long subresult2 = subtask2.join();

return subresult1 + subresult2;
}
}

// 求和
private long sum(int start, int end) {
long sum = 0;
for (int i = start; i < end; ++i) {
sum += array[i];
}
return sum;
}
}
}

执行过程

  1. 主线程执行 fjp.invoke(task); 后就阻塞,等待计算完成
    invoke 内部会调用 externalPush 将任务 SumTask(0-400) 放到任务队列,然后生成一个 ForkJoinWorkerThread(假设是 Thread-1)执行任务代码

  2. Thread-1 执行任务 SumTask(0-400)
    对于大于阈值的任务, SumTask 分解成两个子任务 SumTask(0-200)、SumTask(200-400)。
    在 invokeAll 方法中,会执行 SumTask(200-400) 的 fork 方法,产生线程 Thread-2 执行任务,然后执行 SumTask(0-200) 任务,待 T2 执行完成后再返回

  3. Thread-2 执行任务 SumTask(200-400)
    这时任务分配图如下, Thread-1 执行 SumTask(0-200),Thread-2 执行 SumTask(200-400)
    setp-1.png

  4. 假设在 sum 方法打断点,两个线程都执行到断点
    这时任务分配图如下:
    setp-2.png
    由于线程数已经是 2,所以 fork 方法不会产生新线程,而是将任务 push 到线程各自的工作队列。因此,到断点位置,各线程的状态是

  • Thread-1 执行 SumTask(0-50) 的求和计算, SumTask(0-400)、SumTask(0-200) 和 SumTask(0-100) 在调用栈中,SumTask(100-200) 和 SumTask(50-100) 在工作队列中
  • Thread-2 执行 SumTask(200-250) 的求和计算,SumTask(200-400) 和 SumTask(200-300) 在调用栈中,SumTask(300-400) 和 SumTask(250-300) 在工作队列中
  1. 让 Thread-2 执行,Thread-2 会依次计算
  • SumTask(200-250)、SumTask(250-300)
  • SumTask(200-300) 执行合并
  • SumTask(300-350) 、SumTask(350-400)(从 SumTask(300-400) 分解出来)
  • SumTask(300-400) 执行合并
  • SumTask(200-400) 执行合并
    到这里, Thread-2 的所有计算任务执行完成,但 Thread-1 的工作队列中还有 SumTask(50-100) 和 SumTask(100-200) 。Thread-2 会调用 ForkJoinPool 的 scan 方法从 Thread-1 的工作队列中“窃取”任务来执行,于是 Thread-2 继续计算
  • SumTask(100-150)、SumTask(150-200)(从 SumTask(100-250) 分解出来)
  • SumTask(100-250) 执行合并
  • SumTask(50-100)
    到这里 Thread-2 都没有任务可执行,Thread-2 进入等待状态
  1. 让 Thread-1 执行,Thread-1 会依次计算
  • SumTask(0-50)
  • SumTask(0-100) 执行合,其中 SumTask(50-100) 是 Thread-2 的计算结果
  • SumTask(0-200) 执行合,其中 SumTask(100-200) 是 Thread-2 的计算结果
  • SumTask(0-400) 执行合,其中 SumTask(200-400) 是 Thread-2 的计算结果
  1. 至此所有计算任务执行完成,fjp.invoke(task); 返回最后计算结果

JDK 1.8 API 参考

ForkJoinPool

ForkJoinPool 既是用于执行任务的线程池,也是用户提交任务的入口

构造方法

1
public ForkJoinPool()

构造一个并行度等于 CPU 核心个数的 ForkJoinPool,CPU 核心个数由 Runtime.getRuntime().availableProcessors() 取得,其余参数采用默认值

1
public ForkJoinPool(int parallelism)

构造一个并行度是 parallelism 的 ForkJoinPool,其余参数采用默认值

1
public ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode)

构造一个并行度是 parallelism 的 ForkJoinPool,指定了 factory, handler 和 asyncMode

1
private ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, int mode, String workerNamePrefix)

上面所有公共构造方法都用这个实现最终的构造

还有一个静态实例通过 ForkJoinPools.commonPool() 获取,大部分场景都可以通过这个静态实例使用 Fork/Join。

主要 public 成员方法

1
public <T> T invoke(ForkJoinTask<T> task)

提交一个 ForkJoinTask, 并等待计算完成,返回计算结果

1
public void execute(ForkJoinTask<?> task)

提交一个 ForkJoinTask,并立即返回(无结果)

1
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task)

提交一个 ForkJoinTask,返回入参 task

1
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)

将容器里的 Callable 转换为 ForkJoinTask.AdaptedCallable 并提交,返回一个 ForkJoinTask.AdaptedCallable (实现了 Future 接口)的 List

1
public boolean awaitTermination(long timeout, TimeUnit unit)

等待线程池终止。由于 commonPool() 不会终止,对 commonPool() 调用 awaitTermination 等同于调用 awaitQuiescence

1
public boolean awaitQuiescence(long timeout, TimeUnit unit)

如果在 ForkJoinTask 中调用(当前线程是 ForkJoinWorkerThread),等价于调用 helpQuiescePool,否则等待线程池静止

ForkJoinTask

表示在 ForkJoinPool 中运行的 task 的抽象基类。ForkJoinTask 是类似于线程的实体,但比线程要轻很多。大量的 task 可以被少量线程管理。

主要 public 成员方法

1
public final ForkJoinTask<V> fork()

调度任务,如果当前线程是 ForkJoinWorkerThread,则加到线程的工作队列,否则加到 ForkJoinPool.common

条件允许的话会创建新的线程来处理这个任务

1
public final V join()

返回任务的计算结果,如果计算未完成,会阻塞等待。

1
public final V invoke()

执行任务计算,如果需要的话,会阻塞等待。

1
2
3
4
5
public static void invokeAll(ForkJoinTask<?> t1, ForkJoinTask<?> t2)

public static void invokeAll(ForkJoinTask<?>... tasks)

public static <T extends ForkJoinTask<?>> Collection<T> invokeAll(Collection<T> tasks)

这几个方法都是执行多个任务