commit 3e58383481ec0d55ddc70c8892c83ddc0467ba88
parent 9b85e20d877c0b1c35b2b39ef5d86bb2a22fd702
Author: Stefan Koch <programming@stefan-koch.name>
Date: Sun, 3 Sep 2023 12:31:16 +0200
add Spanish support
Diffstat:
4 files changed, 81 insertions(+), 41 deletions(-)
diff --git a/src/nlparrot/tokenization/croatian.py b/src/nlparrot/tokenization/croatian.py
@@ -1,56 +1,20 @@
-import collections.abc
-
import stanza
-from .generic import Token, Tokenizer
+from .generic import StanzaVocabularyTokenizer
-class CroatianVocabularyTokenizer(Tokenizer):
+class CroatianVocabularyTokenizer(StanzaVocabularyTokenizer):
def __init__(self, nonstandard=False):
- self._nonstandard = nonstandard
- self._stanza = None
-
- def tokenize(self, text: str) -> collections.abc.Iterable[Token]:
- # classla (and probably also stanza) performs strip() on lines and
- # re-sets the char counter on newline. To maintain a correct char
- # counter we need to perform this on our own and memorize what we
- # removed.
- char_offset = 0
- nlp = self._get_stanza()
-
- for paragraph in text.split("\n"):
- lstripped = paragraph.lstrip()
- lstripped_count = len(paragraph) - len(lstripped)
- stripped = lstripped.rstrip()
- rstripped_count = len(lstripped) - len(stripped)
-
- doc = nlp(stripped)
+ super().__init__("hr")
- char_offset += lstripped_count
-
- for sentence in doc.sentences:
- for token in sentence.tokens:
- data = token.to_dict()[0]
-
- if data["upos"].lower() not in ["punct", "sym"]:
- yield Token(
- start=token.start_char + char_offset,
- end=token.end_char + char_offset,
- token=data["lemma"],
- original_text=data["text"],
- )
-
- char_offset += len(stripped)
- char_offset += rstripped_count
- # add one for the split \n
- char_offset += 1
+ self._nonstandard = nonstandard
def _get_stanza(self):
type_ = "nonstandard" if self._nonstandard else "default"
if self._stanza is None:
# TODO: For some reason suddenly I do not get a "start_char" anymore
- # from classla. Using original stanza for the time being.
+ # from classla. Using original stanza for the time being.
self._stanza = stanza.Pipeline(
"hr",
processors="tokenize,pos,lemma,depparse",
diff --git a/src/nlparrot/tokenization/factory.py b/src/nlparrot/tokenization/factory.py
@@ -1,8 +1,13 @@
from .croatian import CroatianVocabularyTokenizer
+from .generic import StanzaVocabularyTokenizer
from .japanese import JapaneseKanjiTokenizer, JapaneseWordTokenizer
def get_tokenizers(language_code: str):
+ if language_code.lower() == "es":
+ return {
+ "vocabulary": StanzaVocabularyTokenizer("es"),
+ }
if language_code.lower() == "hr":
return {
"vocabulary": CroatianVocabularyTokenizer(),
diff --git a/src/nlparrot/tokenization/generic.py b/src/nlparrot/tokenization/generic.py
@@ -2,6 +2,8 @@ import abc
import collections.abc
import dataclasses
+import stanza
+
@dataclasses.dataclass
class Token:
@@ -27,3 +29,54 @@ class Tokenizer(abc.ABC):
@abc.abstractmethod
def tokenize(self, text: str) -> collections.abc.Iterable[Token]:
pass
+
+
+class StanzaVocabularyTokenizer(Tokenizer):
+ def __init__(self, language: str):
+ self._language = language
+ self._stanza = None
+
+ def tokenize(self, text: str) -> collections.abc.Iterable[Token]:
+ # stanza performs strip() on lines and re-sets the char counter on newline.
+ # To maintain a correct char counter we have to perform this on our own and
+ # memorize what we removed.
+ char_offset = 0
+ nlp = self._get_stanza()
+
+ for paragraph in text.split("\n"):
+ lstripped = paragraph.lstrip()
+ lstripped_count = len(paragraph) - len(lstripped)
+ stripped = lstripped.rstrip()
+ rstripped_count = len(lstripped) - len(stripped)
+
+ doc = nlp(stripped)
+
+ char_offset += lstripped_count
+
+ for sentence in doc.sentences:
+ for token in sentence.tokens:
+ data = token.to_dict()[0]
+
+ if data["upos"].lower() not in ["punct", "sym"]:
+ yield Token(
+ start=token.start_char + char_offset,
+ end=token.end_char + char_offset,
+ token=data["lemma"],
+ original_text=data["text"],
+ )
+
+ char_offset += len(stripped)
+ char_offset += rstripped_count
+ # add one for the split \n
+ char_offset += 1
+
+ def _get_stanza(self):
+ if self._stanza is None:
+ self._stanza = stanza.Pipeline(
+ self._language,
+ processors="tokenize,pos,lemma,depparse",
+ # TODO: Add a DEV mode in which we use DownloadMethod.DOWNLOAD_RESOURCES
+ download_method=stanza.DownloadMethod.REUSE_RESOURCES,
+ )
+
+ return self._stanza
diff --git a/tests/tokenization/test_spanish.py b/tests/tokenization/test_spanish.py
@@ -0,0 +1,18 @@
+from nlparrot.tokenization.generic import StanzaVocabularyTokenizer, Token
+
+
+def test_spanish_vocabulary_tokenizer_keeps_whitespace():
+ tokenizer = StanzaVocabularyTokenizer("es")
+
+ # classla re-sets the char counter on each newline, and trims
+ # each line. We also count newline chars and whitespace.
+ result = list(
+ tokenizer.tokenize(
+ " Bienvenidos a Wikipedia,\n\n\n la enciclopedia de contenido libre \n que todos pueden editar. ."
+ )
+ )
+
+ assert Token(start=1, end=12, token="bienvenido", original_text="Bienvenidos") in result
+ assert Token(start=34, end=46, token="enciclopedia", original_text="enciclopedia") in result
+ assert Token(start=73, end=78, token="todo", original_text="todos") in result
+ assert Token(start=79, end=85, token="poder", original_text="pueden") in result