logo
Browse Source

modify

main
zhang chen 4 years ago
parent
commit
afab897944
  1. 24
      embedding_concat.py
  2. 6
      embedding_concat.yaml

24
embedding_concat.py

@ -1,7 +1,7 @@
import sys import sys
from typing import NamedTuple, List from typing import NamedTuple, List
from pathlib import Path from pathlib import Path
import numpy as np
import numpy
from towhee.operator import Operator from towhee.operator import Operator
@ -11,17 +11,23 @@ class EmbeddingConcat(Operator):
""" """
def __init__(self) -> None:
def __init__(self, w1: float, w2: float) -> None:
super().__init__() super().__init__()
if w1 is not None:
self.w1 = w1
else:
self.w1 = 1.0
if w2 is not None:
self.w2 = w2
else:
self.w2 = 1.0
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
def __call__(self, input_array_list: List[np.ndarray], input_weight_list: List[float]) -> NamedTuple('Outputs', [('feature_vector', np.ndarray)]):
def __call__(self, emb1: numpy.ndarray, emb2: numpy.ndarray) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
concatemb = [] concatemb = []
if input_weight_list is None or input_weight_list == []:
input_weight_list = [1.0] * len(input_array_list)
for emb_part, w in zip(input_array_list, input_weight_list):
concatemb.append(w * emb_part)
concated_feature = np.hstack(concatemb).flatten()
Outputs = NamedTuple('Outputs', [('feature_vector', np.ndarray)])
concatemb.append(self.w1 * emb1)
concatemb.append(self.w2 * emb2)
concated_feature = numpy.hstack(concatemb).flatten()
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(concated_feature) return Outputs(concated_feature)

6
embedding_concat.yaml

@ -5,9 +5,11 @@ labels:
others: '' others: ''
operator: towhee/embedding-concat operator: towhee/embedding-concat
init: init:
w1: float
w2: float
call: call:
input: input:
input_array_list: typing.List[numpy.ndarray]
input_weight_list: typing.List[float]
emb1: numpy.ndarray
emb2: numpy.ndarray
output: output:
feature_vector: numpy.ndarray feature_vector: numpy.ndarray

Loading…
Cancel
Save