logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

33 lines
907 B

import sys
from typing import NamedTuple, List
from pathlib import Path
import numpy
from towhee.operator import Operator
class EmbeddingConcat(Operator):
"""
"""
def __init__(self, w1: float, w2: float) -> None:
super().__init__()
if w1 is not None:
self.w1 = w1
else:
self.w1 = 1.0
if w2 is not None:
self.w2 = w2
else:
self.w2 = 1.0
sys.path.append(str(Path(__file__).parent))
def __call__(self, emb1: numpy.ndarray, emb2: numpy.ndarray) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
concatemb = []
concatemb.append(self.w1 * emb1)
concatemb.append(self.w2 * emb2)
concated_feature = numpy.hstack(concatemb).flatten()
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(concated_feature)