diff --git a/README.md b/README.md index 567d760..a83a693 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,9 @@ This operator can compare two ordered sequences, then detect the range which fea ```python +import towhee +import numpy as np + # simulate a video feature by 10 frames of 512d vectors. videos_embeddings = np.random.randn(10,512) videos_embeddings = videos_embeddings / np.linalg.norm(videos_embeddings,axis=1).reshape(10,-1) diff --git a/tn.py b/tn.py index b0aacf0..c4785eb 100644 --- a/tn.py +++ b/tn.py @@ -16,7 +16,7 @@ import numpy as np import networkx as nx from networkx.algorithms.dag import dag_longest_path from typing import List -from towhee.operator.base import NNOperator, OperatorFlag +from towhee.operator.base import Operator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register @@ -184,14 +184,14 @@ def tn(sims: np.ndarray, @register(output_schema=['vec']) -class TemporalNetwork(NNOperator): +class TemporalNetwork(Operator): """ TemporalNetwork """ def __init__(self, tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10, min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3): - + super().__init__() self._tn_max_step = tn_max_step self._tn_top_k = tn_top_k self._max_path = max_path