diff --git a/datasets.py b/datasets.py index 7797ae6..b836c14 100644 --- a/datasets.py +++ b/datasets.py @@ -1,10 +1,13 @@ import numpy as np +import scipy.sparse as sparse +import pickle +from collections import defaultdict print("Datasets Available") print("-" * 20) # ################### Binary Tree -depth = 5 +depth = 10 N = sum(2 ** i for i in range(depth)) mat = np.zeros((N, N)) for i in range(N): @@ -13,7 +16,7 @@ break mat[i, j] = 1 mat[i, j + 1] = 1 -binary_tree = mat +binary_tree = sparse.csr_matrix(mat) print("Binary Tree : binary_tree") @@ -29,6 +32,27 @@ mat[i, j + 1] = 1 mat[i, j + 2] = 1 mat[i, j + 3] = 1 -quad_tree = mat +quad_tree = sparse.csr_matrix(mat) print("Quad Tree : quad_tree") print("-" * 20) + + + +def get_pickle_dataset(filename): + return pickle.load(open(filename, 'rb')) + +datasets = { + 'binary_tree': binary_tree, + 'quad_tree': quad_tree +} + +print(' '.join(datasets.keys())) + +def get_dataset(dataset_name_or_filename): + print("using {} dataset".format(dataset_name_or_filename)) + if dataset_name_or_filename in datasets.keys(): + print("using builtin dataset") + return datasets[dataset_name_or_filename] + else: + return get_pickle_dataset(dataset_name_or_filename) + diff --git a/lorentz.py b/lorentz.py deleted file mode 100644 index 70ddec5..0000000 --- a/lorentz.py +++ /dev/null @@ -1,430 +0,0 @@ -import os -import sys -import torch -import random -import numpy as np -from torch import nn -from torch import optim -from tqdm import trange, tqdm -from collections import Counter -from datetime import datetime -from tensorboardX import SummaryWriter -from torch.utils.data import Dataset, DataLoader - -import matplotlib - -matplotlib.use("Agg") # this needs to come before other matplotlib imports -import matplotlib.pyplot as plt - -plt.style.use("ggplot") - - -def arcosh(x): - return torch.log(x + torch.sqrt(x ** 2 - 1)) - - -def lorentz_scalar_product(x, y): - # BD, BD -> B - m = x * y - result = m[:, 1:].sum(dim=1) - m[:, 0] - return result - - -def tangent_norm(x): - # BD -> B - return torch.sqrt(lorentz_scalar_product(x, x)) - - -def exp_map(x, v): - # BD, BD -> BD - tn = tangent_norm(v).unsqueeze(dim=1) - tn_expand = tn.repeat(1, x.size()[-1]) - result = torch.cosh(tn) * x + torch.sinh(tn) * (v / tn) - result = torch.where(tn_expand > 0, result, x) # only update if tangent norm is > 0 - return result - - -def set_dim0(x): - x = torch.renorm(x, p=2, dim=0, maxnorm=1e2) # otherwise leaves will explode - # NOTE: the paper does not mention the square part of the equation but if - # you try to derive it you get a square term in the equation - dim0 = torch.sqrt(1 + (x[:, 1:] ** 2).sum(dim=1)) - x[:, 0] = dim0 - return x - - -# ========================= models - - -class RSGD(optim.Optimizer): - def __init__(self, params, learning_rate=None): - learning_rate = learning_rate if learning_rate is not None else 0.01 - defaults = {"learning_rate": learning_rate} - super().__init__(params, defaults=defaults) - - def step(self): - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - B, D = p.size() - gl = torch.eye(D, device=p.device, dtype=p.dtype) - gl[0, 0] = -1 - grad_norm = torch.norm(p.grad.data) - grad_norm = torch.where(grad_norm > 1, grad_norm, torch.tensor(1.0)) - # only normalize if global grad_norm is more than 1 - h = (p.grad.data / grad_norm) @ gl - proj = ( - h - - ( - lorentz_scalar_product(p, h) / lorentz_scalar_product(p, p) - ).unsqueeze(1) - * p - ) - # print(p, lorentz_scalar_product(p, p)) - update = exp_map(p, -group["learning_rate"] * proj) - is_nan_inf = torch.isnan(update) | torch.isinf(update) - update = torch.where(is_nan_inf, p, update) - update[0, :] = p[0, :] # no ❤️ for embedding - update = set_dim0(update) - p.data.copy_(update) - - -class Lorentz(nn.Module): - """ - This will embed `n_items` in a `dim` dimensional lorentz space. - """ - - def __init__(self, n_items, dim, init_range=0.001): - super().__init__() - self.n_items = n_items - self.dim = dim - self.table = nn.Embedding(n_items + 1, dim, padding_idx=0) - nn.init.uniform_(self.table.weight, -init_range, init_range) - # equation 6 - with torch.no_grad(): - self.table.weight[0] = 5 # padding idx push it to corner - set_dim0(self.table.weight) - - def forward(self, I, Ks): - """ - Using the pairwise similarity matrix, generate the following inputs and - provide to this function. - - Inputs: - - I : - long tensor - - size (B,) - - This denotes the `i` used in all equations. - - Ks : - long tensor - - size (B, N) - - This denotes at max `N` documents which come from the - nearest neighbor sample. - - The `j` document must be the first of the N indices. - This is used to calculate the losses - Return: - - size (B,) - - Ranking loss calculated using - document to the given `i` document. - - """ - n_ks = Ks.size()[1] - ui = torch.stack([self.table(I)] * n_ks, dim=1) - uks = self.table(Ks) - # ---------- reshape for calculation - B, N, D = ui.size() - ui = ui.reshape(B * N, D) - uks = uks.reshape(B * N, D) - dists = -lorentz_scalar_product(ui, uks) - dists = torch.where(dists <= 1, torch.ones_like(dists) + 1e-6, dists) - # sometimes 2 embedding can come very close in R^D. - # when calculating the lorenrz inner product, - # -1 can become -0.99(no idea!), then arcosh will become nan - dists = -arcosh(dists) - # print(dists) - # ---------- turn back to per-sample shape - dists = dists.reshape(B, N) - loss = -(dists[:, 0] - torch.log(torch.exp(dists).sum(dim=1) + 1e-6)) - return loss - - def lorentz_to_poincare(self): - table = self.table.weight.data.numpy() - return table[:, 1:] / ( - table[:, :1] + 1 - ) # diffeomorphism transform to poincare ball - - def get_lorentz_table(self): - return self.table.weight.data.numpy() - - def _test_table(self): - x = self.table.weight.data - check = lorentz_scalar_product(x, x) + 1.0 - return check.numpy().sum() - - -class Graph(Dataset): - def __init__(self, pairwise_matrix, sample_size=10): - self.pairwise_matrix = pairwise_matrix - self.n_items = len(pairwise_matrix) - self.sample_size = sample_size - self.arange = np.arange(0, self.n_items) - - def __len__(self): - return self.n_items - - def __getitem__(self, i): - I = torch.Tensor([i + 1]).squeeze().long() - has_child = (self.pairwise_matrix[i] > 0).sum() - has_parent = (self.pairwise_matrix[:, i] > 0).sum() - arange = np.random.permutation(self.arange) - if has_parent: # if no child go for parent - for j in arange: - if self.pairwise_matrix[j, i] > 0: # assuming no disconneted nodes - min = self.pairwise_matrix[j, i] - break - elif has_child: - for j in arange: - if self.pairwise_matrix[i, j] > 0: # assuming no self loop - min = self.pairwise_matrix[i, j] - break - else: - raise Exception(f"Node {i} has no parent and no child") - arange = np.random.permutation(self.arange) - if has_child: - indices = [x for x in arange if i != x and self.pairwise_matrix[i, x] < min] - else: - indices = [x for x in arange if i != x and self.pairwise_matrix[x, i] < min] - indices = indices[: self.sample_size] - Ks = ([i + 1 for i in [j] + indices] + [0] * self.sample_size)[ - : self.sample_size - ] - # print(I, Ks) - return I, torch.Tensor(Ks).long() - - -def recon(table, pair_mat): - "Reconstruction accuracy" - count = 0 - table = torch.tensor(table[1:]) - for i in range(1, len(pair_mat)): # 0 padding, 1 root, we leave those two - x = table[i].repeat(len(table)).reshape([len(table), len(table[i])]) # N, D - mask = torch.tensor([0.0] * len(table)) - mask[i] = 1 - mask = mask * -10000.0 - dists = lorentz_scalar_product(x, table) + mask - dists = ( - dists.numpy() - ) # arccosh is monotonically increasing, so no need of that here - # and no -dist also, as acosh in m i, -acosh(-l(x,y)) is nothing but l(x,y) - # print(dists) - predicted_parent = np.argmax(dists) - actual_parent = np.argmax(pair_mat[:, i]) - # print(predicted_parent, actual_parent, i, end="\n\n") - count += actual_parent == predicted_parent - count = count / (len(pair_mat) - 1) * 100 - return count - - -_moon_count = 0 - - -def _moon(loss, phases="🌕🌖🌗🌘🌑🌒🌓🌔"): - global _moon_count - _moon_count += 1 - p = phases[_moon_count % 8] - return f"{p} Loss: {float(loss)}" - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("dataset", help="File:pairwise_matrix") - parser.add_argument( - "-sample_size", help="How many samples in the N matrix", default=5, type=int - ) - parser.add_argument( - "-batch_size", help="How many samples in the batch", default=32, type=int - ) - parser.add_argument( - "-burn_c", - help="Divide learning rate by this for the burn epochs", - default=10, - type=int, - ) - parser.add_argument( - "-burn_epochs", - help="How many epochs to run the burn phase for?", - default=100, - type=int, - ) - parser.add_argument( - "-plot", help="Plot the embeddings", default=False, action="store_true" - ) - parser.add_argument("-plot_size", help="Size of the plot", default=3, type=int) - parser.add_argument( - "-plot_graph", - help="Plot the Graph associated with the embeddings", - default=False, - action="store_true", - ) - parser.add_argument( - "-overwrite_plots", - help="Overwrite the plots?", - default=False, - action="store_true", - ) - parser.add_argument( - "-ckpt", help="Which checkpoint to use?", default=None, type=str - ) - parser.add_argument( - "-shuffle", help="Shuffle within batch while learning?", default=True, type=bool - ) - parser.add_argument( - "-epochs", help="How many epochs to optimize for?", default=1_000_000, type=int - ) - parser.add_argument( - "-poincare_dim", - help="Poincare projection time. Lorentz will be + 1", - default=2, - type=int, - ) - parser.add_argument( - "-n_items", help="How many items to embed?", default=None, type=int - ) - parser.add_argument( - "-learning_rate", help="RSGD learning rate", default=0.1, type=float - ) - parser.add_argument( - "-log_step", help="Log at what multiple of epochs?", default=1, type=int - ) - parser.add_argument( - "-logdir", help="What folder to put logs in", default="runs", type=str - ) - parser.add_argument( - "-save_step", help="Save at what multiple of epochs?", default=100, type=int - ) - parser.add_argument( - "-savedir", help="What folder to put checkpoints in", default="ckpt", type=str - ) - parser.add_argument( - "-loader_workers", - help="How many workers to generate tensors", - default=4, - type=int, - ) - args = parser.parse_args() - # ----------------------------------- get the correct matrix - if not os.path.exists(args.logdir): - os.mkdir(args.logdir) - if not os.path.exists(args.savedir): - os.mkdir(args.savedir) - - exec(f"from datasets import {args.dataset} as pairwise") - pairwise = pairwise[: args.n_items, : args.n_items] - args.n_items = len(pairwise) if args.n_items is None else args.n_items - print(f"{args.n_items} being embedded") - - # ---------------------------------- Generate the proper objects - net = Lorentz( - args.n_items, args.poincare_dim + 1 - ) # as the paper follows R^(n+1) for this space - if args.plot: - if args.poincare_dim != 2: - print("Only embeddings with `-poincare_dim` = 2 are supported for now.") - sys.exit(1) - if args.ckpt is None: - print("Please provide `-ckpt` when using `-plot`") - sys.exit(1) - if os.path.isdir(args.ckpt): - paths = [ - os.path.join(args.ckpt, c) - for c in os.listdir(args.ckpt) - if c.endswith("ckpt") - ] - else: - paths = [args.ckpt] - paths = list(sorted(paths)) - edges = [ - tuple(edge) - for edge in set( - [ - frozenset((a + 1, b + 1)) - for a, row in enumerate(pairwise > 0) - for b, is_non_zero in enumerate(row) - if is_non_zero - ] - ) - ] - print(len(edges), "nodes") - internal_nodes = set( - node - for node, count in Counter( - [node for edge in edges for node in edge] - ).items() - if count > 1 - ) - edges = np.array([edge for edge in edges if edge[1] in internal_nodes]) - print(len(edges), "internal nodes") - for path in tqdm(paths, desc="Plotting"): - save_path = f"{path}.svg" - if os.path.exists(save_path) and not args.overwrite_plots: - continue - net.load_state_dict(torch.load(path)) - table = net.lorentz_to_poincare() - # skip padding. plot x y - plt.figure(figsize=(7, 7)) - if args.plot_graph: - for edge in edges: - plt.plot( - table[edge, 0], - table[edge, 1], - color="black", - marker="o", - alpha=0.5, - ) - else: - plt.scatter(table[1:, 0], table[1:, 1]) - plt.title(path) - plt.gca().set_xlim(-1, 1) - plt.gca().set_ylim(-1, 1) - plt.gca().add_artist(plt.Circle((0, 0), 1, fill=False, edgecolor="black")) - plt.savefig(save_path) - plt.close() - sys.exit(0) - - dataloader = DataLoader( - Graph(pairwise, args.sample_size), - shuffle=args.shuffle, - batch_size=args.batch_size, - num_workers=args.loader_workers, - ) - rsgd = RSGD(net.parameters(), learning_rate=args.learning_rate) - - name = f"{args.dataset} {datetime.utcnow()}" - writer = SummaryWriter(f"{args.logdir}/{name}") - - with tqdm(ncols=80, mininterval=0.2) as epoch_bar: - for epoch in range(args.epochs): - rsgd.learning_rate = ( - args.learning_rate / args.burn_c - if epoch < args.burn_epochs - else args.learning_rate - ) - for I, Ks in dataloader: - rsgd.zero_grad() - loss = net(I, Ks).mean() - loss.backward() - rsgd.step() - writer.add_scalar("loss", loss, epoch) - writer.add_scalar( - "recon_preform", recon(net.get_lorentz_table(), pairwise), epoch - ) - writer.add_scalar("table_test", net._test_table(), epoch) - if epoch % args.save_step == 0: - torch.save(net.state_dict(), f"{args.savedir}/{epoch} {name}.ckpt") - epoch_bar.set_description( - f"🔥 Burn phase loss: {float(loss)}" - if epoch < args.burn_epochs - else _moon(loss) - ) - epoch_bar.update(1) diff --git a/lorentz_embeddings/lorentz.py b/lorentz_embeddings/lorentz.py new file mode 100644 index 0000000..34cf16e --- /dev/null +++ b/lorentz_embeddings/lorentz.py @@ -0,0 +1,236 @@ +import os +import sys +import torch +import random +import numpy as np +from torch import nn +from torch import optim +from tqdm import trange, tqdm +from collections import Counter +from datetime import datetime +from tensorboardX import SummaryWriter +from torch.utils.data import Dataset, DataLoader +import datasets +import pickle + +import matplotlib + +matplotlib.use("Agg") # this needs to come before other matplotlib imports +import matplotlib.pyplot as plt + +plt.style.use("ggplot") + + +def arcosh(x): + return torch.log(x + torch.sqrt(x ** 2 - 1)) + + +def lorentz_scalar_product(x, y): + # BD, BD -> B + m = x * y + result = m[:, 1:].sum(dim=1) - m[:, 0] + return result + + +def tangent_norm(x): + # BD -> B + return torch.sqrt(lorentz_scalar_product(x, x)) + + +def exp_map(x, v): + # BD, BD -> BD + tn = tangent_norm(v).unsqueeze(dim=1) + tn_expand = tn.repeat(1, x.size()[-1]) + result = torch.cosh(tn) * x + torch.sinh(tn) * (v / tn) + result = torch.where(tn_expand > 0, result, x) # only update if tangent norm is > 0 + return result + + +def set_dim0(x): + x = torch.renorm(x, p=2, dim=0, maxnorm=1e2) # otherwise leaves will explode + # NOTE: the paper does not mention the square part of the equation but if + # you try to derive it you get a square term in the equation + dim0 = torch.sqrt(1 + (x[:, 1:] ** 2).sum(dim=1)) + x[:, 0] = dim0 + return x + + +# ========================= models + + +class RSGD(optim.Optimizer): + def __init__(self, params, learning_rate=None): + learning_rate = learning_rate if learning_rate is not None else 0.01 + defaults = {"learning_rate": learning_rate} + super().__init__(params, defaults=defaults) + + def step(self): + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + B, D = p.size() + gl = torch.eye(D, device=p.device, dtype=p.dtype) + gl[0, 0] = -1 + grad_norm = torch.norm(p.grad.data) + grad_norm = torch.where(grad_norm > 1, grad_norm, torch.tensor(1.0).to(p.device)) + # only normalize if global grad_norm is more than 1 + h = (p.grad.data / grad_norm) @ gl + proj = ( + h + - ( + lorentz_scalar_product(p, h) / lorentz_scalar_product(p, p) + ).unsqueeze(1) + * p + ) + # print(p, lorentz_scalar_product(p, p)) + update = exp_map(p, -group["learning_rate"] * proj) + is_nan_inf = torch.isnan(update) | torch.isinf(update) + update = torch.where(is_nan_inf, p, update) + update[0, :] = p[0, :] # no ❤️ for embedding + update = set_dim0(update) + p.data.copy_(update) + + +class Lorentz(nn.Module): + """ + This will embed `n_items` in a `dim` dimensional lorentz space. + """ + + def __init__(self, n_items, dim, init_range=0.001): + super().__init__() + self.n_items = n_items + self.dim = dim + self.table = nn.Embedding(n_items + 1, dim, padding_idx=0) + nn.init.uniform_(self.table.weight, -init_range, init_range) + # equation 6 + with torch.no_grad(): + self.table.weight[0] = 5 # padding idx push it to corner + set_dim0(self.table.weight) + + def forward(self, I, Ks): + """ + Using the pairwise similarity matrix, generate the following inputs and + provide to this function. + + Inputs: + - I : - long tensor + - size (B,) + - This denotes the `i` used in all equations. + - Ks : - long tensor + - size (B, N) + - This denotes at max `N` documents which come from the + nearest neighbor sample. + - The `j` document must be the first of the N indices. + This is used to calculate the losses + Return: + - size (B,) + - Ranking loss calculated using + document to the given `i` document. + + """ + n_ks = Ks.size()[1] + ui = torch.stack([self.table(I)] * n_ks, dim=1) + uks = self.table(Ks) + # ---------- reshape for calculation + B, N, D = ui.size() + ui = ui.reshape(B * N, D) + uks = uks.reshape(B * N, D) + dists = -lorentz_scalar_product(ui, uks) + dists = torch.where(dists <= 1, torch.ones_like(dists) + 1e-6, dists) + # sometimes 2 embedding can come very close in R^D. + # when calculating the lorenrz inner product, + # -1 can become -0.99(no idea!), then arcosh will become nan + dists = -arcosh(dists) + # print(dists) + # ---------- turn back to per-sample shape + dists = dists.reshape(B, N) + loss = -(dists[:, 0] - torch.log(torch.exp(dists).sum(dim=1) + 1e-6)) + return loss + + def lorentz_to_poincare(self): + table = self.table.weight.data.cpu().numpy() + return table[:, 1:] / ( + table[:, :1] + 1 + ) # diffeomorphism transform to poincare ball + + def get_lorentz_table(self): + return self.table.weight.data.cpu().numpy() + + def _test_table(self): + x = self.table.weight.data + check = lorentz_scalar_product(x, x) + 1.0 + return check.cpu().numpy().sum() + + +class Graph(Dataset): + def __init__(self, pairwise_matrix, batch_size, sample_size=10): + self.pairwise_matrix = pairwise_matrix + self.n_items = pairwise_matrix.shape[0] + self.sample_size = sample_size + self.arange = np.arange(0, self.n_items) + self.cnter = 0 + self.batch_size = batch_size + + def __len__(self): + return self.n_items + + def __getitem__(self, i): + self.cnter = (self.cnter + 1) % self.batch_size + I = torch.Tensor([i + 1]).squeeze().long() + has_child = (self.pairwise_matrix[i] > 0).sum() + has_parent = (self.pairwise_matrix[:, i] > 0).sum() + if self.cnter == 0: + arange = np.random.permutation(self.arange) + else: + arange = self.arange + + if has_parent: # if no child go for parent + valid_idxs = arange[self.pairwise_matrix[arange, i].nonzero()[0]] + j = valid_idxs[0] + min = self.pairwise_matrix[j,i] + elif has_child: + valid_idxs = arange[self.pairwise_matrix[i, arange].nonzero()[1]] + j = valid_idxs[0] + min = self.pairwise_matrix[i,j] + else: + raise Exception(f"Node {i} has no parent and no child") + indices = arange + indices = indices[indices != i] + if has_child: + indices = indices[(self.pairwise_matrix[i,indices] < min).nonzero()[0]] + else: + indices = indices[(self.pairwise_matrix[indices, i] < min).nonzero()[1]] + + indices = indices[: self.sample_size] + #print(indices) + #raise NotImplementedError() + Ks = np.concatenate([[j], indices, np.zeros(self.sample_size)])[ + : self.sample_size + ] + # print(I, Ks) + return I, torch.Tensor(Ks).long() + + +def recon(table, pair_mat): + "Reconstruction accuracy" + count = 0 + table = torch.tensor(table[1:]) + n = pair_mat.shape[0] + for i in range(1, n): # 0 padding, 1 root, we leave those two + x = table[i].repeat(len(table)).reshape([len(table), len(table[i])]) # N, D + mask = torch.tensor([0.0] * len(table)) + mask[i] = 1 + mask = mask * -10000.0 + dists = lorentz_scalar_product(x, table) + mask + dists = ( + dists.cpu().numpy() + ) # arccosh is monotonically increasing, so no need of that here + # and no -dist also, as acosh in m i, -acosh(-l(x,y)) is nothing but l(x,y) + # print(dists) + predicted_parent = np.argmax(dists) + actual_parent = np.argmax(pair_mat[:, i]) + # print(predicted_parent, actual_parent, i, end="\n\n") + count += actual_parent == predicted_parent + count = count / (pair_mat.shape[0] - 1) * 100 + return count diff --git a/main.py b/main.py new file mode 100644 index 0000000..506e373 --- /dev/null +++ b/main.py @@ -0,0 +1,245 @@ +import os +import sys +import torch +import random +import numpy as np +from torch import nn +from torch import optim +from tqdm import trange, tqdm +from collections import Counter +from datetime import datetime +from tensorboardX import SummaryWriter +from torch.utils.data import Dataset, DataLoader +import datasets +import pickle +from lorentz_embeddings.lorentz import * +import matplotlib + +matplotlib.use("Agg") # this needs to come before other matplotlib imports +import matplotlib.pyplot as plt +plt.style.use("ggplot") + + +import warnings +warnings.filterwarnings('ignore') + + +_moon_count = 0 + + +def _moon(loss, phases="🌕🌖🌗🌘🌑🌒🌓🌔"): + global _moon_count + _moon_count += 1 + p = phases[_moon_count % 8] + return f"{p} Loss: {float(loss)}" + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("dataset", help="File:pairwise_matrix") + parser.add_argument( + "-sample_size", help="How many samples in the N matrix", default=5, type=int + ) + parser.add_argument( + "-batch_size", help="How many samples in the batch", default=32, type=int + ) + parser.add_argument( + "-burn_c", + help="Divide learning rate by this for the burn epochs", + default=10, + type=int, + ) + parser.add_argument( + "-burn_epochs", + help="How many epochs to run the burn phase for?", + default=1, + type=int, + ) + parser.add_argument( + "-plot", help="Plot the embeddings", default=False, action="store_true" + ) + parser.add_argument("-plot_size", help="Size of the plot", default=3, type=int) + parser.add_argument( + "-plot_graph", + help="Plot the Graph associated with the embeddings", + default=False, + action="store_true", + ) + parser.add_argument( + "-overwrite_plots", + help="Overwrite the plots?", + default=False, + action="store_true", + ) + parser.add_argument( + "-ckpt", help="Which checkpoint to use?", default=None, type=str + ) + parser.add_argument( + "-shuffle", help="Shuffle within batch while learning?", default=True, type=bool + ) + parser.add_argument( + "-epochs", help="How many epochs to optimize for?", default=10, type=int + ) + parser.add_argument( + "-poincare_dim", + help="Poincare projection time. Lorentz will be + 1", + default=2, + type=int, + ) + parser.add_argument( + "-n_items", help="How many items to embed?", default=None, type=int + ) + parser.add_argument( + "-learning_rate", help="RSGD learning rate", default=0.1, type=float + ) + parser.add_argument( + "-log_step", help="Log at what multiple of epochs?", default=1, type=int + ) + parser.add_argument("-log", default=False, type=bool) + parser.add_argument( + "-logdir", help="What folder to put logs in", default="runs", type=str + ) + parser.add_argument( + "-save_step", help="Save at what multiple of epochs?", default=100, type=int + ) + parser.add_argument( + "-savedir", help="What folder to put checkpoints in", default="ckpt", type=str + ) + parser.add_argument( + "-loader_workers", + help="how many workers to generate tensors", + default=4, + type=int, + ) + + parser.add_argument( + "-device", + default='cuda' + ) + args = parser.parse_args() + # ----------------------------------- get the correct matrix + if not os.path.exists(args.logdir): + os.mkdir(args.logdir) + if not os.path.exists(args.savedir): + os.mkdir(args.savedir) + + if args.dataset.split('.')[:1] == ['pkl'] and os.path.exists(args.dataset): + pairwise = pickle.load(open(args.dataset, 'rb')) + else: + pairwise = datasets.get_dataset(args.dataset) + pairwise = pairwise[: args.n_items, : args.n_items] + args.n_items = pairwise.shape[0] if args.n_items is None else args.n_items + print(f"{args.n_items} being embetasksdded") + + # ---------------------------------- Generate the proper objects + net = Lorentz( + args.n_items, args.poincare_dim + 1 + ).to(args.device) # as the paper follows R^(n+1) for this space + if args.plot: + if args.poincare_dim != 2: + print("Only embeddings with `-poincare_dim` = 2 are supported for now.") + sys.exit(1) + if args.ckpt is None: + print("Please provide `-ckpt` when using `-plot`") + sys.exit(1) + if os.path.isdir(args.ckpt): + paths = [ + os.path.join(args.ckpt, c) + for c in os.listdir(args.ckpt) + if c.endswith("ckpt") + ] + else: + paths = [args.ckpt] + paths = list(sorted(paths)) + edges = [ + tuple(edge) + for edge in set( + [ + frozenset((a + 1, b + 1)) + for a, row in enumerate(pairwise > 0) + for b, is_non_zero in enumerate(row) + if is_non_zero + ] + ) + ] + print(len(edges), "nodes") + internal_nodes = set( + node + for node, count in Counter( + [node for edge in edges for node in edge] + ).items() + if count > 1 + ) + edges = np.array([edge for edge in edges if edge[1] in internal_nodes]) + print(len(edges), "internal nodes") + for path in tqdm(paths, desc="Plotting"): + save_path = f"{path}.svg" + if os.path.exists(save_path) and not args.overwrite_plots: + continue + net.load_state_dict(torch.load(path)) + table = net.lorentz_to_poincare() + # skip padding. plot x y + plt.figure(figsize=(7, 7)) + if args.plot_graph: + for edge in edges: + plt.plot( + table[edge, 0], + table[edge, 1], + color="black", + marker="o", + alpha=0.5, + ) + else: + plt.scatter(table[1:, 0], table[1:, 1]) + plt.title(path) + plt.gca().set_xlim(-1, 1) + plt.gca().set_ylim(-1, 1) + plt.gca().add_artist(plt.Circle((0, 0), 1, fill=False, edgecolor="black")) + plt.savefig(save_path) + plt.close() + sys.exit(0) + + graph_dataset = Graph(pairwise, args.sample_size, args.batch_size) + dataloader = DataLoader( + graph_dataset, + shuffle=args.shuffle, + batch_size=args.batch_size, + num_workers=args.loader_workers, + ) + rsgd = RSGD(net.parameters(), learning_rate=args.learning_rate) + + name = f"{args.dataset} {datetime.utcnow()}" + if args.log: + writer = SummaryWriter(f"{args.logdir}/{name}") + + with tqdm(ncols=80, mininterval=0.2, total=args.epochs) as epoch_bar: + for epoch in range(args.epochs): + rsgd.learning_rate = ( + args.learning_rate / args.burn_c + if epoch < args.burn_epochs + else args.learning_rate + ) + for I, Ks in tqdm(dataloader, total=int(np.ceil(graph_dataset.n_items / args.batch_size))): + I = I.to(args.device) + Ks = Ks.to(args.device) + rsgd.zero_grad() + loss = net(I, Ks).mean() + loss.backward() + rsgd.step() + if args.log: + writer.add_scalar("loss", loss, epoch) + writer.add_scalar( + "recon_preform", recon(net.get_lorentz_table(), pairwise), epoch + ) + writer.add_scalar("table_test", net._test_table(), epoch) + if epoch % args.save_step == 0: + torch.save(net.state_dict(), f"{args.savedir}/{epoch} {name}.ckpt") + epoch_bar.set_description( + f"🔥 Burn phase loss: {float(loss)}" + if epoch < args.burn_epochs + else _moon(loss) + ) + epoch_bar.update(1) + torch.save(net.state_dict(), args.dataset.replace(".", "_") + "call_graph_hyper_embeddings.pt")