Files
codex_jxs_code/captcha_clip.py
2026-03-02 10:55:39 +08:00

114 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用 CLIP 做图块相似度匹配,在大图中找模板图标位置
原理:把大图切成滑动窗口小块,用 CLIP 计算每块和模板的视觉相似度,取最高分的块
"""
import cv2
import numpy as np
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from pathlib import Path
BASE = Path(__file__).parent / "images"
MODEL_NAME = "openai/clip-vit-base-patch32" # 约600MB小模型
def load_clip():
print("加载 CLIP 模型约600MB首次自动下载...")
model = CLIPModel.from_pretrained(MODEL_NAME).to("cuda")
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
print("CLIP 加载完成")
return model, processor
def find_by_clip(model, processor, main_img: np.ndarray, template_img: np.ndarray,
step=10, win_sizes=None):
"""
滑动窗口 + CLIP 相似度,找模板在大图中的最佳位置
"""
if win_sizes is None:
th, tw = template_img.shape[:2]
# 尝试原始尺寸及上下浮动
win_sizes = [(int(tw * s), int(th * s)) for s in [0.8, 0.9, 1.0, 1.1, 1.2]]
# 预处理模板
tmpl_pil = Image.fromarray(cv2.cvtColor(template_img, cv2.COLOR_BGR2RGB))
tmpl_inputs = processor(images=tmpl_pil, return_tensors="pt").to("cuda")
with torch.no_grad():
tmpl_out = model.vision_model(**tmpl_inputs)
tmpl_feat = model.visual_projection(tmpl_out.pooler_output).float()
tmpl_feat = tmpl_feat / tmpl_feat.norm(dim=-1, keepdim=True)
mh, mw = main_img.shape[:2]
best_score = -1
best_box = None
for (ww, wh) in win_sizes:
if ww > mw or wh > mh:
continue
for y in range(0, mh - wh + 1, step):
for x in range(0, mw - ww + 1, step):
crop = main_img[y:y+wh, x:x+ww]
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
crop_inputs = processor(images=crop_pil, return_tensors="pt").to("cuda")
with torch.no_grad():
crop_out = model.vision_model(**crop_inputs)
crop_feat = model.visual_projection(crop_out.pooler_output).float()
crop_feat = crop_feat / crop_feat.norm(dim=-1, keepdim=True)
score = (tmpl_feat * crop_feat).sum().item()
if score > best_score:
best_score = score
best_box = (x, y, ww, wh)
return best_box, best_score
def main():
main_img = cv2.imread(str(BASE / "1.jpg"))
templates = {
"2.png": cv2.imread(str(BASE / "2.png")),
"3.png": cv2.imread(str(BASE / "3.png")),
"4.png": cv2.imread(str(BASE / "4.png")),
}
# 模板是 BGRA转 BGR
for name in templates:
img = cv2.imread(str(BASE / name), cv2.IMREAD_UNCHANGED)
if img.shape[2] == 4:
# alpha 通道合成白底
alpha = img[:, :, 3:4] / 255.0
rgb = img[:, :, :3].astype(float)
white = np.ones_like(rgb) * 255
merged = (rgb * alpha + white * (1 - alpha)).astype(np.uint8)
templates[name] = merged
else:
templates[name] = img
model, processor = load_clip()
vis = main_img.copy()
colors = {"2.png": (0, 0, 255), "3.png": (0, 255, 0), "4.png": (255, 0, 0)}
print()
for name, tmpl in templates.items():
print(f"正在匹配 {name} ...")
box, score = find_by_clip(model, processor, main_img, tmpl, step=8)
if box:
x, y, w, h = box
cx, cy = x + w // 2, y + h // 2
color = colors[name]
cv2.rectangle(vis, (x, y), (x+w, y+h), color, 2)
cv2.circle(vis, (cx, cy), 5, color, -1)
cv2.putText(vis, name, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
print(f" {name}: 中心点 ({cx}, {cy}) 相似度={score:.4f}")
else:
print(f" {name}: 未找到")
out = BASE / "result_clip.jpg"
cv2.imwrite(str(out), vis)
print(f"\n结果保存到: {out}")
if __name__ == "__main__":
main()