diff --git a/README.md b/README.md
index 0c2e081..bcac577 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,88 @@
-# temporal-network
+# Video Alignment with Temporal Network
+
+*author: David Wang*
+
+
+
+
+
+
+## Description
+
+This operator can compare two ordered sequences, then detect the range which features from each sequence are computationally similar in order.
+
+
+
+## Code Example
+
+
+```python
+placeholder
+```
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***clip(model_name, modality)***
+***temporal_network(tn_max_step, tn_top_k, max_path, min_sim, min_length, max_iou)***
+
+
+**Parameters:**
+
+ ***tn_max_step:*** *str*
+
+ Max step range in TN.
+
+ ***tn_top_k:*** *str*
+
+ Top k frame similarity selection in TN.
+
+ ***max_path:*** *str*
+
+ Max loop for multiply segments detection.
+
+ ***min_sim:*** *str*
+
+ Min average similarity score for each aligned segment.
+
+ ***min_length:*** *str*
+
+ Min segment length.
+
+ ***max_iout:*** *str*
+
+ Max iou for filtering overlap segments (bbox).
+
+
+
+
+
+## Interface
+
+A Temporal Network operator takes two numpy.ndarray(shape(N,D) N: number of features. D: dimension of features) and get the duplicated ranges and scores.
+
+
+**Parameters:**
+
+ ***src_video_vec data:*** *numpy.ndarray*
+
+ Source video feature vectors.
+
+ ***dst_video_vec:*** *numpy.ndarray*
+
+ Destination video feature vectors.
+
+
+**Returns:** *List[List[Int]], List[float] *
+
+ The returned aligned range and similarity score.
+
+
+
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..5f4c3f6
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .tn import TemporalNetwork
+
+
+def temporal_network(tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10, min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3):
+ return TemporalNetwork(tn_max_step, tn_top_k, max_path, min_sim, min_length, max_iou)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..4bb49b5
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+towhee
+networkx
+
diff --git a/tn.py b/tn.py
new file mode 100644
index 0000000..a5e428f
--- /dev/null
+++ b/tn.py
@@ -0,0 +1,214 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import cv2
+import numpy as np
+import networkx as nx
+from networkx.algorithms.dag import dag_longest_path
+from typing import List
+from towhee.operator.base import NNOperator, OperatorFlag
+from towhee.types.arg import arg, to_image_color
+from towhee import register
+
+def iou(bbox: np.ndarray, gt: np.ndarray) -> np.ndarray:
+ """
+ IoU calculation for next-step filtering
+ Parameters
+ ----------
+ bbox: bounding box array (n, 4)
+ gt: bounding box array (m, 4)
+ Returns
+ -------
+ IoU results with dimension (n, m)
+ """
+ if len(bbox) == 0 or len(gt) == 0:
+ return np.array(0)
+ lt = np.maximum(bbox[:, None, :2], gt[:, :2]) # left_top (x, y)
+ rb = np.minimum(bbox[:, None, 2:], gt[:, 2:]) # right_bottom (x, y)
+ wh = np.maximum(rb - lt + 1, 0) # inter_area (w, h)
+ inter_areas = wh[:, :, 0] * wh[:, :, 1] # shape: (n, m)
+ box_areas = (bbox[:, 2] - bbox[:, 0] + 1) * (bbox[:, 3] - bbox[:, 1] + 1)
+ gt_areas = (gt[:, 2] - gt[:, 0] + 1) * (gt[:, 3] - gt[:, 1] + 1)
+ IoU = inter_areas / (box_areas[:, None] + gt_areas - inter_areas)
+ return np.array(IoU)
+
+def tn(sims: np.ndarray,
+ tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10,
+ min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3) -> List[List[int]]:
+ """
+ TN method for video temporal alignment.
+ Reimplemented paper:
+ {Tan H K, Ngo C W, Hong R, et al. Scalable detection of partial near-duplicate videos by visual-temporal consistency
+ [C]//Proceedings of the 17th ACM international conference on Multimedia. 2009: 145-154.}
+ Parameters
+ ----------
+ sims: input similarity map computed from a copied video pair.
+ tn_max_step: max step range in TN.
+ tn_top_k: Top k frame similarity selection in TN.
+ max_path: max loop for multiply segments detection.
+ min_sim: min average similarity score for each aligned segment.
+ min_length: min segment length.
+ max_iou: max iou for filtering overlap segments (bbox).
+ Returns
+ -------
+ list of temporal aligned copied segments, [query_min, ref_min, query_max, ref_max] for each segment
+ """
+ infringe_box_list = []
+ path = 0
+ node_pair2id = {}
+ node_pair2id[(-1, -1)] = 0
+
+ node_id2pair = {}
+ node_id2pair[0] = (-1, -1) # source
+
+ node_num = 1
+
+ DG = nx.DiGraph()
+ DG.add_node(0)
+
+ # get top-k values and indices, shape (Q_LEN, top_k)
+ top = min(tn_top_k, sims.shape[1])
+
+ topk_indices = np.argsort(-sims)[:, :top]
+ topk_sims = np.take_along_axis(sims, topk_indices, axis=-1)
+
+ # add nodes
+ for qf_idx in range(sims.shape[0]):
+ for k in range(top):
+ rf_idx = topk_indices[qf_idx][k]
+
+ node_id2pair[node_num] = (qf_idx, rf_idx)
+ node_pair2id[(qf_idx, rf_idx)] = node_num
+
+ DG.add_node(node_num)
+ node_num += 1
+
+ # create graph by adding edges
+ for q_i in range(sims.shape[0]):
+ r_i = topk_indices[q_i]
+
+ intermediate_rs = np.empty((0,), dtype=np.int32)
+ # implements Constraints C1 by limiting range end
+ for q_j in range(q_i + 1, min(sims.shape[0], q_i + tn_max_step)):
+ r_j = topk_indices[q_j] # shape (top_k, )
+
+ r_diff = r_j[:, None] - r_i # dst - src, shape (top_k, top_k)
+
+ # Constraints C2
+ C2 = (r_diff > 0) & (r_diff < tn_max_step)
+
+ # Constraints C3
+ if len(intermediate_rs) == 0:
+ C3 = np.ones(C2.shape, dtype=np.bool)
+ else:
+ # "the equal sign" in C3 in paper is wrong because it's contradictory to C2
+ cond1 = intermediate_rs[None, :] > r_i[:, None]
+ cond2 = intermediate_rs[None, :] < r_j[:, None]
+ C3 = np.sum(cond2[:, None, :] & cond1, axis=-1) == 0
+
+ # Constraints C4
+ s_j = topk_sims[q_j] # shape (top_k, )
+ s_j = np.repeat(s_j.reshape(-1, 1), r_diff.shape[1], axis=1) # shape (top_k, top_k)
+ C4 = s_j >= min_sim
+
+ val_rows, val_cols = np.where(C2 & C3 & C4)
+ val_sims = s_j[val_rows, val_cols]
+ # update intermediate_rs
+ valid_r_j = r_j[val_rows]
+ intermediate_rs = np.unique(np.concatenate([intermediate_rs, valid_r_j]))
+
+ edges = [(node_pair2id[(q_i, r_i[c])], node_pair2id[(q_j, r_j[r])], dict(weight=s))
+ for c, r, s in zip(val_cols, val_rows, val_sims)]
+
+ DG.add_edges_from(edges)
+
+ #logger.info("Graph N {} E {} for sim {}x{}", DG.number_of_nodes(), DG.number_of_edges(), sims.shape[0],
+ # sims.shape[1])
+
+ # link sink node
+ for i in range(0, node_num - 1):
+ j = node_num - 1
+
+ pair_i = node_id2pair[i]
+ pair_j = node_id2pair[j]
+
+ if (pair_j[0] > pair_i[0] and pair_j[1] > pair_i[1] and
+ pair_j[0] - pair_i[0] <= tn_max_step and pair_j[1] - pair_i[1] <= tn_max_step):
+ DG.add_edge(i, j, weight=0)
+
+ while True:
+ if path > max_path:
+ break
+ longest_path = dag_longest_path(DG)
+ for i in range(1, len(longest_path)):
+ DG.add_edge(longest_path[i - 1], longest_path[i], weight=0.0)
+ if 0 in longest_path:
+ longest_path.remove(0) # remove source node
+ if node_num - 1 in longest_path:
+ longest_path.remove(node_num - 1) # remove sink node
+ path_query = [node_id2pair[node_id][0] for node_id in longest_path]
+ path_refer = [node_id2pair[node_id][1] for node_id in longest_path]
+
+ if len(path_query) == 0:
+ break
+ score = 0.0
+ for (qf_idx, rf_idx) in zip(path_query, path_refer):
+ score += sims[qf_idx][rf_idx]
+ if score > 0:
+ query_min, query_max = min(path_query), max(path_query)
+ refer_min, refer_max = min(path_refer), max(path_refer)
+ else:
+ query_min, query_max = 0, 0
+ refer_min, refer_max = 0, 0
+ ave_length = (refer_max - refer_min + query_max - query_min) / 2
+ ious = iou(np.expand_dims(np.array([query_min, refer_min, query_max, refer_max]), axis=0),
+ np.array(infringe_box_list))
+
+ if ave_length > 0 and score / ave_length > min_sim and min(refer_max - refer_min,
+ query_max - query_min) > min_length and ious.max() < max_iou:
+ infringe_box_list.append([int(query_min), int(refer_min), int(query_max), int(refer_max)])
+ path += 1
+ return infringe_box_list
+
+
+@register(output_schema=['vec'])
+class TemporalNetwork(NNOperator):
+ """
+ TemporalNetwork
+ """
+ def __init__(self,
+ tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10,
+ min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3):
+
+ self._tn_max_step = tn_max_step
+ self._tn_top_k = tn_top_k
+ self._max_path = max_path
+ self._min_sim = min_sim
+ self._min_length = min_length
+ self._max_iou = max_iou
+
+ def __call__(self, src_video_vec: 'ndarray', dst_video_vec: 'ndarray') -> float:
+ sim_map = np.dot(src_video_vec, dst_video_vec.T)
+ res = tn(sim_map, self._tn_max_step, self._tn_top_k, self._max_path, self._min_sim, self._min_length, self._max_iou)
+ det_scores = []
+
+ for duplicate_det in res:
+ x1, y1, x2, y2 = duplicate_det
+ e1, e2 = x2 - x1, y2 - y1
+ e = max(e1,e2)
+ crop = sim_map[x1:x2, y1:y2]
+ standard_crop = cv2.resize(crop, dsize=(e, e), interpolation=cv2.INTER_CUBIC)
+ diagonal_edge = standard_crop.diagonal()
+ det_scores.append(diagonal_edge.mean())
+ return res, det_scores
+