diff --git a/README.md b/README.md index 0c2e081..bcac577 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,88 @@ -# temporal-network +# Video Alignment with Temporal Network + +*author: David Wang* + + +
+ + + +## Description + +This operator can compare two ordered sequences, then detect the range which features from each sequence are computationally similar in order. +
+ + +## Code Example + + +```python +placeholder +``` + + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***clip(model_name, modality)*** +***temporal_network(tn_max_step, tn_top_k, max_path, min_sim, min_length, max_iou)*** + + +**Parameters:** + +​ ***tn_max_step:*** *str* + +​ Max step range in TN. + +​ ***tn_top_k:*** *str* + +​ Top k frame similarity selection in TN. + +​ ***max_path:*** *str* + +​ Max loop for multiply segments detection. + +​ ***min_sim:*** *str* + +​ Min average similarity score for each aligned segment. + +​ ***min_length:*** *str* + +​ Min segment length. + +​ ***max_iout:*** *str* + +​ Max iou for filtering overlap segments (bbox). + +
+ + + +## Interface + +A Temporal Network operator takes two numpy.ndarray(shape(N,D) N: number of features. D: dimension of features) and get the duplicated ranges and scores. + + +**Parameters:** + +​ ***src_video_vec data:*** *numpy.ndarray* + +​ Source video feature vectors. + +​ ***dst_video_vec:*** *numpy.ndarray* + +​ Destination video feature vectors. + + +**Returns:** *List[List[Int]], List[float] * + +​ The returned aligned range and similarity score. + + + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..5f4c3f6 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tn import TemporalNetwork + + +def temporal_network(tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10, min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3): + return TemporalNetwork(tn_max_step, tn_top_k, max_path, min_sim, min_length, max_iou) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4bb49b5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +towhee +networkx + diff --git a/tn.py b/tn.py new file mode 100644 index 0000000..a5e428f --- /dev/null +++ b/tn.py @@ -0,0 +1,214 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cv2 +import numpy as np +import networkx as nx +from networkx.algorithms.dag import dag_longest_path +from typing import List +from towhee.operator.base import NNOperator, OperatorFlag +from towhee.types.arg import arg, to_image_color +from towhee import register + +def iou(bbox: np.ndarray, gt: np.ndarray) -> np.ndarray: + """ + IoU calculation for next-step filtering + Parameters + ---------- + bbox: bounding box array (n, 4) + gt: bounding box array (m, 4) + Returns + ------- + IoU results with dimension (n, m) + """ + if len(bbox) == 0 or len(gt) == 0: + return np.array(0) + lt = np.maximum(bbox[:, None, :2], gt[:, :2]) # left_top (x, y) + rb = np.minimum(bbox[:, None, 2:], gt[:, 2:]) # right_bottom (x, y) + wh = np.maximum(rb - lt + 1, 0) # inter_area (w, h) + inter_areas = wh[:, :, 0] * wh[:, :, 1] # shape: (n, m) + box_areas = (bbox[:, 2] - bbox[:, 0] + 1) * (bbox[:, 3] - bbox[:, 1] + 1) + gt_areas = (gt[:, 2] - gt[:, 0] + 1) * (gt[:, 3] - gt[:, 1] + 1) + IoU = inter_areas / (box_areas[:, None] + gt_areas - inter_areas) + return np.array(IoU) + +def tn(sims: np.ndarray, + tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10, + min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3) -> List[List[int]]: + """ + TN method for video temporal alignment. + Reimplemented paper: + {Tan H K, Ngo C W, Hong R, et al. Scalable detection of partial near-duplicate videos by visual-temporal consistency + [C]//Proceedings of the 17th ACM international conference on Multimedia. 2009: 145-154.} + Parameters + ---------- + sims: input similarity map computed from a copied video pair. + tn_max_step: max step range in TN. + tn_top_k: Top k frame similarity selection in TN. + max_path: max loop for multiply segments detection. + min_sim: min average similarity score for each aligned segment. + min_length: min segment length. + max_iou: max iou for filtering overlap segments (bbox). + Returns + ------- + list of temporal aligned copied segments, [query_min, ref_min, query_max, ref_max] for each segment + """ + infringe_box_list = [] + path = 0 + node_pair2id = {} + node_pair2id[(-1, -1)] = 0 + + node_id2pair = {} + node_id2pair[0] = (-1, -1) # source + + node_num = 1 + + DG = nx.DiGraph() + DG.add_node(0) + + # get top-k values and indices, shape (Q_LEN, top_k) + top = min(tn_top_k, sims.shape[1]) + + topk_indices = np.argsort(-sims)[:, :top] + topk_sims = np.take_along_axis(sims, topk_indices, axis=-1) + + # add nodes + for qf_idx in range(sims.shape[0]): + for k in range(top): + rf_idx = topk_indices[qf_idx][k] + + node_id2pair[node_num] = (qf_idx, rf_idx) + node_pair2id[(qf_idx, rf_idx)] = node_num + + DG.add_node(node_num) + node_num += 1 + + # create graph by adding edges + for q_i in range(sims.shape[0]): + r_i = topk_indices[q_i] + + intermediate_rs = np.empty((0,), dtype=np.int32) + # implements Constraints C1 by limiting range end + for q_j in range(q_i + 1, min(sims.shape[0], q_i + tn_max_step)): + r_j = topk_indices[q_j] # shape (top_k, ) + + r_diff = r_j[:, None] - r_i # dst - src, shape (top_k, top_k) + + # Constraints C2 + C2 = (r_diff > 0) & (r_diff < tn_max_step) + + # Constraints C3 + if len(intermediate_rs) == 0: + C3 = np.ones(C2.shape, dtype=np.bool) + else: + # "the equal sign" in C3 in paper is wrong because it's contradictory to C2 + cond1 = intermediate_rs[None, :] > r_i[:, None] + cond2 = intermediate_rs[None, :] < r_j[:, None] + C3 = np.sum(cond2[:, None, :] & cond1, axis=-1) == 0 + + # Constraints C4 + s_j = topk_sims[q_j] # shape (top_k, ) + s_j = np.repeat(s_j.reshape(-1, 1), r_diff.shape[1], axis=1) # shape (top_k, top_k) + C4 = s_j >= min_sim + + val_rows, val_cols = np.where(C2 & C3 & C4) + val_sims = s_j[val_rows, val_cols] + # update intermediate_rs + valid_r_j = r_j[val_rows] + intermediate_rs = np.unique(np.concatenate([intermediate_rs, valid_r_j])) + + edges = [(node_pair2id[(q_i, r_i[c])], node_pair2id[(q_j, r_j[r])], dict(weight=s)) + for c, r, s in zip(val_cols, val_rows, val_sims)] + + DG.add_edges_from(edges) + + #logger.info("Graph N {} E {} for sim {}x{}", DG.number_of_nodes(), DG.number_of_edges(), sims.shape[0], + # sims.shape[1]) + + # link sink node + for i in range(0, node_num - 1): + j = node_num - 1 + + pair_i = node_id2pair[i] + pair_j = node_id2pair[j] + + if (pair_j[0] > pair_i[0] and pair_j[1] > pair_i[1] and + pair_j[0] - pair_i[0] <= tn_max_step and pair_j[1] - pair_i[1] <= tn_max_step): + DG.add_edge(i, j, weight=0) + + while True: + if path > max_path: + break + longest_path = dag_longest_path(DG) + for i in range(1, len(longest_path)): + DG.add_edge(longest_path[i - 1], longest_path[i], weight=0.0) + if 0 in longest_path: + longest_path.remove(0) # remove source node + if node_num - 1 in longest_path: + longest_path.remove(node_num - 1) # remove sink node + path_query = [node_id2pair[node_id][0] for node_id in longest_path] + path_refer = [node_id2pair[node_id][1] for node_id in longest_path] + + if len(path_query) == 0: + break + score = 0.0 + for (qf_idx, rf_idx) in zip(path_query, path_refer): + score += sims[qf_idx][rf_idx] + if score > 0: + query_min, query_max = min(path_query), max(path_query) + refer_min, refer_max = min(path_refer), max(path_refer) + else: + query_min, query_max = 0, 0 + refer_min, refer_max = 0, 0 + ave_length = (refer_max - refer_min + query_max - query_min) / 2 + ious = iou(np.expand_dims(np.array([query_min, refer_min, query_max, refer_max]), axis=0), + np.array(infringe_box_list)) + + if ave_length > 0 and score / ave_length > min_sim and min(refer_max - refer_min, + query_max - query_min) > min_length and ious.max() < max_iou: + infringe_box_list.append([int(query_min), int(refer_min), int(query_max), int(refer_max)]) + path += 1 + return infringe_box_list + + +@register(output_schema=['vec']) +class TemporalNetwork(NNOperator): + """ + TemporalNetwork + """ + def __init__(self, + tn_max_step: int = 10, tn_top_k: int = 5, max_path: int = 10, + min_sim: float = 0.2, min_length: int = 5, max_iou: float = 0.3): + + self._tn_max_step = tn_max_step + self._tn_top_k = tn_top_k + self._max_path = max_path + self._min_sim = min_sim + self._min_length = min_length + self._max_iou = max_iou + + def __call__(self, src_video_vec: 'ndarray', dst_video_vec: 'ndarray') -> float: + sim_map = np.dot(src_video_vec, dst_video_vec.T) + res = tn(sim_map, self._tn_max_step, self._tn_top_k, self._max_path, self._min_sim, self._min_length, self._max_iou) + det_scores = [] + + for duplicate_det in res: + x1, y1, x2, y2 = duplicate_det + e1, e2 = x2 - x1, y2 - y1 + e = max(e1,e2) + crop = sim_map[x1:x2, y1:y2] + standard_crop = cv2.resize(crop, dsize=(e, e), interpolation=cv2.INTER_CUBIC) + diagonal_edge = standard_crop.diagonal() + det_scores.append(diagonal_edge.mean()) + return res, det_scores +