Skip to content

Commit 81891a1

Browse files
committed
allow users to pass 1dim conditions
1 parent 323602f commit 81891a1

File tree

4 files changed

+6
-2
lines changed

4 files changed

+6
-2
lines changed
File renamed without changes.

flowmatching_bdt/flow_bdt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def train_single(self, xt, vt, conditions=None):
9494
)
9595

9696
if conditions is not None:
97+
if conditions.ndim == 1:
98+
conditions = np.expand_dims(conditions, axis=1)
9799
xt = np.concatenate([xt, conditions], axis=1)
98100

99101
# learn to predict the velocity field given a noised input
@@ -158,6 +160,8 @@ def model_t(self, t, xt, conditions=None):
158160
flow_step = int(round(t * (self.n_flow_steps - 1)))
159161

160162
if conditions is not None:
163+
if conditions.ndim == 1:
164+
conditions = np.expand_dims(conditions, axis=1)
161165
xt = np.concatenate([xt, conditions], axis=1)
162166

163167
return self.models[flow_step].predict(xt)

flowmatching_bdt/test/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_condtional():
2525

2626
# get new samples
2727
num_samples = 1000
28-
conditions = np.ones((num_samples, 1))
28+
conditions = np.ones(num_samples)
2929
samples = model.predict(num_samples=num_samples, conditions=conditions)
3030

3131

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'flowmatching-bdt',
55
packages = find_packages(exclude=['assets']),
6-
version = '0.1.0',
6+
version = '0.2.0',
77
license='MIT',
88
description = 'Flow Matching with BDTs',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)