Skip to content

Conversation

ligerlac
Copy link

@ligerlac ligerlac commented May 21, 2025

This fixes a bug in the jax_backend.tile() method. Consider the following minimal example:

import jax.numpy as jnp
import pyhf

pyhf.set_backend("jax", default=True)  # works without this line

spec = {
    "channels": [
        {
            "name": "singlechannel",
            "samples": [
                {
                    "name": "signal",
                    "data": jnp.array([0.0, 0.0, 0.0]),
                    "modifiers": [
                        {
                            "name": "mu",
                            "type": "normfactor",
                            "data": None,
                        },
                    ],
                },
            ],
        },
    ],
}

my_model = pyhf.Model(spec, validate=False)

The last line fails with TypeError: tile requires ndarray or scalar arguments, got <class 'list'> at position 0.. However, it works fine when using the numpy backend. The problem stems from differences between np.tile and jnp.tile:

import numpy as np
import jax.numpy as jnp

tensor_in = [[[0, 1, 2]]]
repeats = (0, 1, 1)

np.tile(tensor_in, repeats)  # works fine
jnp.tile(tensor_in, repeats)  # fails with same error message as above
jnp.tile(jnp.array(tensor_in), repeats)  # works fine

Unlike jnp.tile, np.tile implicitly converts the input to the correct type.
This PR ensures tensor_in is a jnp.array to make the behaviour of numpy_backend.tile() and jax_backend.tile() consistent.

@matthewfeickert matthewfeickert changed the title Bugfix: make numpy_backend.tile() and jax_backend.tile() consistent fix: make numpy_backend.tile() and jax_backend.tile() consistent May 22, 2025
@matthewfeickert matthewfeickert added the fix A bug fix label May 22, 2025
@matthewfeickert
Copy link
Member

@ligerlac Thanks for the PR. Today I have been clawing myself out of travel related time dependent TODOs, but I can review this on Thursday (2025-05-22).

I haven't looked/thought about this yet, but I assume that this isn't something unique to tile but more generic to how things are being dealt with in spec validation of pyhf.Model (though maybe if I actually think about the PR the reason would be clear to me). Is this a general solution or more of a targeted use patch?

@matthewfeickert matthewfeickert requested review from matthewfeickert, a team and kratsg and removed request for a team May 22, 2025 06:29
@ligerlac
Copy link
Author

It's more of a patch. You are right, the problem is not unqiue to tile(). There are similar problems with concatenate():

pyhf.set_backend("jax", default=True)  # works without this line

spec = {
    "channels": [
        {
            "name": "singlechannel",
            "samples": [
                {
                    "name": "signal",
                    "data": jnp.array([0.0, 0.0, 0.0]),
                    "modifiers": [
                        {
                            "name": "mu",
                            "type": "normfactor",
                            "data": None,
                        }, 
                    ],
                },
                {
                    "name": "background",
                    "data": jnp.array([0.0, 0.0, 0.0]),  # dummy data
                    "modifiers": [
                        {
                            "name": "correlated_bkg_uncertainty",
                            "type": "histosys",
                            "data": {
                                "hi_data": jnp.array([0.0, 0.0, 0.0]),
                                "lo_data": jnp.array([0.0, 0.0, 0.0]),
                            },
                        },
                    ],
                },
            ],
        },
    ],
}

my_model = pyhf.Model(spec, validate=False)

last line fails with TypeError: concatenate requires ndarray or scalar arguments, got <class 'list'> at position 0.. Again, this can be tracked down to a difference between np.concatenate() and jnp.concatenate():

np.concatenate([[True, True, True]])  # works
jnp.concatenate([[True, True, True]])  # fails
jnp.concatenate(jnp.array([[True, True, True]]))  # works

We could also patch that in the jax backend. But I guess a more elegant solution would be to make sure that each backend is only receiving arguments of the correct type by calling tensorlib.astensor in all the right places (like the _precompute() methods). I'll try to find some time over the weekend to have another look at this (including missings tests).

@ligerlac ligerlac marked this pull request as draft May 23, 2025 09:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix A bug fix
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

2 participants