logo
Browse Source

revise the scoring method.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
b792e276d0
  1. 19
      tn.py

19
tn.py

@ -64,6 +64,7 @@ def tn(sims: np.ndarray,
list of temporal aligned copied segments, [query_min, ref_min, query_max, ref_max] for each segment list of temporal aligned copied segments, [query_min, ref_min, query_max, ref_max] for each segment
""" """
infringe_box_list = [] infringe_box_list = []
infringe_score_list = []
path = 0 path = 0
node_pair2id = {} node_pair2id = {}
node_pair2id[(-1, -1)] = 0 node_pair2id[(-1, -1)] = 0
@ -177,8 +178,9 @@ def tn(sims: np.ndarray,
if ave_length > 0 and score / ave_length > min_sim and min(refer_max - refer_min, if ave_length > 0 and score / ave_length > min_sim and min(refer_max - refer_min,
query_max - query_min) > min_length and ious.max() < max_iou: query_max - query_min) > min_length and ious.max() < max_iou:
infringe_box_list.append([int(query_min), int(refer_min), int(query_max), int(refer_max)]) infringe_box_list.append([int(query_min), int(refer_min), int(query_max), int(refer_max)])
infringe_score_list.append(score / ave_length)
path += 1 path += 1
return infringe_box_list
return infringe_box_list, infringe_score_list
@register(output_schema=['vec']) @register(output_schema=['vec'])
@ -199,16 +201,7 @@ class TemporalNetwork(NNOperator):
def __call__(self, src_video_vec: 'ndarray', dst_video_vec: 'ndarray') -> float: def __call__(self, src_video_vec: 'ndarray', dst_video_vec: 'ndarray') -> float:
sim_map = np.dot(src_video_vec, dst_video_vec.T) sim_map = np.dot(src_video_vec, dst_video_vec.T)
res = tn(sim_map, self._tn_max_step, self._tn_top_k, self._max_path, self._min_sim, self._min_length, self._max_iou)
det_scores = []
for duplicate_det in res:
x1, y1, x2, y2 = duplicate_det
e1, e2 = x2 - x1, y2 - y1
e = max(e1,e2)
crop = sim_map[x1:x2, y1:y2]
standard_crop = cv2.resize(crop, dsize=(e, e), interpolation=cv2.INTER_CUBIC)
diagonal_edge = standard_crop.diagonal()
det_scores.append(diagonal_edge.mean())
return res, det_scores
ranges, scores = tn(sim_map, self._tn_max_step, self._tn_top_k, self._max_path, self._min_sim, self._min_length, self._max_iou)
return ranges, scores

Loading…
Cancel
Save