During fine-tuning, the world model's loss approaches 0 and accuracy approaches 1.0 very quickly (after approximately 2000 steps). In inference, the fine-tuned model achieves a success rate of around 98% on the Spatial task.
To verify what the model is actually learning, I performed a sanity check by feeding fake data (masking the visual input / using dummy values) into the world model. Surprisingly, the accuracy remained unchanged (~98%). This suggests that the model might be completely ignoring the visual input conditioning and relying on other signals.
Could this behavior be related to the small per-device batch size (e.g., BatchNorm statistics issues), or is this an expected phenomenon for this model?
Experimental Setup: I am fine-tuning the model using the following configuration:
- Hardware: 4x GPUs (48GB VRAM each).
- Batch Size: Per-device batch size = 1, with gradient accumulation steps = 5.
Any insights would be appreciated.
