-
Notifications
You must be signed in to change notification settings - Fork 224
Open
Description
Thank you for your work!
I am trying to provide a demo that can run on CPU:
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
from __future__ import print_function, division
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from torchvision import transforms
from PIL import Image
import os
import yaml
from reIDmodel import ft_netAB
import matplotlib.pyplot as plt
# 命令行参数
parser = argparse.ArgumentParser(description='Demo')
parser.add_argument('--which_epoch', default=100000, type=int, help='which epoch to load model')
parser.add_argument('--name', default='E0.5new_reid0.5_w30000', type=str, help='model name')
parser.add_argument('--img_dirs', default='', type=str, help='directory for input images, separate by comma')
opt = parser.parse_args()
# 数据预处理
data_transforms = transforms.Compose([
transforms.Resize((256, 128), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 水平翻转
def fliplr(img):
'''水平翻转图像'''
inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W
img_flip = img.index_select(3, inv_idx)
return img_flip
# 特征归一化
def norm(x):
# 计算全局 L2 范数并归一化
global_norm = torch.norm(x, p=2) # p=2 表示 L2 范数
normalized_global = x / global_norm
return normalized_global
# 加载模型
def load_network(network):
save_path = os.path.join('./outputs', opt.name, 'checkpoints/id_%08d.pt' % opt.which_epoch)
state_dict = torch.load(save_path, map_location=torch.device('cpu'))
network.load_state_dict(state_dict['a'], strict=False)
return network
# 从图像中提取特征
def extract_feature(model, img_path):
img = Image.open(img_path).convert('RGB')
img = data_transforms(img)
img = img.unsqueeze(0) # 添加批次维度
# 提取特征
with torch.no_grad():
n, c, h, w = img.size()
ff = torch.FloatTensor(n, 1024).zero_()
for i in range(2):
if i == 1:
img = fliplr(img)
input_img = Variable(img)
f, x = model(input_img)
x[0] = norm(x[0])
x[1] = norm(x[1])
f = torch.cat((x[0], x[1]), dim=1) # 使用512维特征
ff = ff + f
# 归一化特征向量
ff[:, 0:512] = norm(ff[:, 0:512])
ff[:, 512:1024] = norm(ff[:, 512:1024])
return ff
def compute_similarity(features_list):
"""计算特征之间的余弦相似度"""
features = torch.cat(features_list, dim=0)
features_np = features.numpy()
print('特征形状: {}'.format(features_np.shape))
# 计算余弦相似度
cosine_similarity = np.matmul(features_np, features_np.transpose())
return cosine_similarity, features_np
def main():
# 打印程序信息
print('---------- 行人重识别推理Demo (CPU版本) ----------')
try:
# 加载配置文件
config_path = os.path.join('./outputs', opt.name, 'config.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
print(f'成功加载配置文件: {config_path}')
except Exception as e:
print(f'加载配置文件失败: {str(e)}')
print('尝试创建默认配置...')
# 如果无法加载配置文件,使用默认配置
config = {
'ID_class': 751, # Market-1501 默认类别数
'norm_id': 1,
'ID_stride': 2,
'pool': 'avg'
}
# 创建模型
print('创建模型...')
model_structure = ft_netAB(config['ID_class'], norm=config['norm_id'],
stride=config['ID_stride'], pool=config['pool'])
try:
model = load_network(model_structure)
print('成功加载模型权重')
except Exception as e:
print(f'加载模型权重失败: {str(e)}')
print('请确保模型文件存在于指定路径')
return
# 移除最终的fc层和分类器层
model.model.fc = nn.Sequential()
model.classifier1.classifier = nn.Sequential()
model.classifier2.classifier = nn.Sequential()
# 设置为评估模式
model = model.eval()
# 处理图像目录
if not opt.img_dirs:
print('请提供图像目录!使用 --img_dirs 参数,多个目录用逗号分隔')
return
import glob
img_paths = glob.glob(opt.img_dirs + '/*.png')
img_paths.sort()
if len(img_paths) == 0:
print('未找到图像文件!')
return
print(f'找到 {len(img_paths)} 张图像')
# 提取特征
features_list = []
for i, img_path in enumerate(img_paths):
print(f'处理图像 {i + 1}/{len(img_paths)}: {img_path}')
try:
feature = extract_feature(model, img_path)
features_list.append(feature)
except Exception as e:
print(f'处理图像 {img_path} 失败: {str(e)}')
if len(features_list) == 0:
print('没有成功提取任何特征!')
return
# 计算相似度
similarity_matrix, features = compute_similarity(features_list)
n = len(img_paths)
# 在函数末尾添加热图
plt.figure(figsize=(n, n))
plt.imshow(similarity_matrix, cmap='viridis')
plt.colorbar(label='cosine similarity')
for i in range(n):
for j in range(n):
plt.text(j, i, f'{similarity_matrix[i, j]:.2f}',
ha='center', va='center',
color='white' if similarity_matrix[i, j] < 0.7 else 'black')
plt.xticks(range(n), [os.path.basename(path) for path in img_paths], rotation=90)
plt.yticks(range(n), [os.path.basename(path) for path in img_paths])
plt.tight_layout()
plt.savefig('similarity_heatmap.png', dpi=300)
plt.close()
print("相似度热图已保存为 similarity_heatmap.png")
try:
plt.show()
except Exception as e:
print(f'无法显示图像: {str(e)}')
print('余弦相似度矩阵:')
print(similarity_matrix)
if __name__ == '__main__':
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels