towhee
/
embedding-concat
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
32 lines
855 B
32 lines
855 B
import sys
|
|
from typing import NamedTuple, List
|
|
from pathlib import Path
|
|
import numpy
|
|
|
|
from towhee.operator import Operator
|
|
|
|
|
|
class EmbeddingConcat(Operator):
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, w1: float, w2: float) -> None:
|
|
super().__init__()
|
|
if w1 is not None:
|
|
self.w1 = w1
|
|
else:
|
|
self.w1 = 1.0
|
|
if w2 is not None:
|
|
self.w2 = w2
|
|
else:
|
|
self.w2 = 1.0
|
|
|
|
def __call__(self, emb1: numpy.ndarray, emb2: numpy.ndarray) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
|
|
concatemb = []
|
|
concatemb.append(self.w1 * emb1)
|
|
concatemb.append(self.w2 * emb2)
|
|
concated_feature = numpy.hstack(concatemb).flatten()
|
|
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
|
|
return Outputs(concated_feature)
|
|
|