Skip to content

Conversation

@jleinonen
Copy link
Collaborator

PhysicsNeMo Pull Request

Adds an example of training a temporal interpolation model withphysicsnemo.models.afno.ModAFNO to examples/weather/temporal_interpolation/. See the README.md located there for details.

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Copy link
Collaborator

@CharlelieLrt CharlelieLrt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass on the PR.
Overall looks fine, except a few things that are a little unclear here and there.
My biggest concern is the Trainer. I know this class is used in many of our examples, but here I don't really see the point. I would support having a dedicated abstraction for a trainer only if you expect users to import it into their own codebase and use it in their own external application.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a complete temporal interpolation example to examples/weather/temporal_interpolation/, demonstrating how to train a ModAFNO-based model that increases weather forecast temporal resolution from 6-hour to 1-hour intervals. The implementation includes custom DALI datapipes for efficient HDF5 data loading with random timestep sampling, a latitude-weighted geometric loss function, a modular Trainer class with MLflow/Wandb integration, and comprehensive configuration files. The example is designed for ERA5 hourly atmospheric data (73 channels) and supports distributed training with checkpoint resumption. While the code is well-structured and follows PhysicsNeMo patterns, several issues require attention before merge.

Critical Issues

Training Script (train.py)

The return type annotations are malformed and could mislead developers. Line 56 declares -> tuple[InterpClimateDatapipe, InterpClimateDatapipe] but the function returns three values including num_aux_channels. Line 219 has double brackets: -> tuple[[tuple[torch.Tensor, torch.Tensor], torch.Tensor]] instead of -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor].

Lines 240-242 access batch["latlon"], batch["geopotential"], and batch["land_sea_mask"] without verifying these invariants were enabled in the config. If a user disables these features but the code still tries to access them, it will raise a KeyError at runtime rather than at config validation time.

Line 204 uses .pop("model_name") which destructively modifies the config dictionary and will raise a KeyError if the key is missing. Consider using .get() with a default or explicit validation.

Trainer Class (utils/trainer.py)

Lines 129-132 contain a critical bug in epoch initialization. When load_epoch=None (no checkpoint loaded), the code still increments self.start_epoch = epoch + 1, causing training to incorrectly start from epoch 2 instead of epoch 1. This happens because epoch is initialized to 1 on line 127 regardless of whether a checkpoint was loaded.

The validation loss computation on lines 287-310 uses inconsistent averaging compared to training. Training loss divides by total samples seen (self.samples / self.epoch), while validation averages per batch without accounting for potentially different batch sizes. This could lead to misleading validation metrics if the final validation batch is smaller than others.

Lines 147-149 compute batches = samples // batch_size without validating batch_size > 0, risking division by zero. Additionally, there's no handling for the case where samples % batch_size != 0, meaning the last partial batch is silently dropped.

Data Pipeline (datapipe/climate_interp.py)

Line 77 uses np.random.shuffle(self.indices) without seeding, breaking reproducibility across epochs and workers. The commented-out code on line 75 shows the correct seeded approach that should be used instead.

Lines 181-182 leave debug print statements in production code, which will spam logs during distributed training. These should either be removed or converted to proper logger calls with appropriate rank guards.

Configuration Files

Both train_interp.yaml and train_interp_lite.yaml contain hardcoded absolute paths (lines 39-42 in train_interp.yaml) like /data/era5-73varQ-hourly. The invariant files reference a different dataset subdirectory (/data/era5-wind_gust/) than the main data, which may cause confusion or runtime errors if the directory structure doesn't match expectations.

The batch_size_train: 1 setting in train_interp.yaml (line 42) will severely underutilize GPU memory and compute, resulting in very slow training. This should be increased or documented as a deliberate choice.

Recommendations

  1. Fix return type annotations in train.py lines 56 and 219 to accurately reflect the actual return values.

  2. Add defensive checks in train.py around lines 240-242 to verify invariants are enabled before accessing them, or restructure the code to conditionally build the invariants dict.

  3. Correct epoch initialization logic in trainer.py lines 129-132 to properly handle the load_epoch=None case.

  4. Restore seeded shuffling in climate_interp.py line 77 by uncommenting line 75 and removing line 77.

  5. Remove debug print statements from climate_interp.py and replace with proper logging.

  6. Make paths configurable by using placeholders or environment variables instead of hardcoded absolute paths in both config files.

  7. Add input validation for batch_size, dimensions, and other critical parameters to fail fast with clear error messages.

  8. Consider addressing the modularization feedback from previous reviews. The Trainer abstraction may be appropriate here given its reusability, but ensure the core training logic remains understandable by including better docstrings and possibly a training flow diagram in the README.

Confidence: 3/5 - The implementation is generally solid and follows PhysicsNeMo patterns, but contains several bugs and design issues that could cause runtime failures or confusing behavior. The modularization concerns raised in previous reviews are valid but subjective. Most issues are straightforward to fix, though the epoch initialization bug and validation loss computation require careful attention to avoid breaking checkpoint resumption and metric tracking.

10 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The developer has addressed many of the previous feedback items by: (1) adding a validation script (validate.py) with area-weighted error histograms and baseline comparison; (2) fixing the missing requirements.txt with MLflow and W&B dependencies; (3) expanding the README with validation instructions and more technical background; and (4) implementing validation-specific utilities. However, several critical issues remain: the validation command example in the README is missing a closing quote (line 137), the train.py script still has unguarded .pop("model_name") and nested config access that will raise KeyError if keys are missing, the validation script has hardcoded grid assumptions and potential division-by-zero, and the datapipe/climate_interp.py can generate duplicate interpolation indices when interp_idx == self.stride.

Important Files Changed

Filename Score Overview
examples/weather/temporal_interpolation/README.md 4/5 Comprehensive documentation added with model theory, setup, training, and validation instructions; syntax error on line 137 (missing closing quote in validation command).
examples/weather/temporal_interpolation/train.py 3/5 Main training script added with Hydra orchestration; unguarded .pop("model_name") on line 215 and nested config access on line 401 will raise KeyError if keys missing.
examples/weather/temporal_interpolation/validate.py 3/5 Validation script with histogram-based error metrics; hardcoded 721-lat grid breaks for other resolutions, potential division by zero on line 131, lacks bounds checking on line 95.
examples/weather/temporal_interpolation/datapipe/climate_interp.py 4/5 Specialized datapipe for interpolation training; interp_idx can equal self.stride creating duplicates, unused Tensor import, MRO may not work as intended.
examples/weather/temporal_interpolation/utils/trainer.py 3/5 Reusable trainer class with checkpoint resume and WandB integration; epoch numbering issue if no checkpoint found, validation loss accumulation doesn't match training's per-sample weighting.
examples/weather/temporal_interpolation/utils/loss.py 4/5 Latitude-weighted MSE loss for spherical geometry; potential NaN if all weights zero, otherwise well-implemented with pole handling.
examples/weather/temporal_interpolation/utils/distribute.py 4.5/5 DDP initialization helper; unusual but harmless CUDA stream pattern, otherwise clean and correct.
examples/weather/temporal_interpolation/config/train_interp.yaml 4/5 Production training config with hardcoded absolute paths that users must modify; otherwise well-documented and reasonable.
examples/weather/temporal_interpolation/config/train_interp_lite.yaml 4/5 Lightweight test config with hardcoded paths; clearly marked as test-only, appropriate for smoke testing.
examples/weather/temporal_interpolation/data/data.json 5/5 Static metadata for 73-channel ERA5 dataset; well-formed JSON with no issues.
examples/weather/temporal_interpolation/requirements.txt 4/5 Dependency declaration for MLflow and W&B; missing trailing newline will cause POSIX warnings.

Confidence score: 3/5

  • This PR requires careful review due to multiple unhandled edge cases and potential runtime crashes in critical training and validation paths.
  • Score reflects unguarded dictionary access in train.py (KeyError risks on lines 215, 401), hardcoded grid assumptions in validate.py that break portability, potential duplicate indices in the datapipe, missing closing quote in README validation command, and validation loss accumulation inconsistency. These are not just untested edge cases—they are likely failure modes in common usage scenarios (missing config keys, different data resolutions, checkpoint resume).
  • Pay close attention to train.py (lines 215, 401), validate.py (lines 95, 131, 200-202), datapipe/climate_interp.py (line 90), and README.md (line 137). The validation script and datapipe also need bounds checking and grid size validation.

11 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR.

The developer has added a train_interp_lite.yaml configuration file to the temporal interpolation example at examples/weather/temporal_interpolation/config/train_interp_lite.yaml. Based on the file metadata showing a confidence score of 0 and empty critical sections, this appears to be a new file that was not fully analyzed in previous reviews. The "lite" suffix suggests this is a simplified or resource-constrained variant of the main train_interp.yaml configuration, likely intended for testing or demonstration on smaller infrastructure. This configuration would allow users to experiment with the temporal interpolation training pipeline without requiring full-scale compute resources or complete ERA5 datasets. The addition follows the repository pattern seen in other examples where multiple configuration variants (full-scale vs. lightweight) are provided to accommodate different user environments and use cases.


Important Files Changed

Filename Score Overview
examples/weather/temporal_interpolation/config/train_interp_lite.yaml 3/5 New lightweight configuration variant added for temporal interpolation training, details not yet fully analyzed

Confidence score: 3/5

  • This change introduces a new configuration file that likely reduces resource requirements for the temporal interpolation example, making it more accessible for users with limited infrastructure.
  • Score reflects the lack of detailed analysis in the file metadata (empty critical sections and 0 initial confidence score suggest incomplete review data), preventing assessment of potential hardcoded paths, resource settings, or compatibility issues that may exist.
  • The train_interp_lite.yaml file requires careful review to ensure it provides a genuinely usable lightweight training path and does not contain the same hardcoded absolute path issues identified in train_interp.yaml (lines 34-48 of that file).

11 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@CharlelieLrt CharlelieLrt changed the base branch from main to 1.3.0-rc October 30, 2025 22:29
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a comprehensive example for training temporal interpolation models using physicsnemo.models.afno.ModAFNO, including training/validation scripts, custom datapipes, and configuration files.

Key Changes:

  • New example in examples/weather/temporal_interpolation/ with complete training pipeline
  • Custom InterpClimateDatapipe extending PhysicsNeMo's climate datapipe for interpolation tasks
  • Trainer class implementing the training loop with checkpoint management and distributed training support
  • Validation script for computing error histograms across interpolation timesteps
  • Geometric L2 loss with latitude weighting for proper error calculation on spherical grids
  • Full documentation in README with setup, training, and validation instructions

Implementation Quality:
The code is generally well-structured and follows good practices. However, there are several logic issues that need attention, particularly around error handling for missing config keys, epoch initialization, and data reproducibility.

Confidence Score: 2/5

  • This PR has multiple logic issues that could cause runtime errors or incorrect behavior
  • Several critical issues exist: (1) epoch initialization bug when load_epoch=None causes training to start from epoch 2, (2) potential KeyError exceptions in train.py when config keys are missing, (3) reproducibility issue with unseeded shuffle in datapipe, (4) inconsistent loss averaging between training and validation, (5) potential division by zero in trainer.py. While these are fixable, they affect core functionality.
  • Pay close attention to utils/trainer.py (epoch initialization, division by zero, loss averaging), train.py (KeyError handling), and datapipe/climate_interp.py (reproducibility)

Important Files Changed

File Analysis

Filename Score Overview
examples/weather/temporal_interpolation/train.py 3/5 Main training script with setup functions for datapipes, model, and optimizer; has several potential KeyError issues when accessing config keys without fallbacks
examples/weather/temporal_interpolation/validate.py 3/5 Validation script for computing error histograms; has potential division by zero and boundary issues with step calculation
examples/weather/temporal_interpolation/utils/trainer.py 2/5 Training loop implementation; has critical issues with epoch initialization when load_epoch=None, division by zero risk, and inconsistent loss averaging between training/validation
examples/weather/temporal_interpolation/datapipe/climate_interp.py 2/5 Custom datapipe for interpolation; has reproducibility issue with unseeded shuffle and potential duplicate step generation when interp_idx == self.stride

Sequence Diagram

sequenceDiagram
    participant User
    participant Main as train.py::main()
    participant Setup as setup_trainer()
    participant DP as InterpClimateDatapipe
    participant Model as ModAFNO
    participant Trainer
    participant Loss as GeometricL2Loss

    User->>Main: Start training
    Main->>Setup: Initialize training components
    Setup->>DP: setup_datapipes()
    DP-->>Setup: train & valid datapipes
    Setup->>Model: setup_model()
    Model-->>Setup: ModAFNO instance
    Setup->>Setup: distribute_model()
    Setup->>Setup: setup_optimizer()
    Setup->>Loss: GeometricL2Loss()
    Loss-->>Setup: Loss function
    Setup->>Trainer: Trainer(model, datapipes, optimizer, loss)
    Trainer-->>Setup: trainer instance
    Setup-->>Main: trainer
    
    Main->>Trainer: fit()
    loop For each epoch
        loop For each batch
            Trainer->>DP: next(train_iterator)
            DP-->>Trainer: batch data
            Trainer->>Trainer: input_output_from_batch_data()
            Trainer->>Model: forward(*invar)
            Model-->>Trainer: predictions
            Trainer->>Loss: loss(pred, true)
            Loss-->>Trainer: loss value
            Trainer->>Trainer: optimizer.step()
        end
        Trainer->>Trainer: validate_on_epoch()
        loop Validation batches
            Trainer->>DP: valid_datapipe
            DP-->>Trainer: valid batch
            Trainer->>Model: eval_step()
            Model-->>Trainer: predictions
            Trainer->>Loss: loss(pred, true)
        end
        Trainer->>Trainer: save_checkpoint()
        Trainer->>Trainer: lr_scheduler.step()
    end
    Trainer-->>Main: Training complete
    Main-->>User: Done
Loading

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

steps = np.arange(self.stride + 1)
else:
steps = np.array([0, self.stride, interp_idx])
state_seq = self._load_sequence(year_idx, in_idx, steps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a mismatch in the number of arguments passed to self._load_sequence, it expects 3 (self, year_idx, sample_idx), but 4 are given here. It does not seem to accept a steps argument

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self._load_sequence that actually gets called is here and it does accept a steps argument:

def _load_sequence(
self, year_idx: int, idx: int, steps: np.ndarray, num_retries: int = 10
) -> np.ndarray:

This is different from the call signature in ClimateDaliExternalSource in PhysicsNeMo and I understand it's probably confusing but it doesn't break anything in the example (I have run the training example end-to-end and it runs fine).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand it's probably confusing

Yes, that's so confusing that it's unmaintainable. I gave a try to refactor these sources in climate_interp.py but the class inheritance is so convoluted that I couldn't even get started. Also, the fact that all classes in climate_interp.py have no explicit arguments, but *args and **kwargs that are implicitly passed to the parent class makes it even more difficult to figure out where to start. This pattern is ok when the inheritance is straightforward, but it should be avoided with such a complex inheritance tree.

Copy link
Collaborator

@CharlelieLrt CharlelieLrt Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InterpHDF5DaliExternalSource

such a complex inheritance tree

Actually that's not even a tree, but almost cyclic. This needs to be cleaned up a bit, I think. Would there be a way to merge the two classes that you defined in the climate_interp.py into one? For me it's not clear why one should inherit from the other.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds comprehensive temporal interpolation model example using ModAFNO to increase temporal resolution of AI-based weather forecasts from 6-hour to 1-hour intervals.

Key Changes:

  • Complete training pipeline with custom datapipes for HDF5 climate data at hourly resolution
  • ModAFNO-based interpolation model with conditioning on interpolation timestep
  • Custom training loop with checkpoint management, distributed training support, and MLFlow/WandB logging
  • Validation script for computing error histograms across interpolation steps
  • Latitude-weighted L2 loss function with proper geometric weighting
  • Comprehensive documentation and configuration files for both production and testing

Notable Implementation Details:

  • Deterministic seeded sampling ensures reproducibility across distributed workers
  • IO retry logic handles occasional filesystem errors in HDF5 reading
  • Static capture wrappers optimize training and evaluation loops
  • Recent commits added fixes for step clamping and removed unused imports

Minor Issue:

  • train.py:215 uses pop("model_name") without default, which will fail if the key is missing from config (though all provided configs include it)

Confidence Score: 4/5

  • This PR is safe to merge with only one minor syntax issue that should be addressed
  • The implementation is comprehensive and well-structured with proper error handling, documentation, and testing support. Recent commits have already addressed several issues (step clamping, unused imports). The only remaining issue is a missing default value in pop("model_name") which could cause a KeyError if the config is incomplete, though all provided configs include this field. The code demonstrates good practices with proper license headers, type hints, docstrings, and distributed training support.
  • examples/weather/temporal_interpolation/train.py needs the model_name KeyError fix on line 215

Important Files Changed

File Analysis

Filename Score Overview
examples/weather/temporal_interpolation/train.py 4/5 Adds main training script for temporal interpolation model with ModAFNO; includes comprehensive setup for datapipes, models, optimizers, and logging with MLFlow/WandB support
examples/weather/temporal_interpolation/validate.py 5/5 Validation script for computing interpolation error histograms; includes fixes for step clamping and unused imports from recent commits
examples/weather/temporal_interpolation/utils/trainer.py 4/5 Custom training loop implementation with checkpoint management, validation, and WandB logging; has some edge cases with epoch tracking
examples/weather/temporal_interpolation/utils/loss.py 5/5 Implements latitude-weighted L2 loss with proper geometric weighting; well-structured with appropriate pole handling
examples/weather/temporal_interpolation/datapipe/climate_interp.py 4/5 Custom datapipe for interpolation training with specialized DALI external sources; handles HDF5 climate data with retry logic for IO errors

Sequence Diagram

sequenceDiagram
    participant User
    participant train.py
    participant Trainer
    participant DataPipe
    participant Model as ModAFNO
    participant Loss as GeometricL2Loss
    
    User->>train.py: Execute with config
    train.py->>train.py: setup_datapipes()
    train.py->>DataPipe: InterpClimateDatapipe.init()
    DataPipe-->>train.py: train_datapipe, valid_datapipe
    
    train.py->>train.py: setup_model()
    train.py->>Model: ModAFNO.init()
    Model-->>train.py: model instance
    
    train.py->>train.py: setup_optimizer()
    train.py->>Trainer: Trainer.init()
    
    Trainer->>Trainer: load_checkpoint() if load_epoch set
    
    User->>Trainer: fit()
    
    loop For each epoch
        loop For each batch
            Trainer->>DataPipe: next(train_iterator)
            DataPipe-->>Trainer: batch data
            Trainer->>Trainer: input_output_from_batch_data()
            Trainer->>Model: forward(invar)
            Model-->>Trainer: prediction
            Trainer->>Loss: loss(pred, true)
            Loss-->>Trainer: loss value
            Trainer->>Trainer: optimizer.step()
        end
        
        Trainer->>Trainer: validate_on_epoch()
        loop For each validation batch
            Trainer->>DataPipe: iterate valid_datapipe
            DataPipe-->>Trainer: validation batch
            Trainer->>Model: eval_step(invar)
            Model-->>Trainer: prediction
            Trainer->>Loss: loss(pred, true)
            Loss-->>Trainer: loss value
        end
        
        Trainer->>Trainer: lr_scheduler.step()
        Trainer->>Trainer: save_checkpoint() if checkpoint_epoch
    end
    
    Trainer-->>User: Training complete
Loading

12 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 13, 2025

Greptile Overview

Greptile Summary

This PR adds a comprehensive example for training temporal interpolation models using ModAFNO to improve the temporal resolution of weather forecasts from 6 hours to 1 hour. The implementation includes:

  • Complete training pipeline with distributed training support via torchrun/MPI
  • Custom InterpClimateDatapipe for loading hourly climate data with interpolation targets
  • Latitude-weighted geometric L2 loss function for geographically balanced training
  • Validation script for computing error histograms as a function of interpolation timestep
  • Well-documented configuration files and comprehensive README
  • Integration with MLFlow and Weights & Biases for experiment tracking

The latest commit (390d778) addressed several critical issues from previous reviews:

  • Fixed potential KeyError in train.py:215 by adding default value to pop("model_name", None)
  • Fixed hardcoded latitude grid assumption in validate.py by computing it dynamically from data shape
  • Added proper bounds checking with min() to prevent step from exceeding timesteps

The code follows PhysicsNeMo conventions, includes proper license headers, and integrates well with the existing codebase. The example is production-ready for training temporal interpolation models on large-scale climate datasets.

Confidence Score: 5/5

  • This PR is safe to merge with high confidence - all critical issues from previous reviews have been addressed
  • Score reflects thorough review showing: (1) all previously identified critical bugs have been fixed, (2) no new logic or syntax errors detected, (3) comprehensive documentation and testing support, (4) proper integration with existing PhysicsNeMo infrastructure, (5) production-ready code quality suitable for a major feature addition
  • No files require special attention - all previously flagged issues have been resolved

Important Files Changed

File Analysis

Filename Score Overview
examples/weather/temporal_interpolation/train.py 5/5 Main training script with proper config handling, model setup, and distributed training support; previous KeyError issues have been fixed
examples/weather/temporal_interpolation/validate.py 5/5 Validation script with proper error histogram computation; previous issues with hardcoded latitude grid and step bounds have been resolved
examples/weather/temporal_interpolation/utils/trainer.py 4/5 Training loop implementation with checkpoint management and validation; validation loss aggregation may need review
examples/weather/temporal_interpolation/utils/loss.py 5/5 Geometric L2 loss with latitude weighting; properly normalized weights prevent NaN issues
examples/weather/temporal_interpolation/datapipe/climate_interp.py 4/5 Custom datapipe for temporal interpolation with shuffle and HDF5 loading; reproducibility and duplicate index handling need review

Sequence Diagram

sequenceDiagram
    participant User
    participant Hydra
    participant main
    participant setup_trainer
    participant setup_datapipes
    participant setup_model
    participant Trainer
    participant DataPipe
    participant Model
    
    User->>Hydra: python train.py --config-name=train_interp.yaml
    Hydra->>main: Load config
    main->>setup_trainer: Initialize training environment
    
    setup_trainer->>DistributedManager: Initialize distributed training
    
    setup_trainer->>setup_datapipes: Create train & validation datapipes
    setup_datapipes->>ClimateDataSourceSpec: Setup data source specs
    setup_datapipes->>InterpClimateDatapipe: Create train & valid datapipes
    setup_datapipes-->>setup_trainer: Return datapipes + num_aux_channels
    
    setup_trainer->>setup_model: Create ModAFNO model
    setup_model->>ModAFNO: Initialize model with config
    setup_model-->>setup_trainer: Return model
    
    setup_trainer->>distribute_model: Wrap model for distributed training
    setup_trainer->>setup_optimizer: Create optimizer & scheduler
    
    setup_trainer->>Trainer: Create Trainer instance
    Trainer->>load_checkpoint: Load checkpoint if specified
    Trainer->>StaticCaptureTraining: Wrap train_step_forward
    Trainer->>StaticCaptureEvaluateNoGrad: Wrap eval_step
    setup_trainer-->>main: Return trainer
    
    main->>Trainer: fit()
    
    loop For each epoch
        loop For each batch in epoch
            Trainer->>DataPipe: Get next training batch
            DataPipe-->>Trainer: Return batch data
            Trainer->>input_output_from_batch_data: Convert to model input/output
            Trainer->>Model: train_step_forward(invar, outvar_true)
            Model-->>Trainer: Return loss
            Trainer->>Optimizer: Update weights
            Trainer->>WandB: Log batch metrics (optional)
        end
        
        Trainer->>Trainer: validate_on_epoch()
        loop For each validation batch
            Trainer->>DataPipe: Get validation batch
            DataPipe-->>Trainer: Return batch data
            Trainer->>Model: eval_step(invar)
            Model-->>Trainer: Return predictions
            Trainer->>Loss: Compute validation loss
        end
        
        Trainer->>LRScheduler: Step scheduler
        Trainer->>save_checkpoint: Save checkpoint if needed
        Trainer->>WandB: Log epoch metrics (optional)
    end
    
    Trainer-->>main: Training complete
    main->>WandB: Finish logging (optional)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants