Skip to content

Commit 6aaf4a4

Browse files
polvalentedimamik
andauthored
chore: update nx (#618)
* chore: update nx * fix doctests * fix: update ci * chore: format * fix: tolerances * fix: formatter * fix: tolerances * fix: tolerances * chore: bring changes from #616 Co-Authored-By: Dima Mikielewicz <[email protected]> * chore: format * fix: flaky test --------- Co-authored-by: Dima Mikielewicz <[email protected]>
1 parent a3ed2b9 commit 6aaf4a4

File tree

10 files changed

+67
-43
lines changed

10 files changed

+67
-43
lines changed

.github/workflows/gh-pages.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ concurrency:
1010
cancel-in-progress: true
1111

1212
env:
13-
OTP_VERSION: "25.0"
14-
ELIXIR_VERSION: "1.14.0"
13+
OTP_VERSION: "26.1.1"
14+
ELIXIR_VERSION: "1.15.6"
1515

1616
jobs:
1717
deploy:

.github/workflows/test.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ jobs:
2222
fail-fast: false
2323
matrix:
2424
include:
25+
- otp: "27.1"
26+
elixir: "1.19.0"
27+
lint: true
28+
- otp: "26.1.1"
29+
elixir: "1.15.6"
2530
- otp: "26.1.1"
2631
elixir: "1.15.6"
27-
lint: true
28-
- otp: "25.3.2.6"
29-
elixir: "1.14.5"
30-
- otp: "25.3.2.6"
31-
elixir: "1.14.5"
3232
test_command_prepend: "USE_EXLA=true"
33-
- otp: "25.3.2.6"
34-
elixir: "1.14.5"
33+
- otp: "26.1.1"
34+
elixir: "1.15.6"
3535
test_command_prepend: "USE_TORCHX=true"
3636
steps:
3737
- uses: actions/checkout@v3

lib/axon/activations.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ defmodule Axon.Activations do
214214
#Nx.Tensor<
215215
bf16[batch: 2][data: 3]
216216
[
217-
[7.781982421875e-4, 0.0, 0.0],
217+
[0.0, 0.0, 0.0],
218218
[0.3984375, 0.59765625, 0.796875]
219219
]
220220
>
@@ -249,7 +249,7 @@ defmodule Axon.Activations do
249249
#Nx.Tensor<
250250
bf16[batch: 2][data: 3]
251251
[
252-
[-7.781982421875e-4, -0.0, -0.0],
252+
[-0.0, -0.0, -0.0],
253253
[0.3984375, 1.1953125, 2.390625]
254254
]
255255
>
@@ -645,7 +645,7 @@ defmodule Axon.Activations do
645645
#Nx.Tensor<
646646
bf16[batch: 2][data: 3]
647647
[
648-
[-1.09375, -1.5078125, -1.6640625],
648+
[-1.09375, -1.5, -1.65625],
649649
[1.046875, 2.09375, 3.140625]
650650
]
651651
>

lib/axon/defn.ex

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ defmodule Axon.Defn do
1414
[fun.(vars)]
1515
end
1616

17-
@impl true
18-
def __stream__(_, _, _, _, _, _, _), do: raise("not implemented")
19-
2017
@impl true
2118
def __compile__(_, _, _, _), do: raise("not implemented")
2219

lib/axon/loop.ex

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,7 +1632,7 @@ defmodule Axon.Loop do
16321632
final_metrics_map = loop_state.metrics
16331633
loop_state = %{loop_state | metrics: zero_metrics}
16341634

1635-
{status, final_metrics_map, state} =
1635+
{status, final_metrics_map, %State{} = state} =
16361636
case fire_event(:started, handler_fns, loop_state, debug?) do
16371637
{:halt_epoch, state} ->
16381638
{:halted, final_metrics_map, state}
@@ -1647,7 +1647,7 @@ defmodule Axon.Loop do
16471647
Enum.reduce_while(
16481648
epoch_start..epoch_end//1,
16491649
{batch_fn, final_metrics_map, state},
1650-
fn epoch, {batch_fn, final_metrics_map, loop_state} ->
1650+
fn epoch, {batch_fn, final_metrics_map, %State{} = loop_state} ->
16511651
case fire_event(:epoch_started, handler_fns, loop_state, debug?) do
16521652
{:halt_epoch, state} ->
16531653
halt_epoch(handler_fns, batch_fn, final_metrics_map, state, debug?)
@@ -1691,7 +1691,7 @@ defmodule Axon.Loop do
16911691
{:halt_loop, state} ->
16921692
{:halt, {final_metrics_map, state}}
16931693

1694-
{:continue, state} ->
1694+
{:continue, %State{} = state} ->
16951695
{:cont,
16961696
{batch_fn, Map.put(final_metrics_map, epoch, state.metrics),
16971697
%State{
@@ -1924,7 +1924,7 @@ defmodule Axon.Loop do
19241924
# Halts an epoch during looping
19251925
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do
19261926
case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do
1927-
{:halt_epoch, %{epoch: epoch, metrics: metrics} = state} ->
1927+
{:halt_epoch, %State{epoch: epoch, metrics: metrics} = state} ->
19281928
final_metrics_map = Map.put(final_metrics_map, epoch, metrics)
19291929
{:cont, {batch_fn, final_metrics_map, %State{state | epoch: epoch + 1, iteration: 0}}}
19301930

lib/axon/quantization/layers.ex

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ defmodule Axon.Quantization.Layers do
4545
end
4646

4747
deftransformp reshape_scales(scales, y) do
48-
ones = List.to_tuple(List.duplicate(1, Nx.rank(y) - 1))
49-
Nx.reshape(scales, Tuple.append(ones, :auto))
48+
n = Nx.rank(y) - 1
49+
ones = Tuple.duplicate(1, n)
50+
Nx.reshape(scales, Tuple.insert_at(ones, n, :auto))
5051
end
5152

5253
deftransformp reshape_output(output, x_shape) do
53-
all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1)
54-
new_shape = Tuple.append(all_but_last, :auto)
54+
n = tuple_size(x_shape) - 1
55+
all_but_last = Tuple.delete_at(x_shape, n)
56+
new_shape = Tuple.insert_at(all_but_last, n, :auto)
5557
Nx.reshape(output, new_shape)
5658
end
5759
end

mix.exs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@ defmodule Axon.MixProject do
1414
deps: deps(),
1515
docs: docs(),
1616
description: "Create and train neural networks in Elixir",
17-
package: package(),
18-
preferred_cli_env: [
19-
docs: :docs,
20-
"hex.publish": :docs
21-
]
17+
package: package()
18+
]
19+
end
20+
21+
def cli do
22+
[
23+
docs: :docs,
24+
"hex.publish": :docs
2225
]
2326
end
2427

@@ -35,9 +38,9 @@ defmodule Axon.MixProject do
3538
# Run "mix help deps" to learn about dependencies.
3639
defp deps do
3740
[
38-
{:nx, "~> 0.9", nx_opts()},
39-
{:exla, "~> 0.9", [only: :test] ++ exla_opts()},
40-
{:torchx, "~> 0.9", [only: :test] ++ torchx_opts()},
41+
{:nx, "~> 0.10", nx_opts()},
42+
{:exla, "~> 0.10", [only: :test] ++ exla_opts()},
43+
{:torchx, "~> 0.10", [only: :test] ++ torchx_opts()},
4144
{:ex_doc, "~> 0.23", only: :docs},
4245
{:table_rex, "~> 3.1 or ~> 4.1", optional: true},
4346
{:kino, "~> 0.7", optional: true},

mix.lock

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
%{
2-
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
2+
"complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"},
33
"earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"},
4-
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
4+
"elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"},
55
"ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"},
6-
"exla": {:hex, :exla, "0.9.0", "e048c7a3d33917c214774a7ea1a0c626eb9de01e3fb2423cf9e2b89ef6dada3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.8.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "cbd30b54992d0da01a5aaee361a3160fc29de05a9f6c3dbcbd1fa04b4aa72302"},
6+
"exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"},
7+
"fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"},
78
"fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"},
89
"kino": {:hex, :kino, "0.14.1", "c499afb1cd0be462feaf0a75c0631aa65aacc545b1c10f431b439b74f104be22", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "090aea1aaa267e42e5ac24ee6bc5ed515aecc0a9edb8619aa4ee839201e704aa"},
910
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.13", "03c00405987a2202e4b8014ee55eb7f5727691b3f13d76a3764f6eeccef45322", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "00c72bc270e7b9d3c339f726cdab0012fd3f2fc75e36c7548e0f250fe420fa10"},
@@ -12,12 +13,12 @@
1213
"makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"},
1314
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
1415
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
15-
"nx": {:hex, :nx, "0.9.0", "03a622a27d93eaaa2d24ff9b812d9f675cc04eb0340ca3dd065674f3642867d3", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3810a5a90db0654b6e538430c0fb473a22bfc11b3d02ea7834db493cf3f56153"},
16+
"nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"},
1617
"polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"},
1718
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
1819
"table_rex": {:hex, :table_rex, "4.1.0", "fbaa8b1ce154c9772012bf445bfb86b587430fb96f3b12022d3f35ee4a68c918", [:mix], [], "hexpm", "95932701df195d43bc2d1c6531178fc8338aa8f38c80f098504d529c43bc2601"},
1920
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
20-
"torchx": {:hex, :torchx, "0.9.0", "936cbd32233f89d73700c39b7ef56f94b3f3541db03c90f8ddf6b3fe73260e28", [:mix], [{:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "4e057d6b93fc91191957230b2c61c408861b888abdf6a900baf0db4125405505"},
21+
"torchx": {:hex, :torchx, "0.10.2", "4b8529bfc4b0e641232497c99ef6d2508e652198840b212373333361352f0bae", [:mix], [{:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "cad541c64df8ddcbf50d9b0f212961632361a03050c8e01493f0fc8d4fed96d9"},
2122
"vega_lite": {:hex, :vega_lite, "0.1.9", "d7a288665f916181b68d0a3617f1b3611d16a4dcd5fafb51b847b71db1159d4c", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "c6a056e763162198e73ae6dfb46c09753bb0298474410fd085074e1cdcee7418"},
22-
"xla": {:hex, :xla, "0.8.0", "fef314d085dd3ee16a0816c095239938f80769150e15db16dfaa435553d7cb16", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "739c61c8d93b97e12ba0369d10e76130224c208f1a76ad293e3581f056833e57"},
23+
"xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"},
2324
}

test/axon/compiler_test.exs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4062,9 +4062,10 @@ defmodule CompilerTest do
40624062
}
40634063
} = params = init_fn.(input, ModelState.empty())
40644064

4065-
assert_equal(
4065+
assert_all_close(
40664066
predict_fn.(params, input),
4067-
Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, input, carry, Nx.tensor(0), k, h, b)
4067+
Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, input, carry, Nx.tensor(0), k, h, b),
4068+
atol: 1.0e-7
40684069
)
40694070
end
40704071

@@ -4192,7 +4193,11 @@ defmodule CompilerTest do
41924193
enc = {eik, ehk, eb}
41934194
dec = {dik, dhk, db}
41944195

4195-
assert_equal(predict_fn.(params, input), equiv_fn.(input, enc, dec))
4196+
assert_all_close(
4197+
predict_fn.(params, input),
4198+
equiv_fn.(input, enc, dec),
4199+
atol: 1.0e-7
4200+
)
41964201
end
41974202

41984203
test "initializes with use_bias false" do
@@ -5246,7 +5251,11 @@ defmodule CompilerTest do
52465251

52475252
input = random({1, 1})
52485253

5249-
assert_equal(predict_fn.(params, input), expected_predict_fn.(input, k1, b1, k2, b2))
5254+
assert_all_close(
5255+
predict_fn.(params, input),
5256+
expected_predict_fn.(input, k1, b1, k2, b2),
5257+
atol: 1.0e-7
5258+
)
52505259
end
52515260

52525261
test "predicts correctly with multiple dense, used twice" do
@@ -5290,7 +5299,11 @@ defmodule CompilerTest do
52905299

52915300
input = random({1, 1})
52925301

5293-
assert_equal(predict_fn.(params, input), expected_predict_fn.(input, k1, b1, k2, b2))
5302+
assert_all_close(
5303+
predict_fn.(params, input),
5304+
expected_predict_fn.(input, k1, b1, k2, b2),
5305+
atol: 1.0e-7
5306+
)
52945307
end
52955308

52965309
test "predicts correctly with multiple blocks in network" do
@@ -5703,6 +5716,8 @@ defmodule CompilerTest do
57035716
out =
57045717
ExUnit.CaptureIO.capture_io(fn ->
57055718
predict_fn.(model_state, input)
5719+
# Wait for async print operations to flush
5720+
Process.sleep(1000)
57065721
end)
57075722

57085723
assert out =~ "x:"

test/axon/layers_test.exs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,13 @@ defmodule Axon.LayersTest do
377377
bias = 0.0
378378

379379
assert_equal(
380-
Axon.Layers.conv_transpose(inp, kernel, bias, padding: [{0, 1}, {1, 2}], channels: :first),
380+
Axon.Layers.conv_transpose(
381+
inp,
382+
kernel,
383+
bias,
384+
padding: [{0, 1}, {1, 2}],
385+
channels: :first
386+
),
381387
Nx.tensor([[[[0.0, 2.0, 3.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]]])
382388
)
383389
end

0 commit comments

Comments
 (0)