logo
Browse Source

modify

main
zhang chen 4 years ago
parent
commit
afab897944
  1. 24
      embedding_concat.py
  2. 6
      embedding_concat.yaml

24
embedding_concat.py

@ -1,7 +1,7 @@
import sys
from typing import NamedTuple, List
from pathlib import Path
import numpy as np
import numpy
from towhee.operator import Operator
@ -11,17 +11,23 @@ class EmbeddingConcat(Operator):
"""
def __init__(self) -> None:
def __init__(self, w1: float, w2: float) -> None:
super().__init__()
if w1 is not None:
self.w1 = w1
else:
self.w1 = 1.0
if w2 is not None:
self.w2 = w2
else:
self.w2 = 1.0
sys.path.append(str(Path(__file__).parent))
def __call__(self, input_array_list: List[np.ndarray], input_weight_list: List[float]) -> NamedTuple('Outputs', [('feature_vector', np.ndarray)]):
def __call__(self, emb1: numpy.ndarray, emb2: numpy.ndarray) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
concatemb = []
if input_weight_list is None or input_weight_list == []:
input_weight_list = [1.0] * len(input_array_list)
for emb_part, w in zip(input_array_list, input_weight_list):
concatemb.append(w * emb_part)
concated_feature = np.hstack(concatemb).flatten()
Outputs = NamedTuple('Outputs', [('feature_vector', np.ndarray)])
concatemb.append(self.w1 * emb1)
concatemb.append(self.w2 * emb2)
concated_feature = numpy.hstack(concatemb).flatten()
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(concated_feature)

6
embedding_concat.yaml

@ -5,9 +5,11 @@ labels:
others: ''
operator: towhee/embedding-concat
init:
w1: float
w2: float
call:
input:
input_array_list: typing.List[numpy.ndarray]
input_weight_list: typing.List[float]
emb1: numpy.ndarray
emb2: numpy.ndarray
output:
feature_vector: numpy.ndarray

Loading…
Cancel
Save