@@ -737,6 +737,12 @@ def _set_color_source_vec(
737
737
color = np .full (len (element ), na_color )
738
738
return color , color , False
739
739
740
+ # First check if value_to_plot is likely a color specification rather than a column name
741
+ if value_to_plot is not None and _is_color_like (value_to_plot ) and element is not None :
742
+ # User passed a color, not a column name
743
+ color = np .full (len (element ), value_to_plot )
744
+ return None , color , False
745
+
740
746
# Figure out where to get the color from
741
747
origins = _locate_value (
742
748
value_key = value_to_plot ,
@@ -759,9 +765,12 @@ def _set_color_source_vec(
759
765
table_layer = table_layer ,
760
766
)[value_to_plot ]
761
767
762
- # numerical case, return early
763
- # TODO temporary split until refactor is complete
764
- if color_source_vector is not None and not isinstance (color_source_vector .dtype , pd .CategoricalDtype ):
768
+ # Check what type of data we're dealing with
769
+ is_categorical = isinstance (color_source_vector .dtype , pd .CategoricalDtype )
770
+ is_numeric = pd .api .types .is_numeric_dtype (color_source_vector )
771
+
772
+ # If it's numeric data, handle it appropriately
773
+ if is_numeric and not is_categorical :
765
774
if (
766
775
not isinstance (element , GeoDataFrame )
767
776
and isinstance (palette , list )
@@ -775,12 +784,43 @@ def _set_color_source_vec(
775
784
)
776
785
return None , color_source_vector , False
777
786
778
- color_source_vector = pd .Categorical (color_source_vector ) # convert, e.g., `pd.Series`
779
-
780
- # TODO check why table_name is not passed here.
787
+ # For non-numeric, non-categorical data (like strings), convert to categorical
788
+ if not is_categorical :
789
+ try :
790
+ color_source_vector = pd .Categorical (color_source_vector )
791
+ except (ValueError , TypeError ) as e :
792
+ logger .warning (f"Could not convert '{ value_to_plot } ' to categorical: { e } " )
793
+ # Fall back to returning the original values
794
+ return None , color_source_vector , False
795
+
796
+ # At this point color_source_vector should be categorical
797
+
798
+ # Look for predefined colors in the AnnData object
799
+ adata_with_colors = None
800
+ cluster_key = value_to_plot
801
+
802
+ # First check if the table_name is specified
803
+ if table_name is not None and table_name in sdata .tables :
804
+ adata_with_colors = sdata .tables [table_name ]
805
+ adata_with_colors .uns ["spatialdata_key" ] = table_name
806
+
807
+ # If not, but the element is annotated by any table, use that
808
+ elif element_name is not None :
809
+ annotator_tables = get_element_annotators (sdata , element_name )
810
+ if len (annotator_tables ) > 0 :
811
+ # Use the first table that annotates this element
812
+ first_table = next (iter (annotator_tables ))
813
+ adata_with_colors = sdata .tables [first_table ]
814
+ adata_with_colors .uns ["spatialdata_key" ] = first_table
815
+ # If no specific table is found, try using the default table
816
+ elif sdata .table is not None :
817
+ adata_with_colors = sdata .table
818
+ adata_with_colors .uns ["spatialdata_key" ] = "default_table"
819
+
820
+ # Now generate the color mapping using the appropriate AnnData object and cluster_key
781
821
color_mapping = _get_categorical_color_mapping (
782
- adata = sdata [ "table" ] ,
783
- cluster_key = value_to_plot ,
822
+ adata = adata_with_colors ,
823
+ cluster_key = cluster_key ,
784
824
color_source_vector = color_source_vector ,
785
825
cmap_params = cmap_params ,
786
826
alpha = alpha ,
@@ -790,16 +830,27 @@ def _set_color_source_vec(
790
830
render_type = render_type ,
791
831
)
792
832
833
+ # Set categories to match the mapping keys
793
834
color_source_vector = color_source_vector .set_categories (color_mapping .keys ())
794
835
if color_mapping is None :
795
836
raise ValueError ("Unable to create color palette." )
796
837
797
- # do not rename categories, as colors need not be unique
798
- color_vector = color_source_vector .map (color_mapping )
838
+ # Map categorical values to colors
839
+ # Do not rename categories, as colors need not be unique
840
+ try :
841
+ color_vector = color_source_vector .map (color_mapping )
842
+ except (KeyError , TypeError , ValueError ) as e :
843
+ logger .warning (f"Error mapping colors: { e } . Attempting alternate approach." )
844
+ # Try mapping with string conversion
845
+ str_mapping = {str (k ): v for k , v in color_mapping .items ()}
846
+ color_vector = pd .Series (
847
+ [str_mapping .get (str (x ), color_mapping .get ("NaN" , "#d3d3d3" )) for x in color_source_vector ],
848
+ index = color_source_vector .index ,
849
+ )
799
850
800
851
return color_source_vector , color_vector , True
801
852
802
- logger .warning (f"Color key '{ value_to_plot } ' for element '{ element_name } ' not been found, using default colors." )
853
+ logger .warning (f"Color key '{ value_to_plot } ' for element '{ element_name } ' not found, using default colors." )
803
854
color = np .full (sdata [table_name ].n_obs , to_hex (na_color ))
804
855
return color , color , False
805
856
@@ -817,20 +868,35 @@ def _map_color_seg(
817
868
) -> ArrayLike :
818
869
cell_id = np .array (cell_id )
819
870
820
- if pd .api .types .is_categorical_dtype (color_vector .dtype ):
821
- # Case A: users wants to plot a categorical column
871
+ # Safely handle different types of color_vector
872
+ is_categorical = pd .api .types .is_categorical_dtype (getattr (color_vector , "dtype" , None ))
873
+ is_numeric = pd .api .types .is_numeric_dtype (getattr (color_vector , "dtype" , None ))
874
+ is_pandas_series = isinstance (color_vector , pd .Series )
875
+
876
+ # Case A: categorical column
877
+ if is_categorical :
822
878
if np .any (color_source_vector .isna ()):
823
879
cell_id [color_source_vector .isna ()] = 0
824
880
val_im : ArrayLike = map_array (seg .copy (), cell_id , color_vector .codes + 1 )
825
881
cols = colors .to_rgba_array (color_vector .categories )
826
- elif pd .api .types .is_numeric_dtype (color_vector .dtype ):
827
- # Case B: user wants to plot a continous column
828
- if isinstance (color_vector , pd .Series ):
882
+
883
+ # Case B: continuous column
884
+ elif is_numeric :
885
+ if is_pandas_series :
829
886
color_vector = color_vector .to_numpy ()
830
887
cols = cmap_params .cmap (cmap_params .norm (color_vector ))
831
888
val_im = map_array (seg .copy (), cell_id , cell_id )
889
+
890
+ # Case C & D: Other cases (could be strings, or hex colors)
832
891
else :
833
- # Case C: User didn't specify any colors
892
+ # Get the first color safely, regardless of index structure
893
+ first_color = None
894
+ if is_pandas_series and len (color_vector ) > 0 :
895
+ first_color = color_vector .iloc [0 ]
896
+ elif not is_pandas_series and len (color_vector ) > 0 :
897
+ first_color = color_vector [0 ]
898
+
899
+ # Case C: Using default colors with random generation
834
900
if color_source_vector is not None and (
835
901
set (color_vector ) == set (color_source_vector )
836
902
and len (set (color_vector )) == 1
@@ -840,14 +906,31 @@ def _map_color_seg(
840
906
val_im = map_array (seg .copy (), cell_id , cell_id )
841
907
RNG = default_rng (42 )
842
908
cols = RNG .random ((len (color_vector ), 3 ))
909
+
910
+ # Case D: User specified explicit colors or we're using defaults
843
911
else :
844
- # Case D: User didn't specify a column to color by, but modified the na_color
845
912
val_im = map_array (seg .copy (), cell_id , cell_id )
846
- if "#" in str (color_vector [0 ]):
847
- # we have hex colors
848
- assert all (_is_color_like (c ) for c in color_vector ), "Not all values are color-like."
849
- cols = colors .to_rgba_array (color_vector )
913
+
914
+ # Check if we're dealing with hex colors
915
+ if first_color is not None and isinstance (first_color , str ) and "#" in first_color :
916
+ # We have hex colors
917
+ all_is_color = True
918
+ for c in color_vector :
919
+ if not _is_color_like (c ):
920
+ all_is_color = False
921
+ break
922
+
923
+ if all_is_color :
924
+ try :
925
+ cols = colors .to_rgba_array (color_vector )
926
+ except ValueError as e :
927
+ logger .warning (f"Error converting colors: { e } , falling back to default colormap" )
928
+ cols = cmap_params .cmap (cmap_params .norm (np .arange (len (color_vector ))))
929
+ else :
930
+ # Fall back to colormap
931
+ cols = cmap_params .cmap (cmap_params .norm (color_vector ))
850
932
else :
933
+ # Use the colormap
851
934
cols = cmap_params .cmap (cmap_params .norm (color_vector ))
852
935
853
936
if seg_erosionpx is not None :
@@ -879,21 +962,118 @@ def _generate_base_categorial_color_mapping(
879
962
na_color : ColorLike ,
880
963
cmap_params : CmapParams | None = None ,
881
964
) -> Mapping [str , str ]:
882
- if adata is not None and cluster_key in adata .uns and f"{ cluster_key } _colors" in adata .uns :
883
- colors = adata .uns [f"{ cluster_key } _colors" ]
884
- categories = color_source_vector .categories .tolist () + ["NaN" ]
885
- if "#" not in na_color :
886
- # should be unreachable, but just for safety
887
- raise ValueError ("Expected `na_color` to be a hex color, but got a non-hex color." )
888
-
889
- colors = [to_hex (to_rgba (color )[:3 ]) for color in colors ]
890
- na_color = to_hex (to_rgba (na_color )[:3 ])
965
+ color_key = f"{ cluster_key } _colors"
891
966
892
- if na_color and len (categories ) > len (colors ):
893
- return dict (zip (categories , colors + [na_color ], strict = True ))
894
-
895
- return dict (zip (categories , colors , strict = True ))
967
+ # Break long string template into multiple lines to fix E501 error
968
+ color_found_in_uns_msg_template = (
969
+ "Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. "
970
+ "If this is unexpected, please delete the column from your AnnData object."
971
+ )
896
972
973
+ # Check if we have a valid AnnData and if the color key exists in uns
974
+ if adata is not None and cluster_key is not None :
975
+ # Check for direct color dictionary in uns (e.g., {'A': '#FF5733', 'B': '#3498DB'})
976
+ if cluster_key in adata .uns and isinstance (adata .uns [cluster_key ], dict ):
977
+ # We have a direct color mapping dictionary
978
+ color_dict = adata .uns [cluster_key ]
979
+ table_name = getattr (adata , "uns" , {}).get ("spatialdata_key" , "" )
980
+ if table_name :
981
+ # Format the template with the actual values
982
+ logger .info (color_found_in_uns_msg_template .format (cluster = cluster_key , table = table_name ))
983
+
984
+ # Ensure all values are hex colors
985
+ for k , v in color_dict .items ():
986
+ if isinstance (v , str ) and not v .startswith ("#" ):
987
+ color_dict [k ] = to_hex (to_rgba (v ))
988
+
989
+ # Add NA color if missing
990
+ categories = color_source_vector .categories .tolist ()
991
+ na_color_hex = to_hex (to_rgba (na_color )[:3 ])
992
+
993
+ return {cat : color_dict .get (str (cat ), color_dict .get (cat , na_color_hex )) for cat in categories }
994
+
995
+ if color_key in adata .uns :
996
+ colors = adata .uns [color_key ]
997
+ table_name = getattr (adata , "uns" , {}).get ("spatialdata_key" , "" )
998
+ if table_name :
999
+ if isinstance (colors , dict ):
1000
+ # Format the template with the actual values
1001
+ logger .info (color_found_in_uns_msg_template .format (cluster = cluster_key , table = table_name ))
1002
+ else :
1003
+ # Format the template with the actual values
1004
+ logger .info (color_found_in_uns_msg_template .format (cluster = cluster_key , table = table_name ))
1005
+
1006
+ # Ensure colors are in hex format
1007
+ if isinstance (colors , list ):
1008
+ colors = [to_hex (to_rgba (color )[:3 ]) for color in colors ]
1009
+ categories = color_source_vector .categories .tolist ()
1010
+
1011
+ # Handle NaN values
1012
+ na_color_hex = to_hex (to_rgba (na_color )[:3 ])
1013
+ if "NaN" not in categories :
1014
+ categories .append ("NaN" )
1015
+
1016
+ # Make sure we have enough colors
1017
+ if len (colors ) < len (categories ) - 1 : # -1 for NaN
1018
+ logger .warning (
1019
+ f"Not enough colors in { color_key } ({ len (colors )} ) for all categories ({ len (categories ) - 1 } ). "
1020
+ "Some categories will use default colors."
1021
+ )
1022
+ # Extend with default colors or duplicate the last color
1023
+ colors .extend ([na_color_hex ] * (len (categories ) - 1 - len (colors )))
1024
+
1025
+ # Create mapping with NaN color
1026
+ return dict (zip (categories , colors + [na_color_hex ], strict = False ))
1027
+
1028
+ if isinstance (colors , np .ndarray ):
1029
+ # Convert numpy array to list of hex colors
1030
+ colors = [to_hex (to_rgba (color )[:3 ]) for color in colors ]
1031
+ categories = color_source_vector .categories .tolist ()
1032
+
1033
+ # Handle NaN values
1034
+ na_color_hex = to_hex (to_rgba (na_color )[:3 ])
1035
+ if "NaN" not in categories :
1036
+ categories .append ("NaN" )
1037
+
1038
+ # Make sure we have enough colors
1039
+ if len (colors ) < len (categories ) - 1 : # -1 for NaN
1040
+ logger .warning (
1041
+ f"Not enough colors in { color_key } ({ len (colors )} ) for all categories ({ len (categories ) - 1 } ). "
1042
+ "Some categories will use default colors."
1043
+ )
1044
+ # Extend with default colors
1045
+ colors .extend ([na_color_hex ] * (len (categories ) - 1 - len (colors )))
1046
+
1047
+ # Create mapping with NaN color
1048
+ return dict (zip (categories , colors + [na_color_hex ], strict = False ))
1049
+
1050
+ # Dictionary format - direct color mapping
1051
+ if isinstance (colors , dict ):
1052
+ # Ensure all values are hex colors
1053
+ for k , v in colors .items ():
1054
+ if isinstance (v , str ) and not v .startswith ("#" ):
1055
+ colors [k ] = to_hex (to_rgba (v ))
1056
+
1057
+ # Get categories and handle NaN
1058
+ categories = color_source_vector .categories .tolist ()
1059
+ na_color_hex = to_hex (to_rgba (na_color )[:3 ])
1060
+
1061
+ # Try to match color keys to categories, accounting for string/categorical differences
1062
+ result = {}
1063
+ for cat in categories :
1064
+ # Try direct match first
1065
+ if cat in colors :
1066
+ result [cat ] = colors [cat ]
1067
+ # Then try string conversion - handles int/string mismatches
1068
+ elif str (cat ) in colors :
1069
+ result [cat ] = colors [str (cat )]
1070
+ else :
1071
+ result [cat ] = na_color_hex
1072
+
1073
+ return result
1074
+
1075
+ # If we reach here, we didn't find usable colors in uns, use default color mapping
1076
+ logger .info (f"No colors found for '{ cluster_key } ' in AnnData.uns, using default colors" )
897
1077
return _get_default_categorial_color_mapping (color_source_vector = color_source_vector , cmap_params = cmap_params )
898
1078
899
1079
@@ -1007,13 +1187,23 @@ def _maybe_set_colors(
1007
1187
try :
1008
1188
if palette is not None :
1009
1189
raise KeyError ("Unable to copy the palette when there was other explicitly specified." )
1010
- target .uns [color_key ] = source .uns [color_key ]
1190
+
1191
+ # First check if source has the colors
1192
+ if color_key in source .uns :
1193
+ logger .info (f"Copying color information for '{ key } ' from source to target AnnData" )
1194
+ target .uns [color_key ] = source .uns [color_key ]
1195
+ # Then check if the base key has colors (direct dict mapping)
1196
+ elif key in source .uns and isinstance (source .uns [key ], dict ):
1197
+ logger .info (f"Copying direct color mappings for '{ key } ' from source to target AnnData" )
1198
+ target .uns [key ] = source .uns [key ]
1199
+ else :
1200
+ raise KeyError (f"No color information found for '{ key } ' in source AnnData" )
1201
+
1011
1202
except KeyError :
1012
1203
if isinstance (palette , str ):
1013
1204
palette = ListedColormap ([palette ])
1014
1205
if isinstance (palette , ListedColormap ): # `scanpy` requires it
1015
1206
palette = cycler (color = palette .colors )
1016
- palette = None
1017
1207
add_colors_for_categorical_sample_annotation (target , key = key , force_update_colors = True , palette = palette )
1018
1208
1019
1209
0 commit comments