From e14750a1737ee07cd1cfa02fea25d789b748435d Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Tue, 30 May 2023 18:39:52 +0800 Subject: [PATCH] Add spliter type param Signed-off-by: shiyu22 --- README.md | 8 ++++++-- requirements.txt | 3 +++ spliter.py | 42 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 7092586..03c31aa 100644 --- a/README.md +++ b/README.md @@ -46,9 +46,13 @@ Create the operator via the following factory method **Parameters:** -​ ***chunk_size***: int +​ ***type***: str -​ The size of each chunk, defaults to 300. +​ The type of spliter, defaults to 'RecursiveCharacter'. You can set this parameter in ['[RecursiveCharacter](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/recursive_text_splitter.html)', '[Markdown](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/markdown.html)', '[PythonCode](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/python.html)', '[Character](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/character_text_splitter.html#)', '[NLTK](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/nltk.html)', '[Spacy](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/spacy.html)', '[Tiktoken](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/tiktoken_splitter.html)', '[HuggingFace](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/huggingface_length_function.html)']. + +​ ***chunk_size***: int + +​ The maximum size of chunk, defaults to 300.
diff --git a/requirements.txt b/requirements.txt index ded1321..d78c97b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,4 @@ langchain>=0.0.151 +transformers +tiktoken +spacy diff --git a/spliter.py b/spliter.py index a33a9b6..927d93c 100644 --- a/spliter.py +++ b/spliter.py @@ -1,15 +1,45 @@ from towhee.operator import PyOperator from typing import List -from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.text_splitter import ( + RecursiveCharacterTextSplitter, + MarkdownTextSplitter, + PythonCodeTextSplitter, + CharacterTextSplitter, + NLTKTextSplitter, + SpacyTextSplitter, + TokenTextSplitter, + ) class TextSpliter(PyOperator): '''Split data into a list.''' - def __init__(self, chunk_size: int = 300): + def __init__(self, type: str = 'RecursiveCharacter', chunk_size: int = 300, **kwargs): super().__init__() - self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size) + self.type = type + if self.type == 'RecursiveCharacter': + self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, **kwargs) + elif self.type == 'Markdown': + self.splitter = MarkdownTextSplitter(chunk_size=chunk_size, **kwargs) + elif self.type == 'PythonCode': + self.splitter = PythonCodeTextSplitter(chunk_size=chunk_size, **kwargs) + elif self.type == 'Character': + self.splitter = CharacterTextSplitter(chunk_size=chunk_size, **kwargs) + elif self.type == 'NLTK': + self.splitter = NLTKTextSplitter(chunk_size=chunk_size, **kwargs) + elif self.type == 'Spacy': + self.splitter = SpacyTextSplitter(chunk_size=chunk_size, **kwargs) + elif self.type == 'Tiktoken': + self.splitter = TokenTextSplitter(chunk_size=chunk_size, **kwargs) + elif self.type == 'HuggingFace': + if 'tokenizer' not in kwargs: + from transformers import GPT2TokenizerFast + kwargs['tokenizer'] = GPT2TokenizerFast.from_pretrained("gpt2") + self.splitter = CharacterTextSplitter.from_huggingface_tokenizer(chunk_size=chunk_size, **kwargs) + else: + raise ValueError("Invalid type. You need choose in ['RecursiveCharacter', 'Markdown', 'PythonCode' \ + 'Character', 'NLTK', 'Spacy', 'Tiktoken', 'HuggingFace'].") + def __call__(self, data: str) -> List[str]: - texts = self.splitter.create_documents([data]) - docs = self.splitter.split_documents(texts) - return [str(doc.page_content) for doc in docs] + texts = self.splitter.split_text(data) + return texts