|
|
|
import logging
|
|
|
|
from typing import Union, List
|
|
|
|
|
|
|
|
from elasticsearch import Elasticsearch
|
|
|
|
# import elasticsearch.helpers # type: ignore
|
|
|
|
|
|
|
|
from towhee.operator import PyOperator, SharedType # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
|
|
|
|
class ESSearch(PyOperator):
|
|
|
|
"""
|
|
|
|
Search using ElasticSearch with client ready
|
|
|
|
|
|
|
|
Args:
|
|
|
|
host (`str`): host to connect ElasticSearch client
|
|
|
|
port (`int`): port to connect ElasticSearch client
|
|
|
|
user (`str=None`): user name to connect ElasticSearch client, defaults to None
|
|
|
|
password (`str=None`): user password to connect ElasticSearch client, defaults to None
|
|
|
|
ca_certs (`str=None`): path to CA certificate, defaults to None
|
|
|
|
"""
|
|
|
|
def __init__(self, **connection_kwargs):
|
|
|
|
super().__init__()
|
|
|
|
try:
|
|
|
|
if 'port' in connection_kwargs:
|
|
|
|
assert 'host' in connection_kwargs, 'Missing port in connection kwargs but given port only.'
|
|
|
|
connection_kwargs['hosts'] = [f'https://{host}:{port}']
|
|
|
|
self.client = Elasticsearch(**connection_kwargs)
|
|
|
|
logger.info('Successfully connected to ElasticSearch client.')
|
|
|
|
except Exception as e:
|
|
|
|
logger.error('Failed to connect ElasticSearch client:\n', e)
|
|
|
|
raise e
|
|
|
|
|
|
|
|
def __call__(self, index_name: str, query: dict = {'match_all': {}}):
|
|
|
|
if not query:
|
|
|
|
return ''
|
|
|
|
resp = self.client.search(index=index_name, query=query)
|
|
|
|
return resp
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shared_type(self):
|
|
|
|
return SharedType.NotShareable
|