Files
codex_jxs_code/captcha_vl.py
2026-03-02 10:55:39 +08:00

132 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用 Qwen2-VL 本地模型识别验证码图片位置
首次运行会自动下载模型(约 4GB
"""
import base64
from pathlib import Path
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct"
BASE = Path(__file__).parent / "images"
def img_to_base64(path):
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode()
def load_model():
print("加载模型中首次运行会下载约15GB...")
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="cuda",
)
processor = AutoProcessor.from_pretrained(MODEL_NAME)
print("模型加载完成")
return model, processor
def ask_one(model, processor, main_img_path, template_path):
"""让模型找出单个模板图在主图中的位置,返回原始回答"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": (
"下面是一张背景大图300x200像素"
"以及一个需要在大图中找到的小图标轮廓。\n"
"大图:"
)
},
{"type": "image", "image": str(main_img_path)},
{
"type": "text",
"text": "\n小图标轮廓(这个图标出现在大图中某个物体上):"
},
{"type": "image", "image": str(template_path)},
{
"type": "text",
"text": (
"\n请仔细观察小图标的形状,在大图中找到形状最相似的物体,"
"给出该物体中心点的像素坐标。"
"坐标原点在左上角x向右y向下。"
"只需回答坐标,格式:(x, y)"
)
}
]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=50)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output[0].strip()
def main():
main_img = BASE / "1.jpg"
templates = [BASE / "2.png", BASE / "3.png", BASE / "4.png"]
model, processor = load_model()
import re
import cv2
img = cv2.imread(str(main_img))
colors = {"2.png": (0, 0, 255), "3.png": (0, 255, 0), "4.png": (255, 0, 0)}
results = {}
for tmpl_path in templates:
name = tmpl_path.name
print(f"\n正在识别 {name} ...")
answer = ask_one(model, processor, main_img, tmpl_path)
print(f"{name} 模型回答: {answer}")
match = re.search(r"\((\d+)[,\s]+(\d+)\)", answer)
if match:
x, y = int(match.group(1)), int(match.group(2))
results[name] = (x, y)
color = colors[name]
cv2.circle(img, (x, y), 8, color, -1)
cv2.circle(img, (x, y), 12, color, 2)
cv2.putText(img, name, (x + 14, y + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
else:
print(f"{name}: 未能解析坐标,原始回答: {answer}")
print("\n=== 点击坐标汇总 ===")
for name, (x, y) in results.items():
print(f"{name}: ({x}, {y})")
out = BASE / "result_vl.jpg"
cv2.imwrite(str(out), img)
print(f"\n可视化结果保存到: {out}")
if __name__ == "__main__":
main()