diff --git a/src/algos/complex_nn.py b/src/algos/complex_nn.py index e69de29..71c0143 100644 --- a/src/algos/complex_nn.py +++ b/src/algos/complex_nn.py @@ -0,0 +1,153 @@ +from algos.simba_algo import SimbaDefence +import torch.nn.functional as F +import torch +import torch.nn as nn +import numpy as np +from models.complex_models import Discriminator, RealToComplex, ComplexToReal, ResNetEncoderComplex, ResNetDecoderComplex + +def get_encoder_output_size(encoder, dims): + x = torch.randn((1,)+dims) + with torch.no_grad(): + out = encoder(x) + if type(out) == tuple: + out = out[0] + return list(out.size())[1:] + +class ComplexNN(SimbaDefence): + def __init__(self, config, utils) -> None: + super(ComplexNN, self).__init__(utils) + self.initialize(config) + + def initialize(self, config): + self.optimizer_idx = 0 + self.encoder_model,self.decoder_model = self.init_client_model(config) + img_size = config["img_size"] + size = get_encoder_output_size(self.encoder_model, (3,img_size,img_size)) + self.discriminator = Discriminator(size=size) + models = [self.encoder_model, self.decoder_model, self.discriminator] + self.put_on_gpus(models) + + self.utils.register_model("encoder_model", self.encoder_model) + self.utils.register_model("discriminator_model", self.discriminator) + self.utils.register_model("decoder_model", self.decoder_model) + self.optim_encoder , self.optim_decoder , self.optim_discriminator = self.init_optim(config, self.encoder_model, self.decoder_model, self.discriminator) + + self.real_to_complex = RealToComplex() + self.complex_to_real = ComplexToReal() + self.loss_fn = F.cross_entropy + self.alpha = config["alpha"] + self.k = config["k"] + + self.loss_tag = "decoder_loss" + self.acc_tag = "decoder_acc" + tags = [self.loss_tag, self.acc_tag] + for tag in tags: + self.utils.logger.register_tag("train/" + tag) + self.utils.logger.register_tag("val/" + tag) + + def put_on_gpus(self,models): + for model in models: + model = self.utils.model_on_gpus(model) + + def init_client_model(self, config): + if config["model_name"] == "resnet20complex": + encoder_model = ResNetEncoderComplex(3) + decoder_model = ResNetDecoderComplex(3, config["logits"], "alpha") + else: + print("can't find complex client model") + exit() + + return encoder_model,decoder_model + + def init_optim(self, config, encoder, decoder, discriminator): + encoder_parameters = encoder.parameters() + decoder_parameters = decoder.parameters() + + if config["optimizer"] == "adam": + optimizer_e = torch.optim.Adam(encoder_parameters, + lr=config["lr"], + ) + + optimizer_decoder = torch.optim.Adam(decoder_parameters) + + optimizer_discriminator = torch.optim.Adam( + discriminator.parameters(), + lr=config["lr"], + ) + else: + print("Unknown optimizer {}".format(config["optimizer"])) + return optimizer_e,optimizer_decoder,optimizer_discriminator + + def train(self): + self.mode = "train" + self.encoder_model.train() + self.decoder_model.train() + + def eval(self): + self.mode = "val" + self.encoder_model.eval() + self.decoder_model.eval() + + def forward(self, items): + inp = items["x"] + # Pass through encoder + a = self.encoder_model(inp) + self.a = a + # Shuffle batch elements of a to create b + with torch.no_grad(): + indices = np.random.permutation(a.size(0)) + b = a[indices] + + self.z, self.theta = self.real_to_complex(a,b) + + # Get discriminator score expectation over k rotations + self.score_fake = 0 + for k in range(self.k): + # Shuffle batch to get b + indices = np.random.permutation(a.size(0)) + b = a[indices] + + # Rotate a + x, _ = self.real_to_complex(a,b) + a_rotated = x[:,0] + # Get discriminator score + self.score_fake += self.discriminator(a_rotated) + + self.score_fake /= self.k # Average score + z = self.z.detach() + z.requires_grad = True + return z + + def infer(self, h, labels): + h.retain_grad() + y = self.complex_to_real(h,self.theta) + y.retain_grad() + self.preds = self.decoder_model(y) + self.acc = (self.preds.argmax(dim=1) == labels).sum().item() / self.preds.shape[0] + self.utils.logger.add_entry(self.mode + "/" + self.acc_tag, self.acc) + if self.optimizer_idx%2 == 0: + g_loss_adv = -torch.mean(self.score_fake) + g_loss_ce = self.loss_fn(self.preds,labels) + loss = g_loss_adv + g_loss_ce + self.optim_decoder.zero_grad() + loss.backward(retain_graph=True) + self.optim_decoder.step() + self.utils.logger.add_entry(self.mode + "/" + self.loss_tag, loss.item()) + return h.grad + else: + for p in self.discriminator.parameters(): + p.data.clamp_(-0.01, 0.01) + self.d_loss_adv = -torch.mean(self.discriminator(self.a)) + torch.mean(self.score_fake) + self.optim_discriminator.zero_grad() + self.d_loss_adv.backward() + self.optim_discriminator.step() + return None + + def backward(self, items): + if self.optimizer_idx%2 == 0: + self.optim_encoder.zero_grad() + self.z.backward(items["server_grads"]) + self.optim_encoder.step() + self.optimizer_idx += 1 + + diff --git a/src/algos/gaussian_blur.py b/src/algos/gaussian_blur.py index 63d0250..9c5fa6a 100644 --- a/src/algos/gaussian_blur.py +++ b/src/algos/gaussian_blur.py @@ -21,7 +21,7 @@ class GaussianSmoothing(nn.Module): def __init__(self, channels, kernel_size, sigma, device, dim=2): super(GaussianSmoothing, self).__init__() if isinstance(kernel_size, numbers.Number): - kernel_size = [kernel_size] * dim + kernel_size = [kernel_size*math.sqrt(sigma)] * dim if isinstance(sigma, numbers.Number): sigma = [sigma] * dim diff --git a/src/algos/maxentropy.py b/src/algos/maxentropy.py index 1b48a5f..bfc29ea 100644 --- a/src/algos/maxentropy.py +++ b/src/algos/maxentropy.py @@ -1,3 +1,11 @@ +from algos.simba_algo import SimbaDefence +import torch +from torchvision import models +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss +torch.autograd.set_detect_anomaly(True) +distance = nn.CrossEntropyLoss() from algos.deepobfuscator import DeepObfuscator from utils.metrics import MetricLoader @@ -12,6 +20,35 @@ def forward(self, input): raise Exception('Entropy Loss takes probabilities 0<=input<=1') input = input + 1e-16 # for numerical stability while taking log + H = torch.mean(torch.sum(input * torch.log(input), dim=0)) + + return H + +class MaxEntropy(SimbaDefence): + def __init__(self, config, utils) -> None: + super(MaxEntropy, self).__init__(utils) + self.initialize(config, utils.device) + + def initialize(self, config, device): + self.client_model = self.init_client_model(config) + self.put_on_gpus() + self.utils.register_model("client_model", self.client_model) + self.client_optim = self.init_optim(config, self.client_model) + self.entropy_loss_fn = EntropyLoss() + + def forward(self, items): + x = items["x"] + self.z = self.client_model(x) + z = self.z.detach() + z.requires_grad = True + return z + + def backward(self, items): + entropy_loss = self.entropy_loss_fn(items["pred_lbls"]) + entropy_loss.requires_grad = True + entropy_loss.backward() + self.z.backward(items["server_grads"]) + self.client_optim.step() H = torch.mean(torch.sum(input * torch.log(input), dim=1)) return H @@ -31,4 +68,4 @@ def update_loss(self): def get_adv_loss(self): # Since it is L1, it has to be minimized - return self.adv_loss \ No newline at end of file + return self.adv_loss diff --git a/src/algos/simba_algo.py b/src/algos/simba_algo.py index 71ec47f..6479a17 100644 --- a/src/algos/simba_algo.py +++ b/src/algos/simba_algo.py @@ -64,6 +64,9 @@ def init_optim(self, config, model): def put_on_gpus(self): self.client_model = self.utils.model_on_gpus(self.client_model) + def infer(self,data,labels): + pass + class SimbaAttack(nn.Module): def __init__(self, utils): @@ -101,6 +104,3 @@ def train(self): def eval(self): self.mode = "val" self.model.eval() - - - diff --git a/src/algos/supervised_decoder.py b/src/algos/supervised_decoder.py index 424db5a..b100365 100644 --- a/src/algos/supervised_decoder.py +++ b/src/algos/supervised_decoder.py @@ -53,6 +53,10 @@ def forward(self, items): self.loss = self.loss_fn(self.x, x) self.utils.logger.add_entry(self.mode + "/" + self.loss_tag, self.loss.item()) + + return self.x + + def backward(self, items): if self.mode == "val" and self.attribute == "data": prefix = "val/" diff --git a/src/configs/complex_nn.json b/src/configs/complex_nn.json new file mode 100644 index 0000000..6a41c1d --- /dev/null +++ b/src/configs/complex_nn.json @@ -0,0 +1,20 @@ +{ + "experiment_type": "challenge", + "method": "complex_nn", + "client": {"model_name": "resnet20complex", "split_layer": 6, + "pretrained": false,"logits": 2, "optimizer": "adam", "lr": 3e-4, + "alpha": 0.99, "k":5, "img_size":32}, + "server": {"model_name": "resnet20complex", "split_layer":6, "logits": 2, "pretrained": false, + "lr": 3e-4, "optimizer": "adam", "momentum": 0.99}, + "learning_rate": 0.1, + "total_epochs": 150, + "training_batch_size": 128, + "dataset": "fairface", + "protected_attribute": "data", + "prediction_attribute": "gender", + "img_size": 32, + "split": false, + "test_batch_size": 32, + "exp_id": "1", + "exp_keys": ["client.alpha","client.optimizer"] +} diff --git a/src/configs/decoder_attack.json b/src/configs/decoder_attack.json index adc63a2..059b091 100644 --- a/src/configs/decoder_attack.json +++ b/src/configs/decoder_attack.json @@ -14,4 +14,4 @@ "train_split": 0.9, "test_batch_size": 64, "exp_keys": ["train_split", "adversary.loss_fn"] -} \ No newline at end of file +} diff --git a/src/configs/deep_obfuscator.json b/src/configs/deep_obfuscator.json index 465967b..dcb4d32 100644 --- a/src/configs/deep_obfuscator.json +++ b/src/configs/deep_obfuscator.json @@ -17,4 +17,4 @@ "test_batch_size": 64, "exp_id": "1", "exp_keys": ["client.alpha"] -} \ No newline at end of file +} diff --git a/src/configs/maxentropy.json b/src/configs/maxentropy.json new file mode 100644 index 0000000..bbaccaa --- /dev/null +++ b/src/configs/maxentropy.json @@ -0,0 +1,18 @@ +{ + "method": "maxentropy", + "client": {"model_name": "resnet18", "split_layer": 6, + "pretrained": false, "optimizer": "adam", "lr": 3e-4}, + "server": {"model_name": "resnet18", "split_layer":6, "logits": 2, "pretrained": false, + "lr": 3e-4, "optimizer": "adam"}, + "learning_rate": 0.01, + "total_epochs": 150, + "training_batch_size": 256, + "dataset": "fairface", + "protected_attribute": "data", + "prediction_attribute": "gender", + "img_size": 128, + "split": false, + "test_batch_size": 64, + "exp_id": "1", + "exp_keys": ["client.optimizer"] +} \ No newline at end of file diff --git a/src/configs/nopeek.json b/src/configs/nopeek.json index 72ac56d..620bc88 100644 --- a/src/configs/nopeek.json +++ b/src/configs/nopeek.json @@ -16,4 +16,4 @@ "test_batch_size": 64, "exp_id": "1", "exp_keys": ["client.alpha"] -} \ No newline at end of file +} diff --git a/src/configs/pan.json b/src/configs/pan.json index 28c402d..380b76b 100644 --- a/src/configs/pan.json +++ b/src/configs/pan.json @@ -1,5 +1,4 @@ { - "experiment_type": "challenge", "method": "pan", "client": {"model_name": "resnet18", "split_layer": 6, "pretrained": false, "optimizer": "adam", "lr": 3e-4, @@ -18,4 +17,4 @@ "test_batch_size": 64, "exp_id": "1", "exp_keys": ["client.alpha"] -} \ No newline at end of file +} diff --git a/src/configs/siamese_embedding.json b/src/configs/siamese_embedding.json index 60ee109..1d35bc2 100644 --- a/src/configs/siamese_embedding.json +++ b/src/configs/siamese_embedding.json @@ -16,4 +16,4 @@ "test_batch_size": 64, "exp_id": "1", "exp_keys": ["client.alpha", "client.margin"] -} \ No newline at end of file +} diff --git a/src/configs/uniform_noise.json b/src/configs/uniform_noise.json index c445293..77962c2 100644 --- a/src/configs/uniform_noise.json +++ b/src/configs/uniform_noise.json @@ -16,4 +16,4 @@ "test_batch_size": 64, "exp_id": "1", "exp_keys": ["client.distribution", "client.mean", "client.sigma"] -} \ No newline at end of file +} diff --git a/src/data/loaders.py b/src/data/loaders.py index ca175a9..9a2e9d9 100644 --- a/src/data/loaders.py +++ b/src/data/loaders.py @@ -1,6 +1,7 @@ import numpy as np import torch from torchvision import transforms +from data.dataset_utils import FairFace,Cifar10, CelebA, Cifar10_2, LFW#, UTKFace from data.dataset_utils import FairFace, CelebA, Cifar10, LFW#, UTKFace, Cifar10_2 from data.dataset_utils import Challenge diff --git a/src/interface.py b/src/interface.py index 416f3f5..95d34fd 100644 --- a/src/interface.py +++ b/src/interface.py @@ -5,9 +5,10 @@ from algos.pca_embedding import PCAEmbedding from algos.deepobfuscator import DeepObfuscator from algos.pan import PAN +from algos.complex_nn import ComplexNN from algos.gaussian_blur import GaussianBlur from algos.linear_correlation import LinearCorrelation - +from algos.maxentropy import MaxEntropy from algos.supervised_decoder import SupervisedDecoder from algos.cloak import Cloak from algos.shredder import Shredder @@ -46,6 +47,8 @@ def load_algo(config, utils, dataloader=None): algo = UniformNoise(config["client"], utils) elif method == "siamese_embedding": algo = SiameseEmbedding(config["client"], utils) + elif method == "complex_nn": + algo = ComplexNN(config["client"], utils) elif method == "pca_embedding": algo = PCAEmbedding(config["client"], utils) elif method == "deep_obfuscator": @@ -62,6 +65,8 @@ def load_algo(config, utils, dataloader=None): algo = GaussianBlur(config["client"], utils) elif method == "linear_correlation": algo = LinearCorrelation(config["client"], utils) + elif method == "maxentropy": + algo = MaxEntropy(config["client"], utils) elif method == "supervised_decoder": item = next(iter(dataloader)) z = item["z"] diff --git a/src/main.py b/src/main.py index 933b3da..b382ad5 100644 --- a/src/main.py +++ b/src/main.py @@ -14,4 +14,3 @@ scheduler = Scheduler(args) scheduler.run_job() - diff --git a/src/models/complex_models.py b/src/models/complex_models.py new file mode 100644 index 0000000..ff22833 --- /dev/null +++ b/src/models/complex_models.py @@ -0,0 +1,526 @@ +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + +def get_real_imag_parts(x): + ''' + Extracts the real and imaginary component tensors from a complex number tensor + + Input: + x: Complex number tensor of size [b,2,c,h,w] + Output: + real component tensor of size [b,c,h,w] + imaginary component tensor of size [b,c,h,w] + ''' + assert(x.size(1) == 2) # Complex tensor has real and imaginary components in 2nd dim + return x[:,0], x[:,1] + +def complex_norm(x): + ''' + Calculates the complex norm for each complex element in a tensor + + Input: + x: Complex number tensor of size [b,2,c,h,w] + Output: + tensor of norm values of size [b,c,h,w] + ''' + assert(x.size(1) == 2) # Complex tensor has real and imaginary components in 2nd dim + x_real, x_imag = get_real_imag_parts(x) + x_real = x_real.clone() + x_imag = x_imag.clone() + return torch.sqrt(torch.pow(x_real, 2) + torch.pow(x_imag, 2) + 1e-5) + +class RealToComplex(nn.Module): + ''' + Converts a real value tensor a into a complex value tensor x (Eq. 2). + Adds a fooling counterpart b and rotates the tensor by a random angle theta. + Returns theta for later decoding. + + Shape: + Input: + a: [b,c,h,w] + b: [b,c,h,w] + Output: + x: [b,2,c,h,w] + theta: [1] + ''' + def __init__(self): + super(RealToComplex, self).__init__() + + def forward(self, a, b): + # Randomly choose theta + theta = a.new(a.size(0)).uniform_(0, 2*np.pi) + + # Convert to complex and rotate by theta + real = a*torch.cos(theta)[:, None, None, None] - \ + b*torch.sin(theta)[:, None, None, None] + imag = b*torch.cos(theta)[:, None, None, None] + \ + a*torch.sin(theta)[:, None, None, None] + x = torch.stack((real, imag), dim=1) + + return x, theta + +class ComplexToReal(nn.Module): + ''' + Decodes a complex value tensor h into a real value tensor y by rotating + by -theta (Eq. 3). + + Shape: + Input: + h: [b,2,c,h,w] + theta: [b] + Output: [b,c,h,w] + ''' + def __init__(self): + super(ComplexToReal, self).__init__() + + def forward(self, h, theta): + # Apply opposite rotation to decode + a, b = get_real_imag_parts(h) + if a.dim() == 4: + y = a*torch.cos(-theta)[:, None, None, None] - \ + b*torch.sin(-theta)[:, None, None, None] # Only need real component + else: + y = a*torch.cos(-theta)[:, None] - \ + b*torch.sin(-theta)[:, None] # Only need real component + return y + +class ActivationComplex(nn.Module): + ''' + Complex activation function from Eq. 6. + + Args: + c: Positive constant (>0) from Eq. 6. Default: 1 + Shape: + Input: [b,2,c,h,w] + Output: [b,2,c,h,w] + ''' + def __init__(self, c=1): + super(ActivationComplex, self).__init__() + assert(c>0) + self.c = torch.Tensor([c]) + + def forward(self, x): + x_norm = complex_norm(x).unsqueeze(1) + c = self.c.to(x.device) + scale = x_norm/torch.maximum(x_norm, c) + return x*scale + +def activation_complex(x, c): + ''' + Complex activation function from Eq. 6. This is a functional api to + use in networks that don't have a static c value (AlexNet, LeNet, etc.). + + Input: + x: Complex number tensor of size [b,2,c,h,w] + c: Positive constant (>0) from Eq. 6. + Output: + output tensor of size [b,2,c,h,w] + ''' + assert(c>0) + x_norm = complex_norm(x).unsqueeze(1) + c = torch.Tensor([c]).to(x.device) + scale = x_norm/torch.maximum(x_norm, c) + return x*scale + +def activation_complex_dynamic(x): + ''' + Complex activation function from Eq. 6. This is a functional api to + use in networks that don't have a static c value (AlexNet, LeNet, etc.). + + Input: + x: Complex number tensor of size [b,2,c,h,w] or [b,2,f] + Output: + output tensor of size [b,2,c,h,w] or [b,2,f] + ''' + x_norm = complex_norm(x) + if x.dim() == 5: + # for [b,2,c,h,w] inputs + scale = x_norm.unsqueeze(1)/torch.maximum(x_norm.unsqueeze(1), + x_norm.mean((2, 3))[:, :, None, None].unsqueeze(1)) + else: + # for [b,2,f] inputs + scale = x_norm.unsqueeze(1)/torch.maximum(x_norm.unsqueeze(1), + x_norm.mean(1)[:, None, None]) + + return x*scale + +def activation_complex_dynamic(x): + ''' + Complex activation function from Eq. 6. This is a functional api to + use in networks that don't have a static c value (AlexNet, LeNet, etc.). + + Input: + x: Complex number tensor of size [b,2,c,h,w] or [b,2,f] + Output: + output tensor of size [b,2,c,h,w] or [b,2,f] + ''' + x_norm = complex_norm(x) + if x.dim() == 5: + # for [b,2,c,h,w] inputs + scale = x_norm.unsqueeze(1)/torch.maximum(x_norm.unsqueeze(1), + x_norm.mean((2, 3))[:, :, None, None].unsqueeze(1)) + else: + # for [b,2,f] inputs + scale = x_norm.unsqueeze(1)/torch.maximum(x_norm.unsqueeze(1), + x_norm.mean(1)[:, None, None]) + + return x*scale + +class MaxPool2dComplex(nn.Module): + ''' + Complex max pooling operation. Keeps the complex number feature with the maximum norm within + the window, keeping both the corresponding real and imaginary components. + + Args: + kernel_size: size of the window + stride: stride of the window. Default: kernel_size + padding: amount of zero padding. Default: 0 + dilation: element-wise stride in the window. Default: 1 + ceil_mode: use ceil instead of floor to compute the output shape. Default: False + Shape: + Input: [b,2,c,h_in,w_in] + Output: [b,2,c,h_out,w_out] + ''' + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False + ): + super(MaxPool2dComplex, self).__init__() + self.pool = nn.MaxPool2d( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=True + ) + + def get_indice_elements(self, x, indices): + ''' From: https://discuss.pytorch.org/t/pooling-using-idices-from-another-max-pooling/37209/4 ''' + x_flat = x.flatten(start_dim=2) + output = x_flat.gather(dim=2, index=indices.flatten(start_dim=2)).view_as(indices) + return output + + def forward(self, x): + x_real, x_imag = get_real_imag_parts(x) + x_norm = complex_norm(x) + + # Max pool complex feature norms and get indices + _, indices = self.pool(x_norm) + + # Extract the matching real and imaginary components of the max pooling indices + x_real = self.get_indice_elements(x_real, indices) + x_imag = self.get_indice_elements(x_imag, indices) + + return torch.stack((x_real, x_imag), dim=1) + +class DropoutComplex(nn.Module): + ''' + Complex dropout operation. Randomly zero out both the real and imaginary + components of a complex number feature. + + Args: + p: probability of an element being zeroed + Shape: + Input: [b,2,c,h,w] + Output: [b,2,c,h,w] + ''' + def __init__(self, p): + super(DropoutComplex, self).__init__() + self.p = p + self.dropout = nn.Dropout(p) + + def forward(self, x): + x_real, x_imag = get_real_imag_parts(x) + + # Apply dropout to real part + x_real = self.dropout(x_real) + + # Drop the same indices in the imaginary part + # and scale the rest by 1/1-p + if self.training: + mask = (x_real != 0).float()*(1/(1-self.p)) + x_imag *= mask + + return torch.stack((x_real, x_imag), dim=1) + +class LinearComplex(nn.Module): + ''' + Complex linear layer. The bias term is removed in order to leave the phase invariant. + + Args: + in_features: number of features of the input + out_features: number of channels of the produced output + Shape: + Input: [b,2,in_features] + Output: [b,2,out_features] + ''' + def __init__(self, in_features, out_features): + super(LinearComplex, self).__init__() + + self.linear_real = nn.Linear(in_features, out_features, bias=False) + self.linear_imag = nn.Linear(in_features, out_features, bias=False) + + def forward(self, x): + x_real, x_imag = get_real_imag_parts(x) + out_real = self.linear_real(x_real) - self.linear_imag(x_imag) + out_imag = self.linear_real(x_imag) + self.linear_imag(x_real) + return torch.stack((out_real, out_imag), dim=1) + +class Conv2dComplex(nn.Module): + ''' + Complex 2d convolution operation. Implementation the complex convolution from + https://arxiv.org/abs/1705.09792 (Section 3.2) and removes the bias term + to preserve phase. + + Args: + in_channels: number of channels in the input + out_channels: number of channels produced in the output + kernel_size: size of convolution window + stride: stride of convolution. Default: 1 + padding: amount of zero padding. Default: 0 + padding_mode: 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros' + groups: number of blocked connections from input to output channels. Default: 1 + Shape: + Input: [b,2,c,h_in,w_in] + Output: [b,2,c,h_out,w_out] + ''' + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + padding_mode='zeros', + groups=1, + ): + super(Conv2dComplex, self).__init__() + + self.in_channels = in_channels + + self.conv_real = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode, + groups=groups, + bias=False # Bias always false + ) + self.conv_imag = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode, + groups=groups, + bias=False # Bias always false + ) + + def forward(self, x): + x_real, x_imag = get_real_imag_parts(x) + out_real = self.conv_real(x_real) - self.conv_imag(x_imag) + out_imag = self.conv_real(x_imag) + self.conv_imag(x_real) + return torch.stack((out_real, out_imag), dim=1) + +class BatchNormComplex(nn.Module): + ''' + Complex batch normalization from Eq. 7. Code adapted from + https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py#L39 + + Args: + size: size of a single sample [c,h,w]. + momentum: exponential averaging momentum term for running mean. + Set to None for simple average. Default: 0.1 + track_running_stats: track the running mean for evaluation mode. Default: True + Shape: + Input: [b,2,c,h,w] + Output: [b,2,c,h,w] + ''' + def __init__( + self, + momentum=0.1, + track_running_stats=False + ): + super(BatchNormComplex, self).__init__() + + self.track_running_stats = track_running_stats + self.num_batches_tracked = 0 + self.momentum = momentum + self.running_mean = 0 + + def forward(self, x): + if self.track_running_stats == False: + # Calculate mean of complex norm + x_norm = torch.pow(complex_norm(x), 2) + mean = x_norm.mean([0]) + else: + # Setup exponential factor + ema_factor = 0.0 + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: + ema_factor = 1.0/float(self.num_batches_tracked) # Cumulative moving average + else: + ema_factor = self.momentum + + # Calculate mean of complex norm + if self.training : + x_norm = torch.pow(complex_norm(x), 2) + mean = x_norm.mean([0]) + with torch.no_grad(): + self.running_mean = ema_factor * mean + (1-ema_factor) * self.running_mean + else: + if type(self.running_mean) == int: + mean = x.new(x.size(2), x.size(3), x.size(4))*0 + else: + mean = self.running_mean + + # Normalize + x /= torch.sqrt(mean[None, None, :, :, :] + 1e-5) + return x + +class ResidualBlock(nn.Module): + def __init__(self, channels, downsample): + super(ResidualBlock, self).__init__() + + self.downsample = downsample + self.channels = channels + + self.network = nn.Sequential( + nn.Conv2d(channels // 2 if downsample else channels, channels, kernel_size=3, padding=1, + stride=2 if downsample else 1, bias=False), + nn.BatchNorm2d(channels), + nn.ReLU(), + nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1, bias=False), + nn.BatchNorm2d(channels) + ) + + def forward(self, x): + if self.downsample: + out = self.network(x) + F.pad(x[..., ::2, ::2], (0, 0, 0, 0, self.channels // 4, self.channels // 4)) + else: + out = self.network(x) + x + + return F.relu(out) + + +class ResidualBlockComplex(nn.Module): + def __init__(self, channels, downsample): + super(ResidualBlockComplex, self).__init__() + + self.downsample = downsample + self.channels = channels + self.network = nn.Sequential( + Conv2dComplex(channels // 2 if downsample else channels, channels, kernel_size=3, padding=1, + stride=2 if downsample else 1), + BatchNormComplex(), + ActivationComplex(), + Conv2dComplex(channels, channels, kernel_size=3, padding=1, stride=1), + BatchNormComplex() + ) + + def forward(self, x): + if self.downsample: + out = self.network(x) + F.pad(x[..., ::2, ::2], (0, 0, 0, 0, self.channels // 4, self.channels // 4)) + else: + out = self.network(x) + x + + return activation_complex(out, 1) + +class ResNetDecoderComplex(nn.Module): + def __init__(self, n, num_classes, variant="alpha"): + super(ResNetDecoderComplex, self).__init__() + + conv_layers = [ResidualBlock(64, False)] + if variant == "alpha": + for i in range(n - 2): + conv_layers.append(ResidualBlock(64, False)) + + self.conv_layers = nn.Sequential(*conv_layers) + self.linear = nn.Linear(64, num_classes) + + def forward(self, x): + out = self.conv_layers(x) + out = out.mean([2, 3]) # global average pooling + return self.linear(torch.flatten(out, 1)) + + +class ResNetEncoderComplex(nn.Module): + def __init__(self, n, additional_layers=True): + super(ResNetEncoderComplex, self).__init__() + + network = [ + nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1, bias=False), + nn.BatchNorm2d(16), + nn.ReLU() + ] + + for i in range(n): + network.append(ResidualBlock(16, False)) + + if additional_layers: + network += [ + nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=1), + nn.ReLU(True) + ] + + self.network = nn.Sequential(*network) + + def forward(self, x): + return self.network(x) + +class ResNetProcessorComplex(nn.Module): + def __init__(self, n, variant="alpha"): + super(ResNetProcessorComplex, self).__init__() + + network = [ResidualBlockComplex(32, True)] + for i in range(n - 1): + network.append(ResidualBlockComplex(32, False)) + network.append(ResidualBlockComplex(64, True)) + + if variant == "beta": + for i in range(n - 2): + network.append(ResidualBlockComplex(64, False)) + + self.network = nn.Sequential(*network) + + def forward(self, x): + return self.network(x) + +class Discriminator(nn.Module): + ''' + Adversarial discriminator network. + Args: + size: List of input shape [c,h,w] + Shape: + Input: [b,c,h,w] + Output: [b,1] + ''' + def __init__(self, size): + super(Discriminator, self).__init__() + in_channels = size[0] + + self.net = nn.Sequential( + nn.Conv2d(in_channels, in_channels*2, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(in_channels*2), + nn.ReLU(True), + nn.Flatten(), + nn.Linear(2*size[0]*size[1]//2*size[2]//2, 1) + ) + + def forward(self, x): + x = self.net(x) + return x + + diff --git a/src/models/image_decoder.py b/src/models/image_decoder.py index 5a9375a..b0066f1 100644 --- a/src/models/image_decoder.py +++ b/src/models/image_decoder.py @@ -64,7 +64,6 @@ def __init__(self, config): input_nc = config["channels"] img_size = config["img_size"] patch_size = config["patch_size"] - # Because usually images get downsampled by a factor of 2 feat_map_size = img_size feat_map_list = [] @@ -77,7 +76,6 @@ def __init__(self, config): patch_channels.append(patch_channels[-1] * 2) # remove the last element to ensure the closest match is smallest patch_channels.pop() - # Find closest feature map index = closest(feat_map_list, patch_size) index_channel = closest(patch_channels, input_nc) diff --git a/src/models/model_zoo.py b/src/models/model_zoo.py index 67b7474..77b756b 100644 --- a/src/models/model_zoo.py +++ b/src/models/model_zoo.py @@ -1,7 +1,8 @@ import torch import torchvision.models as models import torch.nn as nn - +import torch.nn.functional as F +from models.complex_models import ResNetProcessorComplex from models.Unet import StochasticUNet from models.Xception import Xception @@ -10,13 +11,14 @@ def __init__(self, config, utils): super(Model, self).__init__() self.loss_fn = nn.CrossEntropyLoss() self.utils = utils - self.loss_tag = "server_loss" - self.acc_tag = "server_acc" - self.utils.logger.register_tag("train/" + self.loss_tag) - self.utils.logger.register_tag("val/" + self.loss_tag) - self.utils.logger.register_tag("train/" + self.acc_tag) - self.utils.logger.register_tag("val/" + self.acc_tag) - + if config["model_name"] != "resnet20complex": + self.loss_tag = "server_loss" + self.acc_tag = "server_acc" + self.utils.logger.register_tag("train/" + self.loss_tag) + self.utils.logger.register_tag("val/" + self.loss_tag) + self.utils.logger.register_tag("train/" + self.acc_tag) + self.utils.logger.register_tag("val/" + self.acc_tag) + self.config = config self.assign_model(config) self.assign_optim(config) @@ -63,6 +65,8 @@ def assign_model(self, config): nn.Linear(num_ftrs, logits)) model = nn.ModuleList(list(model.children())[self.split_layer:]) self.model = nn.Sequential(*model) + elif config["model_name"] == "resnet20complex": + self.model = ResNetProcessorComplex(3,'alpha') if config["model_name"].startswith("vgg"): num_ftrs = model.classifier[0].in_features @@ -71,9 +75,6 @@ def assign_model(self, config): model = nn.ModuleList(list(model.children())[self.split_layer:]) self.model = nn.Sequential(*model) - - - self.model = self.utils.model_on_gpus(self.model) self.utils.register_model("server_model", self.model) @@ -82,28 +83,48 @@ def assign_optim(self, config): if config["optimizer"] == "adam": self.optim = torch.optim.Adam(self.model.parameters(), lr) - def forward(self, x): - x = self.model(x) - return nn.functional.softmax(x, dim=1) + def forward(self, z): + self.z = z + self.z.retain_grad() + if self.config["model_name"] == "resnet20complex": + self.h = self.model(self.z) + h = self.h.detach() + h.requires_grad = True + return h + else: + x = self.model(self.z) + self.preds = nn.functional.softmax(x, dim=1) + return self.preds def compute_loss(self, preds, y): - self.loss = self.loss_fn(preds, y) - self.utils.logger.add_entry(self.mode + "/" + self.loss_tag, - self.loss.item()) - self.utils.logger.add_entry(self.mode + "/" + self.acc_tag, - (preds, y), "acc") + if self.config["model_name"] != "resnet20complex": + self.loss = self.loss_fn(preds, y) + self.utils.logger.add_entry(self.mode + "/" + self.loss_tag, + self.loss.item()) + self.utils.logger.add_entry(self.mode + "/" + self.acc_tag, + (preds, y), "acc") def optimize(self): - self.optim.zero_grad() - self.loss.backward() - self.optim.step() + if self.config["model_name"] == "resnet20complex": + self.optim.zero_grad() + self.h.backward(self.decoder_grads) + self.optim.step() + else: + self.optim.zero_grad() + self.loss.backward() + self.optim.step() + + def backward(self,y,decoder_grads=None): + if decoder_grads != None: + self.decoder_grads = decoder_grads + self.optimize() + if self.config["model_name"] != "resnet20complex": + self.compute_loss(self.preds, y) + self.optimize() + return self.z.grad def processing(self, z, y): z.retain_grad() preds = self.forward(z) self.compute_loss(preds, y) self.optimize() - return z.grad - - - diff --git a/src/scheduler.py b/src/scheduler.py index cb1a219..6091b01 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -57,7 +57,10 @@ def defense_train(self) -> None: for _, sample in enumerate(self.dataloader.train): items = self.utils.get_data(sample) z = self.algo.forward(items) - items["server_grads"] = self.model.processing(z, items["pred_lbls"]) + # self.utils.save_images((items["x"],z),items["filename"]) + data = self.model.forward(z) + items["decoder_grads"] = self.algo.infer(data,items["pred_lbls"]) + items["server_grads"] = self.model.backward(items["pred_lbls"],items["decoder_grads"]) self.algo.backward(items) def defense_test(self) -> None: @@ -66,20 +69,24 @@ def defense_test(self) -> None: for _, sample in enumerate(self.dataloader.test): items = self.utils.get_data(sample) z = self.algo.forward(items) - self.model.processing(z, items["pred_lbls"]) + # self.utils.save_images((items["x"],z),items["filename"]) + data = self.model.forward(z) + self.algo.infer(data,items["pred_lbls"]) + self.model.compute_loss(data,items["pred_lbls"]) def attack_train(self) -> None: self.algo.train() - for _, sample in enumerate(self.dataloader.train): + for _, sample in enumerate(self.dataloader.test): items = self.utils.get_data(sample) z = self.algo.forward(items) self.algo.backward(items) def attack_test(self): self.algo.eval() - for _, sample in enumerate(self.dataloader.train): + for _, sample in enumerate(self.dataloader.test): items = self.utils.get_data(sample) z = self.algo.forward(items) + self.utils.save_images((items["x"],z),items["filename"]) def epoch_summary(self): self.utils.logger.flush_epoch() diff --git a/src/utils/utils.py b/src/utils/utils.py index d9b9e0f..ea48e83 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -6,7 +6,7 @@ from shutil import copytree, copy2 from glob import glob import shutil - +from torchvision.utils import save_image class Utils(): def __init__(self, config) -> None: @@ -43,7 +43,7 @@ def get_data(self, sample): items["x"] = Variable(sample["img"]).to(self.device) items["pred_lbls"] = Variable(sample["prediction_label"]).to(self.device) items["prvt_lbls"] = Variable(sample["private_label"]).to(self.device) - items["filename"] = sample["filename"] + items["filename"] = sample["filename"] return items def copy_source_code(self, path): @@ -90,7 +90,22 @@ def save_data(self, z, filename, challenge_dir): for ele in range(int(z.shape[0])): z_path = challenge_dir + filename[ele] + '.pt' torch.save(z[ele].detach().cpu(), z_path) - + + def save_images(self,x_and_z,filename): + x,z = x_and_z + filepath = self.config["log_path"] + "rec_images/" + if not os.path.isdir(filepath): + os.mkdir(filepath) + filename = [name.split("/")[-1].split('.')[0] for name in filename] + for ele in range(int(z.shape[0])): + path = filepath + filename[ele] + "/" + if not os.path.isdir(path): + os.mkdir(path) + z_path = path + filename[ele] + "_rec.jpg" + x_path = path + filename[ele] + "_orig.jpg" + save_image(z[ele],z_path) + save_image(x[ele],x_path) + def check_path_status(self, path): """experiment_path = None if auto: # This is to not duplicate work already done and to continue running experiments