#!/usr/bin/env python

import argparse
import sys
import subprocess

encoding = 'ascii'

# Parse arguments
parser = argparse.ArgumentParser(description='Use starcode to demultiplex reads containing UMI identifiers.')
parser.add_argument('--umi-len', required=True, type=int, help="UMI lenght (must be at the beginning of the read)")
parser.add_argument('--umi-d', required=False, default=0, type=int, help="UMI match distance (default: 0)")
parser.add_argument('--seq-d', required=False, default=-1, type=int, help="Sequence match distance (default: auto)")
parser.add_argument('--seq-trim', required=False, default=50, type=int, help="Trim sequence length used for clustering. Set to 0 to disable trimming. (default: 50)")
parser.add_argument('--umi-cluster', required=False, default='mp', type=str, help="Algorithm used to cluster UMIs: 'mp' for message passing, 's' for spheres or 'cc' for connected components (default: 'mp')")
parser.add_argument('--seq-cluster', required=False, default='mp', type=str, help="Algorithm used to cluster sequences: 'mp' for message passing, 's' for spheres or 'cc' for connected components (default: 'mp')")
parser.add_argument('--seq-id', action='store_true', help="Print sequence id numbers (1-based)")
parser.add_argument('--umi-threads', required=False, default=1, type=int, help="Starcode threads for UMI clustering (recommended: 1, default: 1)")
parser.add_argument('--seq-threads', required=False, default=1, type=int, help="Starcode threads for seq clustering (default: 1)")
parser.add_argument('--umi-cluster-ratio', required=False, default=3, type=int, help="Min size ratio for merging clusters in UMI message passing (default: 3)")
parser.add_argument('--seq-cluster-ratio', required=False, default=3, type=int, help="Min size ratio for merging clusters in seq message passing (default: 3)")
parser.add_argument('--starcode-path', required=False, default='./starcode', type=str, help="Path to starcode binary file (default: './starcode')")
parser.add_argument('input_file_1',help="Input file in raw, FASTA or FASTQ format. (Forward read file for paired-end mode with one UMI at each end)")
parser.add_argument('input_file_2', nargs='?', default=None, help="Only for paired-end mode with one UMI at each end: Reverse read file.")

params = parser.parse_args()

# Starcode options
path = params.starcode_path
umi_cluster = params.umi_cluster
seq_cluster = params.seq_cluster
umi_ratio = params.umi_cluster_ratio
seq_ratio = params.seq_cluster_ratio

# UMI and SEQ params
umi_len = params.umi_len #length
umi_tau = params.umi_d #umi mismatches
seq_tau = params.seq_d #seq mismatches
seq_trim = params.seq_trim
output_seq_ids = params.seq_id

# Check arguments
clust_opts = ['s','mp','cc']
if umi_cluster not in clust_opts or seq_cluster not in clust_opts:
   sys.stderr.write("Error, cluster option must be 's', 'mp' or 'cc'.\n")
   sys.exit(1)

# starcode subprocesses
# sequence options
if seq_cluster == 's':
   seq_args = "-s"
elif seq_cluster == 'mp':
   seq_args = "-r{}".format(seq_ratio)
elif seq_cluster == 'cc':
   seq_args = "-c"
   
# UMI options
if umi_cluster == 's':
   umi_args = "-s"
elif umi_cluster == 'mp':
   umi_args = "-r{}".format(umi_ratio)
elif umi_cluster == 'cc':
   umi_args = "-c"

# Subprocess: Cluster sequences.
seqarg = [path,"--seq-id","-qd"+str(seq_tau),seq_args,"-t"+str(params.seq_threads)]
seqproc = subprocess.Popen(seqarg, stdout=subprocess.PIPE, stdin=subprocess.PIPE)

# Subprocess: Cluster UMI.
umiarg = [path,"--seq-id","-qd"+str(umi_tau),umi_args,"-t"+str(params.umi_threads)]
umiproc = subprocess.Popen(umiarg, stdout=subprocess.PIPE, stdin=subprocess.PIPE)

# Fill element 0 (starcode seq-ids are 1 based).
data = [None,]
form = 0

sys.stderr.write("parsing sequences ")

# Read input file (single end mode)
if params.input_file_2 is None:
   with open(params.input_file_1) as f:
      # Detect file format.
      line = f.readline().rstrip()
      if line[0] == '@':
         form = 4
      elif line[0] == '>':
         form = 2
      elif line[0] in 'AaCcGgTtUuNn':
         form = 1
         barcode = line[:umi_len]
         sequence = line[umi_len:]
         # Send sequence to seq process.
         if seq_trim == 0:
            seqproc.stdin.write(bytearray(sequence+"\n", encoding))
            data.append((0,""))
         else:
            seqproc.stdin.write(bytearray(sequence[:seq_trim]+"\n", encoding))
            data.append((0,sequence[seq_trim:]))
         # Send barcode to umi process.
         umiproc.stdin.write(bytearray(barcode+"\n", encoding))
      else:
         sys.stderr.write("Unrecognized input format. Compatible formats: FASTQ, FASTA or raw nucleotides (AaCcGgTtUuNn).\n")
         sys.exit(1)
         
      sys.stderr.write("({} format)...".format("FASTQ" if form == 4 else "FASTA" if form == 2 else "raw"))
      lineno = 0
      for line in f:
         if lineno % form == 0:
            # Starcode seqs and UMIs independently.
            line = line.rstrip()
            barcode = line[:umi_len]
            sequence = line[umi_len:]
            # Send sequence to seq process.
            if seq_trim == 0:
               seqproc.stdin.write(bytearray(sequence+"\n", encoding))
               data.append((0,""))
            else:
               seqproc.stdin.write(bytearray(sequence[:seq_trim]+"\n", encoding))
               data.append((0,sequence[seq_trim:]))
            # Send barcode to umi process.
            umiproc.stdin.write(bytearray(barcode+"\n", encoding))
         # Increase linenumber count
         lineno += 1
   
# Read input file (paired end mode)
else:
   with open(params.input_file_1) as f1, open(params.input_file_2) as f2:
      # Detect file format.
      fline = f1.readline().rstrip()
      rline = f2.readline().rstrip()
      if fline[0] == '@' and rline[0] == '@':
         form = 4
      elif fline[0] == '>' and rline[0] == '>':
         form = 2
      elif fline[0] in 'AaCcGgTtUuNn' and rline[0] in 'AaCcGgTtUuNn':
         form = 1
         barcode = fline[:umi_len]+rline[:umi_len]
         sequence = fline[umi_len:]+rline[umi_len:]
         # Send sequence to seq process.
         if seq_trim == 0:
            seqproc.stdin.write(bytearray(sequence+"\n", encoding))
            data.append((len(fline)-umi_len,""))
         else:
            seqproc.stdin.write(bytearray(sequence[:seq_trim]+"\n", encoding))
            data.append((len(fline)-umi_len,sequence[seq_trim:]))
         # Send barcode to umi process.
         umiproc.stdin.write(bytearray(barcode+"\n", encoding))
      else:
         sys.stderr.write("Unrecognized input format. Compatible formats: FASTQ, FASTA or raw nucleotides (AaCcGgTtUuNn).\n")
         sys.exit(1)
         
      sys.stderr.write("({} paired-end format)...".format("FASTQ" if form == 4 else "FASTA" if form == 2 else "raw"))
      
      lineno = 0
      # Read forward and reverse files simultaneously
      for fline,rline in zip(f1,f2):
         if lineno % form == 0:
            # Starcode sequenes seqs and store UMIs and the rest of the sequence.
            fline = fline.rstrip()
            rline = rline.rstrip()
            barcode = fline[:umi_len]+rline[:umi_len]
            sequence = fline[umi_len:]+rline[umi_len:]
            # Send sequence to seq process.
            if seq_trim == 0:
               seqproc.stdin.write(bytearray(sequence+"\n", encoding))
               data.append((len(fline)-umi_len,""))
            else:
               seqproc.stdin.write(bytearray(sequence[:seq_trim]+"\n", encoding))
               data.append((len(fline)-umi_len,sequence[seq_trim:]))
            # Send barcode to UMI process.
            umiproc.stdin.write(bytearray(barcode+"\n", encoding))
         # Increase linenumber count
         lineno += 1
   

# Close pipes and let starcode run.
seqproc.stdin.close()
umiproc.stdin.close()
sys.stderr.write("\nclustering sequences...\n")

# Read UMI clusters.
clustno = 0
for line in umiproc.stdout:
   # Verbose
   clustno += 1
   sys.stderr.write("processing UMI (cluster {})...\r".format(clustno))
   # Parse starcode output.
   mout = line.decode(encoding).rstrip().split("\t")
   umi_id = mout[2].split(",")
   # Store UMI cluster ID (same as seq_id of the canonicals).
   umi_canon = mout[0]
   for id in umi_id:
      data[int(id)] += (umi_canon,)

# Wait UMI process.
umiproc.wait()
sys.stderr.write("processing UMI (cluster {})...\nUMI process finished.\n".format(clustno))

# Read seq clusters.
clustno = 0
for line in seqproc.stdout:
   # Verbose
   clustno += 1
   sys.stderr.write("processing sequences (cluster {})\r".format(clustno))
   # Parse starcode output.
   mout = line.decode(encoding).rstrip().split("\t")
   seq_id = mout[2].split(",")

   # Aux vars to recover canonical end.
   canon_end = ""
   seqbuf = {}
   maxcnt = 0
   # Aux vars for final clusters.
   clusters = {}
   for id in seq_id:
      # Merge seq id and umi canonical to form final clusters.
      umi = data[int(id)][2]
      if umi not in clusters:
         clusters[umi] = (id,)
      else:
         clusters[umi] += (id,)
         
      # If we trimmed, we need to find the canonical for the rest of the sequence.      
      if seq_trim > 0:
         seq_end = data[int(id)][1]
         if seq_end not in seqbuf:
            seqbuf[seq_end] = 1
         else:
            seqbuf[seq_end] += 1
         if seqbuf[seq_end] > maxcnt:
            maxcnt = seqbuf[seq_end]
            canon_end = seq_end

   # Print clusters
   canonical_sequence = mout[0]+canon_end
   for umi in clusters:
      # Output format for single end.
      if params.input_file_2 is None:
         # Append canonical sequence to UMI cluster.
         clust_out = umi+canonical_sequence
      # Output format for paired end.
      else:
         # Use PE sequence break from canonical.
         seq_break = data[int(seq_id[0])][0]
         # Split barcode and join canonical sequence.
         clust_out = umi[:umi_len]+canonical_sequence[:seq_break]+"/"+umi[umi_len:]+canonical_sequence[seq_break:]
      # Append sequence count.
      clust_out += "\t"+str(len(clusters[umi]))
      # Append sequence ids.
      if output_seq_ids:
         clust_out += "\t"+",".join(clusters[umi])
      # Write to stdout.
      sys.stdout.write(clust_out.rstrip()+'\n')
         
seqproc.wait()
sys.stderr.write("processing sequences (cluster {})\nsequence process finished.\n".format(clustno))
