/*
 * Decompiled with CFR 0.152.
 */
package io.bidmachine.ml;

import ai.catboost.CatBoostError;
import ai.catboost.CatBoostModel;
import ai.catboost.CatBoostPredictions;
import io.bidmachine.config.data.SchemaConfig;
import io.bidmachine.config.exceptions.ConfigException;
import io.bidmachine.config.ml.ModelFeaturesConfig;
import io.bidmachine.config.ml.ModelFeaturesConfigReader;
import io.bidmachine.config.providers.ConfigProvider;
import io.bidmachine.config.providers.SevenZConfigProvider;
import io.bidmachine.config.providers.ZipConfigProvider;
import io.bidmachine.data.DataFrame;
import io.bidmachine.data.FeatureRecord;
import io.bidmachine.data.InMemoryDataFrameReader;
import io.bidmachine.ml.FloorPredictorException;
import io.bidmachine.ml.FloorPredictorProblemType;
import io.bidmachine.ml.FloorSelectionStrategy;
import io.bidmachine.ml.InputData;
import io.bidmachine.ml.MLParams;
import io.bidmachine.ml.MLParamsBuilder;
import io.bidmachine.ml.Postprocessing;
import io.bidmachine.ml.PostprocessingParams;
import io.bidmachine.mutators.AdRequestTimestampsMutator;
import io.bidmachine.mutators.CommonMutator;
import io.bidmachine.mutators.DemandStatsMutator;
import io.bidmachine.mutators.EcpmMutator;
import io.bidmachine.mutators.LookupMutator;
import io.bidmachine.mutators.Mutator;
import io.bidmachine.mutators.MutatorException;
import io.bidmachine.mutators.TimezoneMutator;
import io.bidmachine.mutators.TypeMutator;
import io.bidmachine.utils.CheckedHashMap;
import io.bidmachine.utils.CollectionUtils;
import io.bidmachine.utils.MathUtils;
import io.bidmachine.utils.PathUtils;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.NotImplementedException;

public class FloorPredictor {
    private Map<String, Integer> indexOfFloorByModelAdt;
    private Map<String, CatBoostModel> catBoostModelsAdt;
    private Map<String, ModelFeaturesConfig> modelAdtFeaturesConfigs;
    private List<Mutator> mutators;
    Map<String, float[]> cutsAdt;
    Map<String, double[]> cutsDoubleAdt;
    private MLParams mlParams;
    private HashSet<String> featuresAll;

    public FloorPredictor(String aBundlePath) throws FloorPredictorException {
        this(aBundlePath, MLParamsBuilder.getDefault());
    }

    public FloorPredictor(ConfigProvider aProvider) throws FloorPredictorException {
        this(aProvider, MLParamsBuilder.getDefault());
    }

    public FloorPredictor(String aBundlePath, String aMLParamsJson) throws FloorPredictorException {
        this(aBundlePath, MLParamsBuilder.get(aMLParamsJson));
    }

    public FloorPredictor(ConfigProvider aProvider, String aMLParamsJson) throws FloorPredictorException {
        this(aProvider, MLParamsBuilder.get(aMLParamsJson));
    }

    public FloorPredictor(String aBundlePath, MLParams aMLParams) throws FloorPredictorException {
        if (aBundlePath.endsWith(".7z")) {
            this.init(new SevenZConfigProvider(aBundlePath), aMLParams);
        } else if (aBundlePath.endsWith(".zip")) {
            this.init(new ZipConfigProvider(aBundlePath), aMLParams);
        }
        throw new IllegalArgumentException("BundlePath must be either 7z or zip archive");
    }

    public FloorPredictor(ConfigProvider aProvider, MLParams aMLParams) throws FloorPredictorException {
        this.init(aProvider, aMLParams);
    }

    private static String getAdtFromFilename(String aFileName) {
        int i;
        String nameNoExt = aFileName;
        int lastDotIndex = aFileName.lastIndexOf(46);
        if (lastDotIndex > 0) {
            nameNoExt = aFileName.substring(0, lastDotIndex);
        }
        if ((i = nameNoExt.indexOf("__")) >= 0 && i < nameNoExt.length() - 1) {
            return nameNoExt.substring(i + 2);
        }
        throw new IllegalArgumentException("File name doesn't contain adtype: " + aFileName);
    }

    private void init(ConfigProvider aProvider, MLParams aMLParams) throws FloorPredictorException {
        this.mlParams = aMLParams;
        List<String> cutsFiles = aProvider.listFiles("other", "cuts*.csv");
        this.cutsAdt = new HashMap<String, float[]>();
        this.cutsDoubleAdt = new HashMap<String, double[]>();
        for (String cutsFile : cutsFiles) {
            DataFrame cutsTable;
            String adt = FloorPredictor.getAdtFromFilename(cutsFile);
            try {
                cutsTable = InMemoryDataFrameReader.readCsv(aProvider, cutsFile);
            }
            catch (IOException ex) {
                throw new FloorPredictorException(String.format("bad %s: no resource other/cuts.csv inside", aProvider), FloorPredictorProblemType.BAD_BUNDLE_SOMETHING_MISSING, ex);
            }
            catch (ConfigException ex) {
                throw new FloorPredictorException(String.format("bad %s: corrupted other/cuts.csv", aProvider), FloorPredictorProblemType.BAD_BUNDLE_BAD_RESORCE, ex);
            }
            Object[] cutsValues = cutsTable.getFieldValues("flr");
            float[] cuts = new float[cutsValues.length];
            double[] cutsDouble = new double[cutsValues.length];
            for (int i = 0; i < cutsValues.length; ++i) {
                cuts[i] = ((Float)cutsValues[i]).floatValue();
                cutsDouble[i] = ((Number)cutsValues[i]).doubleValue();
            }
            this.cutsAdt.put(adt, cuts);
            this.cutsDoubleAdt.put(adt, cutsDouble);
        }
        HashSet<String> featuresCat = new HashSet<String>();
        HashSet<String> featuresCont = new HashSet<String>();
        this.catBoostModelsAdt = new CheckedHashMap<String, CatBoostModel>();
        this.modelAdtFeaturesConfigs = new CheckedHashMap<String, ModelFeaturesConfig>();
        Pattern pattern = Pattern.compile("model_(.*)\\.cbm");
        this.indexOfFloorByModelAdt = new CheckedHashMap<String, Integer>();
        for (String path : aProvider.listFiles("models", "*.cbm")) {
            ModelFeaturesConfig featuresConfig;
            CatBoostModel model;
            try {
                model = CatBoostModel.loadModel((InputStream)aProvider.getInputStream(path));
            }
            catch (CatBoostError ex) {
                throw new FloorPredictorException(String.format("Catboost cannot load model from %s", path), FloorPredictorProblemType.CATBOOST_FAILED_TO_LOAD, ex);
            }
            catch (IOException ex) {
                throw new FloorPredictorException(String.format("bad %s: no resource %s inside", aProvider, path), FloorPredictorProblemType.BAD_BUNDLE_SOMETHING_MISSING, ex);
            }
            Matcher m = pattern.matcher(path);
            if (!m.find()) {
                throw new IllegalArgumentException(String.format("model path %s is of bad format", path));
            }
            String modelAdtName = m.group(1);
            this.catBoostModelsAdt.put(modelAdtName, model);
            String infoPath = "info_" + modelAdtName + ".json";
            try {
                featuresConfig = ModelFeaturesConfigReader.read(aProvider, PathUtils.joinPathGeneric("models", infoPath));
            }
            catch (ConfigException ex) {
                throw new FloorPredictorException(String.format("Info file %s is bad", infoPath), FloorPredictorProblemType.BAD_BUNDLE_BAD_RESORCE, ex);
            }
            List<String> contVars = featuresConfig.getContVars();
            if (!contVars.contains("flr") & !"lossPrice_lurl".equals(featuresConfig.getTarget())) {
                throw new IllegalArgumentException(String.format("info_%s.json should contain flr feature", modelAdtName));
            }
            for (int i = 0; i < contVars.size(); ++i) {
                if (!"flr".equals(contVars.get(i))) continue;
                this.indexOfFloorByModelAdt.put(modelAdtName, i);
                break;
            }
            featuresCat.addAll(featuresConfig.getCatVars());
            featuresCont.addAll(featuresConfig.getContVars());
            this.modelAdtFeaturesConfigs.put(modelAdtName, featuresConfig);
        }
        this.featuresAll = CollectionUtils.union(featuresCat, featuresCont);
        if (this.featuresAll.size() != featuresCat.size() + featuresCont.size()) {
            HashSet inter = CollectionUtils.intersect(featuresCat, featuresCont);
            String interStr = CollectionUtils.toStr(inter);
            throw new IllegalArgumentException(String.format("Categorical and numerical fields intersect: %s", interStr));
        }
        HashMap<String, DataFrame> lookupFrames = new HashMap<String, DataFrame>();
        for (String path : aProvider.listFiles("lookups", "*.csv")) {
            try {
                lookupFrames.put(path, InMemoryDataFrameReader.readCsv(aProvider, path));
            }
            catch (ConfigException ex) {
                throw new FloorPredictorException(String.format("Lookup file %s is bad", path), FloorPredictorProblemType.BAD_BUNDLE_BAD_RESORCE, ex);
            }
            catch (IOException ex) {
                throw new FloorPredictorException(String.format("bad %s: lookup file %s is missing", aProvider, path), FloorPredictorProblemType.BAD_BUNDLE_SOMETHING_MISSING, ex);
            }
        }
        Set<String> featuresIntermediate = Set.of("tz_mostfrequent", "tz_name");
        HashSet<String> xxx = CollectionUtils.union(this.featuresAll, featuresIntermediate);
        this.mutators = new ArrayList<Mutator>();
        this.mutators.add(new CommonMutator("common"));
        this.mutators.add(new EcpmMutator("ecpm"));
        this.mutators.add(new AdRequestTimestampsMutator("adRequestTimestamps"));
        this.mutators.add(new DemandStatsMutator("demand_stats"));
        for (Map.Entry entry : lookupFrames.entrySet()) {
            String lookupTableName = (String)entry.getKey();
            DataFrame df = (DataFrame)entry.getValue();
            SchemaConfig lookupSchema = df.getSchema();
            String[] allLookupKeyFields = lookupSchema.getKey();
            String[] allLookupValueFields = lookupSchema.getValueFields();
            HashSet<String> smallNumberOfFieldsToLookup = CollectionUtils.intersect(xxx, Arrays.asList(allLookupValueFields));
            if (smallNumberOfFieldsToLookup.isEmpty()) continue;
            LookupMutator mutator = new LookupMutator(lookupTableName, df, Arrays.asList(allLookupKeyFields), new ArrayList<String>(smallNumberOfFieldsToLookup), null, true);
            this.mutators.add(mutator);
        }
        this.mutators.add(new TimezoneMutator("client_time"));
        HashSet<String> featuresContWithoutFloor = new HashSet<String>(featuresCont);
        featuresContWithoutFloor.remove("flr");
        this.mutators.add(new TypeMutator("fix_types", featuresCat, featuresContWithoutFloor));
    }

    public List<String> getFieldsToBeSubmittedByBackend() {
        HashSet<String> features = new HashSet<String>(this.featuresAll);
        for (Mutator m : this.mutators) {
            features.addAll(m.infoInputFeatures());
        }
        for (Mutator m : this.mutators) {
            features.removeAll(m.infoCalcFeatures());
        }
        ArrayList<String> result = new ArrayList<String>(features);
        Collections.sort(result);
        return result;
    }

    public double predictBestFloor(Map<String, Object> aInput) throws FloorPredictorException {
        return this.predictBestFloor(aInput, this.mlParams);
    }

    public double predictBestFloor(Map<String, Object> aInput, String aJsonMLParams) throws FloorPredictorException {
        return this.predictBestFloor(aInput, MLParamsBuilder.get(aJsonMLParams));
    }

    public double predictBestFloor(InputData aInput) throws FloorPredictorException {
        return this.predictBestFloor(aInput.getMap(), this.mlParams);
    }

    public double predictBestFloor(InputData aInput, String aJsonMLParams) throws FloorPredictorException {
        return this.predictBestFloor(aInput.getMap(), MLParamsBuilder.get(aJsonMLParams));
    }

    public double predictBestFloor(Map<String, Object> aInput, MLParams aMLParams) throws FloorPredictorException {
        if (aInput.containsKey("flr")) {
            throw new IllegalArgumentException("flr must not be in the keys when calling predictBestFloor()");
        }
        FeatureRecord fr = new FeatureRecord(aInput);
        Double originalFloorObject = (Double)fr.getOrDefault("originalFloor", Double.NaN);
        double originalFloor = originalFloorObject != null ? originalFloorObject : Double.NaN;
        String adt = (String)fr.get("adt");
        for (Mutator m : this.mutators) {
            try {
                fr = m.mutate(fr);
            }
            catch (MutatorException ex) {
                throw new FloorPredictorException(String.format("Mutator %s failed", m.getId()), FloorPredictorProblemType.MUTATOR_FAILURE, ex);
            }
        }
        if (fr.isStopped()) {
            throw new FloorPredictorException(String.format("One of mutators stopped: " + fr.getStopInfo(), new Object[0]), FloorPredictorProblemType.MUTATOR_STOPPED);
        }
        fr.put("flr", Float.valueOf(0.0f));
        FloorSelectionStrategy strategy = aMLParams.floorSelectionStrategy();
        PostprocessingParams pp = aMLParams.postprocessingParams();
        float[] cuts = this.cutsAdt.get(adt);
        double[] cutsDouble = this.cutsDoubleAdt.get(adt);
        return MathUtils.round(Postprocessing.applyOriginalFloorFix(switch (strategy) {
            case FloorSelectionStrategy.ECPM_PLAIN -> {
                double predictedEcpm = this.predictSingleValueByModel(fr, "lossPrice_lurl", adt);
                yield Postprocessing.calcEcpm(pp, predictedEcpm);
            }
            case FloorSelectionStrategy.ORIGINAL_FLOOR -> Double.isNaN(originalFloor) || originalFloor == 0.0 ? 0.01 : originalFloor;
            case FloorSelectionStrategy.REACH_WIN_PROB -> {
                double[] probIsSpendNurl = this.predictArrayByModel(fr, "is_spend_nurl", adt);
                yield Postprocessing.calcReachWinProb(pp, cuts, probIsSpendNurl);
            }
            case FloorSelectionStrategy.SMART_REACH_WIN_PROB -> {
                double[] probIsSpendNurl = this.predictArrayByModel(fr, "is_spend_nurl", adt);
                yield Postprocessing.calcSmartReachWinProb(pp, cuts, probIsSpendNurl);
            }
            case FloorSelectionStrategy.SMART_WIN -> {
                double[] probHasGoodBids = this.predictArrayByModel(fr, "has_good_bids", adt);
                double[] probIsSpendNurl = this.predictArrayByModel(fr, "is_spend_nurl", adt);
                double[] predictedSpend = this.predictSpendOrTakeFloor(fr, pp, cutsDouble, adt);
                yield Postprocessing.calcSmartWin(pp, cuts, probHasGoodBids, probIsSpendNurl, predictedSpend);
            }
            case FloorSelectionStrategy.SMART_SPEND -> {
                double[] probHasGoodBids = this.predictArrayByModel(fr, "has_good_bids", adt);
                double[] probIsSpendNurl = this.predictArrayByModel(fr, "is_spend_nurl", adt);
                double[] predictedSpend = this.predictSpendOrTakeFloor(fr, pp, cutsDouble, adt);
                yield Postprocessing.calcSmartSpend(pp, cuts, probHasGoodBids, probIsSpendNurl, predictedSpend);
            }
            case FloorSelectionStrategy.RANDOM_FLOOR -> {
                double minValue = pp.getRandomMin();
                double maxValue = pp.getRandomMax();
                yield minValue + (maxValue - minValue) * Math.random();
            }
            case FloorSelectionStrategy.PLAIN_WIN -> {
                double[] probHasGoodBids = this.predictArrayByModel(fr, "has_good_bids", adt);
                double[] probIsSpendNurl = this.predictArrayByModel(fr, "is_spend_nurl", adt);
                yield Postprocessing.calcPlainWin(cuts, probHasGoodBids, probIsSpendNurl);
            }
            case FloorSelectionStrategy.PLAIN_SPEND -> {
                double[] probHasGoodBids = this.predictArrayByModel(fr, "has_good_bids", adt);
                double[] probIsSpendNurl = this.predictArrayByModel(fr, "is_spend_nurl", adt);
                double[] predictedSpend = this.predictSpendOrTakeFloor(fr, pp, cutsDouble, adt);
                yield Postprocessing.calcPlainSpend(cuts, probHasGoodBids, probIsSpendNurl, predictedSpend);
            }
            case FloorSelectionStrategy.MIDDLE_SMART_WIN_SMART_SPEND -> {
                double[] probHasGoodBids = this.predictArrayByModel(fr, "has_good_bids", adt);
                double[] probIsSpendNurl = this.predictArrayByModel(fr, "is_spend_nurl", adt);
                double[] predictedSpend = this.predictSpendOrTakeFloor(fr, pp, cutsDouble, adt);
                yield Postprocessing.calcMiddleSmartSpendSmartWin(pp, cuts, probHasGoodBids, probIsSpendNurl, predictedSpend);
            }
            default -> throw new NotImplementedException(String.format("%s is not implemented yet", new Object[]{strategy}));
        }, pp, originalFloor), pp.getRoundToDecimals());
    }

    private double[] predictSpendOrTakeFloor(FeatureRecord fr, PostprocessingParams pp, double[] cutsDouble, String adt) throws FloorPredictorException {
        if (pp.isFloorInsteadOfSpend()) {
            double mult = pp.getFloorInsteadOfSpendMultiplier();
            if (mult == 1.0) {
                return cutsDouble;
            }
            double[] result = Arrays.copyOf(cutsDouble, cutsDouble.length);
            int i = 0;
            while (i < result.length) {
                int n = i++;
                result[n] = result[n] * mult;
            }
            return result;
        }
        return this.predictArrayByModel(fr, "spend_nurl", adt);
    }

    private double[] predictArrayByModel(FeatureRecord aRecord, String aModelName, String adt) throws FloorPredictorException {
        double[] result;
        block11: {
            String[] catValues;
            float[] contValues;
            String modelAdt = aModelName + "__" + adt;
            CatBoostModel model = this.catBoostModelsAdt.get(modelAdt);
            ModelFeaturesConfig modelConf = this.modelAdtFeaturesConfigs.get(modelAdt);
            try {
                contValues = aRecord.getFloatValues(modelConf.getContVars());
            }
            catch (IllegalArgumentException ex) {
                throw new FloorPredictorException(String.format("Model %s requires feature, but it is missing in the input", modelAdt), FloorPredictorProblemType.FEATURE_MISSING, ex);
            }
            catch (ClassCastException ex) {
                throw new FloorPredictorException(String.format("Model %s, input has a feature which can't be converted to float32", modelAdt), FloorPredictorProblemType.FEATURE_NUMERICAL_NOT_FLOAT32, ex);
            }
            try {
                catValues = aRecord.getStringValues(modelConf.getCatVars());
            }
            catch (IllegalArgumentException ex) {
                throw new FloorPredictorException(String.format("Model %s requires feature, but it is missing in the input", modelAdt), FloorPredictorProblemType.FEATURE_MISSING, ex);
            }
            float[] cuts = this.cutsAdt.getOrDefault(adt, null);
            if (cuts == null) {
                throw new FloorPredictorException("Cuts missing for adtype " + adt, FloorPredictorProblemType.CUTS_MISSING_FOR_ADTYPE);
            }
            int n_rows = cuts.length;
            float[][] contMatrix = new float[n_rows][];
            int indexOfFloor = this.indexOfFloorByModelAdt.get(modelAdt);
            for (int i = 0; i < n_rows; ++i) {
                float[] row = Arrays.copyOf(contValues, contValues.length);
                row[indexOfFloor] = cuts[i];
                contMatrix[i] = row;
            }
            String[][] catMatrix = new String[n_rows][];
            Arrays.fill((Object[])catMatrix, catValues);
            try {
                CatBoostPredictions pred = model.predict((float[][])contMatrix, catMatrix);
                result = pred.copyRowMajorPredictions();
                if (FloorPredictor.isClassificationWithLogits(aModelName)) {
                    MathUtils.applySigmoid(result);
                    break block11;
                }
                if (FloorPredictor.isRegressionWithLog(aModelName)) {
                    MathUtils.applyExp(result);
                    break block11;
                }
                throw new IllegalArgumentException(String.format("Model %s is of unknown type (regression/classification), we need to handle this manually", modelAdt));
            }
            catch (CatBoostError ex) {
                throw new FloorPredictorException(String.format("Failed to do CatBoost prediction for model %s", modelAdt), FloorPredictorProblemType.CATBOOST_FAILED_TO_PREDICT, ex);
            }
        }
        return result;
    }

    private double predictSingleValueByModel(FeatureRecord aRecord, String aModelName, String adt) throws FloorPredictorException {
        String[] catValues;
        float[] contValues;
        String modelAdt = aModelName + "__" + adt;
        CatBoostModel model = this.catBoostModelsAdt.getOrDefault(modelAdt, null);
        ModelFeaturesConfig modelConf = this.modelAdtFeaturesConfigs.getOrDefault(modelAdt, null);
        if (model == null || modelConf == null) {
            throw new FloorPredictorException("Model missing: " + modelAdt, FloorPredictorProblemType.MODEL_MISSING_FOR_ADTYPE);
        }
        try {
            contValues = aRecord.getFloatValues(modelConf.getContVars());
        }
        catch (IllegalArgumentException ex) {
            throw new FloorPredictorException(String.format("Model %s requires feature, but it is missing in the input", modelAdt), FloorPredictorProblemType.FEATURE_MISSING, ex);
        }
        catch (ClassCastException ex) {
            throw new FloorPredictorException(String.format("Model %s, input has a feature which can't be converted to float32", modelAdt), FloorPredictorProblemType.FEATURE_NUMERICAL_NOT_FLOAT32, ex);
        }
        try {
            catValues = aRecord.getStringValues(modelConf.getCatVars());
        }
        catch (IllegalArgumentException ex) {
            throw new FloorPredictorException(String.format("Model %s requires feature, but it is missing in the input", modelAdt), FloorPredictorProblemType.FEATURE_MISSING, ex);
        }
        try {
            return FloorPredictor.predictSingleValueInternal(model, aModelName, contValues, catValues);
        }
        catch (CatBoostError ex) {
            throw new FloorPredictorException(String.format("Failed to do CatBoost prediction for model %s", modelAdt), FloorPredictorProblemType.CATBOOST_FAILED_TO_PREDICT, ex);
        }
    }

    public static boolean isRegressionWithLog(String aModelName) {
        return "spend_nurl".equals(aModelName) || "lossPrice_lurl".equals(aModelName);
    }

    public static boolean isClassificationWithLogits(String aModelName) {
        return "is_spend_nurl".equals(aModelName) || "has_good_bids".equals(aModelName);
    }

    public static double predictSingleValueInternal(CatBoostModel model, String aModelName, float[] contValues, String[] catValues) throws CatBoostError {
        CatBoostPredictions pred = model.predict(contValues, catValues);
        double result = pred.get(0, 0);
        if (FloorPredictor.isRegressionWithLog(aModelName)) {
            return Math.exp(result);
        }
        if (FloorPredictor.isClassificationWithLogits(aModelName)) {
            return 1.0 / (1.0 + Math.exp(-result));
        }
        throw new IllegalArgumentException(String.format("Model %s is of unknown type (regression/classification), we need to handle this manually", aModelName));
    }

    public void setMLParams(MLParams aMLParams) {
        this.mlParams = aMLParams;
    }

    public void setMLParams(String aJsonMLParams) {
        this.mlParams = MLParamsBuilder.get(aJsonMLParams);
    }
}

