towhee
/
text-splitter
copied
3 changed files with 45 additions and 8 deletions
@ -1 +1,4 @@ |
|||
langchain>=0.0.151 |
|||
transformers |
|||
tiktoken |
|||
spacy |
|||
|
@ -1,15 +1,45 @@ |
|||
from towhee.operator import PyOperator |
|||
from typing import List |
|||
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|||
from langchain.text_splitter import ( |
|||
RecursiveCharacterTextSplitter, |
|||
MarkdownTextSplitter, |
|||
PythonCodeTextSplitter, |
|||
CharacterTextSplitter, |
|||
NLTKTextSplitter, |
|||
SpacyTextSplitter, |
|||
TokenTextSplitter, |
|||
) |
|||
|
|||
|
|||
class TextSpliter(PyOperator): |
|||
'''Split data into a list.''' |
|||
def __init__(self, chunk_size: int = 300): |
|||
def __init__(self, type: str = 'RecursiveCharacter', chunk_size: int = 300, **kwargs): |
|||
super().__init__() |
|||
self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size) |
|||
self.type = type |
|||
if self.type == 'RecursiveCharacter': |
|||
self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, **kwargs) |
|||
elif self.type == 'Markdown': |
|||
self.splitter = MarkdownTextSplitter(chunk_size=chunk_size, **kwargs) |
|||
elif self.type == 'PythonCode': |
|||
self.splitter = PythonCodeTextSplitter(chunk_size=chunk_size, **kwargs) |
|||
elif self.type == 'Character': |
|||
self.splitter = CharacterTextSplitter(chunk_size=chunk_size, **kwargs) |
|||
elif self.type == 'NLTK': |
|||
self.splitter = NLTKTextSplitter(chunk_size=chunk_size, **kwargs) |
|||
elif self.type == 'Spacy': |
|||
self.splitter = SpacyTextSplitter(chunk_size=chunk_size, **kwargs) |
|||
elif self.type == 'Tiktoken': |
|||
self.splitter = TokenTextSplitter(chunk_size=chunk_size, **kwargs) |
|||
elif self.type == 'HuggingFace': |
|||
if 'tokenizer' not in kwargs: |
|||
from transformers import GPT2TokenizerFast |
|||
kwargs['tokenizer'] = GPT2TokenizerFast.from_pretrained("gpt2") |
|||
self.splitter = CharacterTextSplitter.from_huggingface_tokenizer(chunk_size=chunk_size, **kwargs) |
|||
else: |
|||
raise ValueError("Invalid type. You need choose in ['RecursiveCharacter', 'Markdown', 'PythonCode' \ |
|||
'Character', 'NLTK', 'Spacy', 'Tiktoken', 'HuggingFace'].") |
|||
|
|||
|
|||
def __call__(self, data: str) -> List[str]: |
|||
texts = self.splitter.create_documents([data]) |
|||
docs = self.splitter.split_documents(texts) |
|||
return [str(doc.page_content) for doc in docs] |
|||
texts = self.splitter.split_text(data) |
|||
return texts |
|||
|
Loading…
Reference in new issue