# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function
import os
import numpy as np
import warnings
from scipy.spatial import distance_matrix
warnings.filterwarnings(action='ignore')
from .dictionary import Dictionary
from multiprocessing import Pool
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)
def inner_smi2coords(smi, seed=42, mode='fast', remove_hs=True):
'''
This function is responsible for converting a SMILES (Simplified Molecular Input Line Entry System) string into 3D coordinates for each atom in the molecule. It also allows for the generation of 2D coordinates if 3D conformation generation fails, and optionally removes hydrogen atoms and their coordinates from the resulting data.
:param smi: (str) The SMILES representation of the molecule.
:param seed: (int, optional) The random seed for conformation generation. Defaults to 42.
:param mode: (str, optional) The mode of conformation generation, 'fast' for quick generation, 'heavy' for more attempts. Defaults to 'fast'.
:param remove_hs: (bool, optional) Whether to remove hydrogen atoms from the final coordinates. Defaults to True.
:return: A tuple containing the list of atom symbols and their corresponding 3D coordinates.
:raises AssertionError: If no atoms are present in the molecule or if the coordinates do not align with the atom count.
'''
mol = Chem.MolFromSmiles(smi)
mol = AllChem.AddHs(mol)
atoms = []
for atom in mol.GetAtoms():
atoms.append(atom.GetSymbol())
assert len(atoms)>0, 'No atoms in molecule: {}'.format(smi)
try:
# will random generate conformer with seed equal to -1. else fixed random seed.
res = AllChem.EmbedMolecule(mol, randomSeed=seed)
if res == 0:
try:
# some conformer can not use MMFF optimize
AllChem.MMFFOptimizeMolecule(mol)
coordinates = mol.GetConformer().GetPositions().astype(np.float32)
except:
coordinates = mol.GetConformer().GetPositions().astype(np.float32)
## for fast test... ignore this ###
elif res == -1 and mode == 'heavy':
AllChem.EmbedMolecule(mol, maxAttempts=5000, randomSeed=seed)
try:
# some conformer can not use MMFF optimize
AllChem.MMFFOptimizeMolecule(mol)
coordinates = mol.GetConformer().GetPositions().astype(np.float32)
except:
AllChem.Compute2DCoords(mol)
coordinates_2d = mol.GetConformer().GetPositions().astype(np.float32)
coordinates = coordinates_2d
else:
AllChem.Compute2DCoords(mol)
coordinates_2d = mol.GetConformer().GetPositions().astype(np.float32)
coordinates = coordinates_2d
except:
logger.info("Failed to generate conformer, replace with zeros.")
coordinates = np.zeros((len(atoms),3))
assert len(atoms) == len(coordinates), "coordinates shape is not align with {}".format(smi)
if remove_hs:
idx = [i for i, atom in enumerate(atoms) if atom != 'H']
atoms_no_h = [atom for atom in atoms if atom != 'H']
coordinates_no_h = coordinates[idx]
assert len(atoms_no_h) == len(coordinates_no_h), "coordinates shape is not align with {}".format(smi)
return atoms_no_h, coordinates_no_h
else:
return atoms, coordinates
def inner_coords(atoms, coordinates, remove_hs=True):
"""
Processes a list of atoms and their corresponding coordinates to remove hydrogen atoms if specified.
This function takes a list of atom symbols and their corresponding coordinates and optionally removes hydrogen atoms from the output. It includes assertions to ensure the integrity of the data and uses numpy for efficient processing of the coordinates.
:param atoms: (list) A list of atom symbols (e.g., ['C', 'H', 'O']).
:param coordinates: (list of tuples or list of lists) Coordinates corresponding to each atom in the `atoms` list.
:param remove_hs: (bool, optional) A flag to indicate whether hydrogen atoms should be removed from the output.
Defaults to True.
:return: A tuple containing two elements; the filtered list of atom symbols and their corresponding coordinates.
If `remove_hs` is False, the original lists are returned.
:raises AssertionError: If the length of `atoms` list does not match the length of `coordinates` list.
"""
assert len(atoms) == len(coordinates), "coordinates shape is not align atoms"
coordinates = np.array(coordinates).astype(np.float32)
if remove_hs:
idx = [i for i, atom in enumerate(atoms) if atom != 'H']
atoms_no_h = [atom for atom in atoms if atom != 'H']
coordinates_no_h = coordinates[idx]
assert len(atoms_no_h) == len(coordinates_no_h), "coordinates shape is not align with atoms"
return atoms_no_h, coordinates_no_h
else:
return atoms, coordinates
def coords2unimol(atoms, coordinates, dictionary, max_atoms=256, remove_hs=True, **params):
"""
Converts atom symbols and coordinates into a unified molecular representation.
:param atoms: (list) List of atom symbols.
:param coordinates: (ndarray) Array of atomic coordinates.
:param dictionary: (Dictionary) An object that maps atom symbols to unique integers.
:param max_atoms: (int) The maximum number of atoms to consider for the molecule.
:param remove_hs: (bool) Whether to remove hydrogen atoms from the representation.
:param params: Additional parameters.
:return: A dictionary containing the molecular representation with tokens, distances, coordinates, and edge types.
"""
atoms, coordinates = inner_coords(atoms, coordinates, remove_hs=remove_hs)
atoms = np.array(atoms)
coordinates = np.array(coordinates).astype(np.float32)
# cropping atoms and coordinates
if len(atoms) > max_atoms:
idx = np.random.choice(len(atoms), max_atoms, replace=False)
atoms = atoms[idx]
coordinates = coordinates[idx]
# tokens padding
src_tokens = np.array([dictionary.bos()] + [dictionary.index(atom) for atom in atoms] + [dictionary.eos()])
src_distance = np.zeros((len(src_tokens), len(src_tokens)))
# coordinates normalize & padding
src_coord = coordinates - coordinates.mean(axis=0)
src_coord = np.concatenate([np.zeros((1,3)), src_coord, np.zeros((1,3))], axis=0)
# distance matrix
src_distance = distance_matrix(src_coord, src_coord)
# edge type
src_edge_type = src_tokens.reshape(-1, 1) * len(dictionary) + src_tokens.reshape(1, -1)
return {
'src_tokens': src_tokens.astype(int),
'src_distance': src_distance.astype(np.float32),
'src_coord': src_coord.astype(np.float32),
'src_edge_type': src_edge_type.astype(int),
}