package com.valor.vod.common.web.filter.sso;

import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.valor.vod.api.model.constant.response.HttpCode2;
import com.valor.vod.common.tools.http.HttpConstant;
import com.valor.vod.common.tools.http.HttpParameterTools;
import com.valor.vod.common.tools.http.HttpTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Set;

import static com.google.common.collect.Sets.newHashSet;

/**
 * Created by Frank.Huang on 2016/7/13.
 */

public abstract class AbstractSSOFilter extends OncePerRequestFilter {
    private static final Logger logger = LoggerFactory.getLogger(AbstractSSOFilter.class);
    protected String errorUrl = "";
    protected String excluded = "";

    protected int retcode = HttpCode2.RET_UNAUTHORIZED;
    protected int errcode = HttpCode2.OK;
    protected String message = "Unauthorized";
    protected Set<String> excludedSet = newHashSet();

    public String getErrorUrl() {
        return errorUrl;
    }

    public void setErrorUrl(String errorUrl) {
        this.errorUrl = errorUrl;
    }

    public String getExcluded() {
        return excluded;
    }

    public void setExcluded(String excluded) {
        this.excluded = excluded;
        try {
            excludedSet.addAll(Splitter.on(",").trimResults().splitToList(excluded));
        } catch (Exception e) {
            logger.error("{}", e);
        }
    }

    public void addExcluded(String excluded) {
        if (!Strings.isNullOrEmpty(excluded)) {
            excludedSet.add(excluded);
        }
    }

    public Set<String> getExcludedSet() {
        return excludedSet;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest httpServletRequest,
                                    HttpServletResponse httpServletResponse,
                                    FilterChain filterChain) throws ServletException, IOException {

        boolean isAllowRequest = false;
        String apiPath = HttpTools.getApiPath(httpServletRequest);
        HeaderMapRequestWrapper requestWrapper = new HeaderMapRequestWrapper(httpServletRequest);

        if (isExclude(apiPath)) {
            isAllowRequest = true;
        } else {
            errcode = verify(requestWrapper);
            if (errcode == HttpCode2.OK) {
                isAllowRequest = true;
            }
        }

        if (!isAllowRequest) {
            onError(httpServletRequest, httpServletResponse);
        } else {
            filterChain.doFilter(requestWrapper, httpServletResponse);
        }
    }

    /**
     * 获取token
     *
     * @param httpRequest
     * @return
     */
    public abstract String getToken(HttpServletRequest httpRequest);

    /**
     * 解密token
     *
     * @param token
     * @return
     */
    public abstract byte[] decryptToken(String token);

    /**
     * 获取token对象
     *
     * @param httpRequest
     * @param bytes
     * @return
     */
    public abstract Object getTokenObj(HttpServletRequest httpRequest, byte[] bytes);

    /**
     * 校验token对象是否合法
     *
     * @param httpRequest
     * @param object
     * @return
     */
    public abstract int verifyToken(HeaderMapRequestWrapper httpRequest, Object object);

    /**
     * 出错处理，token检验错误返回
     *
     * @param httpServletRequest
     * @param httpServletResponse
     * @throws ServletException
     * @throws IOException
     */
    public void onError(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws ServletException, IOException {
        if (Strings.isNullOrEmpty(errorUrl)) {
            throw new IllegalArgumentException(new StringBuilder()
                .append(message).append("(").append(retcode).append("-").append(errcode).append(")")
                .toString());
        } else {
            httpServletRequest.setAttribute(HttpConstant.HTTP_REQ_ATTR_RET_CODE, retcode);
            httpServletRequest.setAttribute(HttpConstant.HTTP_REQ_ATTR_ERR_CODE, errcode);
            httpServletRequest.setAttribute(HttpConstant.HTTP_REQ_ATTR_RET_MSG, message);
            httpServletRequest.setAttribute(HttpConstant.HTTP_REQ_ATTR_REFER, HttpTools.getApiPath(httpServletRequest));

            httpServletRequest.getServletContext().getRequestDispatcher(errorUrl).forward(httpServletRequest, httpServletResponse);
        }
    }


    /**
     * 验证token
     *
     * @param httpRequest
     * @return
     */
    public int verify(HeaderMapRequestWrapper httpRequest) {
        String host = HttpTools.getRemoteHost(httpRequest);
        String did = Strings.nullToEmpty(HttpParameterTools.getParameter(httpRequest, HttpConstant.HTTP_ARG_DID));

        String token = getToken(httpRequest);
        if (Strings.isNullOrEmpty(token)) {
            logger.error("==========HOST:[{}] DID:[{}] Invalid token(empty)", host, did);
            errcode = HttpCode2.ERR_AUTH_TOKEN_IS_EMPTY;
            return errcode;
        }

        byte[] bytes = decryptToken(token);
        if (bytes.length == 0) {
            logger.error("==========HOST:[{}] DID:[{}] Invalid token:[{}]", host, did, token);
            errcode = HttpCode2.ERR_AUTH_INVALID_TOKEN;
            return errcode;
        }

        Object tokenObj = null;
        try {
            tokenObj = getTokenObj(httpRequest, bytes);
        } catch (Exception e) {
            logger.error("exception:", e);
        }

        if (tokenObj == null) {
            logger.error("==========HOST:[{}] DID:[{}] Invalid token(convert token object exception):[{}]", host, did, token);
            errcode = HttpCode2.ERR_AUTH_INVALID_TOKEN_TO_OBJECT;
            return errcode;
        }
        return verifyToken(httpRequest, tokenObj);
    }


    /**
     * 检验是否是排除的url
     *
     * @param apiPath
     * @return
     */
    private boolean isExclude(String apiPath) {
        //ping2默认不校验token
        if (apiPath.endsWith("api/ping2")) {
            return true;
        }
        if (apiPath.endsWith("api/ping/v1")) {
            return true;
        }

        //ns 默认不校验token
        if (apiPath.endsWith("api/ns")) {
            return true;
        }
        //map会失效
        if (apiPath.endsWith("loadPlaylistConfigs/v1")) {
            logger.info("ignore loadPlaylistConfigs authorize");
            return true;
        }

        if (apiPath.endsWith("api/uac/activate/v1")) {
            logger.info("api/uac/activate/v1");
            return true;
        }
        if (apiPath.endsWith("loadRedisData/v1")) {
            logger.info("ignore loadRedisData authorize");
            return true;
        }
        if (apiPath.endsWith("delOldData/v1")) {
            logger.info("ignore delOldData authorize");
            return true;
        }
        if (apiPath.endsWith("rating/add/v2") || apiPath.endsWith("rating/query")) {
            logger.info("ignore auth rating authorize");
            return true;
        }

        // vod的配置黑灰名单管理台的接口
        if (apiPath.startsWith("/api/cloudAcctBg")) {
            return true;
        }

        for (String e : getExcludedSet()) {
            if (isExclude(e, apiPath)) {
                return true;
            }
        }

        return false;
    }

    /**
     * 比较是否是排除的url
     *
     * @param excludedPath
     * @param apiPath
     * @return
     */
    private boolean isExclude(String excludedPath, String apiPath) {
        if (excludedPath.endsWith("*")) {
            int pos = excludedPath.indexOf("*");
            if (pos > 0) {
                String excludePrefix = excludedPath.substring(0, pos);
                if (apiPath.startsWith(excludePrefix)) {
                    return true;
                }
            }
        } else {
            if (excludedPath.equalsIgnoreCase(apiPath)) {
                return true;
            }
        }

        return false;
    }

}
