From 404a9939683f84c5a2942c4e3d5b7cf29c2ef23e Mon Sep 17 00:00:00 2001 From: Douglas Orr Date: Thu, 10 Jul 2025 21:38:18 +0100 Subject: [PATCH 1/3] Fix isort ignore for _version.py --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5997432..c9d83e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,3 +55,7 @@ version = {attr = "unit_scaling._version.__version__"} [tool.setuptools_scm] version_file = "unit_scaling/_version.py" + +[tool.isort] +profile = "black" +extend_skip = ["unit_scaling/_version.py"] From 6f842436de2eb4a174229b527ec27582e058b634 Mon Sep 17 00:00:00 2001 From: Douglas Orr Date: Thu, 10 Jul 2025 22:00:09 +0100 Subject: [PATCH 2/3] Make analysis dependencies optional --- pyproject.toml | 8 ++++++-- unit_scaling/__init__.py | 2 -- unit_scaling/analysis.py | 24 +++++++++++++++--------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c9d83e4..9f43f06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,11 +28,9 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "datasets", "docstring-parser", "einops", "numpy<2.0.0", - "seaborn", "tabulate", "torch>=2.2", ] @@ -46,6 +44,12 @@ dynamic = ["version"] [project.optional-dependencies] dev = ["check-manifest"] test = ["pytest"] +analysis = [ + "datasets", + "matplotlib", + "pandas", + "seaborn", +] [tool.setuptools] packages = ["unit_scaling", "unit_scaling.core", "unit_scaling.transforms"] diff --git a/unit_scaling/__init__.py b/unit_scaling/__init__.py index 42357ae..cab0a46 100644 --- a/unit_scaling/__init__.py +++ b/unit_scaling/__init__.py @@ -26,7 +26,6 @@ TransformerLayer, ) from ._version import __version__ -from .analysis import visualiser from .core.functional import transformer_residual_scaling_rule from .parameter import MupType, Parameter @@ -58,6 +57,5 @@ # Functions "Parameter", "transformer_residual_scaling_rule", - "visualiser", "__version__", ] diff --git a/unit_scaling/analysis.py b/unit_scaling/analysis.py index 1b47001..6ddb092 100644 --- a/unit_scaling/analysis.py +++ b/unit_scaling/analysis.py @@ -8,15 +8,21 @@ from math import isnan from typing import TYPE_CHECKING, Any, List, Optional, Tuple -import matplotlib -import matplotlib.colors -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns # type: ignore[import-untyped] -from datasets import load_dataset # type: ignore[import-untyped] -from torch import Tensor, nn -from torch.fx.graph import Graph -from torch.fx.node import Node +try: + import matplotlib + import matplotlib.colors + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns # type: ignore[import-untyped] + from datasets import load_dataset # type: ignore[import-untyped] + from torch import Tensor, nn + from torch.fx.graph import Graph + from torch.fx.node import Node +except ImportError as e: + raise ImportError( + "Optional dependencies for `unit_scaling.analysis` are missing." + " Please install `unit-scaling[analysis]`" + ) from e from ._internal_utils import generate__all__ from .transforms import ( From 3efc36664e50c93e6f5f0a2a5869dda52a7005a7 Mon Sep 17 00:00:00 2001 From: Douglas Orr Date: Thu, 10 Jul 2025 22:02:38 +0100 Subject: [PATCH 3/3] Remove numpy<2.0.0 pin --- pyproject.toml | 1 - requirements-dev.txt | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9f43f06..eb47592 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ classifiers = [ dependencies = [ "docstring-parser", "einops", - "numpy<2.0.0", "tabulate", "torch>=2.2", ] diff --git a/requirements-dev.txt b/requirements-dev.txt index e8946f6..3318a98 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,7 @@ datasets==3.1.0 docstring-parser==0.16 einops==0.8.0 -numpy==1.26.4 +numpy==2.2.6 seaborn==0.13.2 tabulate==0.9.0 torch==2.5.1+cpu