haha
This commit is contained in:
131
captcha_vl.py
Normal file
131
captcha_vl.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
用 Qwen2-VL 本地模型识别验证码图片位置
|
||||
首次运行会自动下载模型(约 4GB)
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
BASE = Path(__file__).parent / "images"
|
||||
|
||||
def img_to_base64(path):
|
||||
with open(path, "rb") as f:
|
||||
return base64.b64encode(f.read()).decode()
|
||||
|
||||
def load_model():
|
||||
print("加载模型中(首次运行会下载约15GB)...")
|
||||
from transformers import BitsAndBytesConfig
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_NAME,
|
||||
quantization_config=bnb_config,
|
||||
device_map="cuda",
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||||
print("模型加载完成")
|
||||
return model, processor
|
||||
|
||||
def ask_one(model, processor, main_img_path, template_path):
|
||||
"""让模型找出单个模板图在主图中的位置,返回原始回答"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"下面是一张背景大图(300x200像素),"
|
||||
"以及一个需要在大图中找到的小图标轮廓。\n"
|
||||
"大图:"
|
||||
)
|
||||
},
|
||||
{"type": "image", "image": str(main_img_path)},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "\n小图标轮廓(这个图标出现在大图中某个物体上):"
|
||||
},
|
||||
{"type": "image", "image": str(template_path)},
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"\n请仔细观察小图标的形状,在大图中找到形状最相似的物体,"
|
||||
"给出该物体中心点的像素坐标。"
|
||||
"坐标原点在左上角,x向右,y向下。"
|
||||
"只需回答坐标,格式:(x, y)"
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=50)
|
||||
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
return output[0].strip()
|
||||
|
||||
|
||||
def main():
|
||||
main_img = BASE / "1.jpg"
|
||||
templates = [BASE / "2.png", BASE / "3.png", BASE / "4.png"]
|
||||
|
||||
model, processor = load_model()
|
||||
|
||||
import re
|
||||
import cv2
|
||||
|
||||
img = cv2.imread(str(main_img))
|
||||
colors = {"2.png": (0, 0, 255), "3.png": (0, 255, 0), "4.png": (255, 0, 0)}
|
||||
results = {}
|
||||
|
||||
for tmpl_path in templates:
|
||||
name = tmpl_path.name
|
||||
print(f"\n正在识别 {name} ...")
|
||||
answer = ask_one(model, processor, main_img, tmpl_path)
|
||||
print(f"{name} 模型回答: {answer}")
|
||||
|
||||
match = re.search(r"\((\d+)[,,\s]+(\d+)\)", answer)
|
||||
if match:
|
||||
x, y = int(match.group(1)), int(match.group(2))
|
||||
results[name] = (x, y)
|
||||
color = colors[name]
|
||||
cv2.circle(img, (x, y), 8, color, -1)
|
||||
cv2.circle(img, (x, y), 12, color, 2)
|
||||
cv2.putText(img, name, (x + 14, y + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||
else:
|
||||
print(f"{name}: 未能解析坐标,原始回答: {answer}")
|
||||
|
||||
print("\n=== 点击坐标汇总 ===")
|
||||
for name, (x, y) in results.items():
|
||||
print(f"{name}: ({x}, {y})")
|
||||
|
||||
out = BASE / "result_vl.jpg"
|
||||
cv2.imwrite(str(out), img)
|
||||
print(f"\n可视化结果保存到: {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user