Skip to content

Commit af48d51

Browse files
committed
merge
1 parent fd11c33 commit af48d51

File tree

1 file changed

+227
-37
lines changed

1 file changed

+227
-37
lines changed

src/spatialdata_plot/pl/utils.py

Lines changed: 227 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,12 @@ def _set_color_source_vec(
737737
color = np.full(len(element), na_color)
738738
return color, color, False
739739

740+
# First check if value_to_plot is likely a color specification rather than a column name
741+
if value_to_plot is not None and _is_color_like(value_to_plot) and element is not None:
742+
# User passed a color, not a column name
743+
color = np.full(len(element), value_to_plot)
744+
return None, color, False
745+
740746
# Figure out where to get the color from
741747
origins = _locate_value(
742748
value_key=value_to_plot,
@@ -759,9 +765,12 @@ def _set_color_source_vec(
759765
table_layer=table_layer,
760766
)[value_to_plot]
761767

762-
# numerical case, return early
763-
# TODO temporary split until refactor is complete
764-
if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype):
768+
# Check what type of data we're dealing with
769+
is_categorical = isinstance(color_source_vector.dtype, pd.CategoricalDtype)
770+
is_numeric = pd.api.types.is_numeric_dtype(color_source_vector)
771+
772+
# If it's numeric data, handle it appropriately
773+
if is_numeric and not is_categorical:
765774
if (
766775
not isinstance(element, GeoDataFrame)
767776
and isinstance(palette, list)
@@ -775,12 +784,43 @@ def _set_color_source_vec(
775784
)
776785
return None, color_source_vector, False
777786

778-
color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series`
779-
780-
# TODO check why table_name is not passed here.
787+
# For non-numeric, non-categorical data (like strings), convert to categorical
788+
if not is_categorical:
789+
try:
790+
color_source_vector = pd.Categorical(color_source_vector)
791+
except (ValueError, TypeError) as e:
792+
logger.warning(f"Could not convert '{value_to_plot}' to categorical: {e}")
793+
# Fall back to returning the original values
794+
return None, color_source_vector, False
795+
796+
# At this point color_source_vector should be categorical
797+
798+
# Look for predefined colors in the AnnData object
799+
adata_with_colors = None
800+
cluster_key = value_to_plot
801+
802+
# First check if the table_name is specified
803+
if table_name is not None and table_name in sdata.tables:
804+
adata_with_colors = sdata.tables[table_name]
805+
adata_with_colors.uns["spatialdata_key"] = table_name
806+
807+
# If not, but the element is annotated by any table, use that
808+
elif element_name is not None:
809+
annotator_tables = get_element_annotators(sdata, element_name)
810+
if len(annotator_tables) > 0:
811+
# Use the first table that annotates this element
812+
first_table = next(iter(annotator_tables))
813+
adata_with_colors = sdata.tables[first_table]
814+
adata_with_colors.uns["spatialdata_key"] = first_table
815+
# If no specific table is found, try using the default table
816+
elif sdata.table is not None:
817+
adata_with_colors = sdata.table
818+
adata_with_colors.uns["spatialdata_key"] = "default_table"
819+
820+
# Now generate the color mapping using the appropriate AnnData object and cluster_key
781821
color_mapping = _get_categorical_color_mapping(
782-
adata=sdata["table"],
783-
cluster_key=value_to_plot,
822+
adata=adata_with_colors,
823+
cluster_key=cluster_key,
784824
color_source_vector=color_source_vector,
785825
cmap_params=cmap_params,
786826
alpha=alpha,
@@ -790,16 +830,27 @@ def _set_color_source_vec(
790830
render_type=render_type,
791831
)
792832

833+
# Set categories to match the mapping keys
793834
color_source_vector = color_source_vector.set_categories(color_mapping.keys())
794835
if color_mapping is None:
795836
raise ValueError("Unable to create color palette.")
796837

797-
# do not rename categories, as colors need not be unique
798-
color_vector = color_source_vector.map(color_mapping)
838+
# Map categorical values to colors
839+
# Do not rename categories, as colors need not be unique
840+
try:
841+
color_vector = color_source_vector.map(color_mapping)
842+
except (KeyError, TypeError, ValueError) as e:
843+
logger.warning(f"Error mapping colors: {e}. Attempting alternate approach.")
844+
# Try mapping with string conversion
845+
str_mapping = {str(k): v for k, v in color_mapping.items()}
846+
color_vector = pd.Series(
847+
[str_mapping.get(str(x), color_mapping.get("NaN", "#d3d3d3")) for x in color_source_vector],
848+
index=color_source_vector.index,
849+
)
799850

800851
return color_source_vector, color_vector, True
801852

802-
logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.")
853+
logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not found, using default colors.")
803854
color = np.full(sdata[table_name].n_obs, to_hex(na_color))
804855
return color, color, False
805856

@@ -817,20 +868,35 @@ def _map_color_seg(
817868
) -> ArrayLike:
818869
cell_id = np.array(cell_id)
819870

820-
if pd.api.types.is_categorical_dtype(color_vector.dtype):
821-
# Case A: users wants to plot a categorical column
871+
# Safely handle different types of color_vector
872+
is_categorical = pd.api.types.is_categorical_dtype(getattr(color_vector, "dtype", None))
873+
is_numeric = pd.api.types.is_numeric_dtype(getattr(color_vector, "dtype", None))
874+
is_pandas_series = isinstance(color_vector, pd.Series)
875+
876+
# Case A: categorical column
877+
if is_categorical:
822878
if np.any(color_source_vector.isna()):
823879
cell_id[color_source_vector.isna()] = 0
824880
val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
825881
cols = colors.to_rgba_array(color_vector.categories)
826-
elif pd.api.types.is_numeric_dtype(color_vector.dtype):
827-
# Case B: user wants to plot a continous column
828-
if isinstance(color_vector, pd.Series):
882+
883+
# Case B: continuous column
884+
elif is_numeric:
885+
if is_pandas_series:
829886
color_vector = color_vector.to_numpy()
830887
cols = cmap_params.cmap(cmap_params.norm(color_vector))
831888
val_im = map_array(seg.copy(), cell_id, cell_id)
889+
890+
# Case C & D: Other cases (could be strings, or hex colors)
832891
else:
833-
# Case C: User didn't specify any colors
892+
# Get the first color safely, regardless of index structure
893+
first_color = None
894+
if is_pandas_series and len(color_vector) > 0:
895+
first_color = color_vector.iloc[0]
896+
elif not is_pandas_series and len(color_vector) > 0:
897+
first_color = color_vector[0]
898+
899+
# Case C: Using default colors with random generation
834900
if color_source_vector is not None and (
835901
set(color_vector) == set(color_source_vector)
836902
and len(set(color_vector)) == 1
@@ -840,14 +906,31 @@ def _map_color_seg(
840906
val_im = map_array(seg.copy(), cell_id, cell_id)
841907
RNG = default_rng(42)
842908
cols = RNG.random((len(color_vector), 3))
909+
910+
# Case D: User specified explicit colors or we're using defaults
843911
else:
844-
# Case D: User didn't specify a column to color by, but modified the na_color
845912
val_im = map_array(seg.copy(), cell_id, cell_id)
846-
if "#" in str(color_vector[0]):
847-
# we have hex colors
848-
assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like."
849-
cols = colors.to_rgba_array(color_vector)
913+
914+
# Check if we're dealing with hex colors
915+
if first_color is not None and isinstance(first_color, str) and "#" in first_color:
916+
# We have hex colors
917+
all_is_color = True
918+
for c in color_vector:
919+
if not _is_color_like(c):
920+
all_is_color = False
921+
break
922+
923+
if all_is_color:
924+
try:
925+
cols = colors.to_rgba_array(color_vector)
926+
except ValueError as e:
927+
logger.warning(f"Error converting colors: {e}, falling back to default colormap")
928+
cols = cmap_params.cmap(cmap_params.norm(np.arange(len(color_vector))))
929+
else:
930+
# Fall back to colormap
931+
cols = cmap_params.cmap(cmap_params.norm(color_vector))
850932
else:
933+
# Use the colormap
851934
cols = cmap_params.cmap(cmap_params.norm(color_vector))
852935

853936
if seg_erosionpx is not None:
@@ -879,21 +962,118 @@ def _generate_base_categorial_color_mapping(
879962
na_color: ColorLike,
880963
cmap_params: CmapParams | None = None,
881964
) -> Mapping[str, str]:
882-
if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns:
883-
colors = adata.uns[f"{cluster_key}_colors"]
884-
categories = color_source_vector.categories.tolist() + ["NaN"]
885-
if "#" not in na_color:
886-
# should be unreachable, but just for safety
887-
raise ValueError("Expected `na_color` to be a hex color, but got a non-hex color.")
888-
889-
colors = [to_hex(to_rgba(color)[:3]) for color in colors]
890-
na_color = to_hex(to_rgba(na_color)[:3])
965+
color_key = f"{cluster_key}_colors"
891966

892-
if na_color and len(categories) > len(colors):
893-
return dict(zip(categories, colors + [na_color], strict=True))
894-
895-
return dict(zip(categories, colors, strict=True))
967+
# Break long string template into multiple lines to fix E501 error
968+
color_found_in_uns_msg_template = (
969+
"Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. "
970+
"If this is unexpected, please delete the column from your AnnData object."
971+
)
896972

973+
# Check if we have a valid AnnData and if the color key exists in uns
974+
if adata is not None and cluster_key is not None:
975+
# Check for direct color dictionary in uns (e.g., {'A': '#FF5733', 'B': '#3498DB'})
976+
if cluster_key in adata.uns and isinstance(adata.uns[cluster_key], dict):
977+
# We have a direct color mapping dictionary
978+
color_dict = adata.uns[cluster_key]
979+
table_name = getattr(adata, "uns", {}).get("spatialdata_key", "")
980+
if table_name:
981+
# Format the template with the actual values
982+
logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name))
983+
984+
# Ensure all values are hex colors
985+
for k, v in color_dict.items():
986+
if isinstance(v, str) and not v.startswith("#"):
987+
color_dict[k] = to_hex(to_rgba(v))
988+
989+
# Add NA color if missing
990+
categories = color_source_vector.categories.tolist()
991+
na_color_hex = to_hex(to_rgba(na_color)[:3])
992+
993+
return {cat: color_dict.get(str(cat), color_dict.get(cat, na_color_hex)) for cat in categories}
994+
995+
if color_key in adata.uns:
996+
colors = adata.uns[color_key]
997+
table_name = getattr(adata, "uns", {}).get("spatialdata_key", "")
998+
if table_name:
999+
if isinstance(colors, dict):
1000+
# Format the template with the actual values
1001+
logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name))
1002+
else:
1003+
# Format the template with the actual values
1004+
logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name))
1005+
1006+
# Ensure colors are in hex format
1007+
if isinstance(colors, list):
1008+
colors = [to_hex(to_rgba(color)[:3]) for color in colors]
1009+
categories = color_source_vector.categories.tolist()
1010+
1011+
# Handle NaN values
1012+
na_color_hex = to_hex(to_rgba(na_color)[:3])
1013+
if "NaN" not in categories:
1014+
categories.append("NaN")
1015+
1016+
# Make sure we have enough colors
1017+
if len(colors) < len(categories) - 1: # -1 for NaN
1018+
logger.warning(
1019+
f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). "
1020+
"Some categories will use default colors."
1021+
)
1022+
# Extend with default colors or duplicate the last color
1023+
colors.extend([na_color_hex] * (len(categories) - 1 - len(colors)))
1024+
1025+
# Create mapping with NaN color
1026+
return dict(zip(categories, colors + [na_color_hex], strict=False))
1027+
1028+
if isinstance(colors, np.ndarray):
1029+
# Convert numpy array to list of hex colors
1030+
colors = [to_hex(to_rgba(color)[:3]) for color in colors]
1031+
categories = color_source_vector.categories.tolist()
1032+
1033+
# Handle NaN values
1034+
na_color_hex = to_hex(to_rgba(na_color)[:3])
1035+
if "NaN" not in categories:
1036+
categories.append("NaN")
1037+
1038+
# Make sure we have enough colors
1039+
if len(colors) < len(categories) - 1: # -1 for NaN
1040+
logger.warning(
1041+
f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). "
1042+
"Some categories will use default colors."
1043+
)
1044+
# Extend with default colors
1045+
colors.extend([na_color_hex] * (len(categories) - 1 - len(colors)))
1046+
1047+
# Create mapping with NaN color
1048+
return dict(zip(categories, colors + [na_color_hex], strict=False))
1049+
1050+
# Dictionary format - direct color mapping
1051+
if isinstance(colors, dict):
1052+
# Ensure all values are hex colors
1053+
for k, v in colors.items():
1054+
if isinstance(v, str) and not v.startswith("#"):
1055+
colors[k] = to_hex(to_rgba(v))
1056+
1057+
# Get categories and handle NaN
1058+
categories = color_source_vector.categories.tolist()
1059+
na_color_hex = to_hex(to_rgba(na_color)[:3])
1060+
1061+
# Try to match color keys to categories, accounting for string/categorical differences
1062+
result = {}
1063+
for cat in categories:
1064+
# Try direct match first
1065+
if cat in colors:
1066+
result[cat] = colors[cat]
1067+
# Then try string conversion - handles int/string mismatches
1068+
elif str(cat) in colors:
1069+
result[cat] = colors[str(cat)]
1070+
else:
1071+
result[cat] = na_color_hex
1072+
1073+
return result
1074+
1075+
# If we reach here, we didn't find usable colors in uns, use default color mapping
1076+
logger.info(f"No colors found for '{cluster_key}' in AnnData.uns, using default colors")
8971077
return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params)
8981078

8991079

@@ -1007,13 +1187,23 @@ def _maybe_set_colors(
10071187
try:
10081188
if palette is not None:
10091189
raise KeyError("Unable to copy the palette when there was other explicitly specified.")
1010-
target.uns[color_key] = source.uns[color_key]
1190+
1191+
# First check if source has the colors
1192+
if color_key in source.uns:
1193+
logger.info(f"Copying color information for '{key}' from source to target AnnData")
1194+
target.uns[color_key] = source.uns[color_key]
1195+
# Then check if the base key has colors (direct dict mapping)
1196+
elif key in source.uns and isinstance(source.uns[key], dict):
1197+
logger.info(f"Copying direct color mappings for '{key}' from source to target AnnData")
1198+
target.uns[key] = source.uns[key]
1199+
else:
1200+
raise KeyError(f"No color information found for '{key}' in source AnnData")
1201+
10111202
except KeyError:
10121203
if isinstance(palette, str):
10131204
palette = ListedColormap([palette])
10141205
if isinstance(palette, ListedColormap): # `scanpy` requires it
10151206
palette = cycler(color=palette.colors)
1016-
palette = None
10171207
add_colors_for_categorical_sample_annotation(target, key=key, force_update_colors=True, palette=palette)
10181208

10191209

0 commit comments

Comments
 (0)