package com.mm.live.player.catchup.proxy;

import android.annotation.SuppressLint;
import android.os.SystemClock;

import androidx.collection.CircularArray;

import com.mm.live.player.catchup.util.CpsHelper;
import com.mm.live.player.catchup.util.FixedSizeCircularArray;
import com.mm.live.player.catchup.util.PureStringUtils;
import com.stream.core.net.HttpHeaders;
import com.stream.core.proxy.proxycommon.ForkedStream;
import com.stream.core.proxy.proxycommon.StreamResponse;
import com.stream.core.proxy.proxycommon.httpconn.HttpConnectionRepeatController;
import com.stream.mrt.engine.MrtEngineConfig;
import com.stream.mrt.engine.MrtUrlParser;
import com.stream.mrt.engine.MrtValueProvider;
import com.stream.mrt.engine.model.MrtSlice;
import com.stream.tool.log.Logger;
import com.stream.tool.log.LoggerFactory;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;

import rx.Subscriber;
import rx.functions.Func0;

import java.io.*;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.locks.ReentrantLock;

public class CpsHandler {
    public static final String KEY_CACHE_SIZE = "cache_size";
    public static final String KEY_MAX_DOWNLOADER = "max_downloader";
    public static final String KEY_DID = "did";
    public static final String KEY_APP_VER = "app_ver";
    public static final String KEY_PLATFORM = "platform";
    public static final String KEY_PRELOAD_SIZE = "preload_size";
    public final static int TIME_OUT = MrtEngineConfig.DFT_LOGIN_TIMEOUT * MrtEngineConfig.MRT_LOG_IN_TRIES + MrtEngineConfig.DFT_RECEIVE_GROUP_TIMEOUT;

    private Logger mLogger = LoggerFactory.getLogger(CpsHandler.class.getName());
    private boolean mIsShutdown;
    private List<CpsBlock> mCpsBlocks;
    private String mCpsContent;
    private String mM3uContent;
    private boolean mM3uParsed;
    private CircularArray<CacheItem> mCache;
    private int CACHE_SIZE;
    private int PRELOAD_SIZE;
    private FixedSizeCircularArray<String> mTsRequestLog;
    private int mLastBlockSeq = 0;

    private HttpConnectionRepeatController<StreamResponse> mRepeatController = new HttpConnectionRepeatController<>();
    private CpsOnProxyInfoListener mOnProxyInfoListener;

    private ExecutorService mExecutorService = Executors.newCachedThreadPool(new BasicThreadFactory.Builder().namingPattern("CpsDownloader").build());
    private CpsMrtClient mMrtClient;
    private CatchUpStreamProxy.PlayPositionProvider mPlayPositionProvider;

    public CpsHandler(String cpsContent, MrtValueProvider mrtValueProvider, CatchUpStreamProxy.PlayPositionProvider provider) {
        mCpsContent = cpsContent;
        mCpsBlocks = new ArrayList<>();
        mM3uParsed = false;
        CACHE_SIZE = mrtValueProvider.val(KEY_CACHE_SIZE, 16);
        mCache = new CircularArray<>(CACHE_SIZE);
        PRELOAD_SIZE = mrtValueProvider.val(KEY_PRELOAD_SIZE, 3);
        mTsRequestLog = new FixedSizeCircularArray<>(5);
        mMrtClient = new CpsMrtClient(mrtValueProvider.val(KEY_DID, ""), mrtValueProvider.val(KEY_APP_VER, ""), mrtValueProvider.val(KEY_PLATFORM, ""), mrtValueProvider);
        mPlayPositionProvider = provider;
    }

    public StreamResponse buildM3uStreamResponse() throws Exception {
        initCps();
        return new StreamResponse() {
            private String contentType = "application/vnd.apple.mpegurl";

            @Override
            public String getProtocol() {
                return "HTTP/1.1";
            }

            @Override
            public long getContentLength() throws IOException {
                return mM3uContent.getBytes().length;
            }

            @Override
            public int getCode() throws IOException {
                return 200;
            }

            @Override
            public String getMessage() throws IOException {
                return "OK";
            }

            @Override
            public InputStream getInputStream() throws IOException {
                return new ByteArrayInputStream(mM3uContent.getBytes());
            }

            @Override
            public Map<String, List<String>> getHeaders() {
                Map<String, List<String>> headers = new HashMap<>();
                headers.put(HttpHeaders.CONTENT_TYPE, Collections.singletonList(contentType));
                headers.put(HttpHeaders.CONTENT_LENGTH, Collections.singletonList(String.valueOf(mM3uContent.getBytes().length)));
                // headers.put(HttpHeaders.CONNECTION, Arrays.asList("keep-alive"));
                headers.put(HttpHeaders.CONNECTION, Collections.singletonList("close"));
                return headers;
            }

            @Override
            public String getContentType() {
                return contentType;
            }

            @Override
            public boolean isChunked() {
                return false;
            }

            @Override
            public void close() throws IOException {
            }
        };
    }

    private synchronized void initCps() throws Exception {
        if (mM3uParsed) return;
        mCpsBlocks.clear();
        InputStream cpsInputStream = new ByteArrayInputStream(mCpsContent.getBytes());
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(cpsInputStream));
        StringBuilder sb = new StringBuilder();
        sb.append("#EXTM3U\n")
                .append("#EXT-X-VERSION:3\n")
                .append("#EXT-X-MEDIA-SEQUENCE:0\n")
                .append("#EXT-X-PLAYLIST-TYPE:VOD\n")
                .append("#EXT-X-INDEPENDENT-SEGMENTS\n");
        String line;
        CpsBlock cpsBlock = null;
        int index = 0;
        boolean inNewBlockMatching = false;
        String extinfStr = "";
        while ((line = bufferedReader.readLine()) != null) {
            line = StringUtils.trimToEmpty(line);
            if (line.length() == 0)
                continue;
            if (line.startsWith("#EXT-X-TARGETDURATION"))
                sb.append(line).append("\n");
            else if (line.startsWith("#EXTINF")) {
                if (inNewBlockMatching) {
                    mLogger.warn("still in new block matching, but found new #EXTINF, ignored previous one!");
                }
                long duration = CpsHelper.getCpsDuration(line);
                if (duration <= 0) {
                    mLogger.warn("get duration 0 from line [%s], ignored!", line);
                } else {
                    inNewBlockMatching = true;
                    cpsBlock = new CpsBlock();
                    cpsBlock.setDuration(duration);
                    extinfStr = line;
                }
            } else if (line.startsWith("mrt://")) {
                if (inNewBlockMatching) {
                    cpsBlock.setUrl(line);
                    cpsBlock.setIndex(index);
                    cpsBlock.setMrtUrlInfo(MrtUrlParser.parse(line));
                    mCpsBlocks.add(cpsBlock);
                    inNewBlockMatching = false;
                    sb.append(extinfStr).append("\n").append(index).append(".ts\n");
                    index++;
                } else {
                    mLogger.warn("found new mrt line without #EXTINF, ignored!");
                }
            }
        }
        sb.append("#EXT-X-ENDLIST\n");
        mM3uContent = sb.toString();
        mM3uParsed = true;
    }

    public void setOnProxyInfoListener(CpsOnProxyInfoListener l) {
        mOnProxyInfoListener = l;
        mRepeatController.setRequestFailedHandlerForUnavailableWhenException(l);
        mMrtClient.setOnProxyInfoListener(l);
    }

    @SuppressLint("SimpleDateFormat")
    public StreamResponse getTsStreamResponse(final int sessionId,
                                              final String targetUrl,
                                              final long httpRequestBegin,
                                              final long httpRequestEnd,
                                              final ForkedStream forkedStream,
                                              final Func0<Boolean> isDisconnectedFunc) throws IOException {
        int blockSeq = CpsHelper.getCpsBlockSeq(targetUrl);
        mTsRequestLog.add(blockSeq + ".ts@" + new SimpleDateFormat("mm:ss").format(new Date()));
        // addCpsInfo(sessionId, "request_ts", String.format(Locale.ENGLISH, "TS Request: %s [%d] - [%d]", targetUrl, httpRequestBegin, httpRequestEnd));
        addCpsInfo(sessionId, "request_ts", getTsRequestDesc());
        HttpConnectionRepeatController.RequestCtrlInfo requestCtrlInfo = new HttpConnectionRepeatController.RequestCtrlInfo(targetUrl, httpRequestBegin, httpRequestEnd);
        return mRepeatController.exec(
                requestCtrlInfo,
                HttpConnectionRepeatController.CtrlLevel.REPORT_REPEAT_ONLY,
                createExecutor(sessionId, targetUrl, httpRequestBegin, httpRequestEnd, forkedStream, isDisconnectedFunc));
    }

    private HttpConnectionRepeatController.Executor<StreamResponse> createExecutor(final int sessionId,
                                                                                   final String targetUrl,
                                                                                   final long httpRequestBegin,
                                                                                   final long httpRequestEnd,
                                                                                   final ForkedStream forkedStream,
                                                                                   final Func0<Boolean> isDisconnectedFunc) {
        return new HttpConnectionRepeatController.Executor<StreamResponse>() {

            @Override
            public StreamResponse exec() throws IOException {
                return createTsStreamResponse(sessionId, targetUrl, httpRequestBegin, httpRequestEnd, forkedStream);
            }

            @Override
            public boolean isDisconnected() throws IOException {
                return isDisconnectedFunc.call();
            }

            @Override
            public int getCode(StreamResponse streamResponse) throws IOException {
                return streamResponse.getCode();
            }

            @Override
            public String getMsg(StreamResponse streamResponse) throws IOException {
                return streamResponse.getMessage();
            }
        };
    }

    private ReentrantLock mPostDownloadLock = new ReentrantLock();

    private StreamResponse createTsStreamResponse(final int sessionId,
                                                  final String targetUrl,
                                                  final long httpRequestBegin,
                                                  final long httpRequestEnd,
                                                  final ForkedStream forkedStream) throws IOException {
        final int blockSeq = CpsHelper.getCpsBlockSeq(targetUrl);
        if (blockSeq < 0 || blockSeq >= mCpsBlocks.size()) {
            mLogger.error("request url[%s] out of block index", targetUrl);
            throw new IOException("block invalid!");
        }
        fireProxyInfo();
        // check if exits in cache
        mPostDownloadLock.lock(); // todo if need lock here?
        final CacheItem[] cacheItem = {getFromCache(blockSeq)};
        try {
            if (cacheItem[0] != null) {
                mLogger.info("block[%d] found in cache", blockSeq);
                addCpsInfo(sessionId, "process_ts", String.format(Locale.ENGLISH, "cache hit block[%d]", blockSeq));
                // process preload parts here
                processPreload(sessionId, blockSeq, forkedStream);
            } else {
                mLogger.info("block[%d] not in cache, begin download", blockSeq);
                if (mMrtClient.isShutdown()) {
                    mLogger.error("mrt client is shutdown, can't serve any request!");
                    throw new IOException("mrt client is shutdown");
                }
                final CpsBlock cpsBlock = mCpsBlocks.get(blockSeq);
                final CountDownLatch countDownLatch = new CountDownLatch(1);
                if (!cpsBlock.isInProcess()) {
                    addCpsInfo(sessionId, "process_ts", String.format(Locale.ENGLISH, "not in progress commit, block[%d]", blockSeq));
                    cpsBlock.setInProcess(true);
                    mExecutorService.submit(new Runnable() {
                        @Override
                        public void run() {
                            try {
                                mMrtClient.download(sessionId, System.currentTimeMillis(), cpsBlock, forkedStream, new Subscriber<MrtSlice[]>() {

                                    @Override
                                    public void onCompleted() {
                                        // ignored
                                    }

                                    @Override
                                    public void onError(Throwable e) {
                                        // download failed
                                        cpsBlock.setInProcess(false);
                                        mLogger.error(e, "download error block[%d]", blockSeq);
                                        addCpsInfo(sessionId, "download_ts", String.format(Locale.ENGLISH, "download error, block[%d]", blockSeq));
                                        countDownLatch.countDown();
                                    }

                                    @Override
                                    public void onNext(MrtSlice[] mrtSlice) {
                                        // download success
                                        cpsBlock.setInProcess(false);
                                        cacheItem[0] = new CacheItem(blockSeq, mrtSlice);
                                        saveToCache(cacheItem[0]);
                                        addCpsInfo(sessionId, "download_ts", String.format(Locale.ENGLISH, "download complete, block[%d]", blockSeq));
                                        countDownLatch.countDown();
                                        mLogger.info("cache updated: current [%s]", getCacheDesc());
                                        addCpsInfo(sessionId, "ts_cache", "Cache Info:\n" + getCacheDesc());
                                    }
                                });
                            } catch (Exception e) {
                                mLogger.error("download commit error block[%d]", blockSeq);
                                cpsBlock.setInProcess(false);
                                addCpsInfo(sessionId, "download_ts", String.format(Locale.ENGLISH, "post download task error, block[%d]", blockSeq));
                                countDownLatch.countDown();
                            }
                        }
                    });
                } else {
                    mLogger.info("block[%d] is in process already, wait for finish!", blockSeq);
                    addCpsInfo(sessionId, "process_ts", String.format(Locale.ENGLISH, "in progress wait, block[%d]", blockSeq));
                    mExecutorService.submit(new Runnable() {
                        @Override
                        public void run() {
                            while (cpsBlock.isInProcess()) {
                                SystemClock.sleep(5);
                            }
                            cacheItem[0] = getFromCache(blockSeq);
                            addCpsInfo(sessionId, "download_ts", String.format(Locale.ENGLISH, "in progress wait finished, block[%d]", blockSeq));
                            mLogger.info("wait block[%d] process finished, block download success[%s]", blockSeq, cacheItem[0] != null);
                            countDownLatch.countDown();
                        }
                    });
                }
                // process preload parts here
                processPreload(sessionId, blockSeq, forkedStream);
                boolean isTimeout;
                try {
                    isTimeout = !countDownLatch.await(TIME_OUT, TimeUnit.MILLISECONDS);
                } catch (Exception e) {
                    throw new IOException(e);
                }
                if (isTimeout) { // do what after timeout?
                    mLogger.debug("wait download timeout block[%d]", blockSeq);
                    addCpsInfo(sessionId, "process_ts", String.format(Locale.ENGLISH, "wait download timeout, block[%d]", blockSeq));
                    throw new IOException(new TimeoutException());
                }
                if (cacheItem[0] == null)
                    throw new IOException(String.format(Locale.ENGLISH, "session[%d] download block[%d] error", sessionId, blockSeq));
            }
        } finally {
            mPostDownloadLock.unlock();
            mLastBlockSeq = blockSeq;
        }
        addCpsInfo(sessionId, "process_ts", String.format(Locale.ENGLISH, "process complete, feed the client, block[%d]", blockSeq));
        final CacheItem finalCacheItem = cacheItem[0];
        return new StreamResponse() {
            private String contentType = "video/mpegts";
            private long totalLength = finalCacheItem.dataSize();
            private long contentLength = totalLength - Math.max(httpRequestBegin, 0);

            @Override
            public String getProtocol() {
                return "HTTP/1.1";
            }

            @Override
            public long getContentLength() throws IOException {
                return contentLength;
            }

            @Override
            public int getCode() throws IOException {
                /*if (blockSeq == mCpsBlocks.size() - 1) return 200;
                return 206;*/
                if (httpRequestBegin > 0) return 206;
                return 200;
            }

            @Override
            public String getMessage() throws IOException {
                /*if (blockSeq == mCpsBlocks.size() - 1) return "OK";
                return "Partial Content";*/
                return "OK";
            }

            @Override
            public InputStream getInputStream() throws IOException {
                return CpsHelper.getInputStreamFromMrtSlice(finalCacheItem.data, httpRequestBegin);
            }

            @Override
            public Map<String, List<String>> getHeaders() {
                Map<String, List<String>> headers = new HashMap<>();
                headers.put(HttpHeaders.CONTENT_TYPE, Collections.singletonList(contentType));
                headers.put(HttpHeaders.CONTENT_LENGTH, Collections.singletonList(String.valueOf(contentLength)));
                //headers.put(HttpHeaders.CONNECTION, Arrays.asList("keep-alive"));
                headers.put(HttpHeaders.CONNECTION, Collections.singletonList("close"));
                headers.put(HttpHeaders.CONTENT_RANGE, Collections.singletonList(String.format(Locale.ENGLISH, "bytes %d-%d/%d", httpRequestBegin, totalLength - 1, totalLength)));
                return headers;
            }

            @Override
            public String getContentType() {
                return contentType;
            }

            @Override
            public boolean isChunked() {
                return false;
            }

            @Override
            public void close() throws IOException {

            }
        };
    }

    private void processPreload(final long sessionId, final int currentSeq, final ForkedStream forkedStream) {
        // cancel old preload
        // forward
        // 在rk3229上由于播放位置与请求位置不准确导致问题--会cancel掉正在请求的block
        if (mPlayPositionProvider.isTrusted()) {
            if (mPlayPositionProvider != null && mPlayPositionProvider.getPlayPosition() > 0) {
                int playBlock = getBlockContainsPosition(mPlayPositionProvider.getPlayPosition());
                if (playBlock > 0) {
                    mLogger.info("try to cancel block before %d", playBlock);
                    mMrtClient.cancelBeforeBlocks(mCpsBlocks.get(playBlock));
                }
            }
            // rewind
            if (currentSeq < mLastBlockSeq) {
                mMrtClient.cancelAfterBlocks(mCpsBlocks.get(currentSeq + PRELOAD_SIZE));
            }
        }
        // add new preload
        for (int i = 1; i <= PRELOAD_SIZE; i++) {
            if (currentSeq + i >= mCpsBlocks.size())
                break;
            final CpsBlock preloadBlock = mCpsBlocks.get(currentSeq + i);
            final int preloadBlockSeq = preloadBlock.getIndex();
            if (!preloadBlock.isInProcess() && getFromCache(preloadBlockSeq) == null) {
                preloadBlock.setInProcess(true);
                mExecutorService.submit(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            mMrtClient.download(sessionId, System.currentTimeMillis(), preloadBlock, forkedStream, new Subscriber<MrtSlice[]>() {
                                @Override
                                public void onCompleted() {
                                    // ignored
                                }

                                @Override
                                public void onError(Throwable e) {
                                    // download failed
                                    preloadBlock.setInProcess(false);
                                    mLogger.error(e, "download error block[%d]", preloadBlockSeq);
                                    addCpsInfo(sessionId, "download_ts", String.format(Locale.ENGLISH, "download error, block[%d]", preloadBlockSeq));
                                }

                                @Override
                                public void onNext(MrtSlice[] mrtSlice) {
                                    // download success
                                    preloadBlock.setInProcess(false);
                                    saveToCache(new CacheItem(preloadBlockSeq, mrtSlice));
                                    addCpsInfo(sessionId, "download_ts", String.format(Locale.ENGLISH, "download complete, block[%d]", preloadBlockSeq));
                                    mLogger.info("cache updated: current [%s]", getCacheDesc());
                                    addCpsInfo(sessionId, "ts_cache", "Cache Info:\n" + getCacheDesc());
                                }
                            });
                        } catch (Exception e) {
                            mLogger.error("download commit error block[%d]", preloadBlockSeq);
                            preloadBlock.setInProcess(false);
                            addCpsInfo(sessionId, "download_ts", String.format(Locale.ENGLISH, "post download task error, block[%d]", preloadBlockSeq));
                        }
                    }
                });
            }
        }
    }

    //++++++++++++++++++++++++++++
    //cache
    private synchronized CacheItem getFromCache(int blockSeq) {
        for (int i = 0; i < mCache.size(); i++) {
            if (mCache.get(i).blockSeq == blockSeq)
                return mCache.get(i);
        }
        return null;
    }

    private synchronized void saveToCache(CacheItem cacheItem) {
        if (mCache.size() >= CACHE_SIZE) {
            mCache.popFirst();
        }
        mCache.addLast(cacheItem);
    }

    private class CacheItem {
        int blockSeq;
        MrtSlice[] data;

        public CacheItem(int blockSeq, MrtSlice[] data) {
            this.blockSeq = blockSeq;
            this.data = data;
        }

        public int dataSize() {
            int size = 0;
            for (MrtSlice mrtSlice : data) {
                size += mrtSlice.data.length;
            }
            return size;
        }
    }

    private String getCacheDesc() {
        StringBuilder sb = new StringBuilder();
        sb.append("size: ").append(mCache.size()).append("; ").append("data:");
        for (int i = 0; i < mCache.size(); i++) {
            CacheItem item = mCache.get(i);
            sb.append(" ").append(item.blockSeq).append("/").append(PureStringUtils.getDisplayBytesSize(item.dataSize())).append(";");
        }
        return sb.toString();

    }
    //----------------------------

    //+++++++++++++++++++++++++++++
    // TODO: 2018/4/11 show the message to front end
    private void fireProxyInfo() {
    }
    //-----------------------------

    //+++++++++++++++++++++++++++++
    //life cycle

    public boolean isShutdown() {
        return mIsShutdown;
    }

    public void shutdown() {
        mIsShutdown = true;
        try {
            if (mCpsBlocks != null) mCpsBlocks.clear();
            if (mCache != null) mCache.clear();
            CpsMrtClient mrtClient = mMrtClient;
            if (mrtClient != null) {
                mLogger.debug("close mrt client begin");
                mrtClient.shutdown();
                mLogger.debug("close mrt client end");
            }
            if (mExecutorService != null) {
                mExecutorService.shutdown();
                mExecutorService = null;
            }
        } catch (Throwable throwable) {
            mLogger.error(throwable, "close cps handler error");
        }
    }
    //-----------------------------


    //++++++++++++++++++++++++++++++
    // logger
    private LogContainer mLogContainer = new LogContainer(CpsOnProxyInfoListener.CPS_HANDLER_INFO, new LogContainer.Logger() {
        @Override
        public void log(int what, int extra1, long extra2, long extra3, String info) {
            if (mOnProxyInfoListener != null) {
                mOnProxyInfoListener.onProxyInfo(what, extra1, extra2, extra3, info);
            }
        }
    });

    private void addCpsInfo(long sessionId, String event, String msg) {
        mLogContainer.addSessionLog(sessionId, event, msg);
    }

    private String getTsRequestDesc() {
        StringBuilder sb = new StringBuilder();
        sb.append("TS Request:\n");
        for (int i = 0; i < mTsRequestLog.size(); i++) {
            sb.append(mTsRequestLog.get(i)).append(" ");
        }
        return sb.toString();
    }
    //------------------------------

    //+++++++++++++++++++++++++++++
    // some interface for others
    public int getBlockContainsPosition(long position) {
        if (mCpsBlocks == null || mCpsBlocks.isEmpty()) return -1;
        long duration = 0L;
        for (int i = 0; i < mCpsBlocks.size(); i++) {
            duration += mCpsBlocks.get(i).getDuration();
            if (duration >= position) return i;
        }
        return -1;
    }
    //-----------------------------
}
