osschat-index
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
108 lines
3.5 KiB
108 lines
3.5 KiB
2 years ago
|
import logging
|
||
|
from typing import Union, List
|
||
|
|
||
|
from elasticsearch import Elasticsearch
|
||
|
import elasticsearch.helpers # type: ignore
|
||
|
|
||
|
from towhee.operator import PyOperator, SharedType # type: ignore
|
||
|
|
||
|
|
||
|
logger = logging.getLogger()
|
||
|
|
||
|
|
||
|
class ESIndex(PyOperator):
|
||
|
"""
|
||
|
Use bulk to insert docs into ElasticSearch index, using auto id generated.
|
||
|
|
||
|
Args:
|
||
|
host (`str`): host to connect ElasticSearch client
|
||
|
port (`int`): port to connect ElasticSearch client
|
||
|
user (`str=None`): user name to connect ElasticSearch client, defaults to None
|
||
|
password (`str=None`): user password to connect ElasticSearch client, defaults to None
|
||
|
ca_certs (`str=None`): path to CA certificate, defaults to None
|
||
|
"""
|
||
|
def __init__(self, host: str, port: int, user: str = None, password: str = None, ca_certs: str = None):
|
||
|
super().__init__()
|
||
|
try:
|
||
|
self.client = Elasticsearch(
|
||
|
f'https://{host}:{port}',
|
||
|
ca_certs=ca_certs,
|
||
|
basic_auth=(user, password))
|
||
|
logger.info('Successfully connected to ElasticSearch client.')
|
||
|
except Exception as e:
|
||
|
logger.error('Failed to connect ElasticSearch client:\n', e)
|
||
|
raise e
|
||
|
|
||
|
|
||
|
def __call__(self, index_name: str, doc: Union[dict, List[dict]]):
|
||
|
# if index not exist, create with stop words analyzer to strengthen the search accuracy
|
||
|
if not self.is_index_exist(index_name):
|
||
|
logger.info(f'index{index_name} not exists, will create the index with stopwords analyzer')
|
||
|
self.create_index_with_stopwords(index_name)
|
||
|
|
||
|
if isinstance(doc, dict):
|
||
|
docs = [doc]
|
||
|
else:
|
||
|
docs = doc
|
||
|
|
||
|
for x in docs:
|
||
|
assert isinstance(x, dict)
|
||
|
|
||
|
actions = [
|
||
|
{
|
||
|
'_op_type': 'index',
|
||
|
'_index': index_name,
|
||
|
'_source': docs[i]
|
||
|
}
|
||
|
for i in range(len(docs))
|
||
|
]
|
||
|
res = elasticsearch.helpers.bulk(self.client, actions, refresh=True)
|
||
|
return res
|
||
|
|
||
|
def is_index_exist(self, index_name: str):
|
||
|
return self.client.indices.exists(index=index_name)
|
||
|
|
||
|
def create_index_with_stopwords(self, index_name: str):
|
||
|
mappings = {
|
||
|
"properties": {
|
||
|
"milvus_id": {
|
||
|
"type": "long"
|
||
|
},
|
||
|
"paragraph": {
|
||
|
"type": "text",
|
||
|
"analyzer": "my_stop_analyzer",
|
||
|
"fields": {
|
||
|
"keyword": {
|
||
|
"type": "keyword",
|
||
|
"ignore_above": 256
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
"path": {
|
||
|
"type": "text",
|
||
|
"analyzer": "my_stop_analyzer",
|
||
|
"fields": {
|
||
|
"keyword": {
|
||
|
"type": "keyword",
|
||
|
"ignore_above": 256
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
settings = {
|
||
|
"analysis": {
|
||
|
"analyzer": {
|
||
|
"my_stop_analyzer": {
|
||
|
"type": "stop",
|
||
|
"stopwords_path": "stopwords/stopwords-en.txt"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
"number_of_shards": 3,
|
||
|
"number_of_replicas": 0
|
||
|
}
|
||
|
self.client.indices.create(index=index_name, mappings=mappings, settings=settings)
|
||
|
logger.info(f"created index{index_name}")
|
||
|
|