并发类工具库的线程安全问题

Posted by Haiming on April 5, 2022

发现一个非常好的教程:https://learn.lianglianglee.com/%E4%B8%93%E6%A0%8F/Java%20%E4%B8%9A%E5%8A%A1%E5%BC%80%E5%8F%91%E5%B8%B8%E8%A7%81%E9%94%99%E8%AF%AF%20100%20%E4%BE%8B

里面全是在业务相关时候的实操。

今天先讲一个并发类的使用方式:

忽略了 ThreadLocal 变量所在线程是线程池之中对象

问题:使用 ThreadLocal 来获取信息,有时候获取的信息是错误的。

原因:在使用 ThreadLocal 之后没有将其值清空,而程序是运行在 Tomcat 里面的,执行程序的线程是 Tomcat 的工作线程,其基于线程池。那么重用固定的几个线程,获取的可能是之前其他用户的遗留值。这时候所取得的信息就是错误的。

在写业务代码的时候,要弄清楚代码会跑在什么线程上。Web 服务器之中往往会使用线程池来处理请求,这意味线程会被重用。那么使用 ThreadLocal 来存放一些数据的时候,要注意在代码运行结束之后去显式的清除对应数据。

问题代码

private static final ThreadLocal<Integer> currentUser = ThreadLocal.withInitial(() -> null);

@GetMapping("wrong")

public Map wrong(@RequestParam("userId") Integer userId) {

    //设置用户信息之前先查询一次ThreadLocal中的用户信息

    String before  = Thread.currentThread().getName() + ":" + currentUser.get();

    //设置用户信息到ThreadLocal

    currentUser.set(userId);

    //设置用户信息之后再查询一次ThreadLocal中的用户信息

    String after  = Thread.currentThread().getName() + ":" + currentUser.get();

    //汇总输出两次查询结果

    Map result = new HashMap();

    result.put("before", before);

    result.put("after", after);

    return result;

}

修正方案:在代码的 finally 之中,显式的清除 ThreadLocal 之中的数据。这样新的请求过来也不会获取到错误的用户信息。

修正代码:

@GetMapping("right")

public Map right(@RequestParam("userId") Integer userId) {

    String before  = Thread.currentThread().getName() + ":" + currentUser.get();

    currentUser.set(userId);

    try {

        String after = Thread.currentThread().getName() + ":" + currentUser.get();

        Map result = new HashMap();

        result.put("before", before);

        result.put("after", after);

        return result;

    } finally {

        //在finally代码块中删除ThreadLocal中的数据,确保数据不串

        currentUser.remove();

    }

}

线程安全的并发工具使用错误

比如 ConcurrentHashMap,其线程安全只能保证原子性的读写操作是线程安全的。如果操作不是原子性的,或者更过分的操作之中插入了一些其他操作,那么就会出现线程安全的问题。

下面这段代码之中,就是先通过 size()拿到当前的元素数量,计算需要补充多少元素之后对其 log,最后在 putAll 之中把缺少的元素添加进去。

//线程个数

private static int THREAD_COUNT = 10;

//总元素数量

private static int ITEM_COUNT = 1000;

//帮助方法,用来获得一个指定元素数量模拟数据的ConcurrentHashMap

private ConcurrentHashMap<String, Long> getData(int count) {

    return LongStream.rangeClosed(1, count)

            .boxed()

            .collect(Collectors.toConcurrentMap(i -> UUID.randomUUID().toString(), Function.identity(),

                    (o1, o2) -> o1, ConcurrentHashMap::new));

}

@GetMapping("wrong")

public String wrong() throws InterruptedException {

    ConcurrentHashMap<String, Long> concurrentHashMap = getData(ITEM_COUNT - 100);

    //初始900个元素

    log.info("init size:{}", concurrentHashMap.size());

    ForkJoinPool forkJoinPool = new ForkJoinPool(THREAD_COUNT);

    //使用线程池并发处理逻辑

    forkJoinPool.execute(() -> IntStream.rangeClosed(1, 10).parallel().forEach(i -> {

        //查询还需要补充多少个元素

        int gap = ITEM_COUNT - concurrentHashMap.size();

        log.info("gap size:{}", gap);

        //补充元素

        concurrentHashMap.putAll(getData(gap));

    }));

    //等待所有任务完成

    forkJoinPool.shutdown();

    forkJoinPool.awaitTermination(1, TimeUnit.HOURS);

    //最后元素个数会是1000吗?

    log.info("finish size:{}", concurrentHashMap.size());

    return "OK";

}

输出日志为:

img

可以看到好几个线程都往里面怼了100个元素,还有的直接在 gap 里面看到了-236,意味着这个时候这个13号线程里面的元素数量已经是1236个,远远超出了范围。

首先,我们要分清,concurrentHashMap 这种线程安全是怎么个安全法:

其线程安全,指的是不同的线程使用这个类的时候不会有相互干扰的问题。比如 put(),其会先锁住对应的节点然后做 CAS,这样就导致两个线程的更改不会同时产生影响。本文之中的例子,使用的是size(),我们去源码之中看一下。

    /**
     * Returns the number of key-value mappings in this map.  If the
     * map contains more than {@code Integer.MAX_VALUE} elements, returns
     * {@code Integer.MAX_VALUE}.
     *
     * @return the number of key-value mappings in this map
     */
    public int size() {
        long n = sumCount();
        return ((n < 0L) ? 0 :
                (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
                (int)n);
    }

其中主要的方法是 sumCount():

    final long sumCount() {
        CounterCell[] cs = counterCells;
        long sum = baseCount;
        if (cs != null) {
            for (CounterCell c : cs)
                if (c != null)
                    sum += c.value;
        }
        return sum;
    }

把类之中本身存在的 counterCells 进行一个指针的 copy,然后把其中的值进行相加。这个 counterCell 本身是一个分布式的相加,不是我们的重点,所以先放在这源码:

    /**
     * A padded cell for distributing counts.  Adapted from LongAdder
     * and Striped64.  See their internal docs for explanation.
     */
    @jdk.internal.vm.annotation.Contended static final class CounterCell {
        volatile long value;
        CounterCell(long x) { value = x; }
    }

这一段代码的逻辑,并不能保证在线程池之中的每个线程调用这个服务的时候,都能够拿到一样的值。(只是对 CounterCell 里面的值进行一定的处理)。

而且代码之中,是先拿到对应的 size,打了一个 log,再进行相应的填充,很有可能在线程 A 拿值的时候,值是900,然后线程 B 拿值的时候,值也是900,A 直接做一个100的填充,B 并不知道,也填100,导致最后的结果错误。

修正代码:

效率低:加锁

把从查看gap 到最后填充值的部分全都加锁来保证线程安全

@GetMapping("right")

public String right() throws InterruptedException {

    ConcurrentHashMap<String, Long> concurrentHashMap = getData(ITEM_COUNT - 100);

    log.info("init size:{}", concurrentHashMap.size());



    ForkJoinPool forkJoinPool = new ForkJoinPool(THREAD_COUNT);

    forkJoinPool.execute(() -> IntStream.rangeClosed(1, 10).parallel().forEach(i -> {

        //下面的这段复合逻辑需要锁一下这个ConcurrentHashMap

        synchronized (concurrentHashMap) {

            int gap = ITEM_COUNT - concurrentHashMap.size();

            log.info("gap size:{}", gap);

            concurrentHashMap.putAll(getData(gap));

        }

    }));

    forkJoinPool.shutdown();

    forkJoinPool.awaitTermination(1, TimeUnit.HOURS);

    log.info("finish size:{}", concurrentHashMap.size());

    return "OK";

}

结果:img

效率高:使用自带的工具类

下面是一个新的统计需求:

使用 ConcurrentHashMap 来统计,Key 的范围是 10。

使用最多 10 个并发,循环操作 1000 万次,每次操作累加随机的 Key。

如果 Key 不存在的话,首次设置值为 1。

优化之后的代码:

private Map<String, Long> gooduse() throws InterruptedException {

    ConcurrentHashMap<String, LongAdder> freqs = new ConcurrentHashMap<>(ITEM_COUNT);

    ForkJoinPool forkJoinPool = new ForkJoinPool(THREAD_COUNT);

    forkJoinPool.execute(() -> IntStream.rangeClosed(1, LOOP_COUNT).parallel().forEach(i -> {

        String key = "item" + ThreadLocalRandom.current().nextInt(ITEM_COUNT);

                //利用computeIfAbsent()方法来实例化LongAdder,然后利用LongAdder来进行线程安全计数

                freqs.computeIfAbsent(key, k -> new LongAdder()).increment();

            }

    ));

    forkJoinPool.shutdown();

    forkJoinPool.awaitTermination(1, TimeUnit.HOURS);

    //因为我们的Value是LongAdder而不是Long,所以需要做一次转换才能返回

    return freqs.entrySet().stream()

            .collect(Collectors.toMap(

                    e -> e.getKey(),

                    e -> e.getValue().longValue())

            );

}

我们使用了 concurrentHashMap 的原子性方法 computeIfAbsent 来做复合的逻辑操作:

注:这个地方下面有一个思考题,就是 computeIfAbsent() 和 putIfAbsent() 有什么区别。其中 putIfAbsent()是如果本身是空那么进行填充值,但是 computeIfAbsent()则是可以先做一段计算,再将值放到 map 之中。

先上可以用 concurrentHashMap 做统计的官方文档注释:

 * <p>A ConcurrentHashMap can be used as a scalable frequency map (a
 * form of histogram or multiset) by using {@link
 * java.util.concurrent.atomic.LongAdder} values and initializing via
 * {@link #computeIfAbsent computeIfAbsent}. For example, to add a count
 * to a {@code ConcurrentHashMap<String,LongAdder> freqs}, you can use
 * {@code freqs.computeIfAbsent(key, k -> new LongAdder()).increment();}

这个函数本身返回的是 V:

public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {}

那么对于上面这一段:

freqs.computeIfAbsent(key, k -> new LongAdder()).increment();

本身 computeIfAbsent 返回的就是一个 LongAdder,那么就可以调用其中的 increment() 方法。

而文档之中说了,computeIfAbsent 是原子的,原理是 Unsafe 自带的 CAS。在 LongAdder 我们也可以看到其是一个线程安全的累加器,那么可以直接调用 increment 方法进行累加。

ForkJoinPool

FirkJpinPool 本身是一个线程池,用来将对应的任务进行分解再处理。很像算法之中的递归解法。

参考:https://www.liaoxuefeng.com/wiki/1252599548343744/1306581226487842

那么逻辑要分成三部分,第一部分是对任务进行分解,第二部分是判断任务是否已经小到足够执行,并且执行。最后一部分对结果进行一个汇总即可。

public class Main {
    public static void main(String[] args) throws Exception {
        // 创建2000个随机数组成的数组:
        long[] array = new long[2000];
        long expectedSum = 0;
        for (int i = 0; i < array.length; i++) {
            array[i] = random();
            expectedSum += array[i];
        }
        System.out.println("Expected sum: " + expectedSum);
        // fork/join:
        ForkJoinTask<Long> task = new SumTask(array, 0, array.length);
        long startTime = System.currentTimeMillis();
        Long result = ForkJoinPool.commonPool().invoke(task);
        long endTime = System.currentTimeMillis();
        System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
    }

    static Random random = new Random(0);

    static long random() {
        return random.nextInt(10000);
    }
}

class SumTask extends RecursiveTask<Long> {
    static final int THRESHOLD = 500;
    long[] array;
    int start;
    int end;

    SumTask(long[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Long compute() {
        if (end - start <= THRESHOLD) {
            // 如果任务足够小,直接计算:
            long sum = 0;
            for (int i = start; i < end; i++) {
                sum += this.array[i];
                // 故意放慢计算速度:
                try {
                    Thread.sleep(1);
                } catch (InterruptedException e) {
                }
            }
            return sum;
        }
        // 任务太大,一分为二:
        int middle = (end + start) / 2;
        System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
        SumTask subtask1 = new SumTask(this.array, start, middle);
        SumTask subtask2 = new SumTask(this.array, middle, end);
        invokeAll(subtask1, subtask2);
        Long subresult1 = subtask1.join();
        Long subresult2 = subtask2.join();
        Long result = subresult1 + subresult2;
        System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
        return result;
    }
}

工作窃取算法(work-stealing)

参考:https://godiscoder.cn/%E6%8E%80%E4%BD%A0%E7%9F%AD%E8%A3%99/article/5d085d64a70b1d53db964e99

ForkJoinPool 之中使用的是工作窃取算法, 指的是某个线程从其他线程之中窃取任务并且执行。其主要是为了效率(一个线程如果执行完自己的任务,可以去接着执行其他任务而不是干等着)。那么如何保证多个线程不相互干扰呢?通常使用双端队列, 窃取任务和被窃取任务的线程可以从队列的两端来拿任务执行。

没认清并发工具的使用场景,导致性能问题

问题

没有搞清楚其原理就盲目的使用,比如 copyOnWriteArrayList,其本身的线程安全是因为在每次修改数据的时候,会先复制一份数据出来,所以其更适用于读多写少的场景。如果对于写比较多的场景,应该使用加锁(甚至这种情况下乐观锁都不是好方法)的方式来控制其读写。

CopyOnWriteArrayList 和普通 ArrayList 加锁的读写性能对比

//测试并发写的性能

@GetMapping("write")

public Map testWrite() {

    List<Integer> copyOnWriteArrayList = new CopyOnWriteArrayList<>();

    List<Integer> synchronizedList = Collections.synchronizedList(new ArrayList<>());

    StopWatch stopWatch = new StopWatch();

    int loopCount = 100000;

    stopWatch.start("Write:copyOnWriteArrayList");

    //循环100000次并发往CopyOnWriteArrayList写入随机元素

    IntStream.rangeClosed(1, loopCount).parallel().forEach(__ -> copyOnWriteArrayList.add(ThreadLocalRandom.current().nextInt(loopCount)));

    stopWatch.stop();

    stopWatch.start("Write:synchronizedList");

    //循环100000次并发往加锁的ArrayList写入随机元素

    IntStream.rangeClosed(1, loopCount).parallel().forEach(__ -> synchronizedList.add(ThreadLocalRandom.current().nextInt(loopCount)));

    stopWatch.stop();

    log.info(stopWatch.prettyPrint());

    Map result = new HashMap();

    result.put("copyOnWriteArrayList", copyOnWriteArrayList.size());

    result.put("synchronizedList", synchronizedList.size());

    return result;

}

//帮助方法用来填充List

private void addAll(List<Integer> list) {

    list.addAll(IntStream.rangeClosed(1, 1000000).boxed().collect(Collectors.toList()));

}

//测试并发读的性能

@GetMapping("read")

public Map testRead() {

    //创建两个测试对象

    List<Integer> copyOnWriteArrayList = new CopyOnWriteArrayList<>();

    List<Integer> synchronizedList = Collections.synchronizedList(new ArrayList<>());

    //填充数据   

    addAll(copyOnWriteArrayList);

    addAll(synchronizedList);

    StopWatch stopWatch = new StopWatch();

    int loopCount = 1000000;

    int count = copyOnWriteArrayList.size();

    stopWatch.start("Read:copyOnWriteArrayList");

    //循环1000000次并发从CopyOnWriteArrayList随机查询元素

    IntStream.rangeClosed(1, loopCount).parallel().forEach(__ -> copyOnWriteArrayList.get(ThreadLocalRandom.current().nextInt(count)));

    stopWatch.stop();

    stopWatch.start("Read:synchronizedList");

    //循环1000000次并发从加锁的ArrayList随机查询元素

    IntStream.range(0, loopCount).parallel().forEach(__ -> synchronizedList.get(ThreadLocalRandom.current().nextInt(count)));

    stopWatch.stop();

    log.info(stopWatch.prettyPrint());

    Map result = new HashMap();

    result.put("copyOnWriteArrayList", copyOnWriteArrayList.size());

    result.put("synchronizedList", synchronizedList.size());

    return result;

}

上面是十万次写操作,下面是一百万次读操作,区别如下:

十万次 add:

img

一百万次读:

img

个人后记

为什么一般 ThreadLocal 要用 static 修饰?

参考:将ThreadLocal变量设置为private static的好处是啥? - Viscent大千的回答 - 知乎 https://www.zhihu.com/question/35250439/answer/101676937

可以避免重复创建 TSO(thread specific object,也就是 ThreadLocal 所关联的对象)导致的浪费。

一个 HtreadLocal 实例对应当前线程之中的一个 TSO实例,因此,如果将 ThreadLocal 声明为某个累的实例变量,那么每一个实例都会导致一个新的 TSO 被创建。这些被创建的 TSO 是同一个类的实例,因此同一个线程可能访问到不同实例,这样会导致错误