diff --git a/embedding_concat.py b/embedding_concat.py index 84d7631..257673b 100644 --- a/embedding_concat.py +++ b/embedding_concat.py @@ -1,7 +1,7 @@ import sys from typing import NamedTuple, List from pathlib import Path -import numpy as np +import numpy from towhee.operator import Operator @@ -11,17 +11,23 @@ class EmbeddingConcat(Operator): """ - def __init__(self) -> None: + def __init__(self, w1: float, w2: float) -> None: super().__init__() + if w1 is not None: + self.w1 = w1 + else: + self.w1 = 1.0 + if w2 is not None: + self.w2 = w2 + else: + self.w2 = 1.0 sys.path.append(str(Path(__file__).parent)) - def __call__(self, input_array_list: List[np.ndarray], input_weight_list: List[float]) -> NamedTuple('Outputs', [('feature_vector', np.ndarray)]): + def __call__(self, emb1: numpy.ndarray, emb2: numpy.ndarray) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): concatemb = [] - if input_weight_list is None or input_weight_list == []: - input_weight_list = [1.0] * len(input_array_list) - for emb_part, w in zip(input_array_list, input_weight_list): - concatemb.append(w * emb_part) - concated_feature = np.hstack(concatemb).flatten() - Outputs = NamedTuple('Outputs', [('feature_vector', np.ndarray)]) + concatemb.append(self.w1 * emb1) + concatemb.append(self.w2 * emb2) + concated_feature = numpy.hstack(concatemb).flatten() + Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(concated_feature) diff --git a/embedding_concat.yaml b/embedding_concat.yaml index d8f8f8d..9100591 100644 --- a/embedding_concat.yaml +++ b/embedding_concat.yaml @@ -5,9 +5,11 @@ labels: others: '' operator: towhee/embedding-concat init: + w1: float + w2: float call: input: - input_array_list: typing.List[numpy.ndarray] - input_weight_list: typing.List[float] + emb1: numpy.ndarray + emb2: numpy.ndarray output: feature_vector: numpy.ndarray