Skip to content

Using pre trained weights, but the prediction results are poor #86

@2681704096

Description

@2681704096

Hello, thank you for providing such valuable code, which has greatly facilitated my research. But when I tested on my own radar dataset, I found that the prediction results were very poor, but I don't know where the problem lies. Can you give me some advice?

This is my test plot:

Image

This is my test code:

from dgmr import DGMR
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

def create_synthetic_radar_frames(num_frames=4, size=256, normalize=True):
    """
    创建合成雷达图像序列,模拟移动的降水系统

    参数:
        num_frames: 输入帧数量
        size: 图像大小
        normalize: 是否归一化到[0, 1](模型期望的输入格式)

    返回:
        shape为(1, num_frames, 1, size, size)的tensor,值域为[0, 1]
    """
    frames = []

    for t in range(num_frames):
        frame = np.zeros((size, size))

        # 创建一个移动的降水系统(椭圆形)
        center_x = 80 + t * 10  # 从左向右移动
        center_y = 128 + t * 5   # 稍微向下移动

        # 主要降水区域(模拟真实雷达数据,范围0-128 mm/h)
        y, x = np.ogrid[:size, :size]
        mask1 = ((x - center_x)**2 / 40**2 + (y - center_y)**2 / 30**2) <= 1
        # 使用真实的降水强度范围(约50-100 mm/h)
        frame[mask1] = 80 + np.random.randn(*frame[mask1].shape) * 15

        # 次要降水区域(约20-40 mm/h)
        center_x2 = center_x - 30
        center_y2 = center_y + 20
        mask2 = ((x - center_x2)**2 / 25**2 + (y - center_y2)**2 / 20**2) <= 1
        frame[mask2] = 30 + np.random.randn(*frame[mask2].shape) * 10

        # 限制值域到[0, 128] mm/h(真实雷达数据的范围)
        frame = np.clip(frame, 0, 128)
        frames.append(frame)

    # 转换为numpy数组: (num_frames, size, size)
    frames_array = np.array(frames)
    
    # 归一化到[0, 1](模型期望的输入格式)
    if normalize:
        frames_array = frames_array / 128.0
    
    # 转换为tensor: (1, num_frames, 1, size, size)
    frames_tensor = torch.FloatTensor(frames_array).unsqueeze(0).unsqueeze(2)
    return frames_tensor

def visualize_predictions(input_frames, predictions, save_path="dgmr_prediction_test.png", 
                         denormalize=True, max_value=128.0):
    """
    可视化输入帧和预测帧

    参数:
        input_frames: 输入帧 (1, 4, 1, 256, 256),归一化后的值[0, 1]
        predictions: 预测帧 (1, 1, 1, 256, 256),归一化后的值[0, 1]
        save_path: 保存路径
        denormalize: 是否反归一化到mm/h单位
        max_value: 归一化使用的最大值(用于反归一化)
    """
    # 转换为numpy并去除batch和channel维度
    input_np = input_frames[0, :, 0, :, :].cpu().numpy()  # (4, 256, 256)
    pred_np = predictions[0, :, 0, :, :].cpu().numpy()     # (1, 256, 256)

    # 反归一化到mm/h单位(用于可视化)
    if denormalize:
        input_np = input_np * max_value
        pred_np = pred_np * max_value
        vmax = 128.0  # mm/h
        vmin = 0.0
        unit = "mm/h"
    else:
        vmax = 1.0
        vmin = 0.0
        unit = "normalized"

    # 创建图形:4个输入帧 + 1个预测帧
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    # 显示输入帧
    for i in range(4):
        row = i // 3
        col = i % 3
        ax = axes[row, col]
        im = ax.imshow(input_np[i], cmap='viridis', vmin=vmin, vmax=vmax)
        ax.set_title(f'Input t-{4-i}', fontsize=12, fontweight='bold')
        ax.axis('off')
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    # 显示预测帧(在第二行中间)
    ax = axes[1, 1]
    im = ax.imshow(pred_np[0], cmap='viridis', vmin=vmin, vmax=vmax)
    ax.set_title('Prediction t+1 (5min)', fontsize=12, fontweight='bold')
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    # 隐藏其他空白位置
    axes[1, 0].axis('off')
    axes[1, 2].axis('off')

    title = f'DGMR Model: Input Frames and Prediction ({unit})'
    plt.suptitle(title, fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ 可视化结果已保存到: {save_path}")
    plt.close()

def main():
    print("=" * 60)
    print("DGMR 模型可视化测试")
    print("=" * 60)

    # 1. 创建模型
    print("\n[1/4] 初始化模型...")
    # 注意:预训练模型是用 forecast_steps=18 训练的,必须使用相同的配置
    # 否则权重不匹配会导致预测结果异常
    model = DGMR(
        forecast_steps=18,  # 必须与预训练模型一致
        input_channels=1,
        output_shape=256,
        latent_channels=768,
        context_channels=384,
        num_samples=6,
    )

    # 2. 加载权重
    print("[2/4] 加载预训练权重...")
    weight_path = Path("models/pytorch_model.bin")
    if not weight_path.exists():
        print(f"❌ 错误: 权重文件不存在: {weight_path}")
        return

    state_dict = torch.load(weight_path, map_location="cpu")
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    model.eval()
    print("✓ 权重加载成功")
    if missing_keys:
        print(f"  警告: {len(missing_keys)} 个权重未加载(可能是正常的,如果模型结构有变化)")
    if unexpected_keys:
        print(f"  警告: {len(unexpected_keys)} 个权重未使用(可能是正常的)")

    # 3. 生成合成雷达数据(会自动归一化到[0, 1])
    print("[3/4] 生成合成雷达图像...")
    input_frames = create_synthetic_radar_frames(num_frames=4, size=256, normalize=True)
    print(f"✓ 输入形状: {input_frames.shape}")
    print(f"  输入值域: [{input_frames.min():.3f}, {input_frames.max():.3f}] (归一化后)")

    # 4. 进行预测
    print("[4/4] 运行预测...")
    with torch.no_grad():
        predictions = model(input_frames)  # (1, 18, 1, 256, 256)
    print(f"✓ 预测形状: {predictions.shape}")
    print(f"  预测值域: [{predictions.min():.3f}, {predictions.max():.3f}] (归一化后)")
    
    # 只取第一帧用于可视化(t+1,即未来第一帧)
    predictions_first_frame = predictions[:, 0:1, :, :, :]  # (1, 1, 1, 256, 256)
    print(f"  提取第一帧形状: {predictions_first_frame.shape}")

    # 5. 可视化结果(会自动反归一化到mm/h单位)
    print("\n生成可视化...")
    visualize_predictions(input_frames, predictions_first_frame, denormalize=True, max_value=128.0)

    # 6. 统计信息(反归一化后)
    input_denorm = input_frames * 128.0
    pred_denorm_all = predictions * 128.0
    pred_denorm_first = predictions_first_frame * 128.0
    
    print("\n" + "=" * 60)
    print("测试完成!")
    print("=" * 60)
    print(f"\n输入: {input_frames.shape[1]} 帧历史雷达图像")
    print(f"输出: {predictions.shape[1]} 帧未来预测(仅显示第一帧)")
    print(f"图像尺寸: {predictions.shape[3]}x{predictions.shape[4]}")
    print(f"\n统计信息(归一化后):")
    print(f"  输入范围: [{input_frames.min():.3f}, {input_frames.max():.3f}]")
    print(f"  所有预测范围: [{predictions.min():.3f}, {predictions.max():.3f}]")
    print(f"  第一帧预测范围: [{predictions_first_frame.min():.3f}, {predictions_first_frame.max():.3f}]")
    print(f"\n统计信息(反归一化后,mm/h):")
    print(f"  输入范围: [{input_denorm.min():.2f}, {input_denorm.max():.2f}] mm/h")
    print(f"  第一帧预测范围: [{pred_denorm_first.min():.2f}, {pred_denorm_first.max():.2f}] mm/h")
    print(f"  第一帧预测均值: {pred_denorm_first.mean():.2f} mm/h")
    print(f"  第一帧预测标准差: {pred_denorm_first.std():.2f} mm/h")

if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions