# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
[docs]
class Dictionary:
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="[CLS]",
pad="[PAD]",
eos="[SEP]",
unk="[UNK]",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.specials = set()
self.specials.add(bos)
self.specials.add(unk)
self.specials.add(pad)
self.specials.add(eos)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
[docs]
def vec_index(self, a):
return np.vectorize(self.index)(a)
[docs]
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.indices[self.unk_word]
[docs]
def special_index(self):
return [self.index(x) for x in self.specials]
[docs]
def add_symbol(self, word, n=1, overwrite=False, is_special=False):
"""Adds a word to the dictionary"""
if is_special:
self.specials.add(word)
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
[docs]
def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.index(self.bos_word)
[docs]
def pad(self):
"""Helper to get index of pad symbol"""
return self.index(self.pad_word)
[docs]
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.index(self.eos_word)
[docs]
def unk(self):
"""Helper to get index of unk symbol"""
return self.index(self.unk_word)
[docs]
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
[docs]
def add_from_file(self, f):
"""
Loads pre-defined dictionary symbols.
If f == "default", it will load the default atom dictionary.
Otherwise, loads from a text file and adds its symbols to this instance.
"""
if f == "default":
# Default atom dictionary
default_atoms = [
"[PAD]", "[CLS]", "[SEP]", "[UNK]",
"C", "N", "O", "S", "H", "Cl", "F", "Br",
"I", "Si", "P", "B", "Na", "K", "Al", "Ca",
"Sn", "As", "Hg", "Fe", "Zn", "Cr", "Se",
"Gd", "Au", "Li"
]
for line_idx, word in enumerate(default_atoms):
count = len(default_atoms) - line_idx
if word in self:
logger.info(
"Duplicate word found when loading Dictionary: '{}', index is {}.".format(
word, self.indices[word]
)
)
else:
self.add_symbol(word, n=count, overwrite=False)
return
# File reading logic
if isinstance(f, str):
try:
with open(f, "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(
"Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)
)
return
lines = f.readlines()
for line_idx, line in enumerate(lines):
try:
splits = line.rstrip().rsplit(" ", 1)
line = splits[0]
field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
if field == "#overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
logger.info(
"Duplicate word found when loading Dictionary: '{}', index is {}.".format(
word, self.indices[word]
)
)
else:
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
)