diff --git a/nemo_text_processing/inverse_text_normalization/ta/__init__.py b/nemo_text_processing/inverse_text_normalization/ta/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nemo_text_processing/inverse_text_normalization/ta/data/numbers/digit.tsv b/nemo_text_processing/inverse_text_normalization/ta/data/numbers/digit.tsv new file mode 100644 index 000000000..c3298e6d8 --- /dev/null +++ b/nemo_text_processing/inverse_text_normalization/ta/data/numbers/digit.tsv @@ -0,0 +1,10 @@ +1 ஒன்று +2 இரண்டு +3 மூன்று +4 நான்கு +5 ஐந்து +6 ஆறு +7 ஏழு +8 எட்டு +9 ஒன்பது + diff --git a/nemo_text_processing/inverse_text_normalization/ta/data/numbers/teens_and_ties.tsv b/nemo_text_processing/inverse_text_normalization/ta/data/numbers/teens_and_ties.tsv new file mode 100644 index 000000000..b4889a64f --- /dev/null +++ b/nemo_text_processing/inverse_text_normalization/ta/data/numbers/teens_and_ties.tsv @@ -0,0 +1,19 @@ +10 பத்து +11 பதினொன்று +12 பன்னிரண்டு +13 பதின்மூன்று +14 பதினான்கு +15 பதினைந்து +16 பதினாறு +17 பதினேழு +18 பதினெட்டு +19 பத்தொன்பது +20 இருபது +30 முப்பது +40 நாற்பது +50 ஐம்பது +60 அறுபது +70 எழுபது +80 எண்பது +90 தொண்ணூறு + diff --git a/nemo_text_processing/inverse_text_normalization/ta/data/numbers/zero.tsv b/nemo_text_processing/inverse_text_normalization/ta/data/numbers/zero.tsv new file mode 100644 index 000000000..1f8d200b9 --- /dev/null +++ b/nemo_text_processing/inverse_text_normalization/ta/data/numbers/zero.tsv @@ -0,0 +1,2 @@ +0 சுழியம் + diff --git a/nemo_text_processing/inverse_text_normalization/ta/graph_utils.py b/nemo_text_processing/inverse_text_normalization/ta/graph_utils.py new file mode 100644 index 000000000..b002efa52 --- /dev/null +++ b/nemo_text_processing/inverse_text_normalization/ta/graph_utils.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright 2024 and onwards Google, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import string +from pathlib import Path +from typing import Dict + +import pynini +from pynini import Far +from pynini.examples import plurals +from pynini.export import export +from pynini.lib import byte, pynutil, utf8 + +from nemo_text_processing.inverse_text_normalization.hi.utils import get_abs_path, load_labels + +NEMO_CHAR = utf8.VALID_UTF8_CHAR + +graph_digit = pynini.string_file(get_abs_path("data/numbers/digit.tsv")) + +NEMO_HI_DIGIT = pynini.union("०", "१", "२", "३", "४", "५", "६", "७", "८", "९").optimize() +DEVANAGARI_DIGIT = ["०", "१", "२", "३", "४", "५", "६", "७", "८", "९"] + +NEMO_HEX = pynini.union(*string.hexdigits).optimize() +NEMO_NON_BREAKING_SPACE = u"\u00a0" +NEMO_ZWNJ = u"\u200c" +NEMO_SPACE = " " +NEMO_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", u"\u00a0").optimize() +NEMO_NOT_SPACE = pynini.difference(NEMO_CHAR, NEMO_WHITE_SPACE).optimize() +NEMO_NOT_QUOTE = pynini.difference(NEMO_CHAR, r'"').optimize() + +NEMO_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize() +NEMO_GRAPH = pynini.union(NEMO_CHAR, NEMO_PUNCT).optimize() + +NEMO_SIGMA = pynini.closure(NEMO_CHAR) + +delete_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE)) +delete_zero_or_one_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE, 0, 1)) +insert_space = pynutil.insert(" ") +delete_extra_space = pynini.cross(pynini.closure(NEMO_WHITE_SPACE, 1), " ") +delete_preserve_order = pynini.closure( + pynutil.delete(" preserve_order: true") + | (pynutil.delete(" field_order: \"") + NEMO_NOT_QUOTE + pynutil.delete("\"")) +) + + +MIN_NEG_WEIGHT = -0.0001 +MIN_POS_WEIGHT = 0.0001 +INPUT_CASED = "cased" +INPUT_LOWER_CASED = "lower_cased" +MINUS = pynini.union("ऋणात्मक", "नकारात्मक").optimize() + + +def integer_to_devanagari(n: int) -> str: + return ''.join(DEVANAGARI_DIGIT[int(d)] for d in str(n)) + + +def generator_main(file_name: str, graphs: Dict[str, 'pynini.FstLike']): + """ + Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name. + + Args: + file_name: exported file name + graphs: Mapping of a rule name and Pynini WFST graph to be exported + """ + exporter = export.Exporter(file_name) + for rule, graph in graphs.items(): + exporter[rule] = graph.optimize() + exporter.close() + logging.info(f'Created {file_name}') + + +def convert_space(fst) -> 'pynini.FstLike': + """ + Converts space to nonbreaking space. + Used only in tagger grammars for transducing token values within quotes, e.g. name: "hello kitty" + This is making transducer significantly slower, so only use when there could be potential spaces within quotes, otherwise leave it. + + Args: + fst: input fst + + Returns output fst where breaking spaces are converted to non breaking spaces + """ + return fst @ pynini.cdrewrite(pynini.cross(NEMO_SPACE, NEMO_NON_BREAKING_SPACE), "", "", NEMO_SIGMA) + + +def string_map_cased(input_file: str, input_case: str = INPUT_LOWER_CASED): + labels = load_labels(input_file) + + if input_case == INPUT_CASED: + additional_labels = [] + for written, spoken, *weight in labels: + written_capitalized = written[0].upper() + written[1:] + additional_labels.extend( + [ + [written_capitalized, spoken.capitalize()], # first letter capitalized + [ + written_capitalized, + spoken.upper().replace(" AND ", " and "), + ], # # add pairs with the all letters capitalized + ] + ) + + spoken_no_space = spoken.replace(" ", "") + # add abbreviations without spaces (both lower and upper case), i.e. "BMW" not "B M W" + if len(spoken) == (2 * len(spoken_no_space) - 1): + logging.debug(f"This is weight {weight}") + if len(weight) == 0: + additional_labels.extend( + [[written, spoken_no_space], [written_capitalized, spoken_no_space.upper()]] + ) + else: + additional_labels.extend( + [ + [written, spoken_no_space, weight[0]], + [written_capitalized, spoken_no_space.upper(), weight[0]], + ] + ) + labels += additional_labels + + whitelist = pynini.string_map(labels).invert().optimize() + return whitelist + + +class GraphFst: + """ + Base class for all grammar fsts. + + Args: + name: name of grammar class + kind: either 'classify' or 'verbalize' + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, name: str, kind: str, deterministic: bool = True): + self.name = name + self.kind = kind + self._fst = None + self.deterministic = deterministic + + self.far_path = Path(os.path.dirname(__file__) + '/grammars/' + kind + '/' + name + '.far') + if self.far_exist(): + self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst() + + def far_exist(self) -> bool: + """ + Returns true if FAR can be loaded + """ + return self.far_path.exists() + + @property + def fst(self) -> 'pynini.FstLike': + return self._fst + + @fst.setter + def fst(self, fst): + self._fst = fst + + def add_tokens(self, fst) -> 'pynini.FstLike': + """ + Wraps class name around to given fst + + Args: + fst: input fst + + Returns: + Fst: fst + """ + return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }") + + def delete_tokens(self, fst) -> 'pynini.FstLike': + """ + Deletes class name wrap around output of given fst + + Args: + fst: input fst + + Returns: + Fst: fst + """ + res = ( + pynutil.delete(f"{self.name}") + + delete_space + + pynutil.delete("{") + + delete_space + + fst + + delete_space + + pynutil.delete("}") + ) + return res @ pynini.cdrewrite(pynini.cross(u"\u00a0", " "), "", "", NEMO_SIGMA) diff --git a/nemo_text_processing/inverse_text_normalization/ta/taggers/__init__.py b/nemo_text_processing/inverse_text_normalization/ta/taggers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nemo_text_processing/inverse_text_normalization/ta/taggers/cardinal.py b/nemo_text_processing/inverse_text_normalization/ta/taggers/cardinal.py new file mode 100644 index 000000000..f6d8173cc --- /dev/null +++ b/nemo_text_processing/inverse_text_normalization/ta/taggers/cardinal.py @@ -0,0 +1,29 @@ +import pynini +from pynini.lib import pynutil + +from nemo_text_processing.inverse_text_normalization.ta.graph_utils import GraphFst +from nemo_text_processing.inverse_text_normalization.ta.utils import get_abs_path + + +class CardinalFst(GraphFst): + """ + Classifies spoken numbers back to digits, e.g. -> cardinal { integer: "5" } + """ + + def __init__(self): + super().__init__(name="cardinal", kind="classify") + + # The SAME data files (number -> word). For ITN we read them BACKWARDS + # (word -> number) using .invert(). + # TODO 1: add .invert() to each of the three lines below. + graph_digit = pynini.string_file(get_abs_path("data/numbers/digit.tsv")).invert() + graph_zero = pynini.string_file(get_abs_path("data/numbers/zero.tsv")).invert() + graph_teens_and_ties = pynini.string_file(get_abs_path("data/numbers/teens_and_ties.tsv")).invert() + + # TODO 2: Combine them with the union operator | + graph = graph_digit | graph_zero | graph_teens_and_ties + graph = graph.optimize() + + final_graph = pynutil.insert('integer: "') + graph + pynutil.insert('"') + final_graph = self.add_tokens(final_graph) + self.fst = final_graph.optimize() diff --git a/nemo_text_processing/inverse_text_normalization/ta/utils.py b/nemo_text_processing/inverse_text_normalization/ta/utils.py new file mode 100644 index 000000000..8e3f62c3c --- /dev/null +++ b/nemo_text_processing/inverse_text_normalization/ta/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import logging +import os +import pynini + + +def get_abs_path(rel_path): + """ + Get absolute path + + Args: + rel_path: relative path to this file + + Returns absolute path + """ + abs_path = os.path.dirname(os.path.abspath(__file__)) + os.sep + rel_path + + if not os.path.exists(abs_path): + logging.warning(f'{abs_path} does not exist') + return abs_path + + +def load_labels(abs_path): + """ + loads relative path file as dictionary + + Args: + abs_path: absolute path + + Returns dictionary of mappings + """ + label_tsv = open(abs_path, encoding="utf-8") + labels = list(csv.reader(label_tsv, delimiter="\t")) + return labels + + +from pynini.lib import pynutil + + +def apply_fst(text, fst): + """Given a string input, returns the output string + produced by traversing the path with lowest weight. + If no valid path accepts input string, returns an + error. + """ + try: + print(pynini.shortestpath(text @ fst).string()) + except pynini.FstOpError: + print(f"Error: No valid output with given input: '{text}'") diff --git a/nemo_text_processing/inverse_text_normalization/ta/verbalizers/__init__.py b/nemo_text_processing/inverse_text_normalization/ta/verbalizers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nemo_text_processing/inverse_text_normalization/ta/verbalizers/cardinal.py b/nemo_text_processing/inverse_text_normalization/ta/verbalizers/cardinal.py new file mode 100644 index 000000000..18469b02b --- /dev/null +++ b/nemo_text_processing/inverse_text_normalization/ta/verbalizers/cardinal.py @@ -0,0 +1,25 @@ +import pynini +from pynini.lib import pynutil + +from nemo_text_processing.inverse_text_normalization.ta.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space + + +class CardinalFst(GraphFst): + """ + Verbalizes the digits, e.g. cardinal { integer: "5" } -> 5 + """ + + def __init__(self): + super().__init__(name="cardinal", kind="verbalize") + + # TODO 3: keep the digits between the quotes (1 or more non-quote chars). + graph = ( + pynutil.delete("integer:") + + delete_space + + pynutil.delete('"') + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete('"') + ) + + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/run_cardinal_tests.py b/run_cardinal_tests.py new file mode 100644 index 000000000..26b6497af --- /dev/null +++ b/run_cardinal_tests.py @@ -0,0 +1,47 @@ +# run_cardinal_tests.py -- simple checker for the Cardinal exercise. +# Usage (from the repo root, inside your conda env): +# python run_cardinal_tests.py +# Example: +# python run_cardinal_tests.py LANGCODE DIRECTION test_cases_cardinal.txt +import importlib +import sys + +import pynini + + +def apply_fst(fst, text): + lattice = text @ fst + if lattice.num_states() == 0: + return None # input was rejected by the grammar + out = pynini.shortestpath(lattice) + return out.string() if out.num_states() else None + + +def main(): + lang, direction, path = sys.argv[1], sys.argv[2], sys.argv[3] + base = "text_normalization" if direction == "tn" else "inverse_text_normalization" + print("Script started") + tagger = importlib.import_module(f"nemo_text_processing.{base}.{lang}.taggers.cardinal").CardinalFst().fst + verbalizer = importlib.import_module(f"nemo_text_processing.{base}.{lang}.verbalizers.cardinal").CardinalFst().fst + + passed = failed = 0 + with open(path, encoding="utf-8") as f: + for line in f: + line = line.rstrip("\n") + if not line.strip() or "~" not in line: + continue + inp, expected = [s.strip() for s in line.split("~", 1)] + tagged = apply_fst(tagger, inp) + result = apply_fst(verbalizer, tagged) if tagged is not None else None + if result == expected: + passed += 1 + else: + failed += 1 + print(f"FAIL: {inp!r} -> got {result!r}, expected {expected!r}") + + print(f"\n{passed} passed, {failed} failed.") + sys.exit(1 if failed else 0) + + +if __name__ == "__main__": + main() diff --git a/test_cases_cardinal.txt b/test_cases_cardinal.txt new file mode 100644 index 000000000..626c9bdec --- /dev/null +++ b/test_cases_cardinal.txt @@ -0,0 +1,28 @@ +சுழியம்~0 +ஒன்று~1 +இரண்டு~2 +மூன்று~3 +நான்கு~4 +ஐந்து~5 +ஆறு~6 +ஏழு~7 +எட்டு~8 +ஒன்பது~9 +பத்து~10 +பதினொன்று~11 +பன்னிரண்டு~12 +பதின்மூன்று~13 +பதினான்கு~14 +பதினைந்து~15 +பதினாறு~16 +பதினேழு~17 +பதினெட்டு~18 +பத்தொன்பது~19 +இருபது~20 +முப்பது~30 +நாற்பது~40 +ஐம்பது~50 +அறுபது~60 +எழுபது~70 +எண்பது~80 +தொண்ணூறு~90 \ No newline at end of file