Skip to content

IndexError in get_one_hot #5

@NinelK

Description

@NinelK

Hi!

I'm trying to run

idSysF = DPADModel()
args = DPADModel.prepare_args(selectedMethodCode)
idSysF.fit(yTrain.T, Z=zTrain.T, nx=n_factors, n1=n_beh_factors, epochs=epochs, **args)
zTestPredF, yTestPredF, xTestPredF = idSysF.predict(yTest) # Run inference to generate predictions

It works fine on most datasets, but for one of the datasets, it throws:

Traceback (most recent call last):                                                                                                             
  File "/disk/scratch/nkudryas/BAND-torch/scripts/chewie/run_DPAD.py", line 76, in <module>                                                    
    idSysF.fit(yTrain.T, Z=zTrain.T, nx=n_factors, n1=n_beh_factors, epochs=epochs, **args)                                                    
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/DPADModel.py", line 1661, in fit                         
    history1_Cy = model1_Cy.fit(                                                                                                               
                  ^^^^^^^^^^^^^^                                                                                                               
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/RegressionModel.py", line 443, in fit                    
    inputs_val, outputs_val = prep_IO_data(                                                                                                    
                              ^^^^^^^^^^^^^                                                                                                    
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/RegressionModel.py", line 418, in prep_IO_data           
    outputs = get_one_hot(np.array(outputs, dtype=int), self.num_classes)                                                                      
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                      
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/tools/tools.py", line 452, in get_one_hot                
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]                                                                                    
          ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                    
IndexError: index 5 is out of bounds for axis 0 with size 5 

I added an Assertion to prevent this indexing error in get_one_hot and get an insight into the values that go into it, and it looks like an 'off by one' issue:

  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/RegressionModel.py", line 418, in prep_IO_data           
    outputs = get_one_hot(np.array(outputs, dtype=int), self.num_classes)                                                                      
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                      
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/tools/tools.py", line 453, in get_one_hot                
    assert np.all((targets_array >= 0) & (targets_array < nb_classes)), f"Targets must be integers from 0 to {nb_classes-1}, not {np.unique(tar
gets_array)}"                                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                         
AssertionError: Targets must be integers from 0 to 4, not [0 1 2 3 4 5]

How should I best fix it? I tried increasing num_classes by 1 in line 1150 of DPADModel.py: Cy2_args["num_classes"] = len(YClasses), but it causes some shape mismatch errors.
Why is num_classes < targets ? Should targets be corrected somehow?

Many thanks,
Nina

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions