package common.base.tools.limiter;

import com.google.common.cache.Cache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.atomic.AtomicLong;

public class RateCountAndBlockLimiter extends RateCountLimiter {
    private static final Logger logger = LoggerFactory.getLogger(RateCountAndBlockLimiter.class);
    private Cache<String, RateLimiterContinuousStat> lastOOLCache = null;
    private RateBlockLimiter blockLimiter = null;
    private boolean isSupportBlock = false;
    private int blockLimit = 0;

    public RateCountAndBlockLimiter(String limiterName,
                                    int timeUnitSeconds, int countLimit,
                                    int blockTimeSeconds, int blockLimit) {
        super(limiterName + "-COUNT", timeUnitSeconds, countLimit);
        this.blockLimit = blockLimit;

        if (blockTimeSeconds > 0 && blockLimit > 0) {
            lastOOLCache = RateLimiterCacheTools.createCache(limiterName + "-STAT", (blockLimit + 1) * timeUnitSeconds);
            blockLimiter = new RateBlockLimiter(limiterName + "-BLOCK", blockTimeSeconds);
            isSupportBlock = true;
        }
    }

    /**
     * check current key is allow or not.
     *
     * @param key
     * @return
     */
    public boolean isOutOfLimit(String key) {
        //check key is already block or not
        if (blockLimiter != null) {
            boolean isOutOfLimit = blockLimiter.isOutOfLimit(key);
            if (isOutOfLimit) {
                return true;
            }
        }

        boolean isOutOfLimit = super.isOutOfLimit(key);
        if (isOutOfLimit) {
            if (isSupportBlock) {
                int count = getContinuousCount(key);
                logger.info("count:{} limit:{}", count, blockLimit);
                if (count >= blockLimit) {
                    blockLimiter.addBlock(key, new AtomicLong(0L));
                }
            }
            return true;
        }

        return false;
    }

    public int getContinuousCount(String key) {
        if (!isSupportBlock) {
            return 0;
        }

        //find stat info
        long currentPeriod = RateLimiterTools.getPeriodStartTime(timeUnitSeconds);
        RateLimiterContinuousStat stat = lastOOLCache.getIfPresent(key);
        if (stat != null && stat.getLastCycle() == currentPeriod) {
            return stat.getCount();
        }

        if (stat == null) {
            stat = new RateLimiterContinuousStat(currentPeriod, 1);
        }

        //update stat info,if Continuous add count
        int count = 1;
        if (currentPeriod - stat.getLastCycle() == timeUnitSeconds) {
            count = stat.getCount() + 1;
        }

        stat.setLastCycle(currentPeriod);
        stat.setCount(count);

        lastOOLCache.put(key, stat);
        return stat.getCount();
    }
}
