/*
 * Decompiled with CFR 0.152.
 */
package com.gkano.bioinfo.vcf;

import com.gkano.bioinfo.var.Logger;
import com.gkano.bioinfo.vcf.SNPEncoder;
import com.gkano.bioinfo.vcf.VCFDecoder;
import com.gkano.bioinfo.vcf.VCFStreamingIterator;
import com.gkano.bioinfo.vcf.VariantEmbeddingLoader;
import com.gkano.bioinfo.vcf.VariantKeyExtractor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

public class VCFManager
implements Runnable {
    private final List<String> POISON_PILL_BATCH = Collections.emptyList();
    private final int batchSize = 1000;
    private final List<String> inputFileNames;
    private int usingThreads;
    private final int maxSizeOfVariantCache;
    private final Function<String, String> variantParser;
    private boolean verbose = false;
    private BlockingQueue<List<String>> variantRawCache;
    private Map<String, int[]> genotypeEncodingCache;
    private List<String> commentData;
    private String headerData;
    private AtomicInteger currVariantCount;
    private AtomicInteger skippedVariantCount;
    private int numVariants;
    private int numSamples;
    private List<String> sampleNames;
    private int ploidy;
    private int maxAlleles;
    private CountDownLatch startSignal;
    private CountDownLatch doneSignal;
    private ExecutorService pool;
    private List<CompletableFuture<ProcessorResult>> variantProcessors;
    private List<CompletableFuture<EmbeddingProcessorResult>> embeddingProcessors;
    private int numBootstraps = 0;
    private boolean embeddingMode = false;
    private Map<String, double[]> variantEmbeddings;
    private int embeddingDim = -1;
    private VariantKeyExtractor keyExtractor;

    public VCFManager(List<String> inputFileNames, int usingThreads, Function<String, String> variantParser, boolean verbose) {
        this.inputFileNames = inputFileNames;
        this.usingThreads = usingThreads;
        this.variantParser = variantParser;
        this.verbose = verbose;
        this.maxSizeOfVariantCache = 4 * usingThreads;
    }

    public void setEmbeddings(Map<String, double[]> embeddings, VariantKeyExtractor.KeyFormat keyFormat) {
        if (embeddings == null || embeddings.isEmpty()) {
            this.embeddingMode = false;
            this.variantEmbeddings = null;
            this.embeddingDim = -1;
            this.keyExtractor = null;
            return;
        }
        this.embeddingMode = true;
        this.variantEmbeddings = embeddings;
        this.embeddingDim = VariantEmbeddingLoader.getEmbeddingDimension(embeddings);
        this.keyExtractor = new VariantKeyExtractor(keyFormat);
        Logger.info(this, "Embedding mode enabled: " + embeddings.size() + " embeddings with dimension " + this.embeddingDim);
    }

    public boolean isEmbeddingMode() {
        return this.embeddingMode;
    }

    public int getEmbeddingDim() {
        return this.embeddingDim;
    }

    public void init() {
        Logger.setVerbose(this.verbose);
        if (this.inputFileNames == null || this.inputFileNames.isEmpty()) {
            Logger.error(this, "No VCF input files provided.");
            throw new IllegalArgumentException("No VCF input files provided.");
        }
        this.variantRawCache = new LinkedBlockingQueue<List<String>>(this.maxSizeOfVariantCache);
        this.genotypeEncodingCache = new ConcurrentHashMap<String, int[]>();
        this.commentData = new ArrayList<String>();
        this.currVariantCount = new AtomicInteger(0);
        this.skippedVariantCount = new AtomicInteger(0);
        int cpus = Runtime.getRuntime().availableProcessors();
        this.usingThreads = cpus < this.usingThreads ? cpus : this.usingThreads;
        Logger.info(this, "cpus=" + cpus + "\tusing=" + this.usingThreads);
        this.startSignal = new CountDownLatch(1);
        this.doneSignal = new CountDownLatch(1);
        this.pool = Executors.newFixedThreadPool(this.usingThreads);
        if (this.embeddingMode) {
            this.initEmbeddingProcessors();
        } else {
            this.initStandardProcessors();
        }
    }

    private void initStandardProcessors() {
        this.variantProcessors = new ArrayList<CompletableFuture<ProcessorResult>>();
        for (int t = 0; t < this.usingThreads; ++t) {
            CompletableFuture<ProcessorResult> variantProcessor = CompletableFuture.supplyAsync(() -> {
                /*
                 * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
                 * 
                 * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [6[DOLOOP]], but top level block is 9[UNCONDITIONALDOLOOP]
                 *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
                 *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
                 *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
                 *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
                 *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
                 *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
                 *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
                 *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
                 *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1050)
                 *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
                 *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
                 *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
                 *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
                 *     at org.benf.cfr.reader.Main.main(Main.java:54)
                 */
                throw new IllegalStateException("Decompilation failed");
            }, this.pool);
            this.variantProcessors.add(variantProcessor);
        }
    }

    private void initEmbeddingProcessors() {
        this.embeddingProcessors = new ArrayList<CompletableFuture<EmbeddingProcessorResult>>();
        for (int t = 0; t < this.usingThreads; ++t) {
            CompletableFuture<EmbeddingProcessorResult> embeddingProcessor = CompletableFuture.supplyAsync(() -> {
                /*
                 * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
                 * 
                 * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [6[DOLOOP]], but top level block is 9[UNCONDITIONALDOLOOP]
                 *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
                 *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
                 *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
                 *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
                 *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
                 *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
                 *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
                 *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
                 *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1050)
                 *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
                 *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
                 *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
                 *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
                 *     at org.benf.cfr.reader.Main.main(Main.java:54)
                 */
                throw new IllegalStateException("Decompilation failed");
            }, this.pool);
            this.embeddingProcessors.add(embeddingProcessor);
        }
    }

    public void awaitFinalization() throws InterruptedException {
        this.doneSignal.await();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() {
        try {
            Logger.info(this, "START READ" + (this.embeddingMode ? " (embedding mode)" : ""));
            VCFDecoder decoder = new VCFDecoder();
            ArrayList<String> batch = new ArrayList<String>(2500);
            try (VCFStreamingIterator iterator = new VCFStreamingIterator(decoder, this.verbose, this.inputFileNames);){
                for (String line : iterator) {
                    if (line == null) continue;
                    this.processVariantLine(line, batch);
                }
            }
            if (!batch.isEmpty()) {
                this.variantRawCache.put(batch);
            }
            if (this.ploidy <= 0) {
                this.startSignal.countDown();
            }
            for (int i = 0; i < this.usingThreads; ++i) {
                this.variantRawCache.put(this.POISON_PILL_BATCH);
            }
            Logger.info(this, "END READ");
            this.pool.shutdown();
            this.pool.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
            this.numVariants = this.currVariantCount.get();
            Logger.info(this, "Processed variants :\t" + this.numVariants);
            if (this.embeddingMode) {
                int skipped = this.skippedVariantCount.get();
                int used = this.numVariants - skipped;
                Logger.info(this, "Variants with embeddings:\t" + used + " (skipped " + skipped + " without embeddings)");
            }
            this.doneSignal.countDown();
        }
        catch (Exception e) {
            Logger.error(this, e.getMessage());
            this.shutdown();
            this.doneSignal.countDown();
            Thread.currentThread().interrupt();
        }
        finally {
            this.shutdown();
        }
    }

    private void processVariantLine(String line, List<String> batch) throws Exception {
        block10: {
            try {
                if (line.startsWith("##")) {
                    this.commentData.add(line);
                    break block10;
                }
                if (line.startsWith("#")) {
                    this.headerData = line;
                    this.sampleNames = this.getSampleNamesFromHeader();
                    if (this.sampleNames == null || this.sampleNames.size() <= 0) {
                        throw new IllegalStateException("No samples detected from #CHROM header.");
                    }
                    this.numSamples = this.sampleNames.size();
                    break block10;
                }
                if (this.ploidy <= 0) {
                    try {
                        int[] ploidy_maxAlleles = SNPEncoder.guessPloidyAndMaxAllele(line);
                        this.ploidy = ploidy_maxAlleles[0];
                        this.maxAlleles = ploidy_maxAlleles[1];
                        if (this.ploidy > 0) {
                            this.startSignal.countDown();
                        }
                    }
                    catch (IllegalArgumentException e) {
                        throw new IllegalStateException("Failed to infer ploidy / maxAlleles: " + e.getMessage(), e);
                    }
                }
                String parsed = this.variantParser.apply(line);
                batch.add(parsed);
                if (batch.size() >= 1000) {
                    this.variantRawCache.put(new ArrayList<String>(batch));
                    batch.clear();
                }
            }
            catch (IllegalArgumentException | IllegalStateException | InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public double[][] reduceDotProdToDistances() {
        if (this.embeddingMode) {
            return this.reduceEmbeddingsToDistances();
        }
        try {
            int j;
            long[][] finalDotProd = new long[this.numSamples][this.numSamples];
            long[] finalNorm = new long[this.numSamples];
            for (CompletableFuture<ProcessorResult> vp : this.variantProcessors) {
                ProcessorResult r = vp.join();
                for (int i = 0; i < this.numSamples; ++i) {
                    int n = i;
                    finalNorm[n] = finalNorm[n] + r.norm[0][i];
                    for (j = i; j < this.numSamples; ++j) {
                        long[] lArray = finalDotProd[i];
                        int n2 = j;
                        lArray[n2] = lArray[n2] + r.dotProd[0][i][j];
                    }
                }
            }
            double[][] cosineDist = new double[this.numSamples][this.numSamples];
            for (int i = 0; i < this.numSamples; ++i) {
                double normI = Math.sqrt(finalNorm[i]);
                for (j = i; j < this.numSamples; ++j) {
                    double dist;
                    double similarity;
                    double normJ = Math.sqrt(finalNorm[j]);
                    double dot = finalDotProd[i][j];
                    double d = similarity = normI > 0.0 && normJ > 0.0 ? dot / (normI * normJ) : 0.0;
                    if (j == i && normI == 0.0 && normJ == 0.0) {
                        similarity = 1.0;
                    }
                    if ((dist = 1.0 - similarity) < 0.0) {
                        dist = 0.0;
                    }
                    double d2 = dist;
                    cosineDist[j][i] = d2;
                    cosineDist[i][j] = d2;
                }
            }
            return cosineDist;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public double[][] reduceEmbeddingsToDistances() {
        try {
            double[][] finalEmbeddings = new double[this.numSamples][this.embeddingDim];
            for (CompletableFuture<EmbeddingProcessorResult> ep : this.embeddingProcessors) {
                EmbeddingProcessorResult r = ep.join();
                for (int i = 0; i < this.numSamples; ++i) {
                    for (int d = 0; d < this.embeddingDim; ++d) {
                        double[] dArray = finalEmbeddings[i];
                        int n = d;
                        dArray[n] = dArray[n] + r.sampleEmbeddings[0][i][d];
                    }
                }
            }
            return this.computeCosineDistancesFromEmbeddings(finalEmbeddings);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public List<double[][]> reduceDotProdToDistancesBootstraps() {
        if (this.embeddingMode) {
            return this.reduceEmbeddingsToDistancesBootstraps();
        }
        try {
            int numReplicates = this.numBootstraps > 0 ? this.numBootstraps + 1 : 1;
            long[][][] finalDotProd = new long[numReplicates][this.numSamples][this.numSamples];
            long[][] finalNorm = new long[numReplicates][this.numSamples];
            for (CompletableFuture<ProcessorResult> vp : this.variantProcessors) {
                ProcessorResult r = vp.join();
                for (int rep = 0; rep < numReplicates; ++rep) {
                    for (int i = 0; i < this.numSamples; ++i) {
                        long[] lArray = finalNorm[rep];
                        int n = i;
                        lArray[n] = lArray[n] + r.norm[rep][i];
                        for (int j = i; j < this.numSamples; ++j) {
                            long[] lArray2 = finalDotProd[rep][i];
                            int n2 = j;
                            lArray2[n2] = lArray2[n2] + r.dotProd[rep][i][j];
                        }
                    }
                }
            }
            ArrayList<double[][]> allDistances = new ArrayList<double[][]>();
            for (int rep = 0; rep < numReplicates; ++rep) {
                double[][] cosineDist = new double[this.numSamples][this.numSamples];
                for (int i = 0; i < this.numSamples; ++i) {
                    double normI = Math.sqrt(finalNorm[rep][i]);
                    for (int j = i; j < this.numSamples; ++j) {
                        double dist;
                        double similarity;
                        double normJ = Math.sqrt(finalNorm[rep][j]);
                        double dot = finalDotProd[rep][i][j];
                        double d = similarity = normI > 0.0 && normJ > 0.0 ? dot / (normI * normJ) : 0.0;
                        if (j == i && normI == 0.0 && normJ == 0.0) {
                            similarity = 1.0;
                        }
                        if ((dist = 1.0 - similarity) < 0.0) {
                            dist = 0.0;
                        }
                        double d2 = dist;
                        cosineDist[j][i] = d2;
                        cosineDist[i][j] = d2;
                    }
                }
                allDistances.add(cosineDist);
            }
            return allDistances;
        }
        catch (OutOfMemoryError e) {
            Logger.error(this, "Could not allocate memory for bootstrap distance matrices: use --mem option to increase memory available to JVM.");
            throw new RuntimeException("Could not allocate memory for bootstrap distance matrices: " + e.getMessage(), e);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public List<double[][]> reduceEmbeddingsToDistancesBootstraps() {
        try {
            int numReplicates = this.numBootstraps > 0 ? this.numBootstraps + 1 : 1;
            double[][][] finalEmbeddings = new double[numReplicates][this.numSamples][this.embeddingDim];
            for (CompletableFuture<EmbeddingProcessorResult> ep : this.embeddingProcessors) {
                EmbeddingProcessorResult r = ep.join();
                for (int rep = 0; rep < numReplicates; ++rep) {
                    for (int i = 0; i < this.numSamples; ++i) {
                        for (int d = 0; d < this.embeddingDim; ++d) {
                            double[] dArray = finalEmbeddings[rep][i];
                            int n = d;
                            dArray[n] = dArray[n] + r.sampleEmbeddings[rep][i][d];
                        }
                    }
                }
            }
            ArrayList<double[][]> allDistances = new ArrayList<double[][]>();
            for (int rep = 0; rep < numReplicates; ++rep) {
                double[][] cosineDist = this.computeCosineDistancesFromEmbeddings(finalEmbeddings[rep]);
                allDistances.add(cosineDist);
            }
            return allDistances;
        }
        catch (OutOfMemoryError e) {
            Logger.error(this, "Could not allocate memory for bootstrap distance matrices: use --mem option to increase memory available to JVM.");
            throw new RuntimeException("Could not allocate memory for bootstrap distance matrices: " + e.getMessage(), e);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private double[][] computeCosineDistancesFromEmbeddings(double[][] embeddings) {
        int i;
        int n = embeddings.length;
        double[][] cosineDist = new double[n][n];
        double[] norms = new double[n];
        for (i = 0; i < n; ++i) {
            double sumSq = 0.0;
            for (int d = 0; d < embeddings[i].length; ++d) {
                sumSq += embeddings[i][d] * embeddings[i][d];
            }
            norms[i] = Math.sqrt(sumSq);
        }
        for (i = 0; i < n; ++i) {
            double normI = norms[i];
            for (int j = i; j < n; ++j) {
                double normJ = norms[j];
                double dot = 0.0;
                for (int d = 0; d < embeddings[i].length; ++d) {
                    dot += embeddings[i][d] * embeddings[j][d];
                }
                double similarity = normI > 0.0 && normJ > 0.0 ? dot / (normI * normJ) : (i == j ? 1.0 : 0.0);
                double dist = 1.0 - similarity;
                if (dist < 0.0) {
                    dist = 0.0;
                }
                if (dist > 2.0) {
                    dist = 2.0;
                }
                double d = dist;
                cosineDist[j][i] = d;
                cosineDist[i][j] = d;
            }
        }
        return cosineDist;
    }

    public List<String> getSampleNamesFromHeader() {
        if (this.headerData == null || !this.headerData.startsWith("#CHROM")) {
            throw new IllegalArgumentException("VCF header missing #CHROM line: " + this.headerData);
        }
        String[] fields = this.headerData.split("\t", -1);
        if (fields.length <= 9) {
            return Collections.emptyList();
        }
        return Arrays.asList(Arrays.copyOfRange(fields, 9, fields.length));
    }

    private void shutdown() {
        this.pool.shutdownNow();
    }

    public List<String> getSampleNames() {
        return this.sampleNames;
    }

    public int getNumVariants() {
        return this.numVariants;
    }

    public int getNumSamples() {
        return this.numSamples;
    }

    public int getPloidy() {
        return this.ploidy;
    }

    public int getMaxAlleles() {
        return this.maxAlleles;
    }

    public void setNumBootstraps(int replicates) {
        if (replicates < 0) {
            throw new IllegalArgumentException("bootstrap replicates must be >= 0");
        }
        this.numBootstraps = replicates;
    }

    public int getSkippedVariants() {
        return this.skippedVariantCount != null ? this.skippedVariantCount.get() : 0;
    }

    private static int poisson1(Random rand) {
        double L = Math.exp(-1.0);
        int k = 0;
        double p = 1.0;
        do {
            ++k;
        } while ((p *= rand.nextDouble()) > L);
        return k - 1;
    }

    public static class ProcessorResult {
        long[][][] dotProd;
        long[][] norm;
        int numReplicates;

        private ProcessorResult(int numReplicates, int numSamples) {
            this.numReplicates = numReplicates;
            this.dotProd = new long[numReplicates][numSamples][numSamples];
            this.norm = new long[numReplicates][numSamples];
        }

        public void merge(ProcessorResult other) {
            for (int r = 0; r < this.numReplicates; ++r) {
                for (int i = 0; i < this.norm[r].length; ++i) {
                    long[] lArray = this.norm[r];
                    int n = i;
                    lArray[n] = lArray[n] + other.norm[r][i];
                    for (int j = i; j < this.norm[r].length; ++j) {
                        long[] lArray2 = this.dotProd[r][i];
                        int n2 = j;
                        lArray2[n2] = lArray2[n2] + other.dotProd[r][i][j];
                    }
                }
            }
        }
    }

    public static class EmbeddingProcessorResult {
        double[][][] sampleEmbeddings;
        int numReplicates;
        int embeddingDim;

        private EmbeddingProcessorResult(int numReplicates, int numSamples, int embeddingDim) {
            this.numReplicates = numReplicates;
            this.embeddingDim = embeddingDim;
            this.sampleEmbeddings = new double[numReplicates][numSamples][embeddingDim];
        }

        public void merge(EmbeddingProcessorResult other) {
            for (int r = 0; r < this.numReplicates; ++r) {
                for (int i = 0; i < this.sampleEmbeddings[r].length; ++i) {
                    for (int d = 0; d < this.embeddingDim; ++d) {
                        double[] dArray = this.sampleEmbeddings[r][i];
                        int n = d;
                        dArray[n] = dArray[n] + other.sampleEmbeddings[r][i][d];
                    }
                }
            }
        }
    }
}

