diff --git a/examples/sparse_mnist.ipynb b/examples/sparse_mnist.ipynb index 11ccb3f..4c8fa20 100644 --- a/examples/sparse_mnist.ipynb +++ b/examples/sparse_mnist.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,8 @@ "# with Numenta, Inc., for a separate license for this software code, the\n", "# following terms and conditions apply:\n", "#\n", - "# This program is free software: you can redistribute it and/or modify\n", + "# This program is free software: \n", + "# you can redistribute it and/or modify\n", "# it under the terms of the GNU Affero Public License version 3 as\n", "# published by the Free Software Foundation.\n", "#\n", @@ -30,9 +31,35 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting nupic.torch\n", + " Cloning https://github.com/numenta/nupic.torch.git to /private/var/folders/0k/nx51st811lj2bkn32z3c0_dc0000gn/T/pip-install-h0b3r322/nupic-torch_6389283cd1294fb79c8a7c3c4e32176d\n", + " Running command git clone --filter=blob:none --quiet https://github.com/numenta/nupic.torch.git /private/var/folders/0k/nx51st811lj2bkn32z3c0_dc0000gn/T/pip-install-h0b3r322/nupic-torch_6389283cd1294fb79c8a7c3c4e32176d\n", + " Resolved https://github.com/numenta/nupic.torch.git to commit 51269b7aae1024ee301837320a3b3ebee538c8c0\n", + " Installing build dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", + "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: torch<=1.11.0,>=1.6 in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from nupic.torch) (1.11.0)\n", + "Requirement already satisfied: typing-extensions in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from torch<=1.11.0,>=1.6->nupic.torch) (4.3.0)\n", + "Requirement already satisfied: torch in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (1.11.0)\n", + "Requirement already satisfied: torchvision in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (0.13.0)\n", + "Requirement already satisfied: typing-extensions in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from torch) (4.3.0)\n", + "Requirement already satisfied: numpy in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from torchvision) (1.23.1)\n", + "Requirement already satisfied: requests in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from torchvision) (2.28.1)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from torchvision) (9.2.0)\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from requests->torchvision) (2.1.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from requests->torchvision) (2022.6.15)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from requests->torchvision) (3.3)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/bard/miniforge3/envs/nupic.torch/lib/python3.8/site-packages (from requests->torchvision) (1.26.11)\n" + ] + } + ], "source": [ "!pip install git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch\n", "!pip install torch torchvision" @@ -40,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -51,7 +78,8 @@ "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", - "from tqdm import tqdm_notebook as tqdm\n", + "# from tqdm import tqdm_notebook as tqdm\n", + "from tqdm.notebook import tqdm\n", "\n", "SEED = 18\n", "random.seed(SEED)\n", @@ -64,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -132,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -159,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -172,21 +200,21 @@ " (module): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))\n", " )\n", " (cnn1_maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", - " (cnn1_kwinner): KWinners2d(channels=32, local=False, n=0, percent_on=0.1, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000)\n", + " (cnn1_kwinner): KWinners2d(channels=32, local=False, break_ties=False, n=0, percent_on=0.1, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000)\n", " (cnn2): SparseWeights2d(\n", " sparsity=0.55\n", " (module): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))\n", " )\n", " (cnn2_maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", - " (cnn2_kwinner): KWinners2d(channels=64, local=False, n=0, percent_on=0.2, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000)\n", + " (cnn2_kwinner): KWinners2d(channels=64, local=False, break_ties=False, n=0, percent_on=0.2, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000)\n", " (flatten): Flatten()\n", " (linear): SparseWeights(\n", " sparsity=0.8\n", " (module): Linear(in_features=1024, out_features=700, bias=True)\n", " )\n", - " (linear_kwinner): KWinners(n=700, percent_on=0.2, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000)\n", + " (linear_kwinner): KWinners(n=700, percent_on=0.2, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.0, duty_cycle_period=1000, break_ties=False)\n", " (output): Linear(in_features=700, out_features=10, bias=True)\n", - " (softmax): LogSoftmax()\n", + " (softmax): LogSoftmax(dim=1)\n", ")\n" ] } @@ -208,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -232,15 +260,22 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 19, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/15000 [00:00