132 lines
4.3 KiB
Python
132 lines
4.3 KiB
Python
"""
|
||
用 Qwen2-VL 本地模型识别验证码图片位置
|
||
首次运行会自动下载模型(约 4GB)
|
||
"""
|
||
import base64
|
||
from pathlib import Path
|
||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||
from qwen_vl_utils import process_vision_info
|
||
import torch
|
||
|
||
MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct"
|
||
BASE = Path(__file__).parent / "images"
|
||
|
||
def img_to_base64(path):
|
||
with open(path, "rb") as f:
|
||
return base64.b64encode(f.read()).decode()
|
||
|
||
def load_model():
|
||
print("加载模型中(首次运行会下载约15GB)...")
|
||
from transformers import BitsAndBytesConfig
|
||
bnb_config = BitsAndBytesConfig(
|
||
load_in_4bit=True,
|
||
bnb_4bit_compute_dtype=torch.float16,
|
||
bnb_4bit_use_double_quant=True,
|
||
bnb_4bit_quant_type="nf4",
|
||
)
|
||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||
MODEL_NAME,
|
||
quantization_config=bnb_config,
|
||
device_map="cuda",
|
||
)
|
||
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||
print("模型加载完成")
|
||
return model, processor
|
||
|
||
def ask_one(model, processor, main_img_path, template_path):
|
||
"""让模型找出单个模板图在主图中的位置,返回原始回答"""
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": (
|
||
"下面是一张背景大图(300x200像素),"
|
||
"以及一个需要在大图中找到的小图标轮廓。\n"
|
||
"大图:"
|
||
)
|
||
},
|
||
{"type": "image", "image": str(main_img_path)},
|
||
{
|
||
"type": "text",
|
||
"text": "\n小图标轮廓(这个图标出现在大图中某个物体上):"
|
||
},
|
||
{"type": "image", "image": str(template_path)},
|
||
{
|
||
"type": "text",
|
||
"text": (
|
||
"\n请仔细观察小图标的形状,在大图中找到形状最相似的物体,"
|
||
"给出该物体中心点的像素坐标。"
|
||
"坐标原点在左上角,x向右,y向下。"
|
||
"只需回答坐标,格式:(x, y)"
|
||
)
|
||
}
|
||
]
|
||
}
|
||
]
|
||
|
||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||
image_inputs, video_inputs = process_vision_info(messages)
|
||
inputs = processor(
|
||
text=[text],
|
||
images=image_inputs,
|
||
videos=video_inputs,
|
||
padding=True,
|
||
return_tensors="pt",
|
||
).to("cuda")
|
||
|
||
with torch.no_grad():
|
||
generated_ids = model.generate(**inputs, max_new_tokens=50)
|
||
|
||
generated_ids_trimmed = [
|
||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||
]
|
||
output = processor.batch_decode(
|
||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||
)
|
||
return output[0].strip()
|
||
|
||
|
||
def main():
|
||
main_img = BASE / "1.jpg"
|
||
templates = [BASE / "2.png", BASE / "3.png", BASE / "4.png"]
|
||
|
||
model, processor = load_model()
|
||
|
||
import re
|
||
import cv2
|
||
|
||
img = cv2.imread(str(main_img))
|
||
colors = {"2.png": (0, 0, 255), "3.png": (0, 255, 0), "4.png": (255, 0, 0)}
|
||
results = {}
|
||
|
||
for tmpl_path in templates:
|
||
name = tmpl_path.name
|
||
print(f"\n正在识别 {name} ...")
|
||
answer = ask_one(model, processor, main_img, tmpl_path)
|
||
print(f"{name} 模型回答: {answer}")
|
||
|
||
match = re.search(r"\((\d+)[,,\s]+(\d+)\)", answer)
|
||
if match:
|
||
x, y = int(match.group(1)), int(match.group(2))
|
||
results[name] = (x, y)
|
||
color = colors[name]
|
||
cv2.circle(img, (x, y), 8, color, -1)
|
||
cv2.circle(img, (x, y), 12, color, 2)
|
||
cv2.putText(img, name, (x + 14, y + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||
else:
|
||
print(f"{name}: 未能解析坐标,原始回答: {answer}")
|
||
|
||
print("\n=== 点击坐标汇总 ===")
|
||
for name, (x, y) in results.items():
|
||
print(f"{name}: ({x}, {y})")
|
||
|
||
out = BASE / "result_vl.jpg"
|
||
cv2.imwrite(str(out), img)
|
||
print(f"\n可视化结果保存到: {out}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|