from Bio import AlignIO, SeqIO
from Bio.Seq import Seq
import bioframe as bf
from collections import defaultdict
from gpn.data import (
    filter_defined, filter_length, load_table, add_flank, get_annotation_features,
    add_jitter, get_promoters, get_random_intervals, union_intervals,
    intersect_intervals, intervals_size
)
import gzip
import h5py
import math
import more_itertools
import numpy as np
import os
import pandas as pd
import re
import scipy.sparse as sp_sparse
from scipy.special import softmax
from scipy.stats import entropy
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import torch
from tqdm import tqdm
tqdm.pandas()

from gpn.data import make_windows, get_seq
from gpn.data import load_fasta, save_fasta, load_table, load_repeatmasker, Genome


configfile: "config.yaml"

WINDOW_SIZE = config["window_size"]
PRIORITY_ASSEMBLIES = [
    "GCF_000001735.4",  # Arabidopsis thaliana
    "GCF_000309985.2",  # Brassica rapa
]
splits = ["train", "validation", "test"]
EMBEDDING_WINDOW_SIZE = 100
CHROMS = ["1", "2", "3", "4", "5"]
NUCLEOTIDES = list("ACGT")
models = [
    "gonzalobenegas/gpn-brassicales",
    #"/scratch/users/gbenegas/checkpoints/GPN_Arabidopsis_multispecies/ConvNet_ss_12k/checkpoint-12000",
]
all_radius = [
    100_000,
]
all_threshold = [
    0.1,
]

ruleorder: get_kmer_spectrum > get_embeddings

rule all:
    input:
        expand("output/variants/all/vep_embeddings/{model}.parquet", model=models),
        #expand("output/variants/all/vep/{model}.parquet", model=models),
        #expand("output/embedding/classification/{model}.parquet", model=models + ["kmers_3", "kmers_6"]),
        #expand(f"output/merged_dataset/{config['window_size']}/{config['step_size']}/{config['add_rc']}/balanced/data/{{split}}", split=splits),
        #"output/embedding/embeddings/kmers.parquet",
        #"output/assemblies.tex",
        #expand("output/whole_genome/modisco/{model}/report", model=models),
        #expand("output/whole_genome/subset/promoter/modisco/{model}/report", model=models),
        #expand("output/whole_genome/subset/promoter2/modisco/{model}/report", model=models),
        #expand("output/whole_genome/bed_probs/{model}/{nuc}.bw", model=models, nuc=NUCLEOTIDES),
        #"output/embedding/subset_no_repeats/windows.parquet",
        #expand("output/embedding/subset_no_repeats/umap/{model}.parquet", model=models),
        #expand("output/embedding/subset_no_repeats/leiden_{resolution}_/{model}.parquet", model=models, resolution=[0.3]),
        #expand("output/embedding/umap/{model}.parquet", model=models + ["kmers_5", "kmers_3", "kmers_6", "kmers_4"]),
        #expand("output/funtfbs/logits/{model}.parquet", model=models),
        #expand("output/whole_genome/logits/{model}.parquet", model=models),
        #expand("output/simulated_variants/vep/{model}.parquet", model=models),
        #expand("output/aragwas/LD_{radius}_{threshold}.npz", radius=all_radius, threshold=all_threshold),
        #expand("output/aragwas/LD_chrom_{threshold}.npz", threshold=all_threshold),
        #"output/region/modisco/report",
        #"output/promoter/modisco/report",
        #expand("output/variants/all/{model}.parquet", model=["phyloP", "phastCons"]),
        #"output/variants/all/GPN.parquet",
        #"output/region/modisco/report",
        #expand("output/merged_dataset/balanced/data/{split}", split=SPLITS),
        #"output/dataset/GCF_000001735.4/balanced/train.parquet",
        #expand("output/dataset/{assembly}/balanced/train.parquet", assembly=assemblies.index),
        #expand("output/intervals/{assembly}/balanced.parquet", assembly=assemblies.index),
        #"output/variants/chrom/5/dataset/gpn/128/dataset.parquet",
        #expand("output/dataset/mlm/gpn/{split}.parquet", split=splits),
        #expand("output/intervals_filt/{split}.parquet", split=splits),
        #expand("output/dataset/mlm/gpn/all/512/64/{split}/dataset.parquet", split=splits),
        #"output/variants/all/dataset/gpn/512/dataset.parquet",


rule filter_assemblies:
    input:
        config["raw_assemblies_path"],
    output:
        config["assemblies_path"],
    shell:
        """
        python -m gpn.ss.filter_assemblies {input} {output} --keep_one_per_genus \
        --priority_assemblies {PRIORITY_ASSEMBLIES}
        """


module make_dataset_from_ncbi:
    snakefile:
        "https://raw.githubusercontent.com/songlab-cal/gpn/main/workflow/make_dataset_from_ncbi/Snakefile"
    config: config


use rule * from make_dataset_from_ncbi as make_dataset_from_ncbi_*


rule make_assemblies_latex_table:
    input:
        config["assemblies_path"],
    output:
        "output/assemblies.tex",
    run:
        df = pd.read_csv(
            input[0], sep="\t",
            usecols=["Assembly Accession", "Assembly Name", "Organism Name"],
        )
        df["Assembly Accession"] = r"\texttt{"  + df["Assembly Accession"].str.replace("_", "\_") + "}"
        df["Assembly Name"] = r"\texttt{"  + df["Assembly Name"].str.replace("_", "\_") + "}"
        df["Organism Name"] = r"\textit{"  + df["Organism Name"] + "}"
        print(df)
        df.to_latex(output[0], index=False, escape=False)


rule download_reference:
    output:
        "output/genome.raw.fa.gz",
    shell:
        "wget --no-check-certificate {config[FASTA_URL]} -O {output}"


# all this was not necessary in the end
# should have just downloaded from Ensembl
# since the make_dataset_from_ncbi will take care of training dataset creation
rule rename_reference:
    input:
        "output/genome.raw.fa.gz",
        "input/id_mapping.tsv",
    output:
        "output/genome.fa.gz",
    run:
        genome = load_fasta(input[0])
        id_mapping = pd.read_csv(
            input[1], sep="\t", header=None, index_col=0, squeeze=True)
        for chrom in genome.keys():
            genome[chrom].id = id_mapping[genome[chrom].id]
        save_fasta(output[0], genome)


rule download_annotation:
    output:
        "output/annotation.gtf.gz",
    shell:
        "wget --no-check-certificate {config[GTF_URL]} -O {output}"


rule expand_annotation:
    input:
        "output/annotation.gtf.gz",
        "input/repeats.bed.gz",
    output:
        "output/annotation.expanded.parquet",
    run:
        gtf = load_table(input[0])

        repeats = pd.read_csv(input[1], sep="\t").rename(columns=dict(genoName="chrom", genoStart="start", genoEnd="end"))[["chrom", "start", "end"]]
        # TODO: figure out if repeatmasker coordinates are 0-based vs. 1-based
        # shouldn't matter much though
        repeats.chrom = repeats.chrom.str.replace("Chr", "")
        repeats = bf.merge(repeats).drop(columns="n_intervals")
        repeats["feature"] = "Repeat"
        gtf = pd.concat([gtf, repeats], ignore_index=True)

        gtf_intergenic = bf.subtract(gtf.query('feature=="chromosome"'), gtf[gtf.feature.isin(["gene", "ncRNA_gene", "Repeat"])])
        gtf_intergenic.feature = "intergenic"
        gtf = pd.concat([gtf, gtf_intergenic], ignore_index=True)

        gtf_exon = gtf[gtf.feature=="exon"]
        gtf_exon["transcript_id"] = gtf_exon.attribute.str.split(";").str[0].str.split(":").str[-1]

        def get_transcript_introns(df_transcript):
            df_transcript = df_transcript.sort_values("start")
            exon_pairs = more_itertools.pairwise(df_transcript.loc[:, ["start", "end"]].values)
            introns = [[e1[1], e2[0]] for e1, e2 in exon_pairs]
            introns = pd.DataFrame(introns, columns=["start", "end"])
            introns["chrom"] = df_transcript.chrom.iloc[0]
            return introns

        gtf_introns = gtf_exon.groupby("transcript_id").apply(get_transcript_introns).reset_index().drop_duplicates(subset=["chrom", "start", "end"])
        gtf_introns["feature"] = "intron"
        gtf = pd.concat([gtf, gtf_introns], ignore_index=True)
        gtf.to_parquet(output[0], index=False)


rule define_embedding_windows:
    input:
        "output/annotation.expanded.parquet",
        "output/genome.fa.gz",
    output:
        "output/embedding/windows.parquet",
    run:
        gtf = pd.read_parquet(input[0])
        genome = Genome(input[1])
        genome.filter_chroms(["1", "2", "3", "4", "5"])
        defined_intervals = genome.get_defined_intervals()
        defined_intervals = filter_length(defined_intervals, WINDOW_SIZE)
        windows = make_windows(defined_intervals, WINDOW_SIZE, EMBEDDING_WINDOW_SIZE)
        windows.rename(columns={"start": "full_start", "end": "full_end"}, inplace=True)

        windows["start"] = (windows.full_start+windows.full_end)//2 - EMBEDDING_WINDOW_SIZE//2
        windows["end"] = windows.start + EMBEDDING_WINDOW_SIZE

        features_of_interest = [
            "intergenic",
            'CDS',
            'intron',
            'three_prime_UTR',
            'five_prime_UTR',
            "ncRNA_gene",
            "Repeat",
        ]

        for f in features_of_interest:
            print(f)
            windows = bf.coverage(windows, gtf[gtf.feature==f])
            windows.rename(columns=dict(coverage=f), inplace=True)
        
        windows = windows[(windows[features_of_interest]==EMBEDDING_WINDOW_SIZE).sum(axis=1)==1]
        windows["Region"] = windows[features_of_interest].idxmax(axis=1)
        windows.drop(columns=features_of_interest, inplace=True)

        windows.rename(columns={"start": "center_start", "end": "center_end"}, inplace=True)
        windows.rename(columns={"full_start": "start", "full_end": "end"}, inplace=True)
        print(windows)
        windows.to_parquet(output[0], index=False)


rule get_embeddings:
    input:
        "output/embedding/windows.parquet",
        "output/genome.fa.gz",
    output:
        "output/embedding/embeddings/{model}.parquet",
    threads: workflow.cores
    shell:
        """
        python -m gpn.ss.get_embeddings {input} {EMBEDDING_WINDOW_SIZE} \
        {wildcards.model} {output} --per-device-batch-size 4000 --is-file \
        --dataloader-num-workers {threads}
        """


from functools import cache

# source: https://github.com/MindAI/kmer
class kmer_featurization:
    def __init__(self, k):
        """
        seqs: a list of DNA sequences
        k: the "k" in k-mer
        """
        self.k = k
        self.letters = ['A', 'T', 'C', 'G']
        self.multiplyBy = 4 ** np.arange(k-1, -1, -1) # the multiplying number for each digit position in the k-number system
        self.n = 4**k # number of possible k-mers

    def obtain_kmer_feature_for_one_sequence(self, seq, write_number_of_occurrences=False):
        """
        Given a DNA sequence, return the 1-hot representation of its kmer feature.
        Args:
          seq: 
            a string, a DNA sequence
          write_number_of_occurrences:
            a boolean. If False, then in the 1-hot representation, the percentage of the occurrence of a kmer will be recorded; otherwise the number of occurrences will be recorded. Default False.
        """
        number_of_kmers = len(seq) - self.k + 1

        kmer_feature = np.zeros(self.n)

        for i in range(number_of_kmers):
          this_kmer = seq[i:(i+self.k)]
          this_numbering = self.kmer_numbering_for_one_kmer(this_kmer)
          kmer_feature[this_numbering] += 1

        if not write_number_of_occurrences:
          kmer_feature = kmer_feature / number_of_kmers

        return kmer_feature

    @cache
    def kmer_numbering_for_one_kmer(self, kmer):
        """
        Given a k-mer, return its numbering (the 0-based position in 1-hot representation)
        """
        digits = []
        for letter in kmer:
          digits.append(self.letters.index(letter))

        digits = np.array(digits)

        numbering = (digits * self.multiplyBy).sum()

        return numbering


rule get_kmer_spectrum:
    input:
        "output/embedding/windows.parquet",
        "output/genome.fa.gz",
    output:
        "output/embedding/embeddings/kmers_{k}.parquet",
    threads: workflow.cores
    run:
        windows = pd.read_parquet(input[0])
        genome = Genome(input[1])
        print(windows)
        k = int(wildcards["k"])
        k_f = kmer_featurization(k)
        kmers = windows.progress_apply(
            lambda x: np.array(k_f.obtain_kmer_feature_for_one_sequence(
                genome.get_seq(x.chrom, x.center_start, x.center_end).upper()
            )),
            axis=1, result_type="expand",
        )
        kmers.columns = kmers.columns.astype(str)
        print(kmers)
        kmers.to_parquet(output[0], index=False)


rule subset_windows:
    input:
        "{anything}/windows.parquet",
    output:
        "{anything}/subset_{region}/windows.parquet",
    run:
        windows = pd.read_parquet(input[0])
        print(windows)
        if wildcards.region == "no_repeats":
            mask = windows.Region!="Repeat"
        else:
            mask = windows.Region==wildcards.region
        windows = windows[mask]
        print(windows)
        windows.to_parquet(output[0], index=False)


rule subset_embeddings:
    input:
        "{anything}/windows.parquet",
        "{anything}/embeddings/{model}.parquet",
    output:
        "{anything}/subset_{region}/embeddings/{model}.parquet",
    run:
        windows = pd.read_parquet(input[0])
        embeddings = pd.read_parquet(input[1])
        print(windows, embeddings)
        if wildcards.region == "no_repeats":
            mask = windows.Region!="Repeat"
        else:
            mask = windows.Region==wildcards.region
        windows = windows[mask]
        embeddings = embeddings[mask]
        print(windows, embeddings)
        embeddings.to_parquet(output[0], index=False)


rule run_umap:
    input:
        "{anything}/embeddings/{model}.parquet",
    output:
        "{anything}/umap/{model}.parquet",
    run:
        from umap import UMAP

        embeddings = pd.read_parquet(input[0])
        proj = Pipeline([
            ("scaler", StandardScaler()),
            ("umap", UMAP(random_state=42, verbose=True)),
        ]).fit_transform(embeddings)
        proj = pd.DataFrame(proj, columns=["UMAP1", "UMAP2"])
        proj.to_parquet(output[0], index=False)


rule run_classification:
    input:
        "{anything}/windows.parquet",
        "{anything}/embeddings/{model}.parquet",
    output:
        "{anything}/classification/{model}.parquet",
    threads: workflow.cores
    run:
        from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
        from sklearn.model_selection import cross_val_predict, LeaveOneGroupOut
        from sklearn.pipeline import Pipeline
        from sklearn.preprocessing import StandardScaler

        windows = pd.read_parquet(input[0])
        features = pd.read_parquet(input[1])

        clf = Pipeline([
            ("scaler", StandardScaler()),
            ("linear", LogisticRegressionCV(
                random_state=42, verbose=True, max_iter=1000,
                class_weight="balanced", n_jobs=-1
                )
            ),
        ])
        preds = cross_val_predict(
            clf, features, windows.Region, groups=windows.chrom,
            cv=LeaveOneGroupOut(), verbose=True,
        )
        pd.DataFrame({"pred_Region": preds}).to_parquet(output[0], index=False)


rule run_leiden:
    input:
        "{anything}/embeddings/{model}.parquet",
    output:
        "{anything}/leiden_{resolution}_/{model}.parquet",
    run:
        import anndata
        import scanpy as sc

        embeddings = pd.read_parquet(input[0])
        adata = anndata.AnnData(StandardScaler().fit_transform(embeddings))
        sc.tl.pca(adata, n_comps=100)
        sc.pp.neighbors(adata, n_pcs=100)
        sc.tl.leiden(adata, resolution=float(wildcards.resolution), random_state=42)
        leiden = adata.obs[["leiden"]]
        print(leiden)
        leiden.to_parquet(output[0], index=False)


def find_positions(interval):
    df = pd.DataFrame(dict(pos=range(interval.start, interval.end)))
    df["chrom"] = interval.chrom
    df.pos += 1  # we'll treat as 1-based
    return df


rule make_positions_promoter:
    input:
        "output/genome.fa.gz",
        "output/annotation.gtf.gz",
        "input/repeats.bed.gz",
    output:
        "output/whole_genome/subset/promoter/positions.parquet",
    run:
        genome = Genome(input[0])
        genome.filter_chroms(CHROMS)
        annotation = load_table(input[1])
        annotation = annotation[annotation.chrom.isin(CHROMS)]

        intervals = get_promoters(annotation, 500)
        print(intervals)

        exons = bf.merge(annotation.query("feature == 'exon'")).drop(columns="n_intervals")
        intervals = bf.subtract(intervals, exons)
        print(intervals)

        repeats = bf.merge(load_repeatmasker(input[2]))
        repeats.chrom = repeats.chrom.str.replace("Chr", "")
        repeats = repeats[repeats.chrom.isin(CHROMS)]
        intervals = bf.subtract(intervals, repeats)
        print(intervals)

        intervals = intersect_intervals(
            intervals, bf.expand(genome.get_defined_intervals(), pad=-WINDOW_SIZE//2)
        )
        print(intervals)

        intervals = filter_length(intervals, 30)
        print(intervals)

        positions = pd.concat(
            intervals.progress_apply(find_positions, axis=1).values, ignore_index=True
        )
        print(positions)
        positions.to_parquet(output[0], index=False)


rule make_positions_promoter2:
    input:
        "output/genome.fa.gz",
        "output/annotation.gtf.gz",
        "input/repeats.bed.gz",
    output:
        "output/whole_genome/subset/promoter2/positions.parquet",
    run:
        genome = Genome(input[0])
        genome.filter_chroms(CHROMS)
        annotation = load_table(input[1])
        annotation = annotation[annotation.chrom.isin(CHROMS)]

        intervals = get_promoters(annotation, 1000, 1000)
        print(intervals)

        exons = bf.merge(annotation.query("feature == 'CDS'")).drop(columns="n_intervals")
        intervals = bf.subtract(intervals, exons)
        print(intervals)

        repeats = bf.merge(load_repeatmasker(input[2]))
        repeats.chrom = repeats.chrom.str.replace("Chr", "")
        repeats = repeats[repeats.chrom.isin(CHROMS)]
        intervals = bf.subtract(intervals, repeats)
        print(intervals)

        intervals = intersect_intervals(
            intervals, bf.expand(genome.get_defined_intervals(), pad=-WINDOW_SIZE//2)
        )
        print(intervals)

        intervals = filter_length(intervals, 30)
        print(intervals)

        positions = pd.concat(
            intervals.progress_apply(find_positions, axis=1).values, ignore_index=True
        )
        print(positions)
        positions.to_parquet(output[0], index=False)


rule make_positions_funtfbs:
    input:
        "output/funtfbs/TFBS_from_FunTFBS_genome-wide_Ath.gff",
    output:
        "output/funtfbs/positions.parquet",
    run:
        tfbs = pd.read_csv(
            input[0], sep="\t", header=None,
            names=["chrom", "source1", "source2", "start", "end", "score", "strand", "unknown", "annotation"],
            usecols=["chrom", "start", "end", "strand", "annotation"],
        )
        print(tfbs)
        print(tfbs.annotation.iloc[0])
        tfbs.chrom = tfbs.chrom.str.replace("Chr", "")
        tfbs = tfbs[tfbs.chrom.isin(CHROMS)]
        tfbs.start -= 1
        tfbs = bf.merge(tfbs)

        positions = pd.concat(
            tfbs.progress_apply(find_positions, axis=1).values, ignore_index=True
        )
        print(positions)
        positions.to_parquet(output[0], index=False)


rule make_positions_region:
    input:
        "input/repeats.bed.gz",
    output:
        "output/region/positions.parquet",
    run:
        # example defined region: chr5:0-10131131
        intervals = pd.DataFrame(dict(chrom=["5"], start=[1000], end=[10001000]))

        repeats = bf.merge(load_repeatmasker(input[0]).query('chrom == "Chr5"'))
        repeats.chrom = "5"
        intervals = bf.subtract(intervals, repeats)
        print(intervals)

        positions = pd.concat(
            intervals.progress_apply(find_positions, axis=1).values, ignore_index=True
        )
        print(positions)
        positions.to_parquet(output[0], index=False)


rule make_positions_whole_genome:
    input:
        "output/genome.fa.gz",
    output:
        "output/whole_genome/positions.parquet",
    run:
        genome = Genome(input[0])
        genome.filter_chroms(CHROMS)
        intervals = bf.expand(genome.get_defined_intervals(), pad=-WINDOW_SIZE//2)
        intervals = filter_length(intervals, 1)
        print(intervals)
        positions = pd.concat(
            intervals.progress_apply(find_positions, axis=1).values, ignore_index=True
        )
        print(positions)
        positions.to_parquet(output[0], index=False)


rule get_logits:
    input:
        "output/whole_genome/positions.parquet",
        "output/genome.fa.gz",
    output:
        "output/whole_genome/logits/{model}.parquet",
    threads: workflow.cores // 2
    shell:
        """
        python -m gpn.ss.get_logits {input} {WINDOW_SIZE} {wildcards.model} {output} \
        --per-device-batch-size 4000 --is-file \
        --dataloader-num-workers {threads}
        """


rule subset_logits:
    input:
        "output/whole_genome/positions.parquet",
        "output/whole_genome/logits/{model}.parquet",
        "output/whole_genome/subset/{subset}/positions.parquet",
    output:
        "output/whole_genome/subset/{subset}/logits/{model}.parquet",
    run:
        cols = ["chrom", "pos"]
        pos = pd.read_parquet(input[0]).set_index(cols)
        pos["idx"] = np.arange(len(pos))
        logits = pd.read_parquet(input[1])
        target_pos = pd.read_parquet(input[2])[cols]
        idx = pos.loc[pd.MultiIndex.from_frame(target_pos)].idx.values
        target_logits = logits.iloc[idx]
        target_logits.to_parquet(output[0], index=False)


rule make_bed_probs:
    input:
        "output/whole_genome/positions.parquet",
        "output/whole_genome/logits/{model}.parquet",
    output:
        # this should be made temp
        expand("output/whole_genome/bed_probs/{{model}}/{nuc}.bed", nuc=NUCLEOTIDES),
    run:
        df = pd.read_parquet(input[0])
        df.loc[:, NUCLEOTIDES] = softmax(pd.read_parquet(input[1]), axis=1)
        #df = df.head(1000) # just to iterate faster
        df["entropy"] = entropy(df[NUCLEOTIDES], base=2, axis=1)
        print(df)
        df.loc[:, NUCLEOTIDES] = df[NUCLEOTIDES].values * (2-df[["entropy"]].values)
        print(df)
        df["start"] = df.pos-1
        df["end"] = df.pos
        df.chrom = "chr" + df.chrom
        for nuc, path in zip(NUCLEOTIDES, output):
            df.to_csv(
                path, sep="\t", index=False, header=False,
                columns=["chrom", "start", "end", nuc],
            )


rule make_chrom_sizes:
    input:
        "output/genome.fa.gz",
    output:
        "output/chrom.sizes",
    run:
        intervals = Genome(input[0]).get_all_intervals()
        intervals.chrom = "chr" + intervals.chrom
        intervals.to_csv(
            output[0], sep="\t", index=False, header=False,
            columns=["chrom", "end"],
        )


rule convert_bed_to_bigwig:
    input:
        "{anything}.bed",
        "output/chrom.sizes",
    output:
        "{anything}.bw"
    shell:
        # using a local download because conda version didn't work
        "./bedGraphToBigWig {input} {output}"


# example:
# snakemake --cores all --use-conda --conda-frontend mamba 
rule run_modisco:
    input:
        "{anything}/positions.parquet",
        "{anything}/logits/{model}.parquet",
        "output/genome.fa.gz",
    output:
        "{anything}/modisco/{model}/results.h5",
    conda:
        "envs/modisco-lite.yaml"
    script:
        "modisco_run.py"


rule plot_modisco:
    input:
        "{anything}/modisco/{model}/results.h5",
        "output/plantfdb/Ath_TF_binding_motifs.meme",
        #"output/motif_db/jaspar.meme",
    output:
        directory("{anything}/modisco/{model}/report"),
    conda:
        "envs/modisco-lite.yaml"
    script:
        "modisco_report.py"


rule download_FunTFBS:
    output:
        "output/funtfbs/TFBS_from_FunTFBS_genome-wide_Ath.gff",
    shell:
        """
        wget http://plantregmap.gao-lab.org/download_ftp.php?filepath=08-download/Arabidopsis_thaliana/binding/TFBS_from_FunTFBS_genome-wide_Ath.gff -O {output[0]}
        """
        #&& wget http://plantregmap.gao-lab.org/download_ftp.php?filepath=08-download/Arabidopsis_thaliana/binding/TFBS_from_motif_CE_genome-wide_Ath.gff -O {output[1]} \
        #&& wget http://plantregmap.gao-lab.org/download_ftp.php?filepath=08-download/Arabidopsis_thaliana/binding/TFBS_from_motif_genome-wide_Ath.gff -O {output[2]}


rule download_PlanTFDB:
    output:
        "output/plantfdb/Ath_TF_binding_motifs.meme",
        directory("output/plantfdb/individual"),
    shell:
        """
        wget http://planttfdb.gao-lab.org/download/motif/Ath_TF_binding_motifs.meme.gz -O - | gunzip > {output[0]} && \
        mkdir {output[1]} && \
        wget http://planttfdb.gao-lab.org/download/motif/Ath_TF_binding_motifs_individual.tar.gz -O {output[1]} && \
        tar xzvf {output[1]}/Ath_TF_binding_motifs_individual.tar.gz
        """


rule download_vcf:
    output:
        "output/variants/all.vcf.gz",
    shell:
        "wget --no-check-certificate https://1001genomes.org/data/GMI-MPI/releases/v3.1/1001genomes_snp-short-indel_only_ACGTN.vcf.gz -O {output}"


rule filter_vcf:
    input:
        "output/variants/all.vcf.gz",
    output:
        "output/variants/filt.bed.gz",
    shell:
        "vcftools --gzvcf {input} --counts --out tmp --min-alleles 2 --max-alleles 2 --remove-indels && mv tmp.frq.count output/variants/filt.bed && gzip output/variants/filt.bed"


rule process_variants:
    input:
        "output/variants/filt.bed.gz",
        "output/genome.fa.gz",
    output:
        "output/variants/filt.processed.parquet",
    run:
        variants = pd.read_csv(input[0], sep="\t", header=0, names=["chrom", "pos", "N_ALLELES", "AN", "ref_count", "alt_count"]).drop(columns="N_ALLELES")
        print(variants)
        variants.chrom = variants.chrom.astype(str)
        print(variants)
        genome = Genome(input[1])

        def find_ref_alt_AC(row):
            ref = genome.get_nuc(row.chrom, row.pos).upper()
            assert(ref == row.ref_count[0])
            alt, AC = row.alt_count.split(":")
            AC = int(AC)
            return ref, alt, AC

        variants["ref"], variants["alt"], variants["AC"] = zip(*variants.progress_apply(find_ref_alt_AC, axis=1))
        print(variants)
        variants = variants[["chrom", "pos", "ref", "alt", "AC", "AN"]]
        variants.to_parquet(output[0], index=False)


rule download_ensembl_vep:
    output:
        "output/arabidopsis_thaliana_incl_consequences.vcf.gz",
    shell:
        "wget --no-check-certificate https://ftp.ensemblgenomes.org/pub/plants/release-55/variation/vcf/arabidopsis_thaliana/arabidopsis_thaliana_incl_consequences.vcf.gz -O {output}"


rule process_ensembl_vep:
    input:
        "output/arabidopsis_thaliana_incl_consequences.vcf.gz",
    output:
        "output/ensembl_vep.parquet",
    run:
        from cyvcf2 import VCF

        i = 0
        rows = []
        for variant in VCF(input[0]):
            if variant.INFO.get("TSA") != "SNV": continue
            if len(variant.ALT) > 1: continue
            if variant.FILTER is not None: continue  # this is supposed to mean PASS
            VEP = variant.INFO.get("VE").split(",")
            consequence = ",".join(np.unique([transcript_vep.split("|")[0] for transcript_vep in VEP]))
            rows.append([variant.CHROM, variant.POS, variant.REF, variant.ALT[0], consequence])
            i += 1
            if i % 100000 == 0: print(i)
        df = pd.DataFrame(data=rows, columns=["chrom", "pos", "ref", "alt", "consequence"])
        print(df)
        df.to_parquet(output[0], index=False)


rule add_info_variants:
    input:
        "output/variants/filt.processed.parquet",
        "output/ensembl_vep.parquet",
    output:
        "output/variants/all.parquet",
    run:
        variants = pd.read_parquet(input[0])
        print(variants)
        ensembl_vep = pd.read_parquet(input[1])
        print(ensembl_vep)
        variants = variants.merge(ensembl_vep, how="left", on=["chrom", "pos", "ref", "alt"])
        print(variants)
        variants.to_parquet(output[0], index=False)



rule filter_variants:
    input:
        "output/variants/genotype_matrix.hdf5",
        "output/variants/all.parquet",
        "output/genome.fa.gz",
    output:
        "output/variants/all/variants.parquet",
    run:
        f = h5py.File(input[0], 'r')
        n_snps, n_accessions = f["snps"].shape
        AC = f["snps"][:].sum(axis=1)
        pos = f["positions"][:]
        chrom = np.empty_like(pos, dtype=str)
        for c, (left, right) in zip(
            f["positions"].attrs["chrs"], f['positions'].attrs['chr_regions']
        ):
            chrom[left:right] = str(c)
        variants = pd.DataFrame({"chrom": chrom, "pos": pos, "AC": AC})
        variants["AF"] = variants.AC / n_accessions 

        variant_info = pd.read_parquet(input[1])
        variants = variants.merge(
            variant_info.drop(columns=["AC", "AN"]), how="inner", on=["chrom", "pos"]
        )
        variants = variants[["chrom", "pos", "ref", "alt", "AC", "AF", "consequence"]]
        print(variants)

        TOTAL_SIZE = 512
        variants["start"] = variants.pos - TOTAL_SIZE // 2
        variants["end"] = variants.start + TOTAL_SIZE
        variants["strand"] = "+"

        genome = Genome(input[2])

        def check_seq(w):
            seq = genome.get_seq(w.chrom, w.start, w.end).upper()
            if len(seq) != TOTAL_SIZE: return False
            if re.search("[^ACTG]", seq) is not None: return False
            return True

        variants = variants[variants.progress_apply(check_seq, axis=1)]
        variants.drop(columns=["start", "end", "strand"], inplace=True)

        print(variants)
        variants.to_parquet(output[0], index=False)


rule filter_chrom:
    input:
        "output/variants/all/variants.parquet",
    output:
        "output/variants/chrom/{chrom}/variants.parquet",
    run:
        df = pd.read_parquet(input[0])
        print(df.shape)
        df = df[df.chrom==wildcards["chrom"]]
        print(df.shape)
        df.to_parquet(output[0], index=False)


rule make_simulated_variants:
    input:
        "output/genome.fa.gz",
    output:
        # then take this file to ensembl vep online and run with option
        # upstream/downstream = 500
        "output/simulated_variants/variants.vcf.gz",
        "output/simulated_variants/variants.parquet",
    run:
        genome = Genome(input[0])
        chrom = "5"
        start = 3500000
        end = start + 1_000_000
        rows = []
        nucleotides = list("ACGT")
        for pos in range(start, end):
            ref = genome.get_nuc(chrom, pos).upper()
            for alt in nucleotides:
                if alt == ref: continue
                rows.append([chrom, pos, '.', ref, alt, '.', '.', '.'])
        df = pd.DataFrame(data=rows)
        print(df)
        df.to_csv(output[0], sep="\t", index=False, header=False)
        df[[0, 1, 3, 4]].rename(
            columns={0: "chrom", 1: "pos", 3: "ref", 4: "alt"}
        ).to_parquet(output[1], index=False)


rule run_vep:
    input:
        "{anything}/variants.parquet",
        "output/genome.fa.gz",
    output:
        "{anything}/vep/{model}.parquet",
    threads: workflow.cores
    shell:
        """
        python -m gpn.ss.run_vep {input} {WINDOW_SIZE} {wildcards.model} {output} \
        --per-device-batch-size 4000 --is-file \
        --dataloader-num-workers {threads}
        """


rule run_vep_embeddings:
    input:
        "{anything}/variants.parquet",
        "output/genome.fa.gz",
    output:
        "{anything}/vep_embeddings/{model}.parquet",
    threads: workflow.cores
    shell:
        """
        python -m gpn.ss.run_vep_embeddings {input} {WINDOW_SIZE} {wildcards.model} {output} \
        --per-device-batch-size 4000 --is-file \
        --dataloader-num-workers {threads}
        """


rule download_phastcons:
    output:
        "output/conservation/phastCons.bedGraph.gz"
    shell:
        "wget http://plantregmap.gao-lab.org/download_ftp.php?filepath=08-download/Arabidopsis_thaliana/sequence_conservation/Ath_PhastCons.bedGraph.gz -O {output}"


rule download_phylop:
    output:
        "output/conservation/phyloP.bedGraph.gz"
    shell:
        "wget http://plantregmap.gao-lab.org/download_ftp.php?filepath=08-download/Arabidopsis_thaliana/sequence_conservation/Ath_PhyloP.bedGraph.gz -O {output}"


rule process_conservation:
    input:
        "{anything1}/variants.parquet",
        "output/conservation/{model,phyloP|phastCons}.bedGraph.gz",
    output:
        "{anything1}/{model}.parquet",
    run:
        variants = pd.read_parquet(input[0])

        conservation = pd.read_csv(
            input[1], sep="\t", header=None,
            names=["chrom", "pos", "end", "score"],
        ).drop(columns=["end"])
        conservation.pos += 1
        conservation.chrom = conservation.chrom.str.replace("Chr", "")
        conservation.score *= -1

        variants = variants.merge(conservation, how="left", on=["chrom", "pos"])
        variants = variants[["score"]]
        variants.to_parquet(output[0], index=False)


rule download_genotype:
    output:
        "output/variants/genotype_matrix.hdf5",
    shell:
        "wget https://aragwas.1001genomes.org/api/genotypes/download -O genotype.zip && unzip genotype.zip && mv GENOTYPES/4.hdf5 {output}"


rule download_aragwas:
    output:
        "output/aragwas/api_chr{chrom}.csv.gz",
    shell:
        """
        wget "https://aragwas.1001genomes.org/api/associations/download/?&chr={wildcards.chrom}&mac=5" -O - | gzip > {output}
        """


rule merge_aragwas:
    input:
        expand("output/aragwas/api_chr{chrom}.csv.gz", chrom=CHROMS),
    output:
        "output/aragwas/api.parquet",
    run:
        aragwas = pd.concat([
            pd.read_csv(
                path,
                usecols=["snp.chr", "snp.position", "snp.ref", "snp.alt", "score", "maf", "mac",
                        "study.name", "study.phenotype.name", "overBonferroni", "overPermutation"],
            ).rename(columns={
                "snp.chr": "chrom",
                "snp.position": "pos",
                "snp.ref": "ref",
                "snp.alt": "alt",
                "score": "aragwas_score",
            }) for path in input],
            ignore_index=True
        )
        aragwas.chrom = aragwas.chrom.str.replace("chr", "")
        print(aragwas)
        aragwas.to_parquet(output[0], index=False)


rule make_aragwas_genotype:
    input:
        "output/aragwas/api.parquet",
        "output/variants/all/variants.parquet",
        "output/variants/genotype_matrix.hdf5",
    output:
        "output/aragwas/coordinates.parquet",
        "output/aragwas/genotype.npz",
    run:
        cols = ["chrom", "pos", "ref", "alt"]
        aragwas = pd.read_parquet(input[0]).drop_duplicates(cols)
        variants = pd.read_parquet(input[1])
        # remember to put sort=True in the vep_aragwas.ipynb
        variants = variants.merge(aragwas, how="inner", on=cols, sort=True)
        print(variants.shape)
        variants[cols].to_parquet(output[0], index=False)
        f = h5py.File(input[2], 'r')
        pos = f["positions"][:]
        chrom = np.empty_like(pos, dtype=str)
        for c, (left, right) in zip(
            f["positions"].attrs["chrs"], f['positions'].attrs['chr_regions']
        ):
            chrom[left:right] = str(c)
        snp_index = pd.DataFrame({
            "chrom": chrom, "pos": pos, "snp_index": np.arange(len(pos))
        }).set_index(["chrom", "pos"]).snp_index
        print(snp_index)
        locations = snp_index[pd.MultiIndex.from_frame(variants[["chrom", "pos"]])].values
        X = sp_sparse.csr_matrix(f["snps"][locations])
        print(X.shape)
        sp_sparse.save_npz(output[1], X)


rule make_aragwas_vcf:
    input:
        "output/aragwas/coordinates.parquet",
        "output/aragwas/genotype.npz",
    output:
        "output/aragwas/variants.vcf.bgz",
    run:
        coordinates = pd.read_parquet(input[0])
        genotype = sp_sparse.load_npz(input[1])
        n_variants, n_samples = genotype.shape
        sample_names = [str(j) for j in range(n_samples)]

        # it's pretty slow, cyvcf2 might be faster
        import vcfpy
        header_lines = [
            vcfpy.HeaderLine("fileformat", "VCFv4.2"),
            vcfpy.FormatHeaderLine.from_mapping({"ID": "GT", "Number": "1", "Type": "String", "Description": "Genotype"}),
        ]
        header = vcfpy.Header(
            lines=header_lines,
            samples=vcfpy.SamplesInfos(sample_names)
        )
        with open(output[0], "wb") as f:
            writer = vcfpy.Writer.from_stream(f, header, use_bgzf=True)
            for i in tqdm(range(n_variants)):
                row = coordinates.iloc[i]
                snp = genotype[i].toarray().ravel()
                record = vcfpy.Record(
                    row["chrom"],  # chromosome
                    row["pos"],  # position
                    ".",  # ID (missing)
                    row["ref"],  # reference allele
                    [vcfpy.Substitution("SNV", row["alt"])],  # list of alternate alleles
                    ".",  # quality (missing)
                    [],  # list of filters (missing)
                    {},  # dictionary of INFO fields (missing)
                    ["GT"],  # dictionary of FORMAT fields (missing)
                    calls=[vcfpy.Call(sample=sample_names[j], data={"GT": str(snp[j])}) for j in range(n_samples)]
                )
                writer.write_record(record)
            writer.close()


# note, there might be easier ways to create a zarr file
# without passing thru a vcf
#rule vcf2zarr:
#    input:
#
#        "output/aragwas/variants.vcf.bgz",
#    output:
#        "output/aragwas/variants.zarr",
#    run:
#        shell(f"tabix -p vcf {input}")
#        vcf_to_zarr(input, output, ploidy=1, n_alt_alleles=1)


rule calculate_LD:
    input: "output/aragwas/variants.zarr",
    output: "output/aragwas/LD_{radius,[0-9]+}_{threshold}.npz",
    threads: workflow.cores
    run:
        import sgkit as sg
        from sgkit.io.vcf import vcf_to_zarr
        from dask.diagnostics import ProgressBar
        pbar = ProgressBar()
        pbar.register()

        ds = sg.load_dataset(input[0])
        ds["call_dosage"] = ds.call_genotype.sum(dim="ploidy")
        radius = int(wildcards.radius)
        ds = sg.window_by_position(
            ds, size=radius, offset=-radius/2, window_start_position="variant_position",
        )
        ld = sg.ld_matrix(ds, threshold=float(wildcards.threshold))
        ld = ld.compute()
        ld = sp_sparse.coo_matrix((ld.value.values, (ld.i.values, ld.j.values))).tocsr()
        sp_sparse.save_npz(output[0], ld)
