-
-
Notifications
You must be signed in to change notification settings - Fork 68
Open
Description
Hello, thank you for providing such valuable code, which has greatly facilitated my research. But when I tested on my own radar dataset, I found that the prediction results were very poor, but I don't know where the problem lies. Can you give me some advice?
This is my test plot:
This is my test code:
from dgmr import DGMR
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
def create_synthetic_radar_frames(num_frames=4, size=256, normalize=True):
"""
创建合成雷达图像序列,模拟移动的降水系统
参数:
num_frames: 输入帧数量
size: 图像大小
normalize: 是否归一化到[0, 1](模型期望的输入格式)
返回:
shape为(1, num_frames, 1, size, size)的tensor,值域为[0, 1]
"""
frames = []
for t in range(num_frames):
frame = np.zeros((size, size))
# 创建一个移动的降水系统(椭圆形)
center_x = 80 + t * 10 # 从左向右移动
center_y = 128 + t * 5 # 稍微向下移动
# 主要降水区域(模拟真实雷达数据,范围0-128 mm/h)
y, x = np.ogrid[:size, :size]
mask1 = ((x - center_x)**2 / 40**2 + (y - center_y)**2 / 30**2) <= 1
# 使用真实的降水强度范围(约50-100 mm/h)
frame[mask1] = 80 + np.random.randn(*frame[mask1].shape) * 15
# 次要降水区域(约20-40 mm/h)
center_x2 = center_x - 30
center_y2 = center_y + 20
mask2 = ((x - center_x2)**2 / 25**2 + (y - center_y2)**2 / 20**2) <= 1
frame[mask2] = 30 + np.random.randn(*frame[mask2].shape) * 10
# 限制值域到[0, 128] mm/h(真实雷达数据的范围)
frame = np.clip(frame, 0, 128)
frames.append(frame)
# 转换为numpy数组: (num_frames, size, size)
frames_array = np.array(frames)
# 归一化到[0, 1](模型期望的输入格式)
if normalize:
frames_array = frames_array / 128.0
# 转换为tensor: (1, num_frames, 1, size, size)
frames_tensor = torch.FloatTensor(frames_array).unsqueeze(0).unsqueeze(2)
return frames_tensor
def visualize_predictions(input_frames, predictions, save_path="dgmr_prediction_test.png",
denormalize=True, max_value=128.0):
"""
可视化输入帧和预测帧
参数:
input_frames: 输入帧 (1, 4, 1, 256, 256),归一化后的值[0, 1]
predictions: 预测帧 (1, 1, 1, 256, 256),归一化后的值[0, 1]
save_path: 保存路径
denormalize: 是否反归一化到mm/h单位
max_value: 归一化使用的最大值(用于反归一化)
"""
# 转换为numpy并去除batch和channel维度
input_np = input_frames[0, :, 0, :, :].cpu().numpy() # (4, 256, 256)
pred_np = predictions[0, :, 0, :, :].cpu().numpy() # (1, 256, 256)
# 反归一化到mm/h单位(用于可视化)
if denormalize:
input_np = input_np * max_value
pred_np = pred_np * max_value
vmax = 128.0 # mm/h
vmin = 0.0
unit = "mm/h"
else:
vmax = 1.0
vmin = 0.0
unit = "normalized"
# 创建图形:4个输入帧 + 1个预测帧
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# 显示输入帧
for i in range(4):
row = i // 3
col = i % 3
ax = axes[row, col]
im = ax.imshow(input_np[i], cmap='viridis', vmin=vmin, vmax=vmax)
ax.set_title(f'Input t-{4-i}', fontsize=12, fontweight='bold')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# 显示预测帧(在第二行中间)
ax = axes[1, 1]
im = ax.imshow(pred_np[0], cmap='viridis', vmin=vmin, vmax=vmax)
ax.set_title('Prediction t+1 (5min)', fontsize=12, fontweight='bold')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# 隐藏其他空白位置
axes[1, 0].axis('off')
axes[1, 2].axis('off')
title = f'DGMR Model: Input Frames and Prediction ({unit})'
plt.suptitle(title, fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"✓ 可视化结果已保存到: {save_path}")
plt.close()
def main():
print("=" * 60)
print("DGMR 模型可视化测试")
print("=" * 60)
# 1. 创建模型
print("\n[1/4] 初始化模型...")
# 注意:预训练模型是用 forecast_steps=18 训练的,必须使用相同的配置
# 否则权重不匹配会导致预测结果异常
model = DGMR(
forecast_steps=18, # 必须与预训练模型一致
input_channels=1,
output_shape=256,
latent_channels=768,
context_channels=384,
num_samples=6,
)
# 2. 加载权重
print("[2/4] 加载预训练权重...")
weight_path = Path("models/pytorch_model.bin")
if not weight_path.exists():
print(f"❌ 错误: 权重文件不存在: {weight_path}")
return
state_dict = torch.load(weight_path, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
model.eval()
print("✓ 权重加载成功")
if missing_keys:
print(f" 警告: {len(missing_keys)} 个权重未加载(可能是正常的,如果模型结构有变化)")
if unexpected_keys:
print(f" 警告: {len(unexpected_keys)} 个权重未使用(可能是正常的)")
# 3. 生成合成雷达数据(会自动归一化到[0, 1])
print("[3/4] 生成合成雷达图像...")
input_frames = create_synthetic_radar_frames(num_frames=4, size=256, normalize=True)
print(f"✓ 输入形状: {input_frames.shape}")
print(f" 输入值域: [{input_frames.min():.3f}, {input_frames.max():.3f}] (归一化后)")
# 4. 进行预测
print("[4/4] 运行预测...")
with torch.no_grad():
predictions = model(input_frames) # (1, 18, 1, 256, 256)
print(f"✓ 预测形状: {predictions.shape}")
print(f" 预测值域: [{predictions.min():.3f}, {predictions.max():.3f}] (归一化后)")
# 只取第一帧用于可视化(t+1,即未来第一帧)
predictions_first_frame = predictions[:, 0:1, :, :, :] # (1, 1, 1, 256, 256)
print(f" 提取第一帧形状: {predictions_first_frame.shape}")
# 5. 可视化结果(会自动反归一化到mm/h单位)
print("\n生成可视化...")
visualize_predictions(input_frames, predictions_first_frame, denormalize=True, max_value=128.0)
# 6. 统计信息(反归一化后)
input_denorm = input_frames * 128.0
pred_denorm_all = predictions * 128.0
pred_denorm_first = predictions_first_frame * 128.0
print("\n" + "=" * 60)
print("测试完成!")
print("=" * 60)
print(f"\n输入: {input_frames.shape[1]} 帧历史雷达图像")
print(f"输出: {predictions.shape[1]} 帧未来预测(仅显示第一帧)")
print(f"图像尺寸: {predictions.shape[3]}x{predictions.shape[4]}")
print(f"\n统计信息(归一化后):")
print(f" 输入范围: [{input_frames.min():.3f}, {input_frames.max():.3f}]")
print(f" 所有预测范围: [{predictions.min():.3f}, {predictions.max():.3f}]")
print(f" 第一帧预测范围: [{predictions_first_frame.min():.3f}, {predictions_first_frame.max():.3f}]")
print(f"\n统计信息(反归一化后,mm/h):")
print(f" 输入范围: [{input_denorm.min():.2f}, {input_denorm.max():.2f}] mm/h")
print(f" 第一帧预测范围: [{pred_denorm_first.min():.2f}, {pred_denorm_first.max():.2f}] mm/h")
print(f" 第一帧预测均值: {pred_denorm_first.mean():.2f} mm/h")
print(f" 第一帧预测标准差: {pred_denorm_first.std():.2f} mm/h")
if __name__ == "__main__":
main()
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels