from towhee.operator import PyOperator from typing import List from langchain.text_splitter import ( RecursiveCharacterTextSplitter, MarkdownTextSplitter, PythonCodeTextSplitter, CharacterTextSplitter, NLTKTextSplitter, SpacyTextSplitter, TokenTextSplitter, ) class TextSpliter(PyOperator): '''Split data into a list.''' def __init__(self, type: str = 'RecursiveCharacter', chunk_size: int = 300, **kwargs): super().__init__() self.type = type if self.type == 'RecursiveCharacter': self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, **kwargs) elif self.type == 'Markdown': self.splitter = MarkdownTextSplitter(chunk_size=chunk_size, **kwargs) elif self.type == 'PythonCode': self.splitter = PythonCodeTextSplitter(chunk_size=chunk_size, **kwargs) elif self.type == 'Character': self.splitter = CharacterTextSplitter(chunk_size=chunk_size, **kwargs) elif self.type == 'NLTK': self.splitter = NLTKTextSplitter(chunk_size=chunk_size, **kwargs) elif self.type == 'Spacy': self.splitter = SpacyTextSplitter(chunk_size=chunk_size, **kwargs) elif self.type == 'Tiktoken': self.splitter = TokenTextSplitter(chunk_size=chunk_size, **kwargs) elif self.type == 'HuggingFace': if 'tokenizer' not in kwargs: from transformers import GPT2TokenizerFast kwargs['tokenizer'] = GPT2TokenizerFast.from_pretrained("gpt2") self.splitter = CharacterTextSplitter.from_huggingface_tokenizer(chunk_size=chunk_size, **kwargs) else: raise ValueError("Invalid type. You need choose in ['RecursiveCharacter', 'Markdown', 'PythonCode' \ 'Character', 'NLTK', 'Spacy', 'Tiktoken', 'HuggingFace'].") def __call__(self, data: str) -> List[str]: texts = self.splitter.split_text(data) return texts