# Copyright 2022 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from itertools import groupby from typing import List from towhee import register from towhee.operator import PyOperator @register(output_schema=['vec']) class SelectVideo(PyOperator): def __init__(self, top_k: int, reduce_function: str = 'sum', reverse: bool = True): self.top_k = top_k self.reduce_function = reduce_function self.reverse = reverse def _select(self, video_urls, scores): assert len(video_urls) == len(scores), 'len(video_urls) must equal len(scores)' video2score_list = [{'video_url': video_url, 'score': score} for video_url, score in zip(video_urls, scores)] video2score_list.sort(key=lambda d: d['video_url']) reduced_dict = {} for video_url, items in groupby(video2score_list, key=lambda d: d['video_url']): video_score_list = [item['score'] for item in items] if self.reduce_function == 'sum': reduced_value = sum(video_score_list) elif self.reduce_function == 'mean': reduced_value = sum(video_score_list) / len(video_score_list) elif self.reduce_function == 'max': reduced_value = max(video_score_list) elif self.reduce_function == 'min': reduced_value = min(video_score_list) else: raise Exception('unknown reduce_function') reduced_dict[video_url] = reduced_value sorted_video_urls = [k for k, v in sorted(reduced_dict.items(), key=lambda item: item[1], reverse=self.reverse)] top_k = min(self.top_k, len(sorted_video_urls)) return sorted_video_urls[:top_k] def __call__(self, video_urls: List[str], scores: List[float]) -> List[str]: """ args: video_urls scores return: video_list """ if isinstance(video_urls[0], list): video_urls = sum(video_urls, []) scores = sum(scores, []) if isinstance(video_urls[0], str): return self._select(video_urls, scores) # if __name__ == '__main__': # op = SelectVideo(top_k=2, reverse=True, reduce_function='min') # res = op(['a', 'a', 'c', 'a', 'b', 'b', 'c', 'c'], # [2, 1, 9, 5, 2, 1, 2, 2] # ) # print(res)