Skip to content

Unexpected keys in state_dict when loading checkpoint for inference #75

@Yuhuoo

Description

@Yuhuoo

Thanks for your nice work.

I encountered an error when trying to use a trained checkpoint for inference. The model was trained successfully, but fails during inference due to unexpected keys in the state_dict.

  1. I trained the model using this command:
    '''
    CUDA_VISIBLE_DEVICES=5 python -m src.main +experiment=dl3dv
    data_loader.train.batch_size=1
    dataset.roots=[dl3dv-dataset]
    dataset.view_sampler.num_target_views=4
    dataset.view_sampler.num_context_views=2
    dataset.min_views=2
    dataset.max_views=6
    trainer.max_steps=100000
    trainer.num_nodes=1
    model.encoder.num_scales=2
    model.encoder.upsample_factor=4
    model.encoder.lowest_feature_resolution=8
    model.encoder.monodepth_vit_type=vitb
    checkpointing.pretrained_model=pretrained/depthsplat-gs-base-dl3dv-256x448-randview2-6-02c7b19d.pth
    wandb.project=depthsplat
    wandb.mode=disabled
    output_dir=checkpoints/dl3dv-256x448-depthsplat-base-randview2-6
    '''
  2. Then I try to run inference with the checkpoint:
    '''
    CUDA_VISIBLE_DEVICES=5 python -m src.main
    +experiment=dl3dv
    dataset.test_chunk_interval=1
    dataset.roots=[dl3dv-dataset/dl3dv_960p]
    dataset.image_shape=[512,960]
    dataset.ori_image_shape=[540,960]
    model.encoder.upsample_factor=8
    model.encoder.lowest_feature_resolution=8
    model.encoder.gaussian_adapter.gaussian_scale_max=0.1
    model.encoder.monodepth_vit_type=vitb
    checkpointing.pretrained_model=checkpoints/dl3dv-256x448-depthsplat-base-randview2-6/checkpoints/epoch_60-step_30000.ckpt
    mode=test
    dataset/view_sampler=evaluation
    dataset.view_sampler.num_context_views=12
    dataset.view_sampler.index_path=assets/dl3dv_start_0_distance_100_ctx_12v_video.json
    test.save_video=true
    test.stablize_camera=true
    test.compute_scores=false
    test.render_chunk_size=10
    output_dir=outputs/depthsplat-dl3dv-512x960-custom
    '''

But the following error occurs:
'''
Unexpected key(s) in state_dict: "encoder.depth_predictor.mv_pyramid.stages.1.0.weight", "encoder.depth_predictor.mv_pyramid.stages.1.0.bias", "encoder.depth_predictor.mv_pyramid.stages.1.2.weight", "encoder.depth_predictor.mv_pyramid.stages.1.2.bias", "encoder.depth_predictor.mono_pyramid.stages.1.0.weight", "encoder.depth_predictor.mono_pyramid.stages.1.0.bias", "encoder.depth_predictor.mono_pyramid.stages.1.2.weight", "encoder.depth_predictor.mono_pyramid.stages.1.2.bias", "encoder.depth_predictor.regressor.1.0.weight", "encoder.depth_predictor.regressor.1.0.bias", "encoder.depth_predictor.regressor.1.1.weight", "encoder.depth_predictor.regressor.1.1.bias",.....
'''

Could you help me solve this problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions