#!/bin/usr/env python3

"""
Script to generate .xyz trajectory files with solvent molecules from GROMACS output.
"""

import MDAnalysis as mda
from MDAnalysis import transformations
import numpy as np


def create_xyz_traj(
    universe,
    ref,
    molecules, 
    n_mol: int = 1,
    file_name = 'traj.xyz',
    n_frames: int = 100,
    n_init: int = 1,
    n_end: int = -1,
    cm_to_origin = False
):
    """
    Generate .xyz trajectory file with the closest 'n_mol' molecules from reference residue 'ref'.
    
    PARAMETERS:
    universe [type: MDAnalysis.core.universe.Universe] - MDAnalysis universe with all data.
    ref [type: MDAnalysis.core.groups.AtomGroup] - MDAnalysis AtomGroup with the reference molecule/residue.
    molecules [type: MDAnalysis.core.groups.AtomGroup] - MDAnalysis AtomGroup with all molecules/residues
                                                         to search for the closest 'n_mol'.
    n_mol [type: int] - Number of closest residues to the reference. 
    file_name [type: str] - Name of the output trajectory file.
    n_frames [type: int] - Number of frames considered.
    n_init [type: int] - First frame.
    n_end [type: int] - Last frame.
    cm_to_origin [type: bool] - If True, translate the target's residue center of mass to origin.
    
    OUTPUT:
    'file_name' file.
    """

    with open("{}".format(file_name), "w") as fout:

        # Get the step to get n_frames
        if n_end == -1:
            dt = round((len(universe.trajectory)-n_init)/n_frames)
        else:
            dt = round((n_end-n_init)/n_frames)
        
        # Loop over all frames
        for ts in universe.trajectory[n_init:n_end+1:dt]:

            # Calculate the center of mass for each molecule (residue) in the molecule group
            molecule_centers = np.array([res.atoms.center_of_mass() for res in molecules.residues])

            # Calculate the distances between the center of mass of the reference group and each molecule
            distances = np.linalg.norm(molecule_centers - ref.center_of_mass(), axis=1)

            # Find the indices of the 'n_mol' closest molecules
            closest_indices = np.argsort(distances)[:n_mol]

            # Get the residues of the 'n_mol' closest molecules
            closest_molecules = [molecules.residues[i].atoms for i in closest_indices]
            
            # Shift the target's residue center of mass to origin
            if cm_to_origin:
                
                # Compute the center of mass of ref
                cm = ref.center_of_mass()
                
                # Translate all the system
                universe.atoms.positions -= cm
            
            # Get the number of atoms
            n_atoms = len(ref)
            
            for molecule in closest_molecules:
                n_atoms += len(molecule)
            
            # Write the .xyz trajectory file header
            fout.write("{}\n".format(n_atoms))
            fout.write("Frame {}\n".format(ts.frame))

            # Write the reference molecule coordinates
            for atom in ref:
                fout.write("{:>8}{:>12.5f}{:>11.5f}{:>11.5f}\n".format(atom.element, atom.position[0], 
                                                                       atom.position[1], atom.position[2]))

            # Write the coordinates of the 'n_mol' closest molecules
            for molecule in closest_molecules:
                for atom in molecule:
                    fout.write("{:>8}{:>12.5f}{:>11.5f}{:>11.5f}\n".format(atom.element, atom.position[0], 
                                                                           atom.position[1], atom.position[2]))

u = mda.Universe("h2o_nvt.tpr", "h2o_nvt.xtc")

h2o_1 = u.select_atoms("index 1 to 3")
h2o_others = u.select_atoms("index 4 to 10000")

# Centralize the water 1 and avoid pbc problems
# https://www.mdanalysis.org/2020/03/09/on-the-fly-transformations/
workflow = (transformations.unwrap(u.atoms),
                   transformations.center_in_box(h2o_1, center='mass'),
                   transformations.wrap(u.atoms, compound='fragments'))
u.trajectory.add_transformations(*workflow)

create_xyz_traj(
	universe=u, 
	ref=h2o_1, 
	molecules=h2o_others, 
	n_mol=1, 
	file_name="h2o_configs_1.xyz", 
	n_frames=100, 
	n_init=2, 
	n_end=1001, 
	cm_to_origin=False
)

