package common.web.tools.filter.token;

import com.google.common.base.Strings;
import com.google.common.collect.Sets;
import common.web.tools.http.HttpTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Set;

import static common.web.tools.filter.token.AuthTokenConstants.*;

/**
 * Created by Frank.Huang on 2016/7/13.
 */

public abstract class AbstractTokenFilter extends OncePerRequestFilter {
    private static final Logger logger = LoggerFactory.getLogger(AbstractTokenFilter.class);
    protected String errorUrl = "";
    protected Set<String> excludedSet = Sets.newHashSet("api/ping2", "api/ns");
    protected int errCode = 0;
    private String host;
    private String uid;

    @Override
    protected void doFilterInternal(HttpServletRequest httpServletRequest,
                                    HttpServletResponse httpServletResponse,
                                    FilterChain filterChain) throws ServletException, IOException {

        boolean isAllowRequest = false;
        String apiPath = HttpTools.getApiPath(httpServletRequest);
        if (isExclude(apiPath)) {
            isAllowRequest = true;
        } else {
            errCode = verify(httpServletRequest);
            if (errCode == 0) {
                isAllowRequest = true;
            }
        }

        if (!isAllowRequest) {
            onError(httpServletRequest, httpServletResponse);
        } else {
            filterChain.doFilter(httpServletRequest, httpServletResponse);
        }
    }


    public int getModuleCode() {
        return TOKEN_MODEL_CODE;
    }

    /**
     * 获取token
     *
     * @param httpRequest
     * @return
     */
    public abstract String getToken(HttpServletRequest httpRequest);

    /**
     * 获取用户ID
     *
     * @param httpRequest
     * @return
     */
    public abstract String getUid(HttpServletRequest httpRequest);

    /**
     * 校验token对象是否合法
     *
     * @param httpRequest
     * @param token
     * @return
     */
    public abstract int verifyToken(HttpServletRequest httpRequest, String token);

    public abstract void errorHandler(HttpServletRequest httpRequest,
                                      HttpServletResponse httpResponse,
                                      int retCode, int errCode);

    /**
     * 出错处理，token检验错误返回
     *
     * @param httpServletRequest
     * @param httpServletResponse
     * @throws ServletException
     * @throws IOException
     */
    public void onError(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws ServletException, IOException {
        if (Strings.isNullOrEmpty(errorUrl)) {
            //直接返回错误报文
            errorHandler(httpServletRequest, httpServletResponse, getModuleCode(), errCode);
        } else {
            // forward error controller
            httpServletRequest.setAttribute(TOKEN_ATTR_ERR_CODE, errCode);
            httpServletRequest.setAttribute(TOKEN_ATTR_ERR_REFER, HttpTools.getApiPath(httpServletRequest));
            httpServletRequest.getServletContext().getRequestDispatcher(errorUrl).forward(httpServletRequest, httpServletResponse);
        }
    }


    /**
     * 验证token
     *
     * @param httpRequest
     * @return
     */
    public int verify(HttpServletRequest httpRequest) {
        setHost(HttpTools.getRemoteHost(httpRequest));
        setUid(Strings.nullToEmpty(getUid(httpRequest)));

        //get token
        String token = getToken(httpRequest);
        if (Strings.isNullOrEmpty(token)) {
            logger.error("[AUTH-TOKEN][{}][{}]ERR_TOKEN_IS_EMPTY", host, uid);
            return ERR_TOKEN_IS_EMPTY;
        }

        //Verify Token
        int ret = verifyToken(httpRequest, token);
        if (ret != 0) {
            logger.error("[AUTH-TOKEN][{}][{}]ERR_TOKEN_VERIFY_FAILED", uid, host);
        }

        return ret;
    }


    public String getErrorUrl() {
        return errorUrl;
    }

    public void setErrorUrl(String errorUrl) {
        this.errorUrl = errorUrl;
    }


    public void addExcluded(String excluded) {
        if (!Strings.isNullOrEmpty(excluded)) {
            excludedSet.add(excluded);
        }
    }

    public Set<String> getExcludedSet() {
        return excludedSet;
    }

    public void setExcludedSet(Set<String> excludedSet) {
        this.excludedSet = excludedSet;
    }

    public String getHost() {
        return host;
    }

    public void setHost(String host) {
        this.host = host;
    }

    public String getUid() {
        return uid;
    }

    public void setUid(String uid) {
        this.uid = uid;
    }

    /**
     * 检验是否是排除的url
     *
     * @param apiPath
     * @return
     */
    private boolean isExclude(String apiPath) {
        for (String e : getExcludedSet()) {
            if (isExclude(e, apiPath)) {
                return true;
            }
        }

        return false;
    }

    /**
     * 比较是否是排除的url
     *
     * @param excludedPath
     * @param apiPath
     * @return
     */
    public boolean isExclude(String excludedPath, String apiPath) {

        //有通配"*",已星号之前的部分为前缀
        int pos = excludedPath.indexOf("*");
        String excludePrefix = excludedPath;
        if (pos >= 0) {
            excludePrefix = excludedPath.substring(0, pos);
        }

        //如果是空串,则表示允许所有
        if (Strings.isNullOrEmpty(excludePrefix)) {
            return true;
        }

        return apiPath.startsWith(excludePrefix);
    }
}
