Skip to content

Commit 1b359b5

Browse files
Suggested changes
Co-Authored-By: Ryan Dick <[email protected]>
1 parent 9a14202 commit 1b359b5

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

invokeai/backend/stable_diffusion/denoise_context.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,47 +83,47 @@ class DenoiseContext:
8383
unet: Optional[UNet2DConditionModel] = None
8484

8585
# Current state of latent-space image in denoising process.
86-
# None until `pre_denoise_loop` callback.
86+
# None until `PRE_DENOISE_LOOP` callback.
8787
# Shape: [batch, channels, latent_height, latent_width]
8888
latents: Optional[torch.Tensor] = None
8989

9090
# Current denoising step index.
91-
# None until `pre_step` callback.
91+
# None until `PRE_STEP` callback.
9292
step_index: Optional[int] = None
9393

9494
# Current denoising step timestep.
95-
# None until `pre_step` callback.
95+
# None until `PRE_STEP` callback.
9696
timestep: Optional[torch.Tensor] = None
9797

9898
# Arguments which will be passed to UNet model.
99-
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
99+
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
100100
unet_kwargs: Optional[UNetKwargs] = None
101101

102102
# SchedulerOutput class returned from step function(normally, generated by scheduler).
103-
# Supposed to be used only in `post_step` callback, otherwise can be None.
103+
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
104104
step_output: Optional[SchedulerOutput] = None
105105

106106
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
107-
# Available in events inside step(between `pre_step` and `post_stop`).
107+
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
108108
# Shape: [batch, channels, latent_height, latent_width]
109109
latent_model_input: Optional[torch.Tensor] = None
110110

111111
# [TMP] Defines on which conditionings current unet call will be runned.
112-
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
112+
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
113113
conditioning_mode: Optional[ConditioningMode] = None
114114

115115
# [TMP] Noise predictions from negative conditioning.
116-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
116+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
117117
# Shape: [batch, channels, latent_height, latent_width]
118118
negative_noise_pred: Optional[torch.Tensor] = None
119119

120120
# [TMP] Noise predictions from positive conditioning.
121-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
121+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
122122
# Shape: [batch, channels, latent_height, latent_width]
123123
positive_noise_pred: Optional[torch.Tensor] = None
124124

125125
# Combined noise prediction from passed conditionings.
126-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
126+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
127127
# Shape: [batch, channels, latent_height, latent_width]
128128
noise_pred: Optional[torch.Tensor] = None
129129

invokeai/backend/stable_diffusion/extensions/rescale_cfg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class RescaleCFGExt(ExtensionBase):
1515
def __init__(self, rescale_multiplier: float):
1616
super().__init__()
17-
self.rescale_multiplier = rescale_multiplier
17+
self._rescale_multiplier = rescale_multiplier
1818

1919
@staticmethod
2020
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
@@ -28,9 +28,9 @@ def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, m
2828

2929
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
3030
def rescale_noise_pred(self, ctx: DenoiseContext):
31-
if self.rescale_multiplier > 0:
31+
if self._rescale_multiplier > 0:
3232
ctx.noise_pred = self._rescale_cfg(
3333
ctx.noise_pred,
3434
ctx.positive_noise_pred,
35-
self.rescale_multiplier,
35+
self._rescale_multiplier,
3636
)

0 commit comments

Comments
 (0)