/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.search.aggregations.metrics.PercentilesAggregationBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class AucRoc
implements EvaluationMetric {
    public static final ParseField NAME = new ParseField("auc_roc", new String[0]);
    public static final ParseField INCLUDE_CURVE = new ParseField("include_curve", new String[0]);
    public static final ConstructingObjectParser<AucRoc, Void> PARSER = new ConstructingObjectParser(NAME.getPreferredName(), a -> new AucRoc((Boolean)a[0]));
    private static final String PERCENTILES = "percentiles";
    private static final String TRUE_AGG_NAME;
    private static final String NON_TRUE_AGG_NAME;
    private final boolean includeCurve;
    private EvaluationMetricResult result;

    public static AucRoc fromXContent(XContentParser parser) {
        return (AucRoc)PARSER.apply(parser, null);
    }

    public AucRoc(Boolean includeCurve) {
        this.includeCurve = includeCurve == null ? false : includeCurve;
    }

    public AucRoc(StreamInput in) throws IOException {
        this.includeCurve = in.readBoolean();
    }

    public String getWriteableName() {
        return MlEvaluationNamedXContentProvider.registeredMetricName(BinarySoftClassification.NAME, NAME);
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeBoolean(this.includeCurve);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(INCLUDE_CURVE.getPreferredName(), this.includeCurve);
        builder.endObject();
        return builder;
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        AucRoc that = (AucRoc)o;
        return Objects.equals(this.includeCurve, that.includeCurve);
    }

    public int hashCode() {
        return Objects.hash(this.includeCurve);
    }

    @Override
    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters, String actualField, String predictedProbabilityField) {
        if (this.result != null) {
            return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
        }
        double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> v).toArray();
        AbstractAggregationBuilder percentilesForClassValueAgg = AggregationBuilders.filter((String)TRUE_AGG_NAME, (QueryBuilder)BinarySoftClassification.actualIsTrueQuery(actualField)).subAggregation((AggregationBuilder)((PercentilesAggregationBuilder)AggregationBuilders.percentiles((String)PERCENTILES).field(predictedProbabilityField)).percentiles(percentiles));
        AbstractAggregationBuilder percentilesForRestAgg = AggregationBuilders.filter((String)NON_TRUE_AGG_NAME, (QueryBuilder)QueryBuilders.boolQuery().mustNot(BinarySoftClassification.actualIsTrueQuery(actualField))).subAggregation((AggregationBuilder)((PercentilesAggregationBuilder)AggregationBuilders.percentiles((String)PERCENTILES).field(predictedProbabilityField)).percentiles(percentiles));
        return Tuple.tuple(Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg), Collections.emptyList());
    }

    @Override
    public void process(Aggregations aggs) {
        Filter classAgg = (Filter)aggs.get(TRUE_AGG_NAME);
        Filter restAgg = (Filter)aggs.get(NON_TRUE_AGG_NAME);
        double[] tpPercentiles = AucRoc.percentilesArray((Percentiles)classAgg.getAggregations().get(PERCENTILES), "[" + this.getName() + "] requires at least one actual_field to have the value [true]");
        double[] fpPercentiles = AucRoc.percentilesArray((Percentiles)restAgg.getAggregations().get(PERCENTILES), "[" + this.getName() + "] requires at least one actual_field to have a different value than [true]");
        List<AucRocPoint> aucRocCurve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
        double aucRocScore = AucRoc.calculateAucScore(aucRocCurve);
        this.result = new Result(aucRocScore, this.includeCurve ? aucRocCurve : Collections.emptyList());
    }

    public Optional<EvaluationMetricResult> getResult() {
        return Optional.ofNullable(this.result);
    }

    private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) {
        double[] result = new double[99];
        percentiles.forEach(percentile -> {
            if (Double.isNaN(percentile.getValue())) {
                throw ExceptionsHelper.badRequestException(errorIfUndefined, new Object[0]);
            }
            result[(int)percentile.getPercent() - 1] = percentile.getValue();
        });
        return result;
    }

    static List<AucRocPoint> buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) {
        assert (tpPercentiles.length == fpPercentiles.length);
        assert (tpPercentiles.length == 99);
        ArrayList<AucRocPoint> aucRocCurve = new ArrayList<AucRocPoint>();
        aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
        aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
        RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true);
        RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false);
        aucRocCurve.addAll(tpCurve.scanPoints(fpCurve));
        aucRocCurve.addAll(fpCurve.scanPoints(tpCurve));
        Collections.sort(aucRocCurve);
        return aucRocCurve;
    }

    static double calculateAucScore(List<AucRocPoint> rocCurve) {
        double aucRoc = 0.0;
        for (int i = 1; i < rocCurve.size(); ++i) {
            AucRocPoint left = rocCurve.get(i - 1);
            AucRocPoint right = rocCurve.get(i);
            aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2.0;
        }
        return aucRoc;
    }

    private static double interpolate(double x, double x1, double y1, double x2, double y2) {
        return y1 + (x - x1) * (y2 - y1) / (x2 - x1);
    }

    static {
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE);
        TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
        NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
    }

    public static class Result
    implements EvaluationMetricResult {
        private final double score;
        private final List<AucRocPoint> curve;

        public Result(double score, List<AucRocPoint> curve) {
            this.score = score;
            this.curve = Objects.requireNonNull(curve);
        }

        public Result(StreamInput in) throws IOException {
            this.score = in.readDouble();
            this.curve = in.readList(x$0 -> new AucRocPoint(x$0));
        }

        public String getWriteableName() {
            return MlEvaluationNamedXContentProvider.registeredMetricName(BinarySoftClassification.NAME, NAME);
        }

        @Override
        public String getMetricName() {
            return NAME.getPreferredName();
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeDouble(this.score);
            out.writeList(this.curve);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field("score", this.score);
            if (!this.curve.isEmpty()) {
                builder.field("curve", this.curve);
            }
            builder.endObject();
            return builder;
        }
    }

    public static final class AucRocPoint
    implements Comparable<AucRocPoint>,
    ToXContentObject,
    Writeable {
        double tpr;
        double fpr;
        double threshold;

        private AucRocPoint(double tpr, double fpr, double threshold) {
            this.tpr = tpr;
            this.fpr = fpr;
            this.threshold = threshold;
        }

        private AucRocPoint(StreamInput in) throws IOException {
            this.tpr = in.readDouble();
            this.fpr = in.readDouble();
            this.threshold = in.readDouble();
        }

        @Override
        public int compareTo(AucRocPoint o) {
            return Comparator.comparingDouble(p -> p.threshold).reversed().thenComparing(p -> p.fpr).thenComparing(p -> p.tpr).compare(this, o);
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeDouble(this.tpr);
            out.writeDouble(this.fpr);
            out.writeDouble(this.threshold);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field("tpr", this.tpr);
            builder.field("fpr", this.fpr);
            builder.field("threshold", this.threshold);
            builder.endObject();
            return builder;
        }

        public String toString() {
            return Strings.toString((ToXContent)this);
        }
    }

    private static class RateThresholdCurve {
        private final double[] percentiles;
        private final boolean isTp;

        private RateThresholdCurve(double[] percentiles, boolean isTp) {
            this.percentiles = percentiles;
            this.isTp = isTp;
        }

        private double getRate(int index) {
            return 1.0 - 0.01 * (double)(index + 1);
        }

        private double getThreshold(int index) {
            return this.percentiles[index];
        }

        private double interpolateRate(double threshold) {
            int binarySearchResult = Arrays.binarySearch(this.percentiles, threshold);
            if (binarySearchResult >= 0) {
                return this.getRate(binarySearchResult);
            }
            int right = binarySearchResult * -1 - 1;
            int left = right - 1;
            if (right >= this.percentiles.length) {
                return 0.0;
            }
            if (left < 0) {
                return 1.0;
            }
            double rightRate = this.getRate(right);
            double leftRate = this.getRate(left);
            return AucRoc.interpolate(threshold, this.percentiles[left], leftRate, this.percentiles[right], rightRate);
        }

        private List<AucRocPoint> scanPoints(RateThresholdCurve againstCurve) {
            ArrayList<AucRocPoint> points = new ArrayList<AucRocPoint>();
            for (int index = 0; index < this.percentiles.length; ++index) {
                double rate = this.getRate(index);
                double scannedThreshold = this.getThreshold(index);
                double againstRate = againstCurve.interpolateRate(scannedThreshold);
                AucRocPoint point = this.isTp ? new AucRocPoint(rate, againstRate, scannedThreshold) : new AucRocPoint(againstRate, rate, scannedThreshold);
                points.add(point);
            }
            return points;
        }
    }
}

