package com.easesource.iot.springbootapps.datacenter.test;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Supplier;

/**
 * singleflight 防止缓存击穿 全部请求都打到 db中导致db被打挂 singleflight 一瞬间只允许一个线程访问数据库
 * <a>golang.org/x/sync/singleflight</a> 实现
 *
 * @author loafer
 */
public class SingleGroup<T> {
    private ReentrantLock mu = new ReentrantLock();
    private Map<String, SingleCall<T>> m = new HashMap<>();

    /**
     * 执行并发请求
     *
     * @param key  key 获取缓存的key
     * @param func func 获取返回的回调
     * @return 缓存值
     * @Exception InterruptedException  CountDownLatch.await() 可能抛出,  Supplier 执行抛出的异常
     */
    public T Do(String key, Supplier<T> func) throws Exception {
        mu.lock();
        if (this.m.containsKey(key)) {
            mu.unlock();
            SingleCall<T> call = this.m.get(key);
            return call.waitAndCall();
        }
        SingleCall<T> newCall = new SingleCall<>();
        this.m.put(key, newCall);
        mu.unlock();
        T result = null;
        // 回调出现异常
        try {
            result = func.get();
        } finally {
            newCall.setVal(result);
            newCall.down();
            mu.lock();
            this.m.remove(key);
            mu.unlock();
        }

        return result;
    }

}

class SingleCall<T> {
    // 给一个默认值 1
    private CountDownLatch wg = new CountDownLatch(1);
    private T val;

    public T waitAndCall() throws InterruptedException {
        wg.await();
        return this.val;
    }

    public void setVal(T val) {
        this.val = val;
    }

    public void down() {
        wg.countDown();
    }

}
