towhee
/
text-splitter
copied
3 changed files with 45 additions and 8 deletions
@ -1 +1,4 @@ |
|||||
langchain>=0.0.151 |
langchain>=0.0.151 |
||||
|
transformers |
||||
|
tiktoken |
||||
|
spacy |
||||
|
@ -1,15 +1,45 @@ |
|||||
from towhee.operator import PyOperator |
from towhee.operator import PyOperator |
||||
from typing import List |
from typing import List |
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
||||
|
from langchain.text_splitter import ( |
||||
|
RecursiveCharacterTextSplitter, |
||||
|
MarkdownTextSplitter, |
||||
|
PythonCodeTextSplitter, |
||||
|
CharacterTextSplitter, |
||||
|
NLTKTextSplitter, |
||||
|
SpacyTextSplitter, |
||||
|
TokenTextSplitter, |
||||
|
) |
||||
|
|
||||
|
|
||||
class TextSpliter(PyOperator): |
class TextSpliter(PyOperator): |
||||
'''Split data into a list.''' |
'''Split data into a list.''' |
||||
def __init__(self, chunk_size: int = 300): |
|
||||
|
def __init__(self, type: str = 'RecursiveCharacter', chunk_size: int = 300, **kwargs): |
||||
super().__init__() |
super().__init__() |
||||
self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size) |
|
||||
|
self.type = type |
||||
|
if self.type == 'RecursiveCharacter': |
||||
|
self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, **kwargs) |
||||
|
elif self.type == 'Markdown': |
||||
|
self.splitter = MarkdownTextSplitter(chunk_size=chunk_size, **kwargs) |
||||
|
elif self.type == 'PythonCode': |
||||
|
self.splitter = PythonCodeTextSplitter(chunk_size=chunk_size, **kwargs) |
||||
|
elif self.type == 'Character': |
||||
|
self.splitter = CharacterTextSplitter(chunk_size=chunk_size, **kwargs) |
||||
|
elif self.type == 'NLTK': |
||||
|
self.splitter = NLTKTextSplitter(chunk_size=chunk_size, **kwargs) |
||||
|
elif self.type == 'Spacy': |
||||
|
self.splitter = SpacyTextSplitter(chunk_size=chunk_size, **kwargs) |
||||
|
elif self.type == 'Tiktoken': |
||||
|
self.splitter = TokenTextSplitter(chunk_size=chunk_size, **kwargs) |
||||
|
elif self.type == 'HuggingFace': |
||||
|
if 'tokenizer' not in kwargs: |
||||
|
from transformers import GPT2TokenizerFast |
||||
|
kwargs['tokenizer'] = GPT2TokenizerFast.from_pretrained("gpt2") |
||||
|
self.splitter = CharacterTextSplitter.from_huggingface_tokenizer(chunk_size=chunk_size, **kwargs) |
||||
|
else: |
||||
|
raise ValueError("Invalid type. You need choose in ['RecursiveCharacter', 'Markdown', 'PythonCode' \ |
||||
|
'Character', 'NLTK', 'Spacy', 'Tiktoken', 'HuggingFace'].") |
||||
|
|
||||
|
|
||||
def __call__(self, data: str) -> List[str]: |
def __call__(self, data: str) -> List[str]: |
||||
texts = self.splitter.create_documents([data]) |
|
||||
docs = self.splitter.split_documents(texts) |
|
||||
return [str(doc.page_content) for doc in docs] |
|
||||
|
texts = self.splitter.split_text(data) |
||||
|
return texts |
||||
|
Loading…
Reference in new issue