osschat-index
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
84 lines
2.8 KiB
84 lines
2.8 KiB
import logging
|
|
from typing import Union, List
|
|
|
|
from elasticsearch import Elasticsearch
|
|
import elasticsearch.helpers # type: ignore
|
|
|
|
from towhee.operator import PyOperator, SharedType # type: ignore
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class ESIndex(PyOperator):
|
|
"""
|
|
Use bulk to insert docs into ElasticSearch index, using auto id generated.
|
|
|
|
Args:
|
|
host (`str`): host to connect ElasticSearch client
|
|
port (`int`): port to connect ElasticSearch client
|
|
user (`str=None`): user name to connect ElasticSearch client, defaults to None
|
|
password (`str=None`): user password to connect ElasticSearch client, defaults to None
|
|
ca_certs (`str=None`): path to CA certificate, defaults to None
|
|
"""
|
|
def __init__(self, **connection_kwargs):
|
|
super().__init__()
|
|
try:
|
|
if 'port' in connection_kwargs:
|
|
assert 'host' in connection_kwargs, 'Missing port in connection kwargs but given port only.'
|
|
connection_kwargs['hosts'] = [f'https://{host}:{port}']
|
|
self.client = Elasticsearch(**connection_kwargs)
|
|
logger.info('Successfully connected to ElasticSearch client.')
|
|
except Exception as e:
|
|
logger.error('Failed to connect ElasticSearch client:\n', e)
|
|
raise e
|
|
|
|
|
|
def __call__(self, index_name: str, doc: Union[dict, List[dict]]):
|
|
# if index not exist, create with stop words analyzer to strengthen the search accuracy
|
|
if not self.is_index_exist(index_name):
|
|
self.create_index(index_name)
|
|
if isinstance(doc, dict):
|
|
docs = [doc]
|
|
else:
|
|
docs = doc
|
|
|
|
for x in docs:
|
|
assert isinstance(x, dict)
|
|
|
|
actions = [
|
|
{
|
|
'_op_type': 'index',
|
|
'_index': index_name,
|
|
'_source': d
|
|
}
|
|
for d in docs
|
|
]
|
|
res = elasticsearch.helpers.bulk(self.client, actions, refresh=True)
|
|
return res
|
|
|
|
def is_index_exist(self, index_name: str):
|
|
return self.client.indices.exists(index=index_name)
|
|
|
|
def create_index(self, index_name: str):
|
|
settings = {
|
|
"analysis": {"analyzer": {"default": {"type": "standard"}}},
|
|
"similarity": {
|
|
"custom_bm25": {
|
|
"type": "BM25",
|
|
"k1": 2.0,
|
|
"b": 0.75,
|
|
}
|
|
},
|
|
}
|
|
mappings = {
|
|
"properties": {
|
|
"sentence": {
|
|
"type": "text",
|
|
"similarity": "custom_bm25", # Use the custom BM25 similarity
|
|
}
|
|
}
|
|
}
|
|
|
|
# Create the index with the specified settings and mappings
|
|
self.client.indices.create(index=index_name, mappings=mappings, settings=settings)
|