from gpn.data import load_table, get_balanced_intervals
from gpn.data import make_windows, get_seq
from gpn.data import Genome
import gzip
import math
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()


# Assembly metadata downloaded from NCBI Genome
# (https://www.ncbi.nlm.nih.gov/data-hub/genome).
# You can choose a set of taxa and apply filters such as annotation level,
# assembly level.
# Checkout the script gpn/ss/filter_assemblies.py for more details, such as 
# how to subsample, or how to keep only one assembly per genus.
assemblies = pd.read_csv(config["assemblies_path"], sep="\t", index_col=0)
assemblies["Assembly Name"] = assemblies["Assembly Name"].str.replace(" ", "_")
assemblies["genome_path"] = (
    "tmp/" + assemblies.index + "/ncbi_dataset/data/" + assemblies.index + "/" +
    assemblies.index + "_" + assemblies["Assembly Name"] + "_genomic.fna"
)
assemblies["annotation_path"] = (
    "tmp/" + assemblies.index + "/ncbi_dataset/data/" + assemblies.index + "/genomic.gff"
)

splits = ["train", "validation", "test"]
# In general we will randomly select chroms for validation and test
split_proportions = [
    config.get("proportion_train", 0.99),
    config.get("proportion_validation", 0.005),
    config.get("proportion_test", 0.005),
]
assert np.isclose(sum(split_proportions), 1)

# The following will be enforced to be validation and test
whitelist_validation_chroms = config.get("whitelist_validation_chroms", [])
whitelist_test_chroms = config.get("whitelist_test_chroms", [])

# the goal is to have shards of a certain size
# 2M works well with sequences of length 512, for example
samples_per_file = config.get("samples_per_file", 2_000_000)


rule download_genome:
    output:
        "output/genome/{assembly}.fa.gz",
        "output/annotation/{assembly}.gff.gz",
    params:
        tmp_dir=directory("tmp/{assembly}"),
        genome_path=lambda wildcards: assemblies.loc[wildcards.assembly, "genome_path"],
        annotation_path=lambda wildcards: assemblies.loc[wildcards.assembly, "annotation_path"],
    shell:
        """
        mkdir -p {params.tmp_dir} && cd {params.tmp_dir} && 
        datasets download genome accession {wildcards.assembly} --include genome,gff3 \
        && unzip ncbi_dataset.zip && cd - && gzip -c {params.genome_path} > {output[0]}\
         && gzip -c {params.annotation_path} > {output[1]} && rm -r {params.tmp_dir}
        """


rule make_defined_intervals:
    input:
        "output/genome/{assembly}.fa.gz",
    output:
        "output/intervals/{window_size}/{assembly}/defined.parquet",
    threads: 2
    run:
        I = Genome(input[0]).get_defined_intervals()
        I = filter_length(I, wildcards["window_size"])
        I.to_parquet(output[0], index=False)


rule make_balanced_intervals:
    input:
        "output/intervals/{window_size}/{assembly}/defined.parquet",
        "output/annotation/{assembly}.gff.gz",
    output:
        "output/intervals/{window_size}/{assembly}/balanced.parquet",
    run:
        defined_intervals = load_table(input[0])
        annotation = load_table(input[1])
        intervals = get_balanced_intervals(
            defined_intervals, annotation, int(wildcards.window_size),
            config.get("promoter_upstream", 1000),
        )
        intervals.to_parquet(output[0], index=False)


rule make_dataset_assembly:
    input:
        "output/intervals/{window_size}/{assembly}/{anything}.parquet",
        "output/genome/{assembly}.fa.gz",
    output:
        expand("output/dataset/{{window_size}}/{{step_size}}/{{add_rc}}/{{assembly}}/{{anything}}/{split}.parquet", split=splits),
    threads: 2
    run:
        intervals = pd.read_parquet(input[0])
        genome = Genome(input[1])
        intervals = make_windows(
            intervals, int(wildcards.window_size), int(wildcards.step_size),
            wildcards.add_rc=="True",
        )
        print(intervals)
        intervals = intervals.sample(frac=1.0, random_state=42)
        intervals["assembly"] = wildcards["assembly"]
        intervals = intervals[["assembly", "chrom", "start", "end", "strand"]]
        intervals = get_seq(intervals, genome)
        assert intervals.seq.apply(
            lambda seq: re.search("[^ACGTacgt]", seq) is None
        ).all()
        print(intervals)

        chroms = intervals.chrom.unique()
        chrom_split = np.random.choice(
            splits, p=split_proportions, size=len(chroms),
        )
        chrom_split[np.isin(chroms, whitelist_validation_chroms)] = "validation"
        chrom_split[np.isin(chroms, whitelist_test_chroms)] = "test"
        chrom_split = pd.Series(chrom_split, index=chroms)

        intervals_split = chrom_split[intervals.chrom]

        for path, split in zip(output, splits):
            print(path, split)
            # to parquet to be able to load faster later
            intervals[(intervals_split==split).values].to_parquet(
                path, index=False,
            )


# before uploading, remove data/split/.snakemake_timestamp files
rule merge_datasets:
    input:
        expand("output/dataset/{{window_size}}/{{step_size}}/{{add_rc}}/{assembly}/{{anything}}/{{split}}.parquet", assembly=assemblies.index),
    output:
        directory("output/merged_dataset/{window_size}/{step_size}/{add_rc}/{anything}/data/{split}"),
    threads: workflow.cores
    run:
        intervals = pd.concat(
            tqdm((pd.read_parquet(path) for path in input), total=len(input)),
            ignore_index=True,
        ).sample(frac=1, random_state=42)
        print(intervals)

        if config.get("subsample_to_target", False) and wildcards.split == "train":
            n_target = (intervals.assembly==config["target_assembly"]).sum()
            intervals = intervals.groupby("assembly").sample(
                n=n_target, random_state=42
            ).sample(frac=1, random_state=42)
            print(wildcards.split, intervals.assembly.value_counts())
            print(intervals)

        n_shards = math.ceil(len(intervals) / samples_per_file)
        assert n_shards < 10000
        os.makedirs(output[0])
        for i in tqdm(range(n_shards)):
            path = Path(output[0]) / f"shard_{i:05}.jsonl.zst"
            intervals.iloc[i::n_shards].to_json(
                path, orient="records", lines=True,
                compression={'method': 'zstd', 'threads': -1}
            )
