114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
"""
|
||
用 CLIP 做图块相似度匹配,在大图中找模板图标位置
|
||
原理:把大图切成滑动窗口小块,用 CLIP 计算每块和模板的视觉相似度,取最高分的块
|
||
"""
|
||
import cv2
|
||
import numpy as np
|
||
from PIL import Image
|
||
import torch
|
||
from transformers import CLIPProcessor, CLIPModel
|
||
from pathlib import Path
|
||
|
||
BASE = Path(__file__).parent / "images"
|
||
MODEL_NAME = "openai/clip-vit-base-patch32" # 约600MB,小模型
|
||
|
||
|
||
def load_clip():
|
||
print("加载 CLIP 模型(约600MB,首次自动下载)...")
|
||
model = CLIPModel.from_pretrained(MODEL_NAME).to("cuda")
|
||
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
|
||
print("CLIP 加载完成")
|
||
return model, processor
|
||
|
||
|
||
def find_by_clip(model, processor, main_img: np.ndarray, template_img: np.ndarray,
|
||
step=10, win_sizes=None):
|
||
"""
|
||
滑动窗口 + CLIP 相似度,找模板在大图中的最佳位置
|
||
"""
|
||
if win_sizes is None:
|
||
th, tw = template_img.shape[:2]
|
||
# 尝试原始尺寸及上下浮动
|
||
win_sizes = [(int(tw * s), int(th * s)) for s in [0.8, 0.9, 1.0, 1.1, 1.2]]
|
||
|
||
# 预处理模板
|
||
tmpl_pil = Image.fromarray(cv2.cvtColor(template_img, cv2.COLOR_BGR2RGB))
|
||
tmpl_inputs = processor(images=tmpl_pil, return_tensors="pt").to("cuda")
|
||
with torch.no_grad():
|
||
tmpl_out = model.vision_model(**tmpl_inputs)
|
||
tmpl_feat = model.visual_projection(tmpl_out.pooler_output).float()
|
||
tmpl_feat = tmpl_feat / tmpl_feat.norm(dim=-1, keepdim=True)
|
||
|
||
mh, mw = main_img.shape[:2]
|
||
best_score = -1
|
||
best_box = None
|
||
|
||
for (ww, wh) in win_sizes:
|
||
if ww > mw or wh > mh:
|
||
continue
|
||
for y in range(0, mh - wh + 1, step):
|
||
for x in range(0, mw - ww + 1, step):
|
||
crop = main_img[y:y+wh, x:x+ww]
|
||
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
|
||
crop_inputs = processor(images=crop_pil, return_tensors="pt").to("cuda")
|
||
with torch.no_grad():
|
||
crop_out = model.vision_model(**crop_inputs)
|
||
crop_feat = model.visual_projection(crop_out.pooler_output).float()
|
||
crop_feat = crop_feat / crop_feat.norm(dim=-1, keepdim=True)
|
||
score = (tmpl_feat * crop_feat).sum().item()
|
||
if score > best_score:
|
||
best_score = score
|
||
best_box = (x, y, ww, wh)
|
||
|
||
return best_box, best_score
|
||
|
||
|
||
def main():
|
||
main_img = cv2.imread(str(BASE / "1.jpg"))
|
||
templates = {
|
||
"2.png": cv2.imread(str(BASE / "2.png")),
|
||
"3.png": cv2.imread(str(BASE / "3.png")),
|
||
"4.png": cv2.imread(str(BASE / "4.png")),
|
||
}
|
||
|
||
# 模板是 BGRA,转 BGR
|
||
for name in templates:
|
||
img = cv2.imread(str(BASE / name), cv2.IMREAD_UNCHANGED)
|
||
if img.shape[2] == 4:
|
||
# alpha 通道合成白底
|
||
alpha = img[:, :, 3:4] / 255.0
|
||
rgb = img[:, :, :3].astype(float)
|
||
white = np.ones_like(rgb) * 255
|
||
merged = (rgb * alpha + white * (1 - alpha)).astype(np.uint8)
|
||
templates[name] = merged
|
||
else:
|
||
templates[name] = img
|
||
|
||
model, processor = load_clip()
|
||
|
||
vis = main_img.copy()
|
||
colors = {"2.png": (0, 0, 255), "3.png": (0, 255, 0), "4.png": (255, 0, 0)}
|
||
|
||
print()
|
||
for name, tmpl in templates.items():
|
||
print(f"正在匹配 {name} ...")
|
||
box, score = find_by_clip(model, processor, main_img, tmpl, step=8)
|
||
if box:
|
||
x, y, w, h = box
|
||
cx, cy = x + w // 2, y + h // 2
|
||
color = colors[name]
|
||
cv2.rectangle(vis, (x, y), (x+w, y+h), color, 2)
|
||
cv2.circle(vis, (cx, cy), 5, color, -1)
|
||
cv2.putText(vis, name, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||
print(f" {name}: 中心点 ({cx}, {cy}) 相似度={score:.4f}")
|
||
else:
|
||
print(f" {name}: 未找到")
|
||
|
||
out = BASE / "result_clip.jpg"
|
||
cv2.imwrite(str(out), vis)
|
||
print(f"\n结果保存到: {out}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|