diff --git a/image_crop_cv2.py b/image_crop_cv2.py index 6eca7a9..0c18a10 100644 --- a/image_crop_cv2.py +++ b/image_crop_cv2.py @@ -31,20 +31,28 @@ log = logging.getLogger() class ImageCropCV2(PyOperator): def __init__(self, clamp = False): self.clamp = clamp + self.h = None + self.w = None @staticmethod def _clamp(x, minimum, maximum): return max(minimum, min(x, maximum)) + + def get_box(self, img, box): + x1, y1, x2, y2 = box + x1 = ImageCropCV2._clamp(x1, 0, self.w) + x2 = ImageCropCV2._clamp(x2, 0, self.w) + y1 = ImageCropCV2._clamp(y1, 0, self.h) + y2 = ImageCropCV2._clamp(y2, 0, self.h) + return img[y1:y2,x1:x2,:] def __call__(self, img: np.ndarray, bboxes: List[Tuple]): - h, w, _ = img.shape + self.h, self.w, _ = img.shape res = [] - for box in bboxes: - x1, y1, x2, y2 = box - x1 = ImageCropCV2._clamp(x1, 0, w) - x2 = ImageCropCV2._clamp(x2, 0, w) - y1 = ImageCropCV2._clamp(y1, 0, h) - y2 = ImageCropCV2._clamp(y2, 0, h) - res.append(img[y1:y2,x1:x2,:]) + try: + res.append(self.get_box(img, bboxes)) + except: + for box in bboxes: + res.append(self.get_box(img, box)) return res