/*
 * Decompiled with CFR 0.152.
 */
package phase;

import ints.IntArray;
import java.util.Optional;
import java.util.Random;
import phase.FixedPhaseData;
import phase.HmmStateProbs;
import phase.LowFreqPhaseIbs;
import phase.PhaseData;
import phase.Stage2Haps;
import vcf.GT;
import vcf.RefGT;

public class Stage2Baum {
    private final FixedPhaseData fpd;
    private final PhaseData phaseData;
    private final HmmStateProbs stateProbs;
    private final int[] nStates = new int[2];
    private final int[][][] states;
    private final float[][][] probs;
    private final GT unphTargGT;
    private final Optional<RefGT> refGT;
    private final int nTargHaps;
    private final int nStage1Markers;
    private final Stage2Haps stage2Haps;
    private final IntArray stage1To2;
    private final Random rand;

    public Stage2Baum(LowFreqPhaseIbs lowFreqPhaseIbs, Stage2Haps stage2Haps) {
        this.fpd = lowFreqPhaseIbs.phaseData().fpd();
        this.phaseData = lowFreqPhaseIbs.phaseData();
        this.nStage1Markers = this.fpd.stage1TargGT().nMarkers();
        this.stateProbs = new HmmStateProbs(lowFreqPhaseIbs);
        this.states = new int[2][this.nStage1Markers][this.stateProbs.maxStates()];
        this.probs = new float[2][this.nStage1Markers][this.stateProbs.maxStates()];
        this.unphTargGT = this.fpd.targGT();
        this.refGT = this.fpd.refGT();
        this.nTargHaps = this.fpd.targGT().nHaps();
        this.stage2Haps = stage2Haps;
        this.stage1To2 = this.fpd.stage1To2();
        this.rand = new Random(this.phaseData.seed());
    }

    public int nTargSamples() {
        return this.fpd.targGT().nSamples();
    }

    public void phase(int n) {
        this.rand.setSeed(this.phaseData.seed() + (long)n);
        int n2 = n << 1;
        int n3 = n2 | 1;
        this.nStates[0] = this.stateProbs.run(n2, this.states[0], this.probs[0]);
        this.nStates[1] = this.stateProbs.run(n3, this.states[1], this.probs[1]);
        int n4 = 0;
        for (int i = 0; i < this.nStage1Markers; ++i) {
            int n5 = this.stage1To2.get(i);
            this.imputeInterval(n, n4, n5);
            n4 = n5 + 1;
        }
        this.imputeInterval(n, n4, this.unphTargGT.nMarkers());
    }

    private void imputeInterval(int n, int n2, int n3) {
        for (int i = n2; i < n3; ++i) {
            int n4 = this.unphTargGT.allele1(i, n);
            int n5 = this.unphTargGT.allele2(i, n);
            if (n4 >= 0 && n5 >= 0) {
                if (n4 != n5) {
                    float f;
                    boolean bl;
                    float[] fArray;
                    float[] fArray2 = this.unscaledAlProbs(i, 0, n4, n5);
                    float f2 = fArray2[n4] * (fArray = this.unscaledAlProbs(i, 1, n4, n5))[n5];
                    boolean bl2 = bl = f2 < (f = fArray2[n5] * fArray[n4]) || f2 == f && this.rand.nextBoolean();
                    if (bl) {
                        int n6 = n4;
                        n4 = n5;
                        n5 = n6;
                    }
                }
            } else {
                n4 = this.imputeAllele(i, 0);
                n5 = this.imputeAllele(i, 1);
            }
            this.stage2Haps.setPhasedGT(i, n, n4, n5);
        }
    }

    private float[] unscaledAlProbs(int n, int n2, int n3, int n4) {
        float[] fArray = new float[this.unphTargGT.marker(n).nAlleles()];
        boolean bl = this.fpd.isLowFreq(n, n3);
        boolean bl2 = this.fpd.isLowFreq(n, n4);
        int n5 = this.fpd.prevStage1Marker(n);
        int n6 = Math.min(n5 + 1, this.nStage1Markers - 1);
        int[] nArray = this.states[n2][n5];
        float[] fArray2 = this.probs[n2][n5];
        float[] fArray3 = this.probs[n2][n6];
        int n7 = this.nStates[n2];
        for (int i = 0; i < n7; ++i) {
            boolean bl3;
            int n8 = nArray[i];
            int n9 = this.allele(n, n8);
            int n10 = this.allele(n, n8 ^ 1);
            if (n9 < 0 || n10 < 0) continue;
            float f = this.fpd.prevStage1Wt(n);
            float f2 = f * fArray2[i] + (1.0f - f) * fArray3[i];
            if (n9 == n10) {
                int n11 = n9;
                fArray[n11] = fArray[n11] + f2;
                continue;
            }
            boolean bl4 = bl && (n3 == n9 || n3 == n10);
            boolean bl5 = bl3 = bl2 && (n4 == n9 || n4 == n10);
            if (!(bl4 ^ bl3)) continue;
            if (bl4) {
                int n12 = n3;
                fArray[n12] = fArray[n12] + f2;
                continue;
            }
            int n13 = n4;
            fArray[n13] = fArray[n13] + f2;
        }
        return fArray;
    }

    private int imputeAllele(int n, int n2) {
        float[] fArray = new float[this.unphTargGT.marker(n).nAlleles()];
        int n3 = this.fpd.prevStage1Marker(n);
        int n4 = Math.min(n3 + 1, this.nStage1Markers - 1);
        int[] nArray = this.states[n2][n3];
        float[] fArray2 = this.probs[n2][n3];
        float[] fArray3 = this.probs[n2][n4];
        int n5 = this.nStates[n2];
        for (int i = 0; i < n5; ++i) {
            boolean bl;
            float f = this.fpd.prevStage1Wt(n);
            float f2 = f * fArray2[i] + (1.0f - f) * fArray3[i];
            int n6 = nArray[i];
            int n7 = this.allele(n, n6);
            int n8 = this.allele(n, n6 ^ 1);
            if (n7 < 0 || n8 < 0) continue;
            if (n7 == n8 || n6 >= this.nTargHaps) {
                int n9 = n7;
                fArray[n9] = fArray[n9] + f2;
                continue;
            }
            boolean bl2 = this.fpd.isLowFreq(n, n7);
            if (bl2 ^ (bl = this.fpd.isLowFreq(n, n8))) {
                if (bl2) {
                    int n10 = n7;
                    fArray[n10] = (float)((double)fArray[n10] + 0.55 * (double)f2);
                    int n11 = n8;
                    fArray[n11] = (float)((double)fArray[n11] + 0.45 * (double)f2);
                    continue;
                }
                int n12 = n7;
                fArray[n12] = (float)((double)fArray[n12] + 0.45 * (double)f2);
                int n13 = n8;
                fArray[n13] = (float)((double)fArray[n13] + 0.55 * (double)f2);
                continue;
            }
            int n14 = n7;
            fArray[n14] = (float)((double)fArray[n14] + 0.5 * (double)f2);
            int n15 = n8;
            fArray[n15] = (float)((double)fArray[n15] + 0.5 * (double)f2);
        }
        return this.maxIndex(fArray);
    }

    private int allele(int n, int n2) {
        if (n2 < this.nTargHaps) {
            return this.unphTargGT.allele(n, n2);
        }
        return this.refGT.get().allele(n, n2 - this.nTargHaps);
    }

    private int maxIndex(float[] fArray) {
        int n = 0;
        for (int i = 1; i < fArray.length; ++i) {
            if (!(fArray[i] > fArray[n])) continue;
            n = i;
        }
        return n;
    }
}

