We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dd31580 commit 2c67c96Copy full SHA for 2c67c96
examples/offline_inference/rlhf_colocate.py
@@ -80,10 +80,7 @@ def __init__(self):
80
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
81
from vllm.platforms import current_platform
82
83
- if current_platform.is_xpu():
84
- self.model.to("xpu:0")
85
- else:
86
- self.model.to("cuda:0")
+ self.model.to(current_platform.device_type + ":0")
87
# Zero out all the parameters.
88
for name, p in self.model.named_parameters():
89
p.data.zero_()
0 commit comments