diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index bda8ed94b..967874ad8 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -41,6 +41,7 @@ jobs: - name: build python wheel - pypi run: | # test that the python wheel builds + python -m pip install -U setuptools wheel versioningit toml build twine python -m build --wheel --no-isolation twine check dist/neutron_tavi-*.whl # - name: build conda package diff --git a/codecov.yaml b/codecov.yaml index 6ff0b3325..41082d82b 100644 --- a/codecov.yaml +++ b/codecov.yaml @@ -9,6 +9,7 @@ coverage: threshold: 0.5% ignore: + - "src/tavi/tavi_model/FileSystem/combine_data" - "src/tavi/tavi_model/FileSystem/filter" - "src/tavi/tavi_model/FileSystem/Facilities" - "src/tavi/tavi_view" diff --git a/src/tavi/tavi_model/combine_data.py b/src/tavi/tavi_model/combine_data.py new file mode 100644 index 000000000..4dd8003c1 --- /dev/null +++ b/src/tavi/tavi_model/combine_data.py @@ -0,0 +1,39 @@ +from typing import Optional + +import numpy as np +from scipy.stats import binned_statistic + +from tavi.tavi_model.FileSystem.tavi_class_factory import Scan + + +class CombineManager: + def __init__(self, target: list[Scan], background: Optional[list[Scan]] = []): + self.target = target + self.background = background + + def combine_1d(self, axis: tuple[str, str], step: float, range: Optional[tuple[float, float]] = None, **kwarg): + x_axis, y_axis = axis + new_x, new_y = np.array([]), np.array([]) + for scan in self.target: + x = getattr(scan.data, x_axis) + y = getattr(scan.data, y_axis) + new_x = np.append(new_x, x) + new_y = np.append(new_y, y) + + # sort based on x + ind = np.argsort(new_x) + new_y = new_y[ind] + new_x = new_x[ind] + new_err = np.sqrt(new_y) + statistics, bin_edges, binnumber = binned_statistic(new_x, new_y, statistic="sum", bins=10, range=range) + bin_center = [(bin_edges[i - 1] + bin_edges[i]) / 2 for i in range(1, len(bin_edges))] + return statistics, bin_edges, binnumber, bin_center, new_x, new_y + + # def _equal_rebin_1d(self, x, y, err): + # return np.histogram2d(x,y) + + def combine_2d(): + pass + + def _equal_bins(): + pass diff --git a/src/tavi/tavi_model/filter.py b/src/tavi/tavi_model/filter.py index 732d78f69..a4c931b9d 100644 --- a/src/tavi/tavi_model/filter.py +++ b/src/tavi/tavi_model/filter.py @@ -39,12 +39,12 @@ class Filter: def __init__( self, - scan_list: dict[str, Scan], + rawdataptr: dict[str, Scan], conditions: Optional[list[Operations]] = None, and_or: Optional[Logic] = None, tol: float = 0.01, # this can be put into a TAVI config json file as filter equal tolerance ): - self.scan_list = scan_list + self.rawdataptr = rawdataptr self.conditions = conditions self.and_or = and_or self.tol = tol @@ -76,7 +76,7 @@ def _condition_factory(self, keyword, value, condition, category): """ tmp_output = set() if category == Category.METADATA: - for filename, scan in self.scan_list.items(): + for filename, scan in self.rawdataptr.items(): if hasattr(scan.metadata, keyword): att = scan.metadata elif hasattr(scan.ubconf, keyword): @@ -99,7 +99,7 @@ def _condition_factory(self, keyword, value, condition, category): tmp_output.add(filename) elif category == Category.DATA: value = float(value) - for filename, scan in self.scan_list.items(): + for filename, scan in self.rawdataptr.items(): if not hasattr(scan.data, keyword): logger.log("No matching entry with", keyword) else: diff --git a/src/tavi/tavi_model/tavi_project.py b/src/tavi/tavi_model/tavi_project.py index 226088c76..59b6d1df6 100644 --- a/src/tavi/tavi_model/tavi_project.py +++ b/src/tavi/tavi_model/tavi_project.py @@ -1,6 +1,7 @@ import os from typing import Iterable, Optional +from tavi.tavi_model.combine_data import CombineManager from tavi.tavi_model.FileSystem.load_manager import LoadManager from tavi.tavi_model.filter import Filter, Logic, Operations from tavi.tavi_model.tavi_data import TaviData @@ -109,8 +110,17 @@ def select_scans( self.tavi_data.show_selected_data[filter_name] = filtered_data # TO DO - def combine_data(): - pass + def combine_data( + self, + target_list: list[str], + background_list: Optional[list[str]] = [], + axis: Optional[tuple[str, str]] = None, + tol: Optional[float] = 0.01, + ): + target = [self.tavi_data.rawdataptr[scan_name] for scan_name in target_list] + background = [self.tavi_data.rawdataptr[scan_name] for scan_name in background_list] + combined_data_1d = CombineManager(target=target, background=background).combine_1d(axis) + return combined_data_1d # TO DO def fit_data(): @@ -129,15 +139,29 @@ def plot_data(): TaviProj.load_scans(filepath) - filename = "CG4C_exp0424_scan0042.dat" - TaviProj.select_scans( - filter_name="scan_contains_42", conditions=([["scan", Operations.CONTAINS, "42"]]), and_or=Logic.OR - ) - - TaviProj.select_scans(filter_name="filter2", conditions=([["scan", Operations.CONTAINS, "4"]]), and_or=Logic.OR) - print(TaviProj.tavi_data.show_selected_data) -# print(type(TaviProj.scans[filename].metadata.scan)) -# print(TaviProj.scans[filename].ubconf) -# print(TaviProj.scans[filename].data.Pt) -# print(TaviProj.scans[filename].error_message) -# print(TaviProj.scans[filename].metadata.time) + # filename = "CG4C_exp0424_scan0042.dat" + # TaviProj.select_scans( + # filter_name="scan_contains_42", conditions=([["scan", Operations.CONTAINS, "42"]]), and_or=Logic.OR + # ) + + # TaviProj.select_scans(filter_name="filter2", conditions=([["scan", Operations.CONTAINS, "4"]]), and_or=Logic.OR) + # print(TaviProj.tavi_data.show_selected_data) + # print(type(TaviProj.scans[filename].metadata.scan)) + # print(TaviProj.scans[filename].ubconf) + # print(TaviProj.scans[filename].data.Pt) + # print(TaviProj.scans[filename].error_message) + # print(TaviProj.scans[filename].metadata.time) + + # -----------------------combine data--------------------------- + target = ["CG4C_exp0424_scan0042.dat", "CG4C_exp0424_scan0042.dat"] + test_return = TaviProj.combine_data(target_list=target, axis=("e", "detector")) + print(test_return[0]) + print(test_return[1]) + print(test_return[2]) + print(test_return[3]) + print(test_return[4]) + print(test_return[5]) + print(len(test_return)) + # import matplotlib.pyplot as plt + # plt.plot(test_return[3], test_return[4], '.') + # plt.show()