#!/usr/bin/env python

"""
Transfers annotation between two genome records, based upon reciprocal
best blast hits. Missing, duplicated and unique proteins are also identified.
"""

import argparse
import drmaa
import glob
import pandas as pd
import re
import shutil
import stat
import sys
import tempfile
import os

from utils import check_format, drmaa_run
from Bio.Blast.Applications import NcbimakeblastdbCommandline
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

def parse_genome(genome: str, locus_tags: list, type: str, tmpdir: str):
    genome_name=os.path.basename(genome)
    genome_name=os.path.splitext(genome_name)[0]
    output=os.path.join(tmpdir,genome_name)

    prots=[]
    gene_mapping={}
    for record in SeqIO.parse(genome, type):
        for feature in record.features:
            locus_tag=None
            gene=None
            translation=None
            product=None

            if (feature.type == 'CDS'):
                if 'gene' in feature.qualifiers:
                    gene=feature.qualifiers['gene'][0]
                if 'locus_tag' in feature.qualifiers:
                    locus_tag=feature.qualifiers['locus_tag'][0]
                if 'translation' in feature.qualifiers:
                    translation=feature.qualifiers['translation'][0]
                if 'product' in feature.qualifiers:
                    product=feature.qualifiers['product'][0]

                if (len(locus_tags) and locus_tag in locus_tags) or not len(locus_tags):
                    gene_mapping[locus_tag]={'gene':gene, 'product':product}
                    prot_seq=None
                    if gene and translation:
                        prot_seq=SeqRecord(Seq(translation), id=locus_tag, 
                            description="{}\t{}".format(gene,product))
                    elif translation:
                        prot_seq=SeqRecord(Seq(translation), 
                            id=locus_tag, description="{}".format(product))

                    if prot_seq:
                        prots.append(prot_seq)

    SeqIO.write(prots, output, 'fasta')
    df=pd.DataFrame.from_dict(data=gene_mapping,orient='index')

    return(genome_name,df)

def index_db(tmpdir: str,db: str):
    """
    Blast indexes a database

    Required parameters:
        tmpdir - path to temporary directory
        db - path to fasta file to index

    Returns:
        None
    """
    dbpath=os.path.join(tmpdir,db)
    index_cline=NcbimakeblastdbCommandline(dbtype='prot',input_file=dbpath)
    stderr=index_cline()[1]
    
    if stderr:
        print("Warning - the following STDERR was reported by BLAST")
        print(stderr)
        sys.exit(1)

def create_blast_script(tmpdir:str, dbs: list, cov: int, num_align: int):

    """Generates a bash script for queue submission which carries out
    blastp searches of one genome against another

    Required params:
        tmpdir: Path to temp dir
        dbs: List of databases
        cov: proportion of query hsp coverage required
        num_align: Number of alignments to report

    Returns:
        Name of generated script
    """

    fd,tmpfile = tempfile.mkstemp()
    file_ = os.fdopen(fd, "w")
    file_.write('''\
#!/bin/bash
db1_name=$(basename {})
db2_name=$(basename {})
blastp -query {}/{} -db {}/{} -num_threads 8 -outfmt 6 -qcov_hsp_perc {} -num_alignments {} -evalue 0.001 -out {}/$db1_name.$db2_name.blast
'''.format(dbs[0], dbs[1], tmpdir, dbs[0], tmpdir, dbs[1], cov, num_align, tmpdir))
    file_.close()
    os.chmod(tmpfile,0o755)
    scriptname="{}_vs_{}.sh".format(os.path.basename(dbs[0]),os.path.basename(dbs[1]))
    shutil.copy(tmpfile,scriptname)
    os.remove(tmpfile)
   
    return(scriptname)

def run_blasts(tmpdir: str, dbs: list, recip: bool, len_thresh: int, num_align: int):
    """ Runs blasts of db1 vs db2 and vice versa 

    Required params: 
        tmpdir: path to temporary directory
        db: List of databases
        recip: Run reciprocal blast if True, otherwise db1 vs db2
        len_tresh: Length threshold proportion for hsp length of query
        num_align: Number of alignments to report

    Returns:
        None
    """

    scripts=[]

    script1=create_blast_script(tmpdir, dbs, len_thresh, num_align)
    scripts.append(script1)

    if recip:
        script2=create_blast_script(tmpdir, [dbs[1],dbs[0]], len_thresh, num_align)
        scripts.append(script2)
    
    drmaa_run(tmpdir, scripts, 'Blast')

def parse_blasts(tmpdir: str, dbs: list, pident: int, recip: bool):

    """Produces pandas DataFrames from blast outputs
    If a reciprocal search is carried out, then two dataframes are returned,
    one for db1 vs db2, and one for db2 vs db1

    Required Params:
        tmpdir: path to temporary directory
        dbs: 2-element list of blast database names
        recip: Parse reciprocal blast results

    Returns:
        dfs: List of pandas dataframe holding blast results (reference, subject)
    """

    blast_files=[]
    blast_file_1="{}.blast".format(".".join(dbs))
    blast_files.append(blast_file_1)

    if recip:
        dbs=[dbs[1],dbs[0]]
        blast_file_2="{}.blast".format(".".join(dbs))
        blast_files.append(blast_file_2)

    colnames=['qseq','sseq','pident','length','mismatch','gapopen','qstart',
        'qend','sstart','send','evalue','bitscore']

    dfs=[]

    for file in blast_files:
        df=pd.read_table(os.path.join(tmpdir, file), header=None, names=colnames)
        df=df.loc[df['pident']>pident]
        dfs.append(df)

    return(dfs)

def intersect_results(blast_dfs: list, gene_dfs: list):

    """Creates DataFrame combining reciprocal blast hits and their annotations
    Required Params:
        blast_dfs: List of DataFrames of reciprocal blast hits
        gene_dfs: List of DataFrames containing reference [0] 
            and subject [1] annotations
    
    Returns:
        DataFrame of merged results
    """

    ref_hits=blast_dfs[0][['qseq','sseq']].copy()
    common_hits=pd.merge(ref_hits,gene_dfs[0][['gene','product']],left_on='qseq', right_index=True)
    common_hits=pd.merge(common_hits,gene_dfs[1][['gene','product']],left_on='sseq',right_index=True)
    colnames={
        'qseq': 'RefID',
        'sseq': 'SubjID',
        'gene_x': 'RefGene',
        'product_x': 'RefProduct',
        'gene_y': 'SubjGene',
        'product_y': 'SubjProduct'
    }
    common_hits=common_hits[['sseq', 'gene_y', 'product_y', 'qseq', 'gene_x', 'product_x']]
    common_hits.rename(columns=colnames, inplace=True)

    return(common_hits)

def run_mmseq(tmpdir: str, db: str):
    """ Runs mmseq clustering of proteins

    Required params: 
        tmpdir: path to temporary directory
        db: database to cluster

    Returns:
        None
    """

    fd,tmpfile = tempfile.mkstemp()
    file_ = os.fdopen(fd, "w")
    file_.write('''\
#!/bin/bash
mmseqs easy-cluster {}/{} {}/{} {}
'''.format(tmpdir,db,tmpdir,db,tmpdir))
    file_.close()
    os.chmod(tmpfile,0o755)
    scriptname="{}/mmseqs_{}.sh".format(tmpdir,os.path.basename(db))
    shutil.copy(tmpfile,scriptname)
    os.remove(tmpfile)

    drmaa_run(tmpdir, [scriptname], 'mmseq')

def update_annotation(genome: str, common_hits: pd.DataFrame, truncated: pd.DataFrame):

    """
    Assigns reference annotations to subject genome record
    Required params:
        genome: path to subject genome file
        common_hits: dataframe containing annotation mappings
        truncated: dataframe containing truncated proteins

    Returns:
        conflicts: list of conflicting annotations
    """
    conflicts=[]

    genome_name=os.path.basename(genome)
    (genome_name,suffix)=os.path.splitext(genome_name)
    genome_file='.'.join([genome_name,'reannotated'])
    genome_file=genome_file+suffix

    if os.path.exists(genome_file):
        os.remove(genome_file)

    type=check_format(genome)
    for record in SeqIO.parse(genome, type):
        for feature in record.features:
            if feature.type=='CDS':
                # Annotated are first transferred from reciprocal top hits, and if there is none 
                # present, then we look for a truncated annotation instead.
                gene_data=common_hits.loc[common_hits['SubjID']==feature.qualifiers['locus_tag'][0]]
                trunc_data=truncated.loc[truncated['SubjID']==feature.qualifiers['locus_tag'][0]]
                gene_name=gene_data['RefGene'].values
                if len(gene_name)>1:
                    refid=gene_data['RefID'].values
                    subjid=gene_data['SubjID'].values
                    if len(refid)>1 or len(subjid)>1:
                        conflicts.append('{}:{}'.format(subjid,refid))
                elif len(gene_name)==1:
                    feature.qualifiers['gene']=gene_name[0]
                elif len(gene_name)==0:
                    gene_name=trunc_data['RefGene'].values
                    if len(gene_name)==1:
                        feature.qualifiers['gene']='{}_tr'.format(gene_name[0])
                    elif len(gene_name)>1:
                        # In the event of multiple gene names, we have conflicting annotations
                        # so add these to 'conflicts'
                        refid=trunc_data['RefID'].values
                        subjid=trunc_data['SubjID'].values
                        if len(refid)>1 or len(subjid)>1:
                            conflicts.append('{}:{}'.format(subjid,refid)) 
                
                if 'product' in feature.qualifiers:
                    product=gene_data['RefProduct'].values
                    if len(product)==1:
                        feature.qualifiers['product']=product[0]
                    elif len(product)==0:
                        product=trunc_data['RefProduct'].values
                        if len(product)==1:
                            feature.qualifiers['product']='{} (truncated)'.format(product[0])
                
        with open(genome_file,'a') as out:
            SeqIO.write(record, out, type)

        print("Updated annotations written to {}".format(genome_file))

    return(conflicts)

def reciprocal_blast(tmpdir: str, dbs: list, pident: int, gene_dfs: list):

    """Carries out reciprocal blast searches between two databases

    Required Params:
        tmpdir: path to temporary directory
        dbs: list of blast database names
        pident: percent-identity cutoff
        gene_dfs: list of dataframes containing ref and subj gene annotations

    Returns:
        common_hits: dataframe of reciprocal best hits
    """

    print('Identifying reciprocal top hits...')
    run_blasts(tmpdir, dbs, True, 90, 1)
    blast_dfs=parse_blasts(tmpdir, dbs, pident, True)

    common_hits=intersect_results(blast_dfs, gene_dfs) 
    common_hits.to_csv(
        '{}_{}.annotation_mapping.txt'.format(dbs[1],dbs[0]),
            sep='\t', index=None)

    return(common_hits)

def report_conflicts(conflicts: list, ref_genes: pd.DataFrame, genome_name:str):

    """Produces a report of genes with conflicting hits to the reference
    Required Parameters:
        conflicts: list of conflicting locus tags
        ref_genes: data frame of reference gene annotations
        genome_name: name of genome...

    Returns: None
    """
    if len(conflicts):
        with open('{}.conflicts.txt'.format(genome_name),'w') as file:
            for conflict in conflicts:
        
                conflict=re.sub(r'[\[\]\'\n]', '', conflict)
                subj,ref=conflict.split(':')
                subjs=subj.split(' ')
                subjs=set(subjs)
                if len(subjs)>1:
                    print('Warning: non-unique subject ids in conflict list')
                subj=subjs.pop()

                refs=ref.split(' ')
                ref_strs=list()
                for ref in refs:
                    ref_vals=ref_genes.loc[ref,].values
                    ref_string='{} ({})'.format(ref_vals[0],ref_vals[1])
                    ref_strs.append(ref_string)
                file.write('{}\t{}\n'.format(subj, '; '.join(ref_strs)))
        print('{} conflict(s) identified: written to {}.conflicts.txt'.format(len(conflicts),genome_name))

def find_truncated(tmpdir: str, genome:str, dbs: list, common_hits: pd.DataFrame, gene_dfs: list, pident: int):

    """ Proteins present as a tophit in the subject search but not the reference searc may be truncated...
        Identify candidates, then Re-blast at a 50% length threshold to see if they then come out 
        as reciprocal hits

        Required Params:
            tmpdir: path to temporary directory
            dbs: list of blast db names
            common_hits: Dataframe of reciprocal best hits
            gene_dfs: list of dataframes containing gene annotations
            pident: blast %ID cutoff

        Returns: 
            truncated: Dataframe of truncated proteins
            blast_dfs: List of dataframes of blast hits
    """
    print("Identifying potentially truncated proteins")
    subj_gene_df=gene_dfs[1]

    subjIDs=set(subj_gene_df.index.values)
    annotated_subjIDs=set(common_hits['SubjID'].values)
    unannotated_subjIDs=list(subjIDs-annotated_subjIDs)

    # A minimal blast database is created for the subject containing just the 
    # sequences we are interested in to restrict the outputs
    type=check_format(genome)
    parse_genome(genome, unannotated_subjIDs, type, tmpdir)

    # A 50% length threshold is applied for potential truncated proteins
    # Note that reciprocal blasts are carried out, since the results of the 
    # opposite search are used in the subsequence find_missing() call
    run_blasts(tmpdir, [dbs[1],dbs[0]], True, 50, 1)
    blast_dfs=parse_blasts(tmpdir, dbs, pident, True)
    # merge is opposite way round to that in intersect_results, so won't reuse that function...
    ref_hits=blast_dfs[1][['sseq','qseq']].copy()
    truncated=pd.merge(ref_hits,gene_dfs[0][['gene','product']],left_on='sseq', right_index=True)
    truncated=pd.merge(truncated,gene_dfs[1][['gene','product']],left_on='qseq',right_index=True)
    colnames={
        'sseq': 'RefID',
        'qseq': 'SubjID',
        'gene_x': 'RefGene',
        'product_x': 'RefProduct',
        'gene_y': 'SubjGene',
        'product_y': 'SubjProduct'
    }
    truncated=truncated[['sseq', 'gene_x', 'product_x', 'qseq', 'gene_y', 'product_y']]
    truncated.rename(columns=colnames, inplace=True)

    print('{} truncated protein(s) identified: written to {}.truncated.txt'.format(len(truncated),dbs[1]))
    truncated.to_csv('{}.truncated.txt'.format(dbs[1]), header=True, sep="\t",index=None)

    return(truncated,blast_dfs)

def find_missing(blast_df: pd.DataFrame, gene_df: pd.DataFrame, truncated: pd.DataFrame, genome_name: str):
    """ Proteins missing from the subject genome are identified from the blast results 
    with a 50% query hsp length cutoff

    Required parameters:
        blast_df: Dataframe holding results of ref->subj blast search at 50% length threshold
        gene_df: Dataframe holding details of reference genes
        geneome_name: Name of reference genome

    Returns:
        None
    """

    print('Identifying missing proteins')
    hit_ids=set(blast_df['qseq'])
    ref_ids=set(gene_df.index.values)
    truncated_ids=set(truncated['RefID'])
    missing_ids=list(ref_ids-hit_ids-truncated_ids)
    missing=gene_df[gene_df.index.isin(missing_ids)]
    missing=missing.reset_index()
    missing.rename(columns={'index': 'RefId'},inplace=True)

    print('{} missing protein(s) identified: written to {}.missing.txt'.format(len(missing),genome_name))
    missing.to_csv('{}.missing.txt'.format(genome_name), header=True, sep="\t",index=None)

def find_novel(tmpdir: str, dbs: list, gene_df: list, pident: int, genome_name: str):

    """ Identify novel protein sequences in the subject by lack of hits to the reference...

    Required Paramater:
        tmpdir: Path to temporary directory
        dbs: list of databases 
        gene_df: Dataframe of gene annotations in subject genome
        pident: %ID blast cutoff
        genome_name: name of subject

    Returns:
        None
    """

    print('Identifying novel proteins')
    run_blasts(tmpdir, [dbs[1],dbs[0]], False, 90, 1)
    blast_df=parse_blasts(tmpdir, [dbs[1],dbs[0]], pident,False)[0]
    subj_ids=set(gene_df.index.values)
    hit_ids=set(blast_df['qseq'].values)
    novel=list(subj_ids-hit_ids)
    novel_df=gene_df[gene_df.index.isin(novel)]
    novel_df=novel_df.reset_index()
    novel_df.rename(columns={'index':'SubjId'},inplace=True)

    print('{} novel protein(s) identified: written to {}.novel.txt'.format(len(novel_df),genome_name))
    novel_df.to_csv('{}.novel.txt'.format(genome_name), header=True, sep="\t",index=None)


def find_duplicates(tmpdir: str, dbs:list, gene_df: list, pident: int, genome_name: str, subject_genome: str):

    print('Identifying duplicated proteins')
    run_blasts(tmpdir,[dbs[1],dbs[1]], False, 90, 20)
    blast_df=parse_blasts(tmpdir, [dbs[1],dbs[1]], pident, False)[0]

    multi_hits=pd.DataFrame(blast_df['qseq'].value_counts())
    multi_hits=multi_hits[multi_hits['qseq']>1].index.values
    multi_hit_details=blast_df[blast_df['qseq'].isin(multi_hits)] 

    multi_hit_details.to_csv('{}.multihits.blast.txt'.format(genome_name), header=True, sep="\t")
    multi_hits=list(multi_hit_details['qseq'].values)
    multi_hits.extend(multi_hit_details['sseq'].values)
    multi_hits=list(set(multi_hits))

    ## Some of these hits may have local similarity to a longer sequence
    # so carry out clustering to determine who is truly duplicated...
    type=check_format(subject_genome)
    parse_genome(subject_genome, multi_hits, type, tmpdir)
    run_mmseq(tmpdir,genome_name )

    colnames=['query','subject']
    clusters=pd.read_table(os.path.join(tmpdir, '{}_cluster.tsv'.format(genome_name)), 
        header=None,names=colnames,index_col=0)
    # Discard singletons...
    cluster_counts=pd.DataFrame(clusters.index.value_counts())
    multi_cluster=cluster_counts[cluster_counts['query']>1].index.values
    clusters=clusters[clusters.index.isin(multi_cluster)]
    cluster_names=clusters.index.unique()

    with open('{}.duplicates.txt'.format(genome_name),'w') as file:
        file.write("Gene name\tProduct\tDuplicates\n")
        for cluster_name in cluster_names:
            cluster=clusters[clusters['subject'].index==cluster_name].values
            flat_cluster=[]
            for sublist in cluster:
                for item in sublist:
                    flat_cluster.append(item)
            annots=gene_df.loc[cluster_name]
            file.write("{}\t{}\t{}\n".format(annots['gene'],annots['product'],",".join(flat_cluster)))
    
    if len(cluster_names):
        print("{} duplicate protein(s) identified: written to {}.duplicates.txt".format(len(cluster_names),genome_name))
    

def main():

    parser = argparse.ArgumentParser(
        description="Transfer annotations between genome records based on reciprocal blast searches")
    parser.add_argument('--ref-genome', dest='ref', 
        help='path to EMBL/Genbank file containing reference annotated genome', required=True)
    parser.add_argument('--subject-genome', dest='subj', 
        help='path to EMBL/Genbank file containing subject annotated genome', required=True)
    parser.add_argument('--percent-id', dest='pident', default=95, type=int,
        help='Percentage ID threshold for blast hits (default: 95)')

    args = parser.parse_args()
    dbs=[] 
    gene_dfs=[]

    with tempfile.TemporaryDirectory(dir='.') as tmpdir:
        for genome in (args.ref, args.subj):
            type=check_format(genome)
            dbname,gene_df=parse_genome(genome, [], type, tmpdir)
            dbs.append(dbname)
            gene_dfs.append(gene_df)
            index_db(tmpdir, dbname)

        subj_name=dbs[1]

        # N.B. order of calls is important...full subject database replaced in find_truncated(), so
        # searches requiring full database should occur before this call...
        common_hits=reciprocal_blast(tmpdir, dbs, args.pident, gene_dfs)
        find_novel(tmpdir, dbs, gene_dfs[1], args.pident, subj_name)
        find_duplicates(tmpdir, dbs,  gene_dfs[1], args.pident, subj_name, args.subj)
        truncated,blast_dfs=find_truncated(tmpdir, args.subj, dbs, common_hits, gene_dfs, args.pident)
        find_missing(blast_dfs[0], gene_dfs[0], truncated, subj_name)
        conflicts=update_annotation(args.subj, common_hits, truncated)
        report_conflicts(conflicts, gene_dfs[0], subj_name)

        for db in dbs:
            fasta_file='{}.fa'.format(db)
            if os.path.exists(fasta_file):
                os.remove(fasta_file)

if __name__ == '__main__':
    main()