|
13 | 13 | "metadata": {},
|
14 | 14 | "outputs": [],
|
15 | 15 | "source": [
|
16 |
| - "#|default_exp utils" |
| 16 | + "# |default_exp utils" |
17 | 17 | ]
|
18 | 18 | },
|
19 | 19 | {
|
|
22 | 22 | "metadata": {},
|
23 | 23 | "outputs": [],
|
24 | 24 | "source": [
|
25 |
| - "#|export\n", |
| 25 | + "# |export\n", |
26 | 26 | "import matplotlib.pyplot as plt\n",
|
27 |
| - "from matplotlib.collections import LineCollection\n", |
| 27 | + "from matplotlib.collections import LineCollection\n", |
28 | 28 | "import numpy as np\n",
|
29 | 29 | "import jax\n",
|
30 | 30 | "import jax.numpy as jnp\n",
|
|
44 | 44 | "metadata": {},
|
45 | 45 | "outputs": [],
|
46 | 46 | "source": [
|
47 |
| - "#|export\n", |
48 |
| - "key = jax.random.PRNGKey(0)\n", |
| 47 | + "# |export\n", |
| 48 | + "key = jax.random.PRNGKey(0)\n", |
49 | 49 | "logsumexp = jax.scipy.special.logsumexp"
|
50 | 50 | ]
|
51 | 51 | },
|
|
55 | 55 | "metadata": {},
|
56 | 56 | "outputs": [],
|
57 | 57 | "source": [
|
58 |
| - "#|export\n", |
| 58 | + "# |export\n", |
59 | 59 | "def keysplit(key, *ns):\n",
|
60 |
| - " if len(ns) == 0: \n", |
| 60 | + " if len(ns) == 0:\n", |
61 | 61 | " return jax.random.split(key, 1)[0]\n",
|
62 | 62 | " elif len(ns) == 1:\n",
|
63 |
| - " n, = ns\n", |
64 |
| - " if n == 1: return keysplit(key)\n", |
65 |
| - " else: return jax.random.split(key, ns[0])\n", |
| 63 | + " (n,) = ns\n", |
| 64 | + " if n == 1:\n", |
| 65 | + " return keysplit(key)\n", |
| 66 | + " else:\n", |
| 67 | + " return jax.random.split(key, ns[0])\n", |
66 | 68 | " else:\n",
|
67 | 69 | " keys = []\n",
|
68 |
| - " for n in ns: keys.append(keysplit(key, n))\n", |
69 |
| - " return keys\n" |
| 70 | + " for n in ns:\n", |
| 71 | + " keys.append(keysplit(key, n))\n", |
| 72 | + " return keys" |
70 | 73 | ]
|
71 | 74 | },
|
72 | 75 | {
|
|
122 | 125 | "metadata": {},
|
123 | 126 | "outputs": [],
|
124 | 127 | "source": [
|
125 |
| - "#|export\n", |
| 128 | + "# |export\n", |
126 | 129 | "def bounding_box(arr, pad=0):\n",
|
127 | 130 | " \"\"\"Takes a euclidean-like arr (`arr.shape[-1] == 2`) and returns its bounding box.\"\"\"\n",
|
128 |
| - " return jnp.array([\n", |
129 |
| - " [jnp.min(arr[...,0])-pad, jnp.min(arr[...,1])-pad],\n", |
130 |
| - " [jnp.max(arr[...,0])+pad, jnp.max(arr[...,1])+pad]\n", |
131 |
| - " ])" |
| 131 | + " return jnp.array(\n", |
| 132 | + " [\n", |
| 133 | + " [jnp.min(arr[..., 0]) - pad, jnp.min(arr[..., 1]) - pad],\n", |
| 134 | + " [jnp.max(arr[..., 0]) + pad, jnp.max(arr[..., 1]) + pad],\n", |
| 135 | + " ]\n", |
| 136 | + " )" |
132 | 137 | ]
|
133 | 138 | },
|
134 | 139 | {
|
|
137 | 142 | "metadata": {},
|
138 | 143 | "outputs": [],
|
139 | 144 | "source": [
|
140 |
| - "#|export\n", |
| 145 | + "# |export\n", |
141 | 146 | "def argmax_axes(a, axes=None):\n",
|
142 | 147 | " \"\"\"Argmax along specified axes\"\"\"\n",
|
143 |
| - " if axes is None: return jnp.argmax(a)\n", |
144 |
| - " \n", |
145 |
| - " n = len(axes) \n", |
146 |
| - " axes_ = set(range(a.ndim))\n", |
| 148 | + " if axes is None:\n", |
| 149 | + " return jnp.argmax(a)\n", |
| 150 | + "\n", |
| 151 | + " n = len(axes)\n", |
| 152 | + " axes_ = set(range(a.ndim))\n", |
147 | 153 | " axes_0 = axes\n",
|
148 |
| - " axes_1 = sorted(axes_ - set(axes_0)) \n", |
149 |
| - " axes_ = axes_0 + axes_1\n", |
| 154 | + " axes_1 = sorted(axes_ - set(axes_0))\n", |
| 155 | + " axes_ = axes_0 + axes_1\n", |
150 | 156 | "\n",
|
151 | 157 | " b = jnp.transpose(a, axes=axes_)\n",
|
152 | 158 | " c = b.reshape(np.prod(b.shape[:n]), -1)\n",
|
153 | 159 | "\n",
|
154 | 160 | " I = jnp.argmax(c, axis=0)\n",
|
155 |
| - " I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(b.shape[n:] + (n,))\n", |
| 161 | + " I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(\n", |
| 162 | + " b.shape[n:] + (n,)\n", |
| 163 | + " )\n", |
156 | 164 | "\n",
|
157 |
| - " return I" |
| 165 | + " return I" |
158 | 166 | ]
|
159 | 167 | },
|
160 | 168 | {
|
|
177 | 185 | "test_shape = (3, 99, 5, 9)\n",
|
178 | 186 | "a = jnp.arange(np.prod(test_shape)).reshape(test_shape)\n",
|
179 | 187 | "\n",
|
180 |
| - "I = argmax_axes(a, axes=[0,1])\n", |
| 188 | + "I = argmax_axes(a, axes=[0, 1])\n", |
181 | 189 | "I.shape"
|
182 | 190 | ]
|
183 | 191 | },
|
|
194 | 202 | "metadata": {},
|
195 | 203 | "outputs": [],
|
196 | 204 | "source": [
|
197 |
| - "#|export\n", |
198 |
| - "def cam_to_screen(x): return jnp.array([x[0]/x[2], x[1]/x[2], jnp.linalg.norm(x)])\n", |
199 |
| - "def screen_to_cam(y): return y[2]*jnp.array([y[0], y[1], 1.0])" |
| 205 | + "# |export\n", |
| 206 | + "def cam_to_screen(x):\n", |
| 207 | + " return jnp.array([x[0] / x[2], x[1] / x[2], jnp.linalg.norm(x)])\n", |
| 208 | + "\n", |
| 209 | + "\n", |
| 210 | + "def screen_to_cam(y):\n", |
| 211 | + " return y[2] * jnp.array([y[0], y[1], 1.0])" |
200 | 212 | ]
|
201 | 213 | },
|
202 | 214 | {
|
|
205 | 217 | "metadata": {},
|
206 | 218 | "outputs": [],
|
207 | 219 | "source": [
|
208 |
| - "#|export\n", |
209 |
| - "def rot2d(hd): return jnp.array([\n", |
210 |
| - " [jnp.cos(hd), -jnp.sin(hd)], \n", |
211 |
| - " [jnp.sin(hd), jnp.cos(hd)]\n", |
212 |
| - " ]);\n", |
| 220 | + "# |export\n", |
| 221 | + "def rot2d(hd):\n", |
| 222 | + " return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]])\n", |
| 223 | + "\n", |
213 | 224 | "\n",
|
214 |
| - "def pack_2dpose(x,hd): \n", |
215 |
| - " return jnp.concatenate([x,jnp.array([hd])])\n", |
| 225 | + "def pack_2dpose(x, hd):\n", |
| 226 | + " return jnp.concatenate([x, jnp.array([hd])])\n", |
216 | 227 | "\n",
|
217 |
| - "def apply_2dpose(p, ys): \n", |
218 |
| - " return ys@rot2d(p[2] - jnp.pi/2).T + p[:2]\n", |
219 | 228 | "\n",
|
220 |
| - "def unit_vec(hd): \n", |
| 229 | + "def apply_2dpose(p, ys):\n", |
| 230 | + " return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2]\n", |
| 231 | + "\n", |
| 232 | + "\n", |
| 233 | + "def unit_vec(hd):\n", |
221 | 234 | " return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n",
|
222 | 235 | "\n",
|
| 236 | + "\n", |
223 | 237 | "def adjust_angle(hd):\n",
|
224 | 238 | " \"\"\"Adjusts angle to lie in the interval [-pi,pi).\"\"\"\n",
|
225 |
| - " return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi" |
| 239 | + " return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi" |
226 | 240 | ]
|
227 | 241 | },
|
228 | 242 | {
|
|
238 | 252 | "metadata": {},
|
239 | 253 | "outputs": [],
|
240 | 254 | "source": [
|
241 |
| - "#|export\n", |
| 255 | + "# |export\n", |
242 | 256 | "from genjax.incremental import UnknownChange, NoChange, Diff\n",
|
243 | 257 | "\n",
|
244 | 258 | "\n",
|
245 | 259 | "def argdiffs(args, other=None):\n",
|
246 |
| - " return tuple(map(lambda v: Diff(v, UnknownChange), args))\n" |
| 260 | + " return tuple(map(lambda v: Diff(v, UnknownChange), args))" |
247 | 261 | ]
|
248 | 262 | },
|
249 | 263 | {
|
|
252 | 266 | "metadata": {},
|
253 | 267 | "outputs": [],
|
254 | 268 | "source": [
|
255 |
| - "#|export\n", |
| 269 | + "# |export\n", |
256 | 270 | "from builtins import property as _property, tuple as _tuple\n",
|
257 | 271 | "from typing import Any\n",
|
258 | 272 | "\n",
|
259 | 273 | "\n",
|
260 | 274 | "class Args(tuple):\n",
|
261 | 275 | " def __new__(cls, *args, **kwargs):\n",
|
262 | 276 | " return _tuple.__new__(cls, list(args) + list(kwargs.values()))\n",
|
263 |
| - " \n", |
| 277 | + "\n", |
264 | 278 | " def __init__(self, *args, **kwargs):\n",
|
265 | 279 | " self._d = dict()\n",
|
266 |
| - " for k,v in kwargs.items():\n", |
| 280 | + " for k, v in kwargs.items():\n", |
267 | 281 | " self._d[k] = v\n",
|
268 | 282 | " setattr(self, k, v)\n",
|
269 | 283 | "\n",
|
|
297 | 311 | "metadata": {},
|
298 | 312 | "outputs": [],
|
299 | 313 | "source": [
|
300 |
| - "#|export\n", |
301 |
| - "# \n", |
| 314 | + "# |export\n", |
| 315 | + "#\n", |
302 | 316 | "# Monkey patching `sample` for `BuiltinGenerativeFunction`\n",
|
303 |
| - "# \n", |
| 317 | + "#\n", |
304 | 318 | "cls = genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction\n",
|
305 | 319 | "\n",
|
| 320 | + "\n", |
306 | 321 | "def genjax_sample(self, key, *args, **kwargs):\n",
|
307 | 322 | " tr = self.simulate(key, args)\n",
|
308 | 323 | " return tr.get_retval()\n",
|
309 | 324 | "\n",
|
| 325 | + "\n", |
310 | 326 | "setattr(cls, \"sample\", genjax_sample)\n",
|
311 | 327 | "\n",
|
312 | 328 | "\n",
|
313 |
| - "# \n", |
| 329 | + "#\n", |
314 | 330 | "# Monkey patching `sample` for `DeferredGenerativeFunctionCall`\n",
|
315 |
| - "# \n", |
| 331 | + "#\n", |
316 | 332 | "cls = genjax._src.generative_functions.supports_callees.SugaredGenerativeFunctionCall\n",
|
317 | 333 | "\n",
|
| 334 | + "\n", |
318 | 335 | "def deff_gen_func_call(self, key, **kwargs):\n",
|
319 | 336 | " return self.gen_fn.sample(key, *self.args, **kwargs)\n",
|
320 | 337 | "\n",
|
| 338 | + "\n", |
321 | 339 | "def deff_gen_func_logpdf(self, x, **kwargs):\n",
|
322 | 340 | " return self.gen_fn.logpdf(x, *self.args, **kwargs)\n",
|
323 | 341 | "\n",
|
| 342 | + "\n", |
324 | 343 | "setattr(cls, \"__call__\", deff_gen_func_call)\n",
|
325 | 344 | "setattr(cls, \"sample\", deff_gen_func_call)\n",
|
326 | 345 | "setattr(cls, \"logpdf\", deff_gen_func_logpdf)"
|
|
0 commit comments