114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
|
|
"""
|
|||
|
|
用 CLIP 做图块相似度匹配,在大图中找模板图标位置
|
|||
|
|
原理:把大图切成滑动窗口小块,用 CLIP 计算每块和模板的视觉相似度,取最高分的块
|
|||
|
|
"""
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
from PIL import Image
|
|||
|
|
import torch
|
|||
|
|
from transformers import CLIPProcessor, CLIPModel
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
BASE = Path(__file__).parent / "images"
|
|||
|
|
MODEL_NAME = "openai/clip-vit-base-patch32" # 约600MB,小模型
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_clip():
|
|||
|
|
print("加载 CLIP 模型(约600MB,首次自动下载)...")
|
|||
|
|
model = CLIPModel.from_pretrained(MODEL_NAME).to("cuda")
|
|||
|
|
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
|
|||
|
|
print("CLIP 加载完成")
|
|||
|
|
return model, processor
|
|||
|
|
|
|||
|
|
|
|||
|
|
def find_by_clip(model, processor, main_img: np.ndarray, template_img: np.ndarray,
|
|||
|
|
step=10, win_sizes=None):
|
|||
|
|
"""
|
|||
|
|
滑动窗口 + CLIP 相似度,找模板在大图中的最佳位置
|
|||
|
|
"""
|
|||
|
|
if win_sizes is None:
|
|||
|
|
th, tw = template_img.shape[:2]
|
|||
|
|
# 尝试原始尺寸及上下浮动
|
|||
|
|
win_sizes = [(int(tw * s), int(th * s)) for s in [0.8, 0.9, 1.0, 1.1, 1.2]]
|
|||
|
|
|
|||
|
|
# 预处理模板
|
|||
|
|
tmpl_pil = Image.fromarray(cv2.cvtColor(template_img, cv2.COLOR_BGR2RGB))
|
|||
|
|
tmpl_inputs = processor(images=tmpl_pil, return_tensors="pt").to("cuda")
|
|||
|
|
with torch.no_grad():
|
|||
|
|
tmpl_out = model.vision_model(**tmpl_inputs)
|
|||
|
|
tmpl_feat = model.visual_projection(tmpl_out.pooler_output).float()
|
|||
|
|
tmpl_feat = tmpl_feat / tmpl_feat.norm(dim=-1, keepdim=True)
|
|||
|
|
|
|||
|
|
mh, mw = main_img.shape[:2]
|
|||
|
|
best_score = -1
|
|||
|
|
best_box = None
|
|||
|
|
|
|||
|
|
for (ww, wh) in win_sizes:
|
|||
|
|
if ww > mw or wh > mh:
|
|||
|
|
continue
|
|||
|
|
for y in range(0, mh - wh + 1, step):
|
|||
|
|
for x in range(0, mw - ww + 1, step):
|
|||
|
|
crop = main_img[y:y+wh, x:x+ww]
|
|||
|
|
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
|
|||
|
|
crop_inputs = processor(images=crop_pil, return_tensors="pt").to("cuda")
|
|||
|
|
with torch.no_grad():
|
|||
|
|
crop_out = model.vision_model(**crop_inputs)
|
|||
|
|
crop_feat = model.visual_projection(crop_out.pooler_output).float()
|
|||
|
|
crop_feat = crop_feat / crop_feat.norm(dim=-1, keepdim=True)
|
|||
|
|
score = (tmpl_feat * crop_feat).sum().item()
|
|||
|
|
if score > best_score:
|
|||
|
|
best_score = score
|
|||
|
|
best_box = (x, y, ww, wh)
|
|||
|
|
|
|||
|
|
return best_box, best_score
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
main_img = cv2.imread(str(BASE / "1.jpg"))
|
|||
|
|
templates = {
|
|||
|
|
"2.png": cv2.imread(str(BASE / "2.png")),
|
|||
|
|
"3.png": cv2.imread(str(BASE / "3.png")),
|
|||
|
|
"4.png": cv2.imread(str(BASE / "4.png")),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 模板是 BGRA,转 BGR
|
|||
|
|
for name in templates:
|
|||
|
|
img = cv2.imread(str(BASE / name), cv2.IMREAD_UNCHANGED)
|
|||
|
|
if img.shape[2] == 4:
|
|||
|
|
# alpha 通道合成白底
|
|||
|
|
alpha = img[:, :, 3:4] / 255.0
|
|||
|
|
rgb = img[:, :, :3].astype(float)
|
|||
|
|
white = np.ones_like(rgb) * 255
|
|||
|
|
merged = (rgb * alpha + white * (1 - alpha)).astype(np.uint8)
|
|||
|
|
templates[name] = merged
|
|||
|
|
else:
|
|||
|
|
templates[name] = img
|
|||
|
|
|
|||
|
|
model, processor = load_clip()
|
|||
|
|
|
|||
|
|
vis = main_img.copy()
|
|||
|
|
colors = {"2.png": (0, 0, 255), "3.png": (0, 255, 0), "4.png": (255, 0, 0)}
|
|||
|
|
|
|||
|
|
print()
|
|||
|
|
for name, tmpl in templates.items():
|
|||
|
|
print(f"正在匹配 {name} ...")
|
|||
|
|
box, score = find_by_clip(model, processor, main_img, tmpl, step=8)
|
|||
|
|
if box:
|
|||
|
|
x, y, w, h = box
|
|||
|
|
cx, cy = x + w // 2, y + h // 2
|
|||
|
|
color = colors[name]
|
|||
|
|
cv2.rectangle(vis, (x, y), (x+w, y+h), color, 2)
|
|||
|
|
cv2.circle(vis, (cx, cy), 5, color, -1)
|
|||
|
|
cv2.putText(vis, name, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
|||
|
|
print(f" {name}: 中心点 ({cx}, {cy}) 相似度={score:.4f}")
|
|||
|
|
else:
|
|||
|
|
print(f" {name}: 未找到")
|
|||
|
|
|
|||
|
|
out = BASE / "result_clip.jpg"
|
|||
|
|
cv2.imwrite(str(out), vis)
|
|||
|
|
print(f"\n结果保存到: {out}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|