Python结合FastSAM实现图像自动标注的完整指南
目录
- 概述
- 环境准备
- 项目结构
- 核心代码解析
- 1. 主类FastSAMAutoLabeler
- 2. 图像处理流程
- 3. 掩码处理与边界框提取
- 4. 对象分类
- 辅助工具类
- 1. 标注可视化器
- 2. 标注验证器
- 使用方法
- 命令行参数
- 批量处理
- 扩展功能
- 1. 自定义过滤规则
- 2. 多类别支持
- 完整代码
- 总结与展望
在计算机视觉领域,数据标注是模型训练的基础,但手动标注耗时耗力。本文将介绍一个基于python的自动标注工具,它结合了FastSAM(快速分割任何东西模型)和YOLO分类模型,能够高效地生成高质量的标注数据。
概述
FastSAM是SAM(Segment Anything Model)的加速版本,能够在保持较高精度的同时大幅提升处理速度。我们的自动标注工具利用FastSAM进行对象检测和分割,再通过YOLO模型对检测到的对象进行分类,最终输出YOLO格式的标注文件。
本文将详细解析代码结构、实现原理和使用方法,帮助读者快速掌握图像自动标注的核心技术。
环境准备
在开始之前,需要安装必要的Python库。建议使用Python 3.7或更高版本。
pip install torch torchvision pip install opencv-python pillow pip install ultralytics pip install numpy tqdm pathlib argparse
确保已下载FastSAM模型文件(如FastSAM-x.pt)和YOLO分类模型(如yolov8n-cls.pt)。
项目结构
fastsam_autolabeler/
├── autolabeler.py # 主程序文件├── models/ # 模型目录│ ├── FastSAM-x.pt│ └── yolov8n-cls.pt├── images/ # 输入图片目录├── dataset/ # 输出标注数据│ ├── images/│ ├── labels/│ └── visualization/└── README.md
核心代码解析
1. 主类FastSAMAutoLabeler
FastSAMAutoLabeler类是自动标注工具的核心,负责协调整个标注流程。
class FastSAMAutoLabeler:
def __init__(self, model_path='FastSAM-x.pt', device='cuda' if torch.cuda.is_available() else 'cpu', classification_model='yolov8n-cls.pt'):
self.device = device
self.model = FastSAM(model_path)
self.classification_model = YOLO(classification_model)
初始化过程会加载两个模型:FastSAM用于对象检测和分割,YOLO用于对象分类。代码自动检测可用的计算设备,优先使用GPU加速处理。
2. 图像处理流程
process_image方法是主要的处理管道,包含以下步骤:
def process_image(self, image_path, output_dir, conf=0.4, iou=0.9, min_area_ratio=0.001, max_area_ratio=0.95):
# 读取图片
image = cv2.imread(image_path)
height, width = image.shape[:2]
# FastSAM推理
everything_results = self.model(image_path, device=self.device, retina_masks=True, imgsz=1024, conf=conf, iou=iou)
# 处理掩码数据
detections = self._process_masks_manually(ann, everything_results, width, height)
# 过滤检测结果
filtered_detections = self._filter_detections(detections, image_area, min_area_ratio, max_area_ratio)
# 对象分类
classified_detections = self._classify_objects(image, filtered_detections)
# 生成标注
return self._generate_annotations(image, classified_detections, output_dir, Path(image_path).stem)
此方法完整实现了从图像读取到标注生成的整个流程,每个步骤都设计了适当的错误处理机制。
3. 掩码处理与边界框提取
_process_masks_manually方法将FastSAM输出的分割掩码转换为边界框:
def _process_masks_manually(self, ann, everything_results, img_width, img_height):
masks_np = ann.cpu().numpy()
boxes = []
for i in range(num_masks):
mask = masks_np[i]
y_indices, x_indices = np.where(mask > 0.5) # 阈值处理
# 计算边界框
x1 = np.min(x_indices)
y1 = np.min(y_indices)
x2 = np.max(x_indices)
y2 = np.max(y_indices)
boxes.append([x1, y1, x2, y2])
这种方法不依赖额外的计算机视觉库,实现了自包含的掩码处理功能。
4. 对象分类
_classify_objects方法对每个检测到的对象进行分类:
def _classify_objects(self, image, detections):
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = map(int, bbox)
object_image = image[y1:y2, x1:x2]
object_image_resized = cv2.resize(object_image, (224, 224))
# 使用YOLO分类模型
results = self.classification_model(object_image_resized)
top1 = results[0].probs.top1
top1conf = results[0].probs.top1conf.item()
通过结合实例分割和分类模型,工具能够准确识别和分类图像中的各个对象。
辅助工具类
1. 标注可视化器
ManualAnnotationVisualizer类提供标注结果的可视化功能:
class ManualAnnotationVisualizer:
def draw_annotations(self, image_path, label_path=None, detections=None, output_path=None):
# 绘制边界框和标签
for i, bbox in enumerate(detections['boxes']):
color = self.colors[i % len(self.colors)]
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
可视化结果使用不同颜色区分各个检测对象,并显示类别标签和置信度。
2. 标注验证器
AnnotationValidator类检查标注文件的质量:
class AnnotationValidator:
def validate_annotations(self, image_path, label_path):
# 检查数值范围、边界框有效性、重叠等
issues = []
for i, line in enumerate(lines):
# 验证每个标注行的格式和数值
if not (0 <= x_center <= 1):
issues.append(f"第{i+1}行x_center超出范围 [0,1]: {x_center}")
验证器帮助用户发现标注中的问题,确保生成的数据集质量。
使用方法
命令行参数
工具支持丰富的命令行参数,满足不同场景的需求:
python autolabeler.py \
--input images/ \
--output dataset/ \
--model models/FastSAM-x.pt \
--classification-model models/yolov8n-cls.pt \
--conf 0.4 \
--iou 0.9 \
--visualize \
--validate
主要参数包括:
--input: 输入图片路径(文件或目录)--output: 输出目录--conf: 检测置信度阈值--iou: 非极大值抑制IOU阈值--visualize: 生成可视化结果--validate: 验证标注质量
批量处理
工具支持单张图片和批量处理模式。当输入为目录时,会自动遍历所有支持格式的图片文件:
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
for img_path in tqdm(image_files, desc="处理图片"):
result = labeler.process_image(str(img_path), args.output, args.conf, args.iou)
批量处理时显示进度条,方便用户了解处理进度。
扩展功能
1. 自定义过滤规则
用户可以调整检测结果的过滤条件,如基于对象面积的比例:
min_area_ratio=0.001 # 最小面积比例(相对于图像面积) max_area_ratio=0.95 # 最大面积比例
这有助于过滤掉过小或过大的检测结果,提高标注质量。
2. 多类别支持
通过替换YOLO分类模型,工具可以适应不同的领域和类别需求。例如,使用针对特定场景训练的专用分类器。
完整代码
import torch
import cv2
import numpy as np
from PIL import Image
from ultralytics import FastSAM, YOLO
import os
import json
from pathlib import Path
import argparse
import glob
from tqdm import tqdm
import math
class FastSAMAutoLabeler:
def __init__(self, model_path='FastSAM-x.pt', device='cuda' if torch.cuda.is_available() else 'cpu', classification_model='yolov8n-cls.pt'):
"""
FastSAM自动标注工具(不依赖supervision)
"""
self.device = device
self.model = FastSAM(model_path)
# 加载分类模型
self.classification_model = YOLO(classification_model)
print(f"模型加载完成,使用设备: {device}")
def process_image(self, image_path, output_dir, conf=0.4, iou=0.9,
min_area_ratio=0.001, max_area_ratio=0.95):
"""
处理单张图片并生成标注
"""
try:
# 读取图片
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法读取图片: {image_path}")
height, width = image.shape[:2]
image_area = height * width
print(f"图片尺寸: {width}x{height}, 面积: {image_area}")
# 使用FastSAM进行推理
everything_results = self.model(
image_path,
device=self.device,
retina_masks=True,
imgsz=1024,
conf=conf,
iou=iou
)
# 检查是否有检测结果
if not everything_results or len(everything_results) == 0:
print("警告: 未获得任何检测结果")
return self._create_empty_result(image, output_dir, Path(image_path).stem)
# 获取掩码数据
masks_data = everything_results[0].masks
if masks_data is None:
print("警告: 未检测到任何掩码")
return self._create_empty_result(image, output_dir, Path(image_path).stem)
ann = masks_data.data
print(f"原始掩码形状: {ann.shape}")
# 处理掩码维度
if len(ann.shape) == 2:
ann = ann.unsqueeze(0)
elif len(ann.shape) != 3:
raise ValueError(f"不支持的掩码形状: {ann.shape}")
# 手动处理检测结果
detections = self._process_masks_manually(ann, everything_results, width, height)
# 过滤检测结果
filtered_detections = self._filter_detections(
detections, image_area, min_area_ratio, max_area_ratio
)
# 对每个检测到的对象进行分类
classified_detections = self._classify_objects(image, filtered_detections)
return self._generate_annotations(image, classified_detections, output_dir, Path(image_path).stem)
except Exception as e:
print(f"处理图片时发生错误: {e}")
import traceback
traceback.print_exc()
return self._create_error_result(image, output_dir, Path(image_path).stem, str(e))
def _process_masks_manually(self, ann, everything_results, img_width, img_height):
"""
手动处理掩码数据,替代supervision的功能
"""
# 将张量转换为numpy数组
if isinstance(ann, torch.Tensor):
masks_np = ann.cpu().numpy()
else:
masks_np = ann
print(f"掩码numpy数组形状: {masks_np.shape}")
# 确保是三维的 [N, H, W]
if len(masks_np.shape) == 2:
masks_np = np.expand_dims(masks_np, axis=0)
num_masks = masks_np.shape[0]
print(f"检测到 {num_masks} 个掩码")
if num_masks == 0:
return self._create_empty_detections()
# 为每个掩码计算边界框和相关信息
boxes = []
confidences = []
class_ids = []
masks = []
for i in range(num_masks):
mask = masks_np[i]
# 找到掩码中为True的像素位置
y_indices, x_indices = np.where(mask > 0.5) # 阈值处理
if len(x_indices) == 0 or len(y_indices) == 0:
continue
# 计算边界框
x1 = np.min(x_indices)
y1 = np.min(y_indices)
x2 = np.max(x_indices)
y2 = np.max(y_indices)
# 计算面积和置信度(使用掩码面积作为置信度参考)
bbox_area = (x2 - x1) * (y2 - y1)
mask_area = len(x_indices)
confidence = min(mask_area / bbox_area, 1.0) if bbox_area > 0 else 0
boxes.append([x1, y1, x2, y2])
confidences.append(confidence)
class_ids.append(0) # 默认类别ID
masks.append(mask)
if not boxes:
return self._create_empty_detections()
return {
'boxes': np.array(boxes),
'confidences': np.array(confidences),
'class_ids': np.array(class_ids),
'masks': np.array(masks)
}
def _filter_detections(self, detections, image_area, min_area_ratio, max_area_ratio):
"""根据面积过滤检测结果"""
if len(detections['boxes']) == 0:
return detections
filtered_boxes = []
filtered_confidences = []
filtered_class_ids = []
filtered_masks = []
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = bbox
area = (x2 - x1) * (y2 - y1)
area_ratio = area / image_area
if min_area_ratio <= area_ratio <= max_area_ratio:
filtered_boxes.append(bbox)
filtered_confidences.append(detections['confidences'][i])
filtered_class_ids.append(detections['class_ids'][i])
if i < len(detections['masks']):
filtered_masks.append(detections['masks'][i])
filtered_detections = {
'boxes': np.array(filtered_boxes) if filtered_boxes else np.empty((0, 4)),
'confidences': np.array(filtered_confidences) if filtered_confidences else np.empty(0),
'class_ids': np.array(filtered_class_ids) if filtered_class_ids else np.empty(0),
'masks': np.array(filtered_masks) if filtered_masks else np.empty(0)
}
print(f"过滤后保留 {len(filtered_boxes)} 个检测结果")
return filtered_detections
def _classify_objects(self, image, detections):
"""
对检测到的对象进行分类
"""
if len(detections['boxes']) == 0:
return detections
classified_class_ids = []
classified_confidences = []
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = map(int, bbox)
# 提取对象区域
object_image = image[y1:y2, x1:x2]
if object_image.size == 0:
classified_class_ids.append(0)
classified_confidences.append(detections['confidences'][i])
continue
# 调整图像大小以适应分类模型
object_image_resized = cv2.resize(object_image, (224, 224))
# 使用分类模型进行预测
try:
results = self.classification_model(object_image_resized)
# 获取最高置信度的类别
top1 = results[0].probs.top1
top1conf = results[0].probs.top1conf.item()
classified_class_ids.append(top1)
classified_confidences.append(top1conf)
except Exception as e:
print(f"分类时出错: {e}")
# 如果分类失败,保持原始类别
classified_class_ids.append(detections['class_ids'][i])
classified_confidences.append(detections['confidences'][i])
# 更新检测结果
detections['class_ids'] = np.array(classified_class_ids)
detections['confidences'] = np.array(classified_confidences)
return detections
def _generate_annotations(self, image, detections, output_dir, image_name):
"""生成YOLO格式标注文件"""
height, width = image.shape[:2]
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'labels'), exist_ok=True)
# 保存图片
image_output_path = os.path.join(output_dir, 'images', f'{image_name}.jpg')
cv2.imwrite(image_output_path, image)
# 生成YOLO格式标注
yolo_annotations = []
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = bbox
# 转换为YOLO格式 (中心点坐标和宽高,归一化)
x_center = ((x1 + x2) / 2) / width
y_center = ((y1 + y2) / 2) / height
w = (x2 - x1) / width
h = (y2 - y1) / height
# 边界检查
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
w = max(0, min(1, w))
h = max(0, min(1, h))
# 如果宽高太小则跳过
if w < 0.001 or h < 0.001:
continue
# 获取类别ID和置信度
class_id = int(detections['class_ids'][i]) if i < len(detections['class_ids']) else 0
confidence = detections['confidences'][i] if i < len(detections['confidences']) else 1.0
yolo_annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}")
# 保存YOLO标签
label_path = os.path.join(output_dir, 'labels', f'{image_name}.txt')
with open(label_path, 'w') as f:
f.write('\n'.join(yolo_annotations))
return {
'image_path': image_output_path,
'label_path': label_path,
'detections_count': len(yolo_annotations),
'image_name': image_name
}
def _create_empty_detections(self):
"""创建空的检测结果"""
return {
'boxes': np.empty((0, 4)),
'confidences': np.empty(0),
'class_ids': np.empty(0),
'masks': np.empty(0)
}
def _create_empty_result(self, image, output_dir, image_name):
"""创建空结果"""
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'labels'), exist_ok=True)
image_output_path = os.path.join(output_dir, 'images', f'{image_name}.jpg')
cv2.imwrite(image_output_path, image)
label_path = os.path.join(output_dir, 'labels', f'{image_name}.txt')
with open(label_path, 'w') as f:
pass
return {
'image_path': image_output_path,
'label_path': label_path,
编程客栈'detections_count': 0,
'image_name': image_name
}
def _create_error_result(self, image, output_dir, image_name, error_msg):
"""创建错误结果"""
print(f"为图片 {image_name} 创建错误结果: {error_msg}")
return self._create_empty_result(image, output_dir, image_name)
class ManualAnnotationVisualizer:
"""
手动实现的标注可视化工具(不依赖supervision)
"""
def __init__(self, class_names=None, colors=None):
self.class_names = class_names or ['object']
self.colors = colors or self._generate_default_colors()
def _generate_default_colors(self):
"""生成默认颜色列表"""
return [
(255, 0, 0), # 红色
(0, 255, 0), # 绿色
(0, 0, 255), # 蓝色
(255, 255, 0), # 青色
(255, 0, 255), # 紫色
(0, 255, 255), # 黄色
(255, 165, 0), # 橙色
(128, 0, 128), # 紫色
(255, 192, 203), # 粉色
(165, 42, 42), # 棕色
]
def draw_annotations(self, image_path, label_path=None, detections=None,
output_path=None, show_labels=True, show_confidences=True):
"""
绘制标注结果
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法读取图片: {image_path}")
height, width = image.shape[:2]
if detections is None and label_path is not None:
# 从YOLO标签文件读取检测结果
detections = self._read_yolo_labels(label_path, width, height)
elif detections is None:
raise ValueError("必须提供label_path或detections参数")
# 绘制边界框和标签
annotated_image = self._draw_bounding_boxes(image, detections, show_labels, show_confidences)
if output_path:
cv2.imwrite(output_path, annotated_image)
print(f"可视化结果已保存: {output_path}")
return annotated_image
def _draw_bounding_boxes(self, image, detections, show_labels, show_confidences):
"""绘制边界框和标签"""
annotated_image = image.copy()
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = map(int, bbox)
# 选择颜色
color = self.colors[i % len(self.colors)]
# 绘制边界框
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
if show_labels:
# 准备标签文本
class_id = int(detections['class_ids'][i]) if i < len(detections['class_ids']) else 0
class_name = self.class_names[class_id] if class_id < len(self.class_names) else f'class_{class_id}'
label = class_name
if show_confidences and i < len(detections['confidences']):
confidence = detections['confidences'][i]
label += f" {confidence:.2f}"
# 绘制标签背景
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(annotated_image,
(x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1),
color, -1)
# 绘制标签文本
cv2.putText(annotated_image, label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
return annotated_image
def _read_yolo_labels(self, label_path, img_width, img_height):
"""读取YOLO格式标签并转换为检测结果格式"""
boxes = []
class_ids = []
confidences = []
if not os.path.exists(label_path):
return self._create_empty_detections()
with open(label_path, 'r') as f:
for label_line in f:
parts = label_line.strip().split()
if len(parts) < 5:
continue
class_id = int(parts[0])
x_center, y_center, width, height = map(float, parts[1:5])
# 转换为绝对坐标
x_center_abs = x_center * img_width
y_center_abs = y_center * img_height
width_abs = width * img_width
height_abs = height * img_height
x1 = max(0, x_center_abs - width_abs / 2)
y1 = max(0, y_center_abs - height_abs / 2)
x2 = min(img_width, x_center_abs + width_abs / 2)
y2 = min(img_height, y_center_abs + height_abs / 2)
boxes.append([x1, y1, x2, y2])
class_ids.append(class_id)
confidences.append(1.0) # YOLO格式没有置信度,设为1.0
return {
'boxes': np.array(boxes) if boxes else np.empty((0, 4)),
'confidences': np.array(confidences) if confidences else np.empty(0),
'class_ids': np.array(class_ids) if class_ids else np.empty(0),
'masks': np.empty(0)
}
def _create_empty_detections(self):
"""创建空的检测结果"""
return {
'boxes': np.empty((0, 4)),
'confidences': np.empty(0),
'class_ids': np.empty(0),
'masks': np.empty(0)
}
class AnnotationValidator:
"""
标注验证器,用于检查标注质量和提供统计信息
"""
def __init__(self, class_names=None):
self.class_names = class_names or ['object']
def validate_annotations(self, image_path, label_path):
"""
验证标注文件的质量
"""
# 读取图像
image = cv2.imread(image_path)
if image is None:
return {"error": f"无法读取图像: {image_path}"}
height, width = image.shape[:2]
# 读取标注文件
if not os.path.exists(label_path):
return {"error": f"标注文件不存在: {label_path}"}
with open(label_path, 'r') as f:
lines = f.readlines()
if not lines:
return {"warning": "标注文件为空"}
issues = []
class_counts = {}
boxes = []
for i, line in enumerate(lines):
parts = line.strip().split()
if len(parts) < 5:
issues.append(f"第{i+1}行格式错误: 需要至少5个值,实际得到{len(parts)}个")
continue
try:
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
w = float(parts[3])
h = float(parts[4])
# 检查数值范围
if not (0 <= x_center <= 1):
issues.append(f"第{i+1}行x_center超出范围 [0,1]: {x_center}")
if not (0 <= y_center <= 1):
issues.append(f"第{i+1}行y_center超出范围 [0,1]: {y_center}")
if not (0 <= w <= 1):
issues.append(f"第{i+1}行width超出范围 [0,1]: {w}")
if not (0 <= h <= 1):
issues.append(f"第{i+1}行height超出范围 [0,1]: {h}")
编程客栈
# 检查边界框是否有效
if w <= 0 or h <= 0:
issues.append(f"第{i+1}行边界框尺寸无效: width={w}, height={h}")
# 统计类别
class_counts[class_id] = class_counts.get(class_id, 0) + 1
# 转换为像素坐标用于重叠检查
x1 = max(0, (x_center - w/2) * width)
y1 = max(0, (y_center - h/2) * height)
x2 = min(width, (x_center + w/2) * width)
y2 = min(height, (y_center + h/2) * height)
boxes.append((x1, y1, x2, y2))
except ValueError as e:
issues.append(f"第{i+1}行数值转换错误: {str(e)}")
# 检查重叠的边界框
overlapping_boxes = self._check_overlapping_boxes(boxes)
if overlapping_boxes:
issues.append(f"发现{len(overlapping_boxes)}对重叠的边界框")
# 生成报告
report = {
"total_objects": len(lines),
"class_distribution": class_counts,
"issues": issues,
"image_size": (width, height)
}
if class_counts:
# 添加类别名称映射
class_names_mapping = {}
for class_id in class_counts:
if class_id < len(self.class_names):
class_names_mapping[class_id] = self.class_names[class_id]
else:
class_names_mapping[class_id] = f"未知类别_{class_id}"
report["class_names"] = class_names_mapping
return report
def _check_overlapping_boxes(self, boxes, overlap_threshold=0.5):
"""
检查重叠的边界框
"""
overlapping = []
for i in range(len(boxes)):
for j in range(i+1, len(boxes)):
x1_a, y1_a, x2_a, y2_a = boxes[i]
x1_b, y1_b, x2_b, y2_b = bopythonxes[j]
# 计算交集
x_left = max(x1_a, x1_b)
y_top = max(y1_a, y1_b)
x_right = min(x2_a, x2_b)
y_bottom = min(y2_a, y2_b)
if x_right > x_left and y_bottom > y_top:
# 计算交集面积
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# 计算两个框的面积
area_a = (x2_a - x1_a) * (y2_a - y1_a)
area_b = (x2_b - x1_b) * (y2_b - y1_b)
# 计算重叠率
overlap = intersection_area / min(area_a, area_b)
if overlap > overlap_threshold:
overlapping.append((i, j, overlap))
return overlapping
def main():
"""主函数示例"""
parser = argparse.ArgumentParser(description='FastSAM自动标注工具(无依赖版)')
parser.add_argument('--input', type=str,default="images", help='输入图片目录或文件路径')
parser.add_argument('--output', type=str, default="dataset", help='输出目录')
parser.add_argument('--model', type=str, default='FastSAM-x.pt', help='FastSAM模型路径')
编程客栈 parser.add_argument('--classification-model', type=str, default='yolov8x-cls.pt', help='分类模型路径')
parser.add_argument('--conf', type=float, default=0.4, help='置信度阈值')
parser.add_argument('--iou', type=float, default=0.9, help='IOU阈值')
parser.add_argument('--min-area', type=float, default=0.001, help='最小面积比例')
parser.add_argument('--max-area', type=float, default=0.95, help='最大面积比例')
parser.add_argument('--visualize', action='store_true', help='是否生成可视化结果')
parser.add_argument('--validate', action='store_true', help='是否验证标注结果')
args = parser.parse_args()
# 创建输出目录
os.makedirs(args.output, exist_ok=True)
# 初始化标注器
labeler = FastSAMAutoLabeler(args.model, classification_model=args.classification_model)
# 处理输入
if os.path.isfile(args.input):
# 单文件处理
result = labeler.process_image(
args.input, args.output, args.conf, args.iou, args.min_area, args.max_area
)
print(f"处理完成: {result}")
if args.visualize:
# 获取分类模型的类别名称
class_names = labeler.classification_model.names if hasattr(labeler.classification_model, 'names') else None
visualizer = ManualAnnotationVisualizer(class_names=class_names)
vis_path = os.path.join(args.output, 'visualization', f"{Path(args.input).stem}_annotated.jpg")
os.makedirs(os.path.dirname(vis_path), exist_ok=True)
visualizer.draw_annotations(
result['image_path'], result['label_path'], output_path=vis_path
)
# 如果需要验证,执行验证
if args.validate:
class_names = labeler.classification_model.names if hasattr(labeler.classification_model, 'names') else None
validator = AnnotationValidator(class_names=class_names)
validation_report = validator.validate_annotations(result['image_path'], result['label_path'])
print("\n标注验证报告:")
print("=" * 50)
if "error" in validation_report:
print(f"错误: {validation_report['error']}")
elif "warning" in validation_report:
print(f"警告: {validation_report['warning']}")
else:
print(f"图像尺寸: {validation_report['image_size'][0]}x{validation_report['image_size'][1]}")
print(f"总对象数: {validation_report['total_objects']}")
print("\n类别分布:")
for class_id, count in validation_report['class_distribution'].items():
class_name = validation_report['class_names'].get(class_id, f"未知类别_{class_id}")
print(f" {class_name} ({class_id}): {count}个")
if validation_report['issues']:
print(f"\n发现问题 ({len(validation_report['issues'])}个):")
for issue in validation_report['issues']:
print(f" - {issue}")
else:
print("\n标注质量良好,未发现问题。")
else:
# 目录处理
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
image_files = []
for ext in image_extensions:
pattern = f'**/*{ext}'
image_files.extend(Path(args.input).glob(pattern))
image_files.extend(Path(args.input).glob(pattern.upper()))
image_files = list(set(image_files))
print(f"找到 {len(image_files)} 张图片")
results = []
successful = 0
failed = 0
for img_path in tqdm(image_files, desc="处理图片"):
try:
result = labeler.process_image(
str(img_path), args.output, args.conf, args.iou, args.min_area, args.max_area
)
results.append(result)
successful += 1
print(f"✓ 成功处理: {img_path.name} (检测到 {result['detections_count']} 个对象)")
except Exception as e:
failed += 1
print(f"✗ 处理失败: {img_path.name} - 错误: {e}")
print(f"\n处理完成!")
print(f"成功: {successful}, 失败: {failed}")
if args.visualize and results:
# 为前几张图片生成可视化结果
# 获取分类模型的类别名称
class_names = labeler.classification_model.names if hasattr(labeler.classification_model, 'names') else None
visualizer = ManualAnnotationVisualizer(class_names=class_names)
vis_dir = os.path.join(args.output, 'visualization')
os.makedirs(vis_dir, exist_ok=True)
sample_count = min(5, len(results))
print(f"\n为前 {sample_count} 张图片生成可视化结果...")
for i, result in enumerate(results[:sample_count]):
vis_path = os.path.join(vis_dir, f"{result['image_name']}_annotated.jpg")
visualizer.draw_annotations(
result['image_path'], result['label_path'], output_path=vis_path
)
# 如果需要验证,执行验证
if args.validate and results:
print("\n开始验证标注结果...")
class_names = labeler.classification_model.names if hasattr(labeler.classification_model, 'names') else None
validator = AnnotationValidator(class_names=class_names)
total_issues = 0
for result in results:
validation_report = validator.validate_annotations(result['image_path'], result['label_path'])
if "issues" in validation_report and validation_report["issues"]:
total_issues += len(validation_report["issues"])
print(f"\n{result['image_name']}发现问题:")
for issue in validation_report["issues"]:
print(f" - {issue}")
if total_issues == 0:
print("所有标注文件验证通过,未发现问题。")
else:
print(f"\n总共发现 {total_issues} 个问题。")
if __name__ == "__main__":
main()
总结与展望
本文介绍的FastSAM自动标注工具展示了如何将先进的计算机视觉模型应用于实际数据标注任务。其主要优势包括:
- 高效性:结合FastSAM的快速分割和YOLO的准确分类,大幅提升标注效率
- 灵活性:支持参数调整和自定义过滤规则,适应不同场景需求
- 质量保证:内置验证和可视化功能,确保标注数据质量
- 易用性:简单的命令行接口,支持批量处理
未来可能的改进方向包括:
- 支持更多标注格式(如COCO、Pascal VOC)
- 添加交互式修正界面
- 集成主动学习策略,优先标注不确定性高的样本
- 优化模型推理速度,支持实时标注
这个工具不仅适用于学术研究,也可用于工业界的实际项目,为计算机视觉模型训练提供高质量的数据支持。通过本文的详http://www.devze.com细解析,读者可以深入了解实现原理,并根据自身需求进行定制化开发。
希望本指南能帮助您更高效地处理图像标注任务,欢迎在实践中进一步探索和优化这个工具。
以上就是Python结合FastSAM实现图像自动标注的完整指南的详细内容,更多关于Python FastSAM图像自动标注的资料请关注编程客栈(www.devze.com)其它相关文章!
加载中,请稍侯......
精彩评论