package models;

import Assisted_Classes.MyException;
import java.util.HashSet;

/* loaded from: input_file:models/Multinomial_Logistic_Regression.class */
public class Multinomial_Logistic_Regression {
    private double[][] wald;
    private double[][] wald_pvalue;
    private double[][] Betas;
    private String[] Targets;
    private double MAXIMUMlikelihood;
    private String K = null;
    private int Targetsize = 0;
    private int predictors = 0;

    public void setKvalue(String str) {
        this.K = str;
    }

    public void regression(double[][] dArr, String[] strArr, double d, int i) {
        this.predictors = dArr[0].length;
        if (i < 1) {
            i = 1;
        }
        if (dArr.length != strArr.length) {
            try {
                throw new MyException("Your matrix and Target arrays need to have the same length.");
            } catch (MyException e) {
                e.printStackTrace();
                return;
            }
        }
        if (d <= 0.0d) {
            d = 1.0E-4d;
        }
        HashSet hashSet = new HashSet();
        for (String str : strArr) {
            hashSet.add(str);
        }
        this.Targets = (String[]) hashSet.toArray(new String[hashSet.size()]);
        this.Targetsize = this.Targets.length;
        if (this.Targetsize < 3) {
            try {
                throw new MyException("Your Target Variable needs to have more than 2 distinct categories, else you may use a simple Binary Logistic model");
            } catch (MyException e2) {
                e2.printStackTrace();
                return;
            }
        }
        int i2 = -1;
        if (this.K != null) {
            int i3 = 0;
            while (true) {
                if (i3 >= this.Targets.length) {
                    break;
                }
                if (this.Targets[i3].equals(this.K)) {
                    i2 = i3;
                    break;
                }
                i3++;
            }
        }
        if (i2 > 0) {
            String str2 = this.Targets[0];
            this.Targets[0] = this.Targets[i2];
            this.Targets[i2] = str2;
        }
        for (int i4 = 0; i4 < this.Targets.length; i4++) {
            System.out.println();
        }
        double[][] dArr2 = new double[dArr.length][dArr[0].length + 1];
        for (int i5 = 0; i5 < dArr.length; i5++) {
            dArr2[i5][0] = 1.0d;
            for (int i6 = 0; i6 < dArr[0].length; i6++) {
                dArr2[i5][i6 + 1] = dArr[i5][i6];
            }
        }
        this.Betas = new double[this.Targetsize - 1][dArr2[0].length];
        this.wald = new double[this.Targetsize - 1][dArr2[0].length];
        this.wald_pvalue = new double[this.Targetsize - 1][dArr2[0].length];
        for (int i7 = 1; i7 < this.Targetsize; i7++) {
            double[] dArr3 = new double[dArr2.length];
            for (int i8 = 0; i8 < dArr3.length; i8++) {
                if (strArr[i8].equals(this.Targets[i7])) {
                    dArr3[i8] = 1.0d;
                } else {
                    dArr3[i8] = 0.0d;
                }
            }
            Logistic_Regression logistic_Regression = new Logistic_Regression();
            logistic_Regression.regression(dArr2, dArr3, false, d, i);
            double[] dArr4 = logistic_Regression.getbetas();
            double[] wald = logistic_Regression.getWald();
            double[] wald_P_Values = logistic_Regression.getWald_P_Values();
            for (int i9 = 0; i9 < dArr4.length; i9++) {
                this.Betas[i7 - 1][i9] = dArr4[i9];
                this.wald[i7 - 1][i9] = wald[i9];
                this.wald_pvalue[i7 - 1][i9] = wald_P_Values[i9];
            }
            this.MAXIMUMlikelihood += logistic_Regression.getMAXIMUMlikelihood();
            System.out.println("K is : " + i7);
        }
    }

    public void regression(double[][] dArr, String[] strArr) {
        regression(dArr, strArr, 1.0E-4d, 20);
    }

    public double[][] getprobabilites(double[][] dArr) {
        if (this.Targets == null || this.Targets.length <= 2) {
            try {
                throw new MyException("The Create_Logic method needs to be run successfully in order to create the logic before attempting classifying a new set");
            } catch (MyException e) {
                e.printStackTrace();
                return null;
            }
        }
        if (dArr[0].length != this.predictors) {
            try {
                throw new MyException("The number of columns in the given array needs to be the same with the number ofcolumns of the principal array used in the regerssion method");
            } catch (MyException e2) {
                e2.printStackTrace();
                return null;
            }
        }
        double[][] dArr2 = new double[dArr.length][this.Targetsize];
        for (int i = 0; i < dArr.length; i++) {
            double d = 1.0d;
            dArr2[i][0] = 1.0d;
            for (int i2 = 0; i2 < this.Betas.length; i2++) {
                dArr2[i][i2 + 1] = this.Betas[i2][0];
                for (int i3 = 0; i3 < this.predictors; i3++) {
                    dArr2[i][i2 + 1] = dArr2[i][i2 + 1] + (dArr[i][i3] * this.Betas[i2][i3 + 1]);
                }
                dArr2[i][i2 + 1] = Math.exp(dArr2[i][i2 + 1]);
                d += dArr2[i][i2 + 1];
            }
            for (int i4 = 0; i4 < this.Targetsize; i4++) {
                dArr2[i][i4] = dArr2[i][i4] / d;
            }
        }
        return dArr2;
    }

    public String[] getclassification(double[][] dArr) {
        if (this.Targets == null || this.Targets.length <= 2) {
            try {
                throw new MyException("The Create_Logic method needs to be run successfully in order to create the logic before attempting classifying a new set");
            } catch (MyException e) {
                e.printStackTrace();
                return null;
            }
        }
        if (dArr[0].length != this.predictors) {
            try {
                throw new MyException("The number of columns in the given array needs to be the same with the number ofcolumns of the principal array used in the regerssion method");
            } catch (MyException e2) {
                e2.printStackTrace();
                return null;
            }
        }
        String[] strArr = new String[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            double d = 1.0d;
            strArr[i] = this.Targets[0];
            for (int i2 = 0; i2 < this.Betas.length; i2++) {
                double d2 = this.Betas[i2][0];
                for (int i3 = 0; i3 < this.predictors; i3++) {
                    d2 += dArr[i][i3] * this.Betas[i2][i3 + 1];
                }
                double exp = Math.exp(d2);
                if (exp > d) {
                    d = exp;
                    strArr[i] = this.Targets[i2 + 1];
                }
            }
        }
        return strArr;
    }

    public double[][] getbetas() {
        return this.Betas;
    }

    public double[][] getwalds() {
        return this.wald;
    }

    public double[][] getwaldpvalues() {
        return this.wald_pvalue;
    }

    public double getmaximumlikelihood() {
        return this.MAXIMUMlikelihood;
    }

    public String[] get_distinct_values() {
        return this.Targets;
    }
}
