towhee
/
text-splitter
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
45 lines
2.0 KiB
45 lines
2.0 KiB
from towhee.operator import PyOperator
|
|
from typing import List
|
|
from langchain.text_splitter import (
|
|
RecursiveCharacterTextSplitter,
|
|
MarkdownTextSplitter,
|
|
PythonCodeTextSplitter,
|
|
CharacterTextSplitter,
|
|
NLTKTextSplitter,
|
|
SpacyTextSplitter,
|
|
TokenTextSplitter,
|
|
)
|
|
|
|
|
|
class TextSplitter(PyOperator):
|
|
'''Split data into a list.'''
|
|
def __init__(self, type: str = 'RecursiveCharacter', chunk_size: int = 300, **kwargs):
|
|
super().__init__()
|
|
self.type = type
|
|
if self.type == 'RecursiveCharacter':
|
|
self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, **kwargs)
|
|
elif self.type == 'Markdown':
|
|
self.splitter = MarkdownTextSplitter(chunk_size=chunk_size, **kwargs)
|
|
elif self.type == 'PythonCode':
|
|
self.splitter = PythonCodeTextSplitter(chunk_size=chunk_size, **kwargs)
|
|
elif self.type == 'Character':
|
|
self.splitter = CharacterTextSplitter(chunk_size=chunk_size, **kwargs)
|
|
elif self.type == 'NLTK':
|
|
self.splitter = NLTKTextSplitter(chunk_size=chunk_size, **kwargs)
|
|
elif self.type == 'Spacy':
|
|
self.splitter = SpacyTextSplitter(chunk_size=chunk_size, **kwargs)
|
|
elif self.type == 'Tiktoken':
|
|
self.splitter = TokenTextSplitter(chunk_size=chunk_size, **kwargs)
|
|
elif self.type == 'HuggingFace':
|
|
if 'tokenizer' not in kwargs:
|
|
from transformers import GPT2TokenizerFast
|
|
kwargs['tokenizer'] = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
self.splitter = CharacterTextSplitter.from_huggingface_tokenizer(chunk_size=chunk_size, **kwargs)
|
|
else:
|
|
raise ValueError("Invalid type. You need choose in ['RecursiveCharacter', 'Markdown', 'PythonCode' \
|
|
'Character', 'NLTK', 'Spacy', 'Tiktoken', 'HuggingFace'].")
|
|
|
|
|
|
def __call__(self, data: str) -> List[str]:
|
|
texts = self.splitter.split_text(data)
|
|
return texts
|