Skip to content

Commit 98ede19

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ccf064e commit 98ede19

File tree

81 files changed

+12125
-9644
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+12125
-9644
lines changed

scripts/_mkl/notebooks/00a - Types.ipynb

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9-
"#|default_exp types"
9+
"# |default_exp types"
1010
]
1111
},
1212
{
@@ -15,7 +15,7 @@
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18-
"#|export\n",
18+
"# |export\n",
1919
"from typing import Any, NamedTuple\n",
2020
"import numpy as np\n",
2121
"import jax\n",
@@ -29,18 +29,18 @@
2929
"Int = Array\n",
3030
"FaceIndex = int\n",
3131
"FaceIndices = Array\n",
32-
"ArrayN = Array\n",
33-
"Array3 = Array\n",
34-
"Array2 = Array\n",
35-
"ArrayNx2 = Array\n",
36-
"ArrayNx3 = Array\n",
37-
"Matrix = jaxlib.xla_extension.ArrayImpl\n",
38-
"PrecisionMatrix = Matrix\n",
32+
"ArrayN = Array\n",
33+
"Array3 = Array\n",
34+
"Array2 = Array\n",
35+
"ArrayNx2 = Array\n",
36+
"ArrayNx3 = Array\n",
37+
"Matrix = jaxlib.xla_extension.ArrayImpl\n",
38+
"PrecisionMatrix = Matrix\n",
3939
"CovarianceMatrix = Matrix\n",
40-
"CholeskyMatrix = Matrix\n",
41-
"SquareMatrix = Matrix\n",
42-
"Vector = Array\n",
43-
"Direction = Vector\n",
40+
"CholeskyMatrix = Matrix\n",
41+
"SquareMatrix = Matrix\n",
42+
"Vector = Array\n",
43+
"Direction = Vector\n",
4444
"BaseVector = Vector"
4545
]
4646
},

scripts/_mkl/notebooks/00b - Utils.ipynb

Lines changed: 70 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
16-
"#|default_exp utils"
16+
"# |default_exp utils"
1717
]
1818
},
1919
{
@@ -22,9 +22,9 @@
2222
"metadata": {},
2323
"outputs": [],
2424
"source": [
25-
"#|export\n",
25+
"# |export\n",
2626
"import matplotlib.pyplot as plt\n",
27-
"from matplotlib.collections import LineCollection\n",
27+
"from matplotlib.collections import LineCollection\n",
2828
"import numpy as np\n",
2929
"import jax\n",
3030
"import jax.numpy as jnp\n",
@@ -44,8 +44,8 @@
4444
"metadata": {},
4545
"outputs": [],
4646
"source": [
47-
"#|export\n",
48-
"key = jax.random.PRNGKey(0)\n",
47+
"# |export\n",
48+
"key = jax.random.PRNGKey(0)\n",
4949
"logsumexp = jax.scipy.special.logsumexp"
5050
]
5151
},
@@ -55,18 +55,21 @@
5555
"metadata": {},
5656
"outputs": [],
5757
"source": [
58-
"#|export\n",
58+
"# |export\n",
5959
"def keysplit(key, *ns):\n",
60-
" if len(ns) == 0: \n",
60+
" if len(ns) == 0:\n",
6161
" return jax.random.split(key, 1)[0]\n",
6262
" elif len(ns) == 1:\n",
63-
" n, = ns\n",
64-
" if n == 1: return keysplit(key)\n",
65-
" else: return jax.random.split(key, ns[0])\n",
63+
" (n,) = ns\n",
64+
" if n == 1:\n",
65+
" return keysplit(key)\n",
66+
" else:\n",
67+
" return jax.random.split(key, ns[0])\n",
6668
" else:\n",
6769
" keys = []\n",
68-
" for n in ns: keys.append(keysplit(key, n))\n",
69-
" return keys\n"
70+
" for n in ns:\n",
71+
" keys.append(keysplit(key, n))\n",
72+
" return keys"
7073
]
7174
},
7275
{
@@ -122,13 +125,15 @@
122125
"metadata": {},
123126
"outputs": [],
124127
"source": [
125-
"#|export\n",
128+
"# |export\n",
126129
"def bounding_box(arr, pad=0):\n",
127130
" \"\"\"Takes a euclidean-like arr (`arr.shape[-1] == 2`) and returns its bounding box.\"\"\"\n",
128-
" return jnp.array([\n",
129-
" [jnp.min(arr[...,0])-pad, jnp.min(arr[...,1])-pad],\n",
130-
" [jnp.max(arr[...,0])+pad, jnp.max(arr[...,1])+pad]\n",
131-
" ])"
131+
" return jnp.array(\n",
132+
" [\n",
133+
" [jnp.min(arr[..., 0]) - pad, jnp.min(arr[..., 1]) - pad],\n",
134+
" [jnp.max(arr[..., 0]) + pad, jnp.max(arr[..., 1]) + pad],\n",
135+
" ]\n",
136+
" )"
132137
]
133138
},
134139
{
@@ -137,24 +142,27 @@
137142
"metadata": {},
138143
"outputs": [],
139144
"source": [
140-
"#|export\n",
145+
"# |export\n",
141146
"def argmax_axes(a, axes=None):\n",
142147
" \"\"\"Argmax along specified axes\"\"\"\n",
143-
" if axes is None: return jnp.argmax(a)\n",
144-
" \n",
145-
" n = len(axes) \n",
146-
" axes_ = set(range(a.ndim))\n",
148+
" if axes is None:\n",
149+
" return jnp.argmax(a)\n",
150+
"\n",
151+
" n = len(axes)\n",
152+
" axes_ = set(range(a.ndim))\n",
147153
" axes_0 = axes\n",
148-
" axes_1 = sorted(axes_ - set(axes_0)) \n",
149-
" axes_ = axes_0 + axes_1\n",
154+
" axes_1 = sorted(axes_ - set(axes_0))\n",
155+
" axes_ = axes_0 + axes_1\n",
150156
"\n",
151157
" b = jnp.transpose(a, axes=axes_)\n",
152158
" c = b.reshape(np.prod(b.shape[:n]), -1)\n",
153159
"\n",
154160
" I = jnp.argmax(c, axis=0)\n",
155-
" I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(b.shape[n:] + (n,))\n",
161+
" I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(\n",
162+
" b.shape[n:] + (n,)\n",
163+
" )\n",
156164
"\n",
157-
" return I"
165+
" return I"
158166
]
159167
},
160168
{
@@ -177,7 +185,7 @@
177185
"test_shape = (3, 99, 5, 9)\n",
178186
"a = jnp.arange(np.prod(test_shape)).reshape(test_shape)\n",
179187
"\n",
180-
"I = argmax_axes(a, axes=[0,1])\n",
188+
"I = argmax_axes(a, axes=[0, 1])\n",
181189
"I.shape"
182190
]
183191
},
@@ -194,9 +202,13 @@
194202
"metadata": {},
195203
"outputs": [],
196204
"source": [
197-
"#|export\n",
198-
"def cam_to_screen(x): return jnp.array([x[0]/x[2], x[1]/x[2], jnp.linalg.norm(x)])\n",
199-
"def screen_to_cam(y): return y[2]*jnp.array([y[0], y[1], 1.0])"
205+
"# |export\n",
206+
"def cam_to_screen(x):\n",
207+
" return jnp.array([x[0] / x[2], x[1] / x[2], jnp.linalg.norm(x)])\n",
208+
"\n",
209+
"\n",
210+
"def screen_to_cam(y):\n",
211+
" return y[2] * jnp.array([y[0], y[1], 1.0])"
200212
]
201213
},
202214
{
@@ -205,24 +217,26 @@
205217
"metadata": {},
206218
"outputs": [],
207219
"source": [
208-
"#|export\n",
209-
"def rot2d(hd): return jnp.array([\n",
210-
" [jnp.cos(hd), -jnp.sin(hd)], \n",
211-
" [jnp.sin(hd), jnp.cos(hd)]\n",
212-
" ]);\n",
220+
"# |export\n",
221+
"def rot2d(hd):\n",
222+
" return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]])\n",
223+
"\n",
213224
"\n",
214-
"def pack_2dpose(x,hd): \n",
215-
" return jnp.concatenate([x,jnp.array([hd])])\n",
225+
"def pack_2dpose(x, hd):\n",
226+
" return jnp.concatenate([x, jnp.array([hd])])\n",
216227
"\n",
217-
"def apply_2dpose(p, ys): \n",
218-
" return ys@rot2d(p[2] - jnp.pi/2).T + p[:2]\n",
219228
"\n",
220-
"def unit_vec(hd): \n",
229+
"def apply_2dpose(p, ys):\n",
230+
" return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2]\n",
231+
"\n",
232+
"\n",
233+
"def unit_vec(hd):\n",
221234
" return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n",
222235
"\n",
236+
"\n",
223237
"def adjust_angle(hd):\n",
224238
" \"\"\"Adjusts angle to lie in the interval [-pi,pi).\"\"\"\n",
225-
" return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi"
239+
" return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi"
226240
]
227241
},
228242
{
@@ -238,12 +252,12 @@
238252
"metadata": {},
239253
"outputs": [],
240254
"source": [
241-
"#|export\n",
255+
"# |export\n",
242256
"from genjax.incremental import UnknownChange, NoChange, Diff\n",
243257
"\n",
244258
"\n",
245259
"def argdiffs(args, other=None):\n",
246-
" return tuple(map(lambda v: Diff(v, UnknownChange), args))\n"
260+
" return tuple(map(lambda v: Diff(v, UnknownChange), args))"
247261
]
248262
},
249263
{
@@ -252,18 +266,18 @@
252266
"metadata": {},
253267
"outputs": [],
254268
"source": [
255-
"#|export\n",
269+
"# |export\n",
256270
"from builtins import property as _property, tuple as _tuple\n",
257271
"from typing import Any\n",
258272
"\n",
259273
"\n",
260274
"class Args(tuple):\n",
261275
" def __new__(cls, *args, **kwargs):\n",
262276
" return _tuple.__new__(cls, list(args) + list(kwargs.values()))\n",
263-
" \n",
277+
"\n",
264278
" def __init__(self, *args, **kwargs):\n",
265279
" self._d = dict()\n",
266-
" for k,v in kwargs.items():\n",
280+
" for k, v in kwargs.items():\n",
267281
" self._d[k] = v\n",
268282
" setattr(self, k, v)\n",
269283
"\n",
@@ -297,30 +311,35 @@
297311
"metadata": {},
298312
"outputs": [],
299313
"source": [
300-
"#|export\n",
301-
"# \n",
314+
"# |export\n",
315+
"#\n",
302316
"# Monkey patching `sample` for `BuiltinGenerativeFunction`\n",
303-
"# \n",
317+
"#\n",
304318
"cls = genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction\n",
305319
"\n",
320+
"\n",
306321
"def genjax_sample(self, key, *args, **kwargs):\n",
307322
" tr = self.simulate(key, args)\n",
308323
" return tr.get_retval()\n",
309324
"\n",
325+
"\n",
310326
"setattr(cls, \"sample\", genjax_sample)\n",
311327
"\n",
312328
"\n",
313-
"# \n",
329+
"#\n",
314330
"# Monkey patching `sample` for `DeferredGenerativeFunctionCall`\n",
315-
"# \n",
331+
"#\n",
316332
"cls = genjax._src.generative_functions.supports_callees.SugaredGenerativeFunctionCall\n",
317333
"\n",
334+
"\n",
318335
"def deff_gen_func_call(self, key, **kwargs):\n",
319336
" return self.gen_fn.sample(key, *self.args, **kwargs)\n",
320337
"\n",
338+
"\n",
321339
"def deff_gen_func_logpdf(self, x, **kwargs):\n",
322340
" return self.gen_fn.logpdf(x, *self.args, **kwargs)\n",
323341
"\n",
342+
"\n",
324343
"setattr(cls, \"__call__\", deff_gen_func_call)\n",
325344
"setattr(cls, \"sample\", deff_gen_func_call)\n",
326345
"setattr(cls, \"logpdf\", deff_gen_func_logpdf)"

0 commit comments

Comments
 (0)