from Bio import SeqIO
from Bio.Seq import Seq
import bioframe as bf
from datasets import load_dataset
from gpn.data import load_fasta, load_table, Genome
from gpn.data import (
    filter_defined, filter_length, load_table, add_flank, get_annotation_features,
    add_jitter, get_promoters, get_random_intervals, union_intervals,
    intersect_intervals, intervals_size, get_balanced_intervals
)
from gpn.data import make_windows, get_seq
from gpn.data import GenomeMSA
import gzip
from joblib import Parallel, delayed
import numpy as np
import os
import pandas as pd
from pathlib import Path
import scipy.sparse as sp_sparse
from scipy.special import softmax
from scipy.stats import combine_pvalues, entropy
from tqdm import tqdm
tqdm.pandas()


configfile: "config/config.yaml"


# TODO: move into config/common.smk
NUCLEOTIDES = list("ACGT")
WINDOW_SIZE = 128
FASTA_URL = "http://ftp.ensembl.org/pub/release-107/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna_sm.primary_assembly.fa.gz"
GTF_URL = "http://ftp.ensembl.org/pub/release-107/gtf/homo_sapiens/Homo_sapiens.GRCh38.107.chr.gtf.gz"
SPLIT_CHROMS = {
    "train": [str(i) for i in range(1, 21)] + ["X", "Y"],
    "validation": ["21"],
    "test": ["22"],
}
SPLITS = SPLIT_CHROMS.keys()
CHROMS = np.array([str(i) for i in range(1, 23)] + ["X", "Y"])

datasets = [
    #"results/clinvar/merged",
    #"results/cosmic/merged",
    #"results/omim/merged",
    #"results/gnomad/merged/filt",
    #"results/gnomad/merged/subsampled",
    #"songlab/human_variants",
    #"results/gwas/matched",
    #"results/gwas/matched/missense_variant",
    #"results/eqtl/matched/ge",
    #"results/eqtl/matched/leafcutter",
    #"results/mpra/merged",
    #"results/SGE/DDX3X/processed",
    #"results/dms/merged",
    #"results/variants_enformer",
    #"results/ism_subset",
    #"results/dms2/snv",
    "results/clinvar/likely",
]

model_template = "{dataset}/{model_size}/{loss_weight}/{seed}/{max_steps}/{use_aux_features}/{weight_conserved}/{flip_nonconserved}"

# default is first
hparams = {
    "dataset": [
        # default
        "multiz100way/89/128/64/True/defined.phastCons.percentile-75_0.05_0.001",

        # changing percentage
        "multiz100way/89/128/64/True/defined.phastCons.percentile-75_0.50_0.001",
        "multiz100way/89/128/64/True/defined.phastCons.percentile-75_1.00_0.00",

        # using different species
        "multiz100way/99/128/64/True/defined.phastCons.percentile-75_0.05_0.001",
        "multiz100way_mammals/51/128/64/True/defined.phastCons.percentile-75_0.05_0.001",
        "multiz100way_vertebrates/51/128/64/True/defined.phastCons.percentile-75_0.05_0.001",

        # changing window size
        #"multiz100way/89/32/16/True/defined.phastCons.percentile-75_0.05_0.001",
        #"multiz100way/89/64/32/True/defined.phastCons.percentile-75_0.05_0.001",
        "multiz100way/89/256/128/True/defined.phastCons.percentile-75_0.05_0.001",
        #"multiz100way/89/512/256/True/defined.phastCons.percentile-75_0.05_0.001",
    ],
    "use_aux_features": [
        True,
        False,
    ],
    "loss_weight": [
        0.1,
    ],
    "weight_conserved": [
        True,
        False,
    ],
    "flip_nonconserved": [
        True,
        False,
    ],
}

default_d = {k: v[0] for k, v in hparams.items()}
hparam_models = [expand(model_template, **default_d, allow_missing=True)[0]]

# one ablation at a time
for k, v in hparams.items():
    for vv in v[1:]:
        new_d = default_d.copy()
        new_d[k] = vv
        hparam_models.append(expand(model_template, **new_d, allow_missing=True)[0])

models = sum(
    [
        expand(
            m,
            seed=[
                42,
                43,
                44,
            ],
            max_steps=[
                30_000,
            ],
            model_size=[
                "medium",
            ],
        ) for m in hparam_models
    ],
    []
)
best_model = models[0]
models = [best_model]
#models = []

#models += [
#    f"{models[i]}.{window_size_ablation}"
#    for window_size_ablation in [4, 8, 16, 32, 64]
#    for i in range(3)
#]


include: "rules/common.smk"
include: "rules/training.smk"
include: "rules/vep.smk"
include: "rules/embeddings.smk"
include: "rules/logits.smk"

include: "rules/nucleotide_transformer.smk"
include: "rules/conservation.smk"
include: "rules/esm1b.smk"
include: "rules/cadd.smk"
include: "rules/spliceai.smk"
include: "rules/hyenadna.smk"

#models.append("msa_multiz100way/89")
#models.append("conservation_combination")
#pd.DataFrame(models).to_csv("ablation_models.txt", index=False, header=False)

#models.append(nucleotide_transformer_models[-1])
#models.append(hyenadna_models[-1])
models.append("ESM-1b")
models.append("phyloP")
models.append("phastCons")
models.append("phyloP-Zoonomia")
#models.append("PrimateAI-3D")


rule all:
    input:
        expand("results/preds/{dataset}/{model}.parquet", model=models, dataset=datasets),
        #expand("results/preds/{dataset}/{model}.parquet", model=[nucleotide_transformer_models[-1], best_model, "phyloP", "phastCons", "phyloP-Zoonomia"], dataset=["results/omim/merged"]),
        #expand("results/preds/{dataset}/{model}.parquet", model=[best_model, "phyloP", "phastCons", "phyloP-Zoonomia"], dataset=["results/variants_enformer"]),
        #expand("results/preds/vep_embedding/{dataset}/{model}.parquet", model=models, dataset=datasets),
        #expand("results/preds/vep_embedding/{dataset}/{model}.{shard}.parquet", model=models, dataset=datasets, shard=range(n_shards)),
        #expand("results/dataset/{dataset}/train.parquet", dataset=hparams["dataset"]),
        #expand("results/checkpoints/{model}", model=models),
        #expand("results/preds/{dataset}/{model}.{window_size_ablation}.parquet", model=models, dataset=datasets, window_size_ablation=[4, 8, 16, 32, 64]),
        #expand("results/positions/{chrom}/processed_logits/{model}.parquet", chrom=CHROMS, model=[best_model]),
        #f"results/positions/promoter/modisco/{best_model}/report",
        #f"results/embedding/umap/{best_model}.parquet",
        #f"results/embedding/classification/{best_model}.parquet",
        #expand("results/preds/results/variants_enformer/{model}.parquet", model=[best_model, "phyloP"]),
        #f"results/positions/6/bed_probs/{best_model}/bigwig.done",
        #f"results/preds/results/gnomad/all/defined/128/{best_model}.parquet",
        #f"results/add_preds/results/gnomad/all/defined/128/{best_model}.parquet",
        #f"results/add_preds/results/gnomad/all/defined/128/{best_model}.tsv.bgz.tbi",
        #f"results/preds/results/gwas/matched/{best_model}.BestFeature.parquet",
        #f"results/preds/results/gwas/matched/{best_model}.SumFeatures.parquet",
        #f"results/preds/vep_embedding/results/gwas/matched/{best_model}/concat/Enformer.parquet",


rule all2:
    input:
        #f"results/preds/results/gnomad/merged/all/{best_model}.parquet",
        f"results/preds/results/gnomad/merged/all/msa_multiz100way/89.parquet",
