• 微服务 »
  • redis+lua实现Token Bucket限流

redis+lua实现Token Bucket限流

在做网关限流时,经过研究后决定使用redis+lua实现Token Bucket算法做分布式限流。

Token Bucket算法要求匀速放入令牌,用定时器可以实现。但是网关的来源IP太多,桶很多,需要很多线程加定时器才能实时处理得过来,不是一个好办法。
换个思路,由于生产token是匀速的,且取token是有时间差的,可以利用两次取token的时间差计算出间隔间应该产生的token数量。原有token+新产生token数-要获取的token数,就是桶里剩余的token。因此对于一个桶,要记录剩余token数,上次取token时间戳。获取token是要传入当前时间戳,要获取token数量(默认1)。在修改桶的两个信息时,需要是原子操作,而且获取令牌过程中带有计算逻辑。
redis是以单线程串行处理请求,在处理EVAL命令时是一个原子操作。因此使用redis+lua脚本刚好可以满足上述要求。

脚本如下

-- bucket name
local key = KEYS[1]
-- token generate interval
local intervalPerPermit = tonumber(ARGV[1])
-- grant timestamp
local refillTime = tonumber(ARGV[2])
-- limit token count
local limit = tonumber(ARGV[3])
-- ratelimit time period
local interval = tonumber(ARGV[4])

local counter = redis.call('hgetall', key)

if table.getn(counter) == 0 then
    -- first check if bucket not exists, if yes, create a new one with full capacity, then grant access
    redis.call('hmset', key, 'lastRefillTime', refillTime, 'tokensRemaining', limit - 1)
    -- expire will save memory
    redis.call('expire', key, interval)
    return 1
elseif table.getn(counter) == 4 then
    -- if bucket exists, first we try to refill the token bucket
    local lastRefillTime, tokensRemaining = tonumber(counter[2]), tonumber(counter[4])
    local currentTokens
    if refillTime > lastRefillTime then
        -- check if refillTime larger than lastRefillTime.
        -- if not, it means some other operation later than this call made the call first.
        -- there is no need to refill the tokens.
        local intervalSinceLast = refillTime - lastRefillTime
        if intervalSinceLast > interval then
            currentTokens = limit
            redis.call('hset', key, 'lastRefillTime', refillTime)
        else
            local grantedTokens = math.floor(intervalSinceLast / intervalPerPermit)
            if grantedTokens > 0 then
                -- ajust lastRefillTime, we want shift left the refill time.
                local padMillis = math.fmod(intervalSinceLast, intervalPerPermit)
                redis.call('hset', key, 'lastRefillTime', refillTime - padMillis)
            end
            currentTokens = math.min(grantedTokens + tokensRemaining, limit)
        end
    else
        -- if not, it means some other operation later than this call made the call first.
        -- there is no need to refill the tokens.
        currentTokens = tokensRemaining
    end
    
    assert(currentTokens >= 0)

    if currentTokens == 0 then
        -- we didn't consume any keys
        redis.call('hset', key, 'tokensRemaining', currentTokens)
        return 0
    else
        -- we take 1 token from the bucket
        redis.call('hset', key, 'tokensRemaining', currentTokens - 1)
        return 1
    end
else
    error("Size of counter is " .. table.getn(counter) .. ", Should Be 0 or 4.")
end

调用这个脚本需要4个参数,其中key是桶的名字,intervalPerPermit是产生令牌的时间间隔,refillTime是调用脚本的即时时间戳,limit是最大间隔访问次数,interval桶的限流间隔。key生产规则可以是”ratelimit:$IP:$PATH:$INTERVAL”。当token不存在时,默认创建一个带有最大token数量的桶。当超过interval都没有访问桶时,对该通的缓存会过期,减少内存使用量。返回1代表成功获取一个令牌,返回0代表未获取到令牌(此次访问应该被限制)。

假设要限制某个IP对某个接口每秒钟最多访问2次,则
interval = 1000(毫秒)
limit = 2
intervalPerPermit = interval/limit = 500
key=”ratelimit:127.0.0.1:/user/get:1000″

命令行调用

redis-cli -h localhost -p 6379 -a wingyiu --ldb --eval ~/Git/api-gateway/src/main/resources/ratelimit_token_bucket.lua ratelimit:127.0.0.1:/user/get:1000 , 500 1520402606803 2 1000

网关使用spring cloud netflix zuul,可以编写一个filter,作为限流功能的实现,调用redis,执行上述Lua脚本

package com.yunzhijia.platform.api.gateway.zuul.filter;

import com.google.common.net.InetAddresses;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import com.xxx.gateway.vo.SimpleRouteVo;
import com.xxx.gateway.common.constants.CacheConstants;
import com.xxx.gateway.common.constants.ErrorMsgConstants;
import com.xxx.gateway.exception.GatewayException;
import com.xxx.gateway.utils.RedisTemplateUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import javax.servlet.http.HttpServletRequest;
import java.util.Collections;

/**
 * 目前仅实现IP+URI的限流
 * 基于令牌桶算法
 * 借助于redis+lua实现对令牌桶的原子操作
 */
@Component
public class RateLimitPreFilter extends ZuulFilter {

	protected static final Logger LOG = LoggerFactory.getLogger(RateLimitPreFilter.class);

	@Override
	public String filterType() {
		return "pre";
	}

	@Override
	public int filterOrder() {
		return 4;
	}

	@Override
	public boolean shouldFilter() {
		return true;
	}

	@Autowired
	RedisScript redisScript;

	@Override
	public Object run() {
		RequestContext ctx = RequestContext.getCurrentContext();
		HttpServletRequest request = ctx.getRequest();

		String uri = request.getRequestURI();
		LOG.info("开始限流检查:{}", uri);
		// 获取ip
		String ip = request.getHeader("X-Forwarded-For");
		if (StringUtils.isBlank(ip)) {
			ip = request.getRemoteAddr();
		}
		ip = ip.split(",")[0];
		LOG.info("remote ip: {}",  ip);

		if (isSiteLocalAddress(ip)) {
			LOG.info("私有地址跳过: {}", ip);
			return null;
		}

		SimpleRouteVo routeVo = (SimpleRouteVo) ctx.get("methodDefinition");

		if (routeVo.getReqPerSec() <= 0 && routeVo.getReqPerMin() <= 0) {
			// 不需要限流
			LOG.info("接口:{} 不限流", uri);
			return null;
		}

		if (routeVo.getReqPerSec() > 0) {
			boolean canAccess = access(ip, uri, 1000, routeVo.getReqPerSec());
			if (!canAccess) {
				LOG.info("被限流: ip={}, uri={}, interval=sec, limit={}", ip, uri, routeVo.getReqPerSec());
				ctx.setSendZuulResponse(false);
				// raise exception to jump to error filter
				throw new GatewayException(ErrorMsgConstants.TOO_MANY_REQUESTS);
			}
		}

		if (routeVo.getReqPerMin() > 0) {
			boolean canAccess = access(ip, uri, 60000, routeVo.getReqPerMin());
			if (!canAccess) {
				LOG.info("被限流: ip={}, uri={}, interval=min, limit={}", ip, uri, routeVo.getReqPerMin());
				ctx.setSendZuulResponse(false);
				// raise exception to jump to error filter
				throw new GatewayException(ErrorMsgConstants.TOO_MANY_REQUESTS);
			}
		}

		return null;
	}

	private boolean isSiteLocalAddress(String ip) {
		// refer to RFC 1918
		// 10/8 prefix
		// 172.16/12 prefix
		// 192.168/16 prefix
		int address = InetAddresses.coerceToInteger(InetAddresses.forString(ip));
		return (((address >>> 24) & 0xFF) == 10)
				|| ((((address >>> 24) & 0xFF) == 172)
				&& ((address >>> 16) & 0xFF) >= 16
				&& ((address >>> 16) & 0xFF) <= 31)
				|| ((((address >>> 24) & 0xFF) == 192)
				&& (((address >>> 16) & 0xFF) == 168));
	}

	public boolean access(String ip, String uri, long intervalInMills, long limit) {
		LOG.info(redisScript.getSha1());
		String key = genKey(ip, uri, intervalInMills, limit);
		key = CacheConstants.genKey(key);
		long intervalPerPermit = intervalInMills / limit;

		try {
			RedisTemplate redisTemplate = RedisTemplateUtils.getRedisTemplate();
			Long refillTime = System.currentTimeMillis();
			LOG.info("调用redis执行lua脚本, {} {} {} {} {}", key, String.valueOf(intervalPerPermit), String.valueOf(refillTime),
					String.valueOf(limit), String.valueOf(intervalInMills));
			Long res = (Long)redisTemplate.execute(redisScript, Collections.singletonList(key),
					String.valueOf(intervalPerPermit), String.valueOf(System.currentTimeMillis()),
					//String.valueOf(limit),
					String.valueOf(limit), String.valueOf(intervalInMills));
			LOG.info("调用redis执行lua脚本:{}", res);
			return res == 1L ? true : false;

		} catch (Exception e) {
			LOG.error("调用redis执行lua脚本出错", e);
			return true; // 内部异常,直接放过吧,保证能访问,不计较太多
		}
	}

	private String genKey(String ip, String uri, long intervalInMills, long limit) {
		return String.format("ratelimit:%s:%s:%s", ip, uri, intervalInMills);
	}
}

实现了对[IP,秒,接口],[IP,分钟,接口]两种限流,具体限流大小由管理后台配置。

留个问题:
Redis Cluster模式下,上述代码能正常工作吗?

参考:
https://zhuanlan.zhihu.com/p/20872901
https://github.com/YigWoo/toys/tree/master/src/main/java/com/yichao/woo/ratelimiter

This entry was posted in 微服务