haha
This commit is contained in:
113
captcha_clip.py
Normal file
113
captcha_clip.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
用 CLIP 做图块相似度匹配,在大图中找模板图标位置
|
||||
原理:把大图切成滑动窗口小块,用 CLIP 计算每块和模板的视觉相似度,取最高分的块
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from pathlib import Path
|
||||
|
||||
BASE = Path(__file__).parent / "images"
|
||||
MODEL_NAME = "openai/clip-vit-base-patch32" # 约600MB,小模型
|
||||
|
||||
|
||||
def load_clip():
|
||||
print("加载 CLIP 模型(约600MB,首次自动下载)...")
|
||||
model = CLIPModel.from_pretrained(MODEL_NAME).to("cuda")
|
||||
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
|
||||
print("CLIP 加载完成")
|
||||
return model, processor
|
||||
|
||||
|
||||
def find_by_clip(model, processor, main_img: np.ndarray, template_img: np.ndarray,
|
||||
step=10, win_sizes=None):
|
||||
"""
|
||||
滑动窗口 + CLIP 相似度,找模板在大图中的最佳位置
|
||||
"""
|
||||
if win_sizes is None:
|
||||
th, tw = template_img.shape[:2]
|
||||
# 尝试原始尺寸及上下浮动
|
||||
win_sizes = [(int(tw * s), int(th * s)) for s in [0.8, 0.9, 1.0, 1.1, 1.2]]
|
||||
|
||||
# 预处理模板
|
||||
tmpl_pil = Image.fromarray(cv2.cvtColor(template_img, cv2.COLOR_BGR2RGB))
|
||||
tmpl_inputs = processor(images=tmpl_pil, return_tensors="pt").to("cuda")
|
||||
with torch.no_grad():
|
||||
tmpl_out = model.vision_model(**tmpl_inputs)
|
||||
tmpl_feat = model.visual_projection(tmpl_out.pooler_output).float()
|
||||
tmpl_feat = tmpl_feat / tmpl_feat.norm(dim=-1, keepdim=True)
|
||||
|
||||
mh, mw = main_img.shape[:2]
|
||||
best_score = -1
|
||||
best_box = None
|
||||
|
||||
for (ww, wh) in win_sizes:
|
||||
if ww > mw or wh > mh:
|
||||
continue
|
||||
for y in range(0, mh - wh + 1, step):
|
||||
for x in range(0, mw - ww + 1, step):
|
||||
crop = main_img[y:y+wh, x:x+ww]
|
||||
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
|
||||
crop_inputs = processor(images=crop_pil, return_tensors="pt").to("cuda")
|
||||
with torch.no_grad():
|
||||
crop_out = model.vision_model(**crop_inputs)
|
||||
crop_feat = model.visual_projection(crop_out.pooler_output).float()
|
||||
crop_feat = crop_feat / crop_feat.norm(dim=-1, keepdim=True)
|
||||
score = (tmpl_feat * crop_feat).sum().item()
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_box = (x, y, ww, wh)
|
||||
|
||||
return best_box, best_score
|
||||
|
||||
|
||||
def main():
|
||||
main_img = cv2.imread(str(BASE / "1.jpg"))
|
||||
templates = {
|
||||
"2.png": cv2.imread(str(BASE / "2.png")),
|
||||
"3.png": cv2.imread(str(BASE / "3.png")),
|
||||
"4.png": cv2.imread(str(BASE / "4.png")),
|
||||
}
|
||||
|
||||
# 模板是 BGRA,转 BGR
|
||||
for name in templates:
|
||||
img = cv2.imread(str(BASE / name), cv2.IMREAD_UNCHANGED)
|
||||
if img.shape[2] == 4:
|
||||
# alpha 通道合成白底
|
||||
alpha = img[:, :, 3:4] / 255.0
|
||||
rgb = img[:, :, :3].astype(float)
|
||||
white = np.ones_like(rgb) * 255
|
||||
merged = (rgb * alpha + white * (1 - alpha)).astype(np.uint8)
|
||||
templates[name] = merged
|
||||
else:
|
||||
templates[name] = img
|
||||
|
||||
model, processor = load_clip()
|
||||
|
||||
vis = main_img.copy()
|
||||
colors = {"2.png": (0, 0, 255), "3.png": (0, 255, 0), "4.png": (255, 0, 0)}
|
||||
|
||||
print()
|
||||
for name, tmpl in templates.items():
|
||||
print(f"正在匹配 {name} ...")
|
||||
box, score = find_by_clip(model, processor, main_img, tmpl, step=8)
|
||||
if box:
|
||||
x, y, w, h = box
|
||||
cx, cy = x + w // 2, y + h // 2
|
||||
color = colors[name]
|
||||
cv2.rectangle(vis, (x, y), (x+w, y+h), color, 2)
|
||||
cv2.circle(vis, (cx, cy), 5, color, -1)
|
||||
cv2.putText(vis, name, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||
print(f" {name}: 中心点 ({cx}, {cy}) 相似度={score:.4f}")
|
||||
else:
|
||||
print(f" {name}: 未找到")
|
||||
|
||||
out = BASE / "result_clip.jpg"
|
||||
cv2.imwrite(str(out), vis)
|
||||
print(f"\n结果保存到: {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user