logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

34 lines
907 B

4 years ago
import sys
from typing import NamedTuple, List
from pathlib import Path
4 years ago
import numpy
4 years ago
from towhee.operator import Operator
class EmbeddingConcat(Operator):
"""
"""
4 years ago
def __init__(self, w1: float, w2: float) -> None:
4 years ago
super().__init__()
4 years ago
if w1 is not None:
self.w1 = w1
else:
self.w1 = 1.0
if w2 is not None:
self.w2 = w2
else:
self.w2 = 1.0
4 years ago
sys.path.append(str(Path(__file__).parent))
4 years ago
def __call__(self, emb1: numpy.ndarray, emb2: numpy.ndarray) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
4 years ago
concatemb = []
4 years ago
concatemb.append(self.w1 * emb1)
concatemb.append(self.w2 * emb2)
concated_feature = numpy.hstack(concatemb).flatten()
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
4 years ago
return Outputs(concated_feature)