Skip to content

Commit 5541c80

Browse files
feat(nodes): use TorchDevice to get device/dtype in bria latent noise node
1 parent 7615d17 commit 5541c80

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

invokeai/app/invocations/bria_latent_noise.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
LatentsField,
88
)
99
from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents
10+
from invokeai.backend.util.devices import TorchDevice
1011
from invokeai.invocation_api import (
1112
BaseInvocation,
1213
Classification,
@@ -42,11 +43,6 @@ class BriaLatentNoiseInvocation(BaseInvocation):
4243
title="Seed",
4344
description="The seed to use for the latent sampler",
4445
)
45-
transformer: TransformerField = InputField(
46-
description="Bria model (Transformer) to load",
47-
input=Input.Connection,
48-
title="Transformer",
49-
)
5046
height: int = InputField(
5147
default=1024,
5248
title="Height",
@@ -60,9 +56,8 @@ class BriaLatentNoiseInvocation(BaseInvocation):
6056

6157
@torch.no_grad()
6258
def invoke(self, context: InvocationContext) -> BriaLatentNoiseInvocationOutput:
63-
with context.models.load(self.transformer.transformer) as transformer:
64-
device = transformer.device
65-
dtype = transformer.dtype
59+
device = TorchDevice.choose_torch_device()
60+
dtype = TorchDevice.choose_torch_dtype()
6661

6762
generator = torch.Generator(device=device).manual_seed(self.seed)
6863

0 commit comments

Comments
 (0)