|
@ -16,7 +16,7 @@ import numpy as np |
|
|
import networkx as nx |
|
|
import networkx as nx |
|
|
from networkx.algorithms.dag import dag_longest_path |
|
|
from networkx.algorithms.dag import dag_longest_path |
|
|
from typing import List |
|
|
from typing import List |
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
|
|
|
|
|
from towhee.operator.base import Operator, OperatorFlag |
|
|
from towhee.types.arg import arg, to_image_color |
|
|
from towhee.types.arg import arg, to_image_color |
|
|
from towhee import register |
|
|
from towhee import register |
|
|
|
|
|
|
|
@ -184,14 +184,14 @@ def tn(sims: np.ndarray, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
@register(output_schema=['vec']) |
|
|
class TemporalNetwork(NNOperator): |
|
|
|
|
|
|
|
|
class TemporalNetwork(Operator): |
|
|
""" |
|
|
""" |
|
|
TemporalNetwork |
|
|
TemporalNetwork |
|
|
""" |
|
|
""" |
|
|
def __init__(self, |
|
|
def __init__(self, |
|
|
tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10, |
|
|
tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10, |
|
|
min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3): |
|
|
min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3): |
|
|
|
|
|
|
|
|
|
|
|
super().__init__() |
|
|
self._tn_max_step = tn_max_step |
|
|
self._tn_max_step = tn_max_step |
|
|
self._tn_top_k = tn_top_k |
|
|
self._tn_top_k = tn_top_k |
|
|
self._max_path = max_path |
|
|
self._max_path = max_path |
|
|