From f34238d92d85a392924488b5449f54d17c9d6396 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 10 May 2021 15:09:58 +0100 Subject: [PATCH 01/15] [nn] Added ability to gather activation stastitics during LUT inference. --- src/logicnets/nn.py | 49 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/src/logicnets/nn.py b/src/logicnets/nn.py index 6effa5cc9..6eb8ab2f9 100644 --- a/src/logicnets/nn.py +++ b/src/logicnets/nn.py @@ -44,10 +44,34 @@ def generate_truth_tables(model: nn.Module, verbose: bool = False) -> None: model.training = training # TODO: Create a container module which performs this function. -def lut_inference(model: nn.Module) -> None: +def lut_inference(model: nn.Module, track_used_luts: bool = False) -> None: for name, module in model.named_modules(): if type(module) == SparseLinearNeq: - module.lut_inference() + module.lut_inference(track_used_luts=track_used_luts) + +# TODO: Create a container module which performs this function. +def save_luts(model: nn.Module, path: str) -> None: + lut_dict = {} + for name, module in model.named_modules(): + if type(module) == SparseLinearNeq: + luts = module.neuron_truth_tables + indices = list(map(lambda x: x[0], luts)) + tt_inputs = list(map(lambda x: x[1], luts)) + tt_input_bin_str = list(map(lambda x: list(map(lambda y: list(map(lambda z: module.input_quant.get_bin_str(z), y)), x)), tt_inputs)) + tt_float_outputs = list(map(lambda x: x[2], luts)) + tt_bin_outputs = list(map(lambda x: x[3], luts)) + tt_outputs_bin_str = list(map(lambda x: list(map(lambda y: module.output_quant.get_bin_str(y), x)), tt_bin_outputs)) + histogram = module.used_luts_histogram + lut_dict[name] = { + 'indices': indices, + 'input_state_space': tt_inputs, + 'input_state_space_bin_str': tt_input_bin_str, + 'output_state_space_float': tt_float_outputs, + 'output_state_space_bin': tt_bin_outputs, + 'output_state_space_bin_str': tt_outputs_bin_str, + 'histogram': histogram, + } + torch.save(lut_dict, path) # TODO: Create a container module which performs this function. def neq_inference(model: nn.Module) -> None: @@ -111,6 +135,8 @@ def __init__(self, in_features: int, out_features: int, input_quant, output_quan self.neuron_truth_tables = None self.apply_input_quant = apply_input_quant self.apply_output_quant = apply_output_quant + self.track_used_luts = False + self.used_luts_histogram = None # TODO: Move the verilog string templates to elsewhere # TODO: Move this to another class @@ -158,8 +184,9 @@ def gen_neuron_verilog(self, index, module_name): lut_string += f"\t\t\t{int(cat_input_bitwidth)}'b{entry_str}: M1r = {int(output_bitwidth)}'b{res_str};\n" return generate_lut_verilog(module_name, int(cat_input_bitwidth), int(output_bitwidth), lut_string) - def lut_inference(self): + def lut_inference(self, track_used_luts=False): self.is_lut_inference = True + self.track_used_luts = track_used_luts self.input_quant.bin_output() self.output_quant.bin_output() @@ -169,7 +196,7 @@ def neq_inference(self): self.output_quant.float_output() # TODO: This function might be a useful utility outside of this class.. - def table_lookup(self, connected_input: Tensor, input_perm_matrix: Tensor, bin_output_states: Tensor) -> Tensor: + def table_lookup(self, connected_input: Tensor, input_perm_matrix: Tensor, bin_output_states: Tensor, neuron_lut_histogram=None) -> Tensor: fan_in_size = connected_input.shape[1] ci_bcast = connected_input.unsqueeze(2) # Reshape to B x Fan-in x 1 pm_bcast = input_perm_matrix.t().unsqueeze(0) # Reshape to 1 x Fan-in x InputStates @@ -178,17 +205,29 @@ def table_lookup(self, connected_input: Tensor, input_perm_matrix: Tensor, bin_o if not (matches == torch.ones_like(matches,dtype=matches.dtype)).all(): raise Exception(f"One or more vectors in the input is not in the possible input state space") indices = torch.argmax(eq.type(torch.int64),dim=1) + if self.track_used_luts: + # TODO: vectorize this loop + for i in indices: + neuron_lut_histogram[i] += 1 return bin_output_states[indices] def lut_forward(self, x: Tensor) -> Tensor: if self.apply_input_quant: x = self.input_quant(x) # Use this to fetch the bin output of the input, if the input isn't already in binary format + # TODO: Put this in a child class(?) + # TODO: Add support for non-uniform fan-in + if self.track_used_luts: + if self.used_luts_histogram is None: + self.used_luts_histogram = self.out_features * [None] + for i in range(self.out_features): + self.used_luts_histogram[i] = torch.zeros(size=(len(self.neuron_truth_tables[i][2]),), dtype=torch.int64) y = torch.zeros((x.shape[0],self.out_features)) # Perform table lookup for each neuron output for i in range(self.out_features): indices, input_perm_matrix, float_output_states, bin_output_states = self.neuron_truth_tables[i] + neuron_lut_histogram = self.used_luts_histogram[i] if self.track_used_luts else None connected_input = x[:,indices] - y[:,i] = self.table_lookup(connected_input, input_perm_matrix, bin_output_states) + y[:,i] = self.table_lookup(connected_input, input_perm_matrix, bin_output_states, neuron_lut_histogram=neuron_lut_histogram) return y def forward(self, x: Tensor) -> Tensor: From aa622a4bde728f5c0407f031cdff5a3334e1dfc8 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 10 May 2021 18:05:33 +0100 Subject: [PATCH 02/15] [jsc] Added script to save LUTs and activation statistics. --- examples/jet_substructure/dump_luts.py | 122 +++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 examples/jet_substructure/dump_luts.py diff --git a/examples/jet_substructure/dump_luts.py b/examples/jet_substructure/dump_luts.py new file mode 100644 index 000000000..c244ee1ba --- /dev/null +++ b/examples/jet_substructure/dump_luts.py @@ -0,0 +1,122 @@ +# Copyright (C) 2021 Xilinx, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser + +import torch +from torch.utils.data import DataLoader + +from logicnets.nn import generate_truth_tables, \ + lut_inference, \ + save_luts, \ + module_list_to_verilog_module + +from train import configs, model_config, dataset_config, other_options, test +from dataset import JetSubstructureDataset +from models import JetSubstructureNeqModel, JetSubstructureLutModel +from logicnets.synthesis import synthesize_and_get_resource_counts + +if __name__ == "__main__": + parser = ArgumentParser(description="Synthesize convert a PyTorch trained model into verilog") + parser.add_argument('--arch', type=str, choices=configs.keys(), default="jsc-s", + help="Specific the neural network model to use (default: %(default)s)") + parser.add_argument('--batch-size', type=int, default=None, metavar='N', + help="Batch size for evaluation (default: %(default)s)") + parser.add_argument('--input-bitwidth', type=int, default=None, + help="Bitwidth to use at the input (default: %(default)s)") + parser.add_argument('--hidden-bitwidth', type=int, default=None, + help="Bitwidth to use for activations in hidden layers (default: %(default)s)") + parser.add_argument('--output-bitwidth', type=int, default=None, + help="Bitwidth to use at the output (default: %(default)s)") + parser.add_argument('--input-fanin', type=int, default=None, + help="Fanin to use at the input (default: %(default)s)") + parser.add_argument('--hidden-fanin', type=int, default=None, + help="Fanin to use for the hidden layers (default: %(default)s)") + parser.add_argument('--output-fanin', type=int, default=None, + help="Fanin to use at the output (default: %(default)s)") + parser.add_argument('--hidden-layers', nargs='+', type=int, default=None, + help="A list of hidden layer neuron sizes (default: %(default)s)") + parser.add_argument('--dataset-file', type=str, default='data/processed-pythia82-lhc13-all-pt1-50k-r1_h022_e0175_t220_nonu_truth.z', + help="The file to use as the dataset input (default: %(default)s)") + parser.add_argument('--dataset-config', type=str, default='config/yaml_IP_OP_config.yml', + help="The file to use to configure the input dataset (default: %(default)s)") + parser.add_argument('--log-dir', type=str, default='./log', + help="A location to store the log output of the training run and the output model (default: %(default)s)") + parser.add_argument('--checkpoint', type=str, required=True, + help="The checkpoint file which contains the model weights") + args = parser.parse_args() + defaults = configs[args.arch] + options = vars(args) + del options['arch'] + config = {} + for k in options.keys(): + config[k] = options[k] if options[k] is not None else defaults[k] # Override defaults, if specified. + + if not os.path.exists(config['log_dir']): + os.makedirs(config['log_dir']) + + # Split up configuration options to be more understandable + model_cfg = {} + for k in model_config.keys(): + model_cfg[k] = config[k] + dataset_cfg = {} + for k in dataset_config.keys(): + dataset_cfg[k] = config[k] + options_cfg = {} + for k in other_options.keys(): + if k == 'cuda': + continue + options_cfg[k] = config[k] + + # Fetch the test set + dataset = {} + dataset['train'] = JetSubstructureDataset(dataset_cfg['dataset_file'], dataset_cfg['dataset_config'], split="train") + train_loader = DataLoader(dataset["train"], batch_size=config['batch_size'], shuffle=False) + + # Instantiate the PyTorch model + x, y = dataset['train'][0] + dataset_length = len(dataset['train']) + model_cfg['input_length'] = len(x) + model_cfg['output_length'] = len(y) + model = JetSubstructureNeqModel(model_cfg) + + # Load the model weights + checkpoint = torch.load(options_cfg['checkpoint'], map_location='cpu') + model.load_state_dict(checkpoint['model_dict']) + + # Test the PyTorch model + print("Running inference of baseline model on training set (%d examples)..." % (dataset_length)) + model.eval() + baseline_accuracy = test(model, train_loader, cuda=False) + print("Baseline accuracy: %f" % (baseline_accuracy)) + + # Instantiate LUT-based model + lut_model = JetSubstructureLutModel(model_cfg) + lut_model.load_state_dict(checkpoint['model_dict']) + + # Generate the truth tables in the LUT module + print("Converting to NEQs to LUTs...") + generate_truth_tables(lut_model, verbose=True) + + # Test the LUT-based model + print("Running inference of LUT-based model training set (%d examples)..." % (dataset_length)) + lut_inference(lut_model, track_used_luts=True) + lut_model.eval() + lut_accuracy = test(lut_model, train_loader, cuda=False) + print("LUT-Based Model accuracy: %f" % (lut_accuracy)) + print("Saving LUTs to %s... " % (options_cfg["log_dir"] + "/luts.pth")) + save_luts(lut_model, options_cfg["log_dir"] + "/luts.pth") + print("Done!") + From 4d75911d263fda879c98587779f1cb3925e07965 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 18 May 2021 10:43:32 +0100 Subject: [PATCH 03/15] [nn] Added extra parameter to specify a cutoff, for if a TT entry should be included in the output verilog. --- src/logicnets/nn.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/logicnets/nn.py b/src/logicnets/nn.py index 6eb8ab2f9..27529e11a 100644 --- a/src/logicnets/nn.py +++ b/src/logicnets/nn.py @@ -73,6 +73,12 @@ def save_luts(model: nn.Module, path: str) -> None: } torch.save(lut_dict, path) +# TODO: Create a container module which performs this function. +def load_histograms(model: nn.Module, lut_dict: dict) -> None: + for name, module in model.named_modules(): + if name in lut_dict.keys(): + module.used_luts_histogram = lut_dict[name]['histogram'] + # TODO: Create a container module which performs this function. def neq_inference(model: nn.Module) -> None: for name, module in model.named_modules(): @@ -81,7 +87,7 @@ def neq_inference(model: nn.Module) -> None: # TODO: Should this go in with the other verilog functions? # TODO: Support non-linear topologies -def module_list_to_verilog_module(module_list: nn.ModuleList, module_name: str, output_directory: str): +def module_list_to_verilog_module(module_list: nn.ModuleList, module_name: str, output_directory: str, freq_thresh=None): input_bitwidth = None output_bitwidth = None module_contents = "" @@ -89,7 +95,7 @@ def module_list_to_verilog_module(module_list: nn.ModuleList, module_name: str, m = module_list[i] if type(m) == SparseLinearNeq: module_prefix = f"layer{i}" - module_input_bits, module_output_bits = m.gen_layer_verilog(module_prefix, output_directory) + module_input_bits, module_output_bits = m.gen_layer_verilog(module_prefix, output_directory, freq_thresh=freq_thresh) if i == 0: input_bitwidth = module_input_bits elif i == len(module_list)-1: @@ -141,7 +147,7 @@ def __init__(self, in_features: int, out_features: int, input_quant, output_quan # TODO: Move the verilog string templates to elsewhere # TODO: Move this to another class # TODO: Update this code to support custom bitwidths per input/output - def gen_layer_verilog(self, module_prefix, directory): + def gen_layer_verilog(self, module_prefix, directory, freq_thresh=None): _, input_bitwidth = self.input_quant.get_scale_factor_bits() _, output_bitwidth = self.output_quant.get_scale_factor_bits() input_bitwidth, output_bitwidth = int(input_bitwidth), int(output_bitwidth) @@ -152,7 +158,7 @@ def gen_layer_verilog(self, module_prefix, directory): for index in range(self.out_features): module_name = f"{module_prefix}_N{index}" indices, _, _, _ = self.neuron_truth_tables[index] - neuron_verilog = self.gen_neuron_verilog(index, module_name) # Generate the contents of the neuron verilog + neuron_verilog = self.gen_neuron_verilog(index, module_name, freq_thresh=freq_thresh) # Generate the contents of the neuron verilog with open(f"{directory}/{module_name}.v", "w") as f: f.write(neuron_verilog) connection_string = generate_neuron_connection_verilog(indices, input_bitwidth) # Generate the string which connects the synapses to this neuron @@ -168,7 +174,7 @@ def gen_layer_verilog(self, module_prefix, directory): # TODO: Move the verilog string templates to elsewhere # TODO: Move this to another class - def gen_neuron_verilog(self, index, module_name): + def gen_neuron_verilog(self, index, module_name, freq_thresh=None): indices, input_perm_matrix, float_output_states, bin_output_states = self.neuron_truth_tables[index] _, input_bitwidth = self.input_quant.get_scale_factor_bits() _, output_bitwidth = self.output_quant.get_scale_factor_bits() @@ -181,7 +187,8 @@ def gen_neuron_verilog(self, index, module_name): val = input_perm_matrix[i,idx] entry_str += self.input_quant.get_bin_str(val) res_str = self.output_quant.get_bin_str(bin_output_states[i]) - lut_string += f"\t\t\t{int(cat_input_bitwidth)}'b{entry_str}: M1r = {int(output_bitwidth)}'b{res_str};\n" + if (freq_thresh is None) or (self.used_luts_histogram[index][i] >= freq_thresh): + lut_string += f"\t\t\t{int(cat_input_bitwidth)}'b{entry_str}: M1r = {int(output_bitwidth)}'b{res_str};\n" return generate_lut_verilog(module_name, int(cat_input_bitwidth), int(output_bitwidth), lut_string) def lut_inference(self, track_used_luts=False): From 538281538cd43d0e040c67f10b0aeb7ef68d51ec Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 18 May 2021 10:52:07 +0100 Subject: [PATCH 04/15] [jsc] Made verification of verilog simulation optional with a flag. --- examples/jet_substructure/models.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/jet_substructure/models.py b/examples/jet_substructure/models.py index 7c0fba9f9..be6891bff 100644 --- a/examples/jet_substructure/models.py +++ b/examples/jet_substructure/models.py @@ -64,12 +64,14 @@ def __init__(self, model_config): self.verilog_dir = None self.top_module_filename = None self.dut = None + self.verify = True - def verilog_inference(self, verilog_dir, top_module_filename): + def verilog_inference(self, verilog_dir, top_module_filename, verify=True): self.verilog_dir = realpath(verilog_dir) self.top_module_filename = top_module_filename self.dut = PyVerilator.build(f"{self.verilog_dir}/{self.top_module_filename}", verilog_path=[self.verilog_dir], build_dir=f"{self.verilog_dir}/verilator") self.is_verilog_inference = True + self.verify = verify def pytorch_inference(self): self.is_verilog_inference = False @@ -92,11 +94,8 @@ def verilog_forward(self, x): self.dut.io.clk = 0 for i in range(x.shape[0]): x_i = x[i,:] - y_i = self.pytorch_forward(x[i:i+1,:])[0] xv_i = list(map(lambda z: input_quant.get_bin_str(z), x_i)) - ys_i = list(map(lambda z: output_quant.get_bin_str(z), y_i)) xvc_i = reduce(lambda a,b: a+b, xv_i[::-1]) - ysc_i = reduce(lambda a,b: a+b, ys_i[::-1]) self.dut["M0"] = int(xvc_i, 2) for j in range(self.latency + 1): #print(self.dut.io.M5) @@ -104,9 +103,13 @@ def verilog_forward(self, x): result = f"{res:0{int(total_output_bits)}b}" self.dut.io.clk = 1 self.dut.io.clk = 0 - expected = f"{int(ysc_i,2):0{int(total_output_bits)}b}" result = f"{res:0{int(total_output_bits)}b}" - assert(expected == result) + if self.verify: + y_i = self.pytorch_forward(x[i:i+1,:])[0] + ys_i = list(map(lambda z: output_quant.get_bin_str(z), y_i)) + ysc_i = reduce(lambda a,b: a+b, ys_i[::-1]) + expected = f"{int(ysc_i,2):0{int(total_output_bits)}b}" + assert(expected == result) res_split = [result[i:i+output_bitwidth] for i in range(0, len(result), output_bitwidth)][::-1] yv_i = torch.Tensor(list(map(lambda z: int(z, 2), res_split))) y[i,:] = yv_i From 45d7955eae74d3ca4af3f7752f7446c305c3a9f0 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 18 May 2021 10:53:03 +0100 Subject: [PATCH 05/15] [jsc] Added loading of calculated histograms and specifying a TT frequency threshold. --- examples/jet_substructure/neq2lut.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/jet_substructure/neq2lut.py b/examples/jet_substructure/neq2lut.py index 15b6d2b80..c106128aa 100644 --- a/examples/jet_substructure/neq2lut.py +++ b/examples/jet_substructure/neq2lut.py @@ -20,7 +20,8 @@ from logicnets.nn import generate_truth_tables, \ lut_inference, \ - module_list_to_verilog_module + module_list_to_verilog_module, \ + load_histograms from train import configs, model_config, dataset_config, other_options, test from dataset import JetSubstructureDataset @@ -55,6 +56,10 @@ help="A location to store the log output of the training run and the output model (default: %(default)s)") parser.add_argument('--checkpoint', type=str, required=True, help="The checkpoint file which contains the model weights") + parser.add_argument('--histograms', type=str, required=True, + help="The checkpoint histograms of LUT usage") + parser.add_argument('--freq-thresh', type=int, default=0, + help="Threshold to use to include this truth table into the model (default: %(default)s)") args = parser.parse_args() defaults = configs[args.arch] options = vars(args) @@ -118,13 +123,15 @@ 'test_accuracy': lut_accuracy} torch.save(modelSave, options_cfg["log_dir"] + "/lut_based_model.pth") + luts = torch.load(args.histograms) + load_histograms(lut_model, luts) print("Generating verilog in %s..." % (options_cfg["log_dir"])) - module_list_to_verilog_module(lut_model.module_list, "logicnet", options_cfg["log_dir"]) + module_list_to_verilog_module(lut_model.module_list, "logicnet", options_cfg["log_dir"], freq_thresh=args.freq_thresh) print("Top level entity stored at: %s/logicnet.v ..." % (options_cfg["log_dir"])) print("Running inference simulation of Verilog-based model...") - lut_model.verilog_inference(options_cfg["log_dir"], "logicnet.v") + lut_model.verilog_inference(options_cfg["log_dir"], "logicnet.v", verify=args.freq_thresh == 0) verilog_accuracy = test(lut_model, test_loader, cuda=False) print("Verilog-Based Model accuracy: %f" % (verilog_accuracy)) From dc302d1f8344c02bcbf312aecdeb88ab052723b4 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 19 May 2021 15:52:34 +0100 Subject: [PATCH 06/15] [nn] Added a default case to verilog LUT generation. --- examples/jet_substructure/models.py | 2 +- src/logicnets/nn.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/jet_substructure/models.py b/examples/jet_substructure/models.py index be6891bff..e8c99a612 100644 --- a/examples/jet_substructure/models.py +++ b/examples/jet_substructure/models.py @@ -69,7 +69,7 @@ def __init__(self, model_config): def verilog_inference(self, verilog_dir, top_module_filename, verify=True): self.verilog_dir = realpath(verilog_dir) self.top_module_filename = top_module_filename - self.dut = PyVerilator.build(f"{self.verilog_dir}/{self.top_module_filename}", verilog_path=[self.verilog_dir], build_dir=f"{self.verilog_dir}/verilator") + self.dut = PyVerilator.build(f"{self.verilog_dir}/{self.top_module_filename}", verilog_path=[self.verilog_dir], build_dir=f"{self.verilog_dir}/verilator", command_args=("--x-assign","0",)) self.is_verilog_inference = True self.verify = verify diff --git a/src/logicnets/nn.py b/src/logicnets/nn.py index 27529e11a..39af6a959 100644 --- a/src/logicnets/nn.py +++ b/src/logicnets/nn.py @@ -189,6 +189,9 @@ def gen_neuron_verilog(self, index, module_name, freq_thresh=None): res_str = self.output_quant.get_bin_str(bin_output_states[i]) if (freq_thresh is None) or (self.used_luts_histogram[index][i] >= freq_thresh): lut_string += f"\t\t\t{int(cat_input_bitwidth)}'b{entry_str}: M1r = {int(output_bitwidth)}'b{res_str};\n" + # Add a default "don't care" statement + default_string = int(output_bitwidth) * 'x' + lut_string += f"\t\t\tdefault: M1r = {int(output_bitwidth)}'b{default_string};\n" return generate_lut_verilog(module_name, int(cat_input_bitwidth), int(output_bitwidth), lut_string) def lut_inference(self, track_used_luts=False): From 536271bf27b9ae490153ce3434c161eed1072edc Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 20 May 2021 16:15:21 +0100 Subject: [PATCH 07/15] [nn/jsc] Made registers optional in verilog generation. Default is no registers. --- examples/jet_substructure/models.py | 2 +- src/logicnets/verilog.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/jet_substructure/models.py b/examples/jet_substructure/models.py index e8c99a612..49fec0e1b 100644 --- a/examples/jet_substructure/models.py +++ b/examples/jet_substructure/models.py @@ -60,7 +60,7 @@ def __init__(self, model_config): layer_list.append(layer) self.module_list = nn.ModuleList(layer_list) self.is_verilog_inference = False - self.latency = len(self.num_neurons) + self.latency = 1 self.verilog_dir = None self.top_module_filename = None self.dut = None diff --git a/src/logicnets/verilog.py b/src/logicnets/verilog.py index df33fa701..f073a4692 100644 --- a/src/logicnets/verilog.py +++ b/src/logicnets/verilog.py @@ -45,13 +45,15 @@ def generate_logicnets_verilog(module_name: str, input_name: str, input_bits: in output_bits_1=output_bits-1, module_contents=module_contents) -def layer_connection_verilog(layer_string: str, input_string: str, input_bits: int, output_string: str, output_bits: int, output_wire=True): - layer_connection_template = """\ +def layer_connection_verilog(layer_string: str, input_string: str, input_bits: int, output_string: str, output_bits: int, output_wire=True, register=False): + if register: + layer_connection_template = """\ wire [{input_bits_1:d}:0] {input_string}w; myreg #(.DataWidth({input_bits})) {layer_string}_reg (.data_in({input_string}), .clk(clk), .rst(rst), .data_out({input_string}w));\n""" -# layer_connection_template = """\ -#wire [{input_bits_1:d}:0] {input_string}w; -#assign {input_string}w = {input_string};\n""" + else: + layer_connection_template = """\ +wire [{input_bits_1:d}:0] {input_string}w; +assign {input_string}w = {input_string};\n""" layer_connection_template += "wire [{output_bits_1:d}:0] {output_string};\n" if output_wire else "" layer_connection_template += "{layer_string} {layer_string}_inst (.M0({input_string}w), .M1({output_string}));\n" return layer_connection_template.format( layer_string=layer_string, From f4e7810bc482bbeab012f327f23e2e62c35dae77 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 26 May 2021 15:29:26 +0100 Subject: [PATCH 08/15] [verilog] Added 'parallel case' statement to generated verilog. --- src/logicnets/verilog.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/logicnets/verilog.py b/src/logicnets/verilog.py index f073a4692..742cb64f8 100644 --- a/src/logicnets/verilog.py +++ b/src/logicnets/verilog.py @@ -69,6 +69,7 @@ def generate_lut_verilog(module_name, input_fanin_bits, output_bits, lut_string) (*rom_style = "distributed" *) reg [{output_bits_1:d}:0] M1r; assign M1 = M1r; + (* parallel_case *) always @ (M0) begin case (M0) {lut_string} From 40d72e54102bd9ddb0a852b0662d11f0c9ec6445 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 31 May 2021 10:31:39 +0100 Subject: [PATCH 09/15] Revert "[verilog] Added 'parallel case' statement to generated verilog." This reverts commit f4e7810bc482bbeab012f327f23e2e62c35dae77. --- src/logicnets/verilog.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/logicnets/verilog.py b/src/logicnets/verilog.py index 742cb64f8..f073a4692 100644 --- a/src/logicnets/verilog.py +++ b/src/logicnets/verilog.py @@ -69,7 +69,6 @@ def generate_lut_verilog(module_name, input_fanin_bits, output_bits, lut_string) (*rom_style = "distributed" *) reg [{output_bits_1:d}:0] M1r; assign M1 = M1r; - (* parallel_case *) always @ (M0) begin case (M0) {lut_string} From 7970fce539a66ce90dd1412126cdaf43ad0c09b8 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 18 Jun 2021 15:19:58 +0100 Subject: [PATCH 10/15] [jsc] Bugfixes in setting histograms / frequency values --- examples/jet_substructure/neq2lut.py | 7 ++++--- examples/jet_substructure/train.py | 7 +++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/jet_substructure/neq2lut.py b/examples/jet_substructure/neq2lut.py index e23a8d823..ac6c5ccc3 100644 --- a/examples/jet_substructure/neq2lut.py +++ b/examples/jet_substructure/neq2lut.py @@ -62,7 +62,7 @@ help="The checkpoint file which contains the model weights") parser.add_argument('--histograms', type=str, default=None, help="The checkpoint histograms of LUT usage (default: %(default)s)") - parser.add_argument('--freq-thresh', type=int, default=0, + parser.add_argument('--freq-thresh', type=int, default=None, help="Threshold to use to include this truth table into the model (default: %(default)s)") parser.add_argument('--generate-bench', action='store_true', default=False, help="Generate the truth table in BENCH format as well as verilog (default: %(default)s)") @@ -131,8 +131,9 @@ 'test_accuracy': lut_accuracy} torch.save(modelSave, options_cfg["log_dir"] + "/lut_based_model.pth") - luts = torch.load(args.histograms) - load_histograms(lut_model, luts) + if options_cfg["histograms"] is not None: + luts = torch.load(options_cfg["histograms"]) + load_histograms(lut_model, luts) print("Generating verilog in %s..." % (options_cfg["log_dir"])) module_list_to_verilog_module(lut_model.module_list, "logicnet", options_cfg["log_dir"], freq_thresh=options_cfg["freq_thresh"], generate_bench=options_cfg["generate_bench"]) diff --git a/examples/jet_substructure/train.py b/examples/jet_substructure/train.py index 840fd4f18..be90902b6 100644 --- a/examples/jet_substructure/train.py +++ b/examples/jet_substructure/train.py @@ -44,6 +44,8 @@ "learning_rate": 1e-3, "seed": 2, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, "jsc-m": { "hidden_layers": [64, 32, 32, 32], @@ -59,6 +61,8 @@ "learning_rate": 1e-3, "seed": 3, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, "jsc-l": { "hidden_layers": [32, 64, 192, 192, 16], @@ -74,6 +78,8 @@ "learning_rate": 1e-3, "seed": 16, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, } @@ -107,6 +113,7 @@ "checkpoint": None, "generate_bench": None, "freq_thresh": None, + "histograms": None, } def train(model, datasets, train_cfg, options): From fbf8238eb2272afa34fb90288b1805bc852dd2e1 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 28 Jun 2021 16:12:23 +0100 Subject: [PATCH 11/15] [jsc] Updated default PCA to be 12 dimensions. --- examples/jet_substructure/config/yaml_IP_OP_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jet_substructure/config/yaml_IP_OP_config.yml b/examples/jet_substructure/config/yaml_IP_OP_config.yml index e238039bf..95befe1fe 100644 --- a/examples/jet_substructure/config/yaml_IP_OP_config.yml +++ b/examples/jet_substructure/config/yaml_IP_OP_config.yml @@ -45,5 +45,5 @@ L1Reg: 0.0001 NormalizeInputs: 1 InputType: Dense ApplyPca: false -PcaDimensions: 10 +PcaDimensions: 12 From 01f8d43b21d958a71cfc2884f221ecb133a7994b Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 31 Aug 2021 16:48:37 +0100 Subject: [PATCH 12/15] [jsc] Fixed description of commandline arguments --- examples/jet_substructure/dump_luts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/jet_substructure/dump_luts.py b/examples/jet_substructure/dump_luts.py index c244ee1ba..05952268e 100644 --- a/examples/jet_substructure/dump_luts.py +++ b/examples/jet_substructure/dump_luts.py @@ -29,7 +29,7 @@ from logicnets.synthesis import synthesize_and_get_resource_counts if __name__ == "__main__": - parser = ArgumentParser(description="Synthesize convert a PyTorch trained model into verilog") + parser = ArgumentParser(description="Generate histograms of states used throughout LogicNets") parser.add_argument('--arch', type=str, choices=configs.keys(), default="jsc-s", help="Specific the neural network model to use (default: %(default)s)") parser.add_argument('--batch-size', type=int, default=None, metavar='N', @@ -53,7 +53,7 @@ parser.add_argument('--dataset-config', type=str, default='config/yaml_IP_OP_config.yml', help="The file to use to configure the input dataset (default: %(default)s)") parser.add_argument('--log-dir', type=str, default='./log', - help="A location to store the log output of the training run and the output model (default: %(default)s)") + help="A location to store the calculated histograms (default: %(default)s)") parser.add_argument('--checkpoint', type=str, required=True, help="The checkpoint file which contains the model weights") args = parser.parse_args() From 691944eda9684231030de8392dd81878d0277ca3 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 18 Aug 2022 17:28:36 +0100 Subject: [PATCH 13/15] [nids] Initial version supporting histograms. --- examples/cybersecurity/dump_luts.py | 119 ++++++++++++++++++++++++++++ examples/cybersecurity/models.py | 17 ++-- examples/cybersecurity/neq2lut.py | 19 ++++- examples/cybersecurity/train.py | 12 +++ 4 files changed, 156 insertions(+), 11 deletions(-) create mode 100644 examples/cybersecurity/dump_luts.py diff --git a/examples/cybersecurity/dump_luts.py b/examples/cybersecurity/dump_luts.py new file mode 100644 index 000000000..01a57663f --- /dev/null +++ b/examples/cybersecurity/dump_luts.py @@ -0,0 +1,119 @@ +# Copyright (C) 2021 Xilinx, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser + +import torch +from torch.utils.data import DataLoader + +from logicnets.nn import generate_truth_tables, \ + lut_inference, \ + save_luts, \ + module_list_to_verilog_module + +from train import configs, model_config, dataset_config, other_options, test +from dataset import get_preqnt_dataset +from models import UnswNb15NeqModel, UnswNb15LutModel + +if __name__ == "__main__": + parser = ArgumentParser(description="Generate histograms of states used throughout LogicNets") + parser.add_argument('--arch', type=str, choices=configs.keys(), default="jsc-s", + help="Specific the neural network model to use (default: %(default)s)") + parser.add_argument('--batch-size', type=int, default=None, metavar='N', + help="Batch size for evaluation (default: %(default)s)") + parser.add_argument('--input-bitwidth', type=int, default=None, + help="Bitwidth to use at the input (default: %(default)s)") + parser.add_argument('--hidden-bitwidth', type=int, default=None, + help="Bitwidth to use for activations in hidden layers (default: %(default)s)") + parser.add_argument('--output-bitwidth', type=int, default=None, + help="Bitwidth to use at the output (default: %(default)s)") + parser.add_argument('--input-fanin', type=int, default=None, + help="Fanin to use at the input (default: %(default)s)") + parser.add_argument('--hidden-fanin', type=int, default=None, + help="Fanin to use for the hidden layers (default: %(default)s)") + parser.add_argument('--output-fanin', type=int, default=None, + help="Fanin to use at the output (default: %(default)s)") + parser.add_argument('--hidden-layers', nargs='+', type=int, default=None, + help="A list of hidden layer neuron sizes (default: %(default)s)") + parser.add_argument('--dataset-file', type=str, default='data/unsw_nb15_binarized.npz', + help="The file to use as the dataset input (default: %(default)s)") + parser.add_argument('--log-dir', type=str, default='./log', + help="A location to store the calculated histograms (default: %(default)s)") + parser.add_argument('--checkpoint', type=str, required=True, + help="The checkpoint file which contains the model weights") + args = parser.parse_args() + defaults = configs[args.arch] + options = vars(args) + del options['arch'] + config = {} + for k in options.keys(): + config[k] = options[k] if options[k] is not None else defaults[k] # Override defaults, if specified. + + if not os.path.exists(config['log_dir']): + os.makedirs(config['log_dir']) + + # Split up configuration options to be more understandable + model_cfg = {} + for k in model_config.keys(): + model_cfg[k] = config[k] + dataset_cfg = {} + for k in dataset_config.keys(): + dataset_cfg[k] = config[k] + options_cfg = {} + for k in other_options.keys(): + if k == 'cuda': + continue + options_cfg[k] = config[k] + + # Fetch the test set + dataset = {} + dataset['train'] = get_preqnt_dataset(dataset_cfg['dataset_file'], split='train') + train_loader = DataLoader(dataset["train"], batch_size=config['batch_size'], shuffle=False) + + # Instantiate the PyTorch model + x, y = dataset['train'][0] + dataset_length = len(dataset['train']) + model_cfg['input_length'] = len(x) + model_cfg['output_length'] = 1 + model = UnswNb15NeqModel(model_cfg) + + # Load the model weights + checkpoint = torch.load(options_cfg['checkpoint'], map_location='cpu') + model.load_state_dict(checkpoint['model_dict']) + + # Test the PyTorch model + print("Running inference of baseline model on training set (%d examples)..." % (dataset_length)) + model.eval() + baseline_accuracy = test(model, train_loader, cuda=False) + print("Baseline accuracy: %f" % (baseline_accuracy)) + + # Instantiate LUT-based model + lut_model = UnswNb15LutModel(model_cfg) + lut_model.load_state_dict(checkpoint['model_dict']) + + # Generate the truth tables in the LUT module + print("Converting to NEQs to LUTs...") + generate_truth_tables(lut_model, verbose=True) + + # Test the LUT-based model + print("Running inference of LUT-based model training set (%d examples)..." % (dataset_length)) + lut_inference(lut_model, track_used_luts=True) + lut_model.eval() + lut_accuracy = test(lut_model, train_loader, cuda=False) + print("LUT-Based Model accuracy: %f" % (lut_accuracy)) + print("Saving LUTs to %s... " % (options_cfg["log_dir"] + "/luts.pth")) + save_luts(lut_model, options_cfg["log_dir"] + "/luts.pth") + print("Done!") + diff --git a/examples/cybersecurity/models.py b/examples/cybersecurity/models.py index b98ab5dc9..bfbaf2ca5 100644 --- a/examples/cybersecurity/models.py +++ b/examples/cybersecurity/models.py @@ -63,13 +63,15 @@ def __init__(self, model_config): self.verilog_dir = None self.top_module_filename = None self.dut = None + self.verify = True self.logfile = None - def verilog_inference(self, verilog_dir, top_module_filename, logfile: bool = False, add_registers: bool = False): + def verilog_inference(self, verilog_dir, top_module_filename, logfile: bool = False, add_registers: bool = False, verify: bool = True): self.verilog_dir = realpath(verilog_dir) self.top_module_filename = top_module_filename - self.dut = PyVerilator.build(f"{self.verilog_dir}/{self.top_module_filename}", verilog_path=[self.verilog_dir], build_dir=f"{self.verilog_dir}/verilator") + self.dut = PyVerilator.build(f"{self.verilog_dir}/{self.top_module_filename}", verilog_path=[self.verilog_dir], build_dir=f"{self.verilog_dir}/verilator", command_args=("--x-assign","0",)) self.is_verilog_inference = True + self.verify = verify self.logfile = logfile if add_registers: self.latency = len(self.num_neurons) @@ -95,11 +97,8 @@ def verilog_forward(self, x): self.dut.io.clk = 0 for i in range(x.shape[0]): x_i = x[i,:] - y_i = self.pytorch_forward(x[i:i+1,:])[0] xv_i = list(map(lambda z: input_quant.get_bin_str(z), x_i)) - ys_i = list(map(lambda z: output_quant.get_bin_str(z), y_i)) xvc_i = reduce(lambda a,b: a+b, xv_i[::-1]) - ysc_i = reduce(lambda a,b: a+b, ys_i[::-1]) self.dut["M0"] = int(xvc_i, 2) for j in range(self.latency + 1): #print(self.dut.io.M5) @@ -107,9 +106,13 @@ def verilog_forward(self, x): result = f"{res:0{int(total_output_bits)}b}" self.dut.io.clk = 1 self.dut.io.clk = 0 - expected = f"{int(ysc_i,2):0{int(total_output_bits)}b}" result = f"{res:0{int(total_output_bits)}b}" - assert(expected == result) + if self.verify: + y_i = self.pytorch_forward(x[i:i+1,:])[0] + ys_i = list(map(lambda z: output_quant.get_bin_str(z), y_i)) + ysc_i = reduce(lambda a,b: a+b, ys_i[::-1]) + expected = f"{int(ysc_i,2):0{int(total_output_bits)}b}" + assert(expected == result) res_split = [result[i:i+output_bitwidth] for i in range(0, len(result), output_bitwidth)][::-1] yv_i = torch.Tensor(list(map(lambda z: int(z, 2), res_split))) y[i,:] = yv_i diff --git a/examples/cybersecurity/neq2lut.py b/examples/cybersecurity/neq2lut.py index 4302ec304..bcc7ef049 100644 --- a/examples/cybersecurity/neq2lut.py +++ b/examples/cybersecurity/neq2lut.py @@ -20,7 +20,8 @@ from logicnets.nn import generate_truth_tables, \ lut_inference, \ - module_list_to_verilog_module + module_list_to_verilog_module, \ + load_histograms from logicnets.synthesis import synthesize_and_get_resource_counts from logicnets.util import proc_postsynth_file @@ -34,6 +35,8 @@ "checkpoint": None, "generate_bench": False, "add_registers": False, + "histograms": None, + "freq_thresh": None, "simulate_pre_synthesis_verilog": False, "simulate_post_synthesis_verilog": False, } @@ -68,6 +71,10 @@ help="A location to store the log output of the training run and the output model (default: %(default)s)") parser.add_argument('--checkpoint', type=str, required=True, help="The checkpoint file which contains the model weights") + parser.add_argument('--histograms', type=str, default=None, + help="The checkpoint histograms of LUT usage (default: %(default)s)") + parser.add_argument('--freq-thresh', type=int, default=None, + help="Threshold to use to include this truth table into the model (default: %(default)s)") parser.add_argument('--generate-bench', action='store_true', default=False, help="Generate the truth table in BENCH format as well as verilog (default: %(default)s)") parser.add_argument('--dump-io', action='store_true', default=False, @@ -141,9 +148,12 @@ 'test_accuracy': lut_accuracy} torch.save(modelSave, options_cfg["log_dir"] + "/lut_based_model.pth") + if options_cfg["histograms"] is not None: + luts = torch.load(options_cfg["histograms"]) + load_histograms(lut_model, luts) print("Generating verilog in %s..." % (options_cfg["log_dir"])) - module_list_to_verilog_module(lut_model.module_list, "logicnet", options_cfg["log_dir"], generate_bench=options_cfg["generate_bench"], add_registers=options_cfg["add_registers"]) + module_list_to_verilog_module(lut_model.module_list, "logicnet", options_cfg["log_dir"], generate_bench=options_cfg["generate_bench"], add_registers=options_cfg["add_registers"], freq_thresh=options_cfg["freq_thresh"]) print("Top level entity stored at: %s/logicnet.v ..." % (options_cfg["log_dir"])) if args.dump_io: @@ -154,9 +164,10 @@ else: io_filename = None + if args.simulate_pre_synthesis_verilog: print("Running inference simulation of Verilog-based model...") - lut_model.verilog_inference(options_cfg["log_dir"], "logicnet.v", logfile=io_filename, add_registers=options_cfg["add_registers"]) + lut_model.verilog_inference(options_cfg["log_dir"], "logicnet.v", logfile=io_filename, add_registers=options_cfg["add_registers"], verify=options_cfg["freq_thresh"] is None or options_cfg["freq_thresh"] == 0) verilog_accuracy = test(lut_model, test_loader, cuda=False) print("Verilog-Based Model accuracy: %f" % (verilog_accuracy)) @@ -166,7 +177,7 @@ if args.simulate_post_synthesis_verilog: print("Running post-synthesis inference simulation of Verilog-based model...") proc_postsynth_file(options_cfg["log_dir"]) - lut_model.verilog_inference(options_cfg["log_dir"]+"/post_synth", "logicnet_post_synth.v", io_filename, add_registers=options_cfg["add_registers"]) + lut_model.verilog_inference(options_cfg["log_dir"]+"/post_synth", "logicnet_post_synth.v", io_filename, add_registers=options_cfg["add_registers"], verify=options_cfg["freq_thresh"] is None or options_cfg["freq_thresh"] == 0) post_synth_accuracy = test(lut_model, test_loader, cuda=False) print("Post-synthesis Verilog-Based Model accuracy: %f" % (post_synth_accuracy)) diff --git a/examples/cybersecurity/train.py b/examples/cybersecurity/train.py index 3576ae15a..30dcba634 100644 --- a/examples/cybersecurity/train.py +++ b/examples/cybersecurity/train.py @@ -44,6 +44,8 @@ "learning_rate": 1e-1, "seed": 25, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, "nid-s-comp": { "hidden_layers": [49, 7], @@ -59,6 +61,8 @@ "learning_rate": 1e-1, "seed": 81, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, "nid-m": { "hidden_layers": [593, 256, 128, 128], @@ -74,6 +78,8 @@ "learning_rate": 1e-1, "seed": 20, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, "nid-m-comp": { "hidden_layers": [593, 256, 49, 7], @@ -89,6 +95,8 @@ "learning_rate": 1e-1, "seed": 40, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, "nid-l": { "hidden_layers": [593, 100, 100, 100], @@ -104,6 +112,8 @@ "learning_rate": 1e-1, "seed": 2, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, "nid-l-comp": { "hidden_layers": [593, 100, 25, 5], @@ -119,6 +129,8 @@ "learning_rate": 1e-1, "seed": 83, "checkpoint": None, + "histograms": None, + "freq_thresh": None, }, } From 7e83a9c2ebb23fa41306a4243a0cd3b367771f26 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Sun, 5 Mar 2023 12:26:04 +0000 Subject: [PATCH 14/15] [ex/jsc] Bugfix / added AVG ROC-AUC to results --- examples/jet_substructure/dump_luts.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/jet_substructure/dump_luts.py b/examples/jet_substructure/dump_luts.py index 05952268e..f99d21651 100644 --- a/examples/jet_substructure/dump_luts.py +++ b/examples/jet_substructure/dump_luts.py @@ -99,8 +99,9 @@ # Test the PyTorch model print("Running inference of baseline model on training set (%d examples)..." % (dataset_length)) model.eval() - baseline_accuracy = test(model, train_loader, cuda=False) + baseline_accuracy, baseline_avg_roc_auc = test(model, test_loader, cuda=False) print("Baseline accuracy: %f" % (baseline_accuracy)) + print("Baseline AVG ROC AUC: %f" % (baseline_avg_roc_auc)) # Instantiate LUT-based model lut_model = JetSubstructureLutModel(model_cfg) @@ -114,8 +115,9 @@ print("Running inference of LUT-based model training set (%d examples)..." % (dataset_length)) lut_inference(lut_model, track_used_luts=True) lut_model.eval() - lut_accuracy = test(lut_model, train_loader, cuda=False) + lut_accuracy, lut_avg_roc_auc = test(lut_model, test_loader, cuda=False) print("LUT-Based Model accuracy: %f" % (lut_accuracy)) + print("LUT-Based AVG ROC AUC: %f" % (lut_avg_roc_auc)) print("Saving LUTs to %s... " % (options_cfg["log_dir"] + "/luts.pth")) save_luts(lut_model, options_cfg["log_dir"] + "/luts.pth") print("Done!") From 3713b6dbdd8100dbd9fa888075e00be2bb1d144d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 18 Nov 2024 14:21:05 +0000 Subject: [PATCH 15/15] [example/jsc] Bugfix --- examples/jet_substructure/dump_luts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/jet_substructure/dump_luts.py b/examples/jet_substructure/dump_luts.py index f99d21651..c5acfb8b2 100644 --- a/examples/jet_substructure/dump_luts.py +++ b/examples/jet_substructure/dump_luts.py @@ -99,7 +99,7 @@ # Test the PyTorch model print("Running inference of baseline model on training set (%d examples)..." % (dataset_length)) model.eval() - baseline_accuracy, baseline_avg_roc_auc = test(model, test_loader, cuda=False) + baseline_accuracy, baseline_avg_roc_auc = test(model, train_loader, cuda=False) print("Baseline accuracy: %f" % (baseline_accuracy)) print("Baseline AVG ROC AUC: %f" % (baseline_avg_roc_auc)) @@ -115,7 +115,7 @@ print("Running inference of LUT-based model training set (%d examples)..." % (dataset_length)) lut_inference(lut_model, track_used_luts=True) lut_model.eval() - lut_accuracy, lut_avg_roc_auc = test(lut_model, test_loader, cuda=False) + lut_accuracy, lut_avg_roc_auc = test(lut_model, train_loader, cuda=False) print("LUT-Based Model accuracy: %f" % (lut_accuracy)) print("LUT-Based AVG ROC AUC: %f" % (lut_avg_roc_auc)) print("Saving LUTs to %s... " % (options_cfg["log_dir"] + "/luts.pth"))