From 1b9eabad0637dee4de67eb3964f18184cf907704 Mon Sep 17 00:00:00 2001 From: Daniel Hollarek Date: Wed, 18 Dec 2024 20:03:39 +0100 Subject: [PATCH] pattern: Merged plot_all back into PatternDB --- xrdpattern/pattern/db.py | 34 ++++++++++++++++++++++++++--- xrdpattern/pattern/visualization.py | 26 ---------------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/xrdpattern/pattern/db.py b/xrdpattern/pattern/db.py index 51bf64c..ccd5969 100644 --- a/xrdpattern/pattern/db.py +++ b/xrdpattern/pattern/db.py @@ -5,11 +5,13 @@ import random from typing import Optional +from matplotlib import pyplot as plt + from holytools.logging import LoggerFactory from holytools.userIO import TrackedCollection from xrdpattern.parsing import MasterParser, Formats from xrdpattern.xrd import XrayInfo, XrdData -from .visualization import histograms, plot_all +from .visualization import histograms, multiplot from .pattern import XrdPattern patterdb_logger = LoggerFactory.get_logger(name=__name__) @@ -119,9 +121,35 @@ def __eq__(self, other : PatternDB): return False return True - def show_all(self, single_plot : bool = False, limit_patterns : int = 100): + def show_all(self, single_plot : bool = False, limit_patterns : int = 100, title : Optional[str] = None): patterns = self.patterns if len(self.patterns) <= limit_patterns else random.sample(self.patterns, limit_patterns) - plot_all(patterns=patterns, single_plot=single_plot, db_name=self.name) + if single_plot: + data = [p.get_pattern_data() for p in patterns] + fig, ax = plt.subplots(dpi=600) + for x, y in data: + ax.plot(x, y, linewidth=0.25, alpha=0.75) + + ax.set_xlabel(r'$2\theta$ [$^\circ$]') + ax.set_ylabel('Standardized relative intensity (a.u.)') + if title: + ax.set_title(title) + else: + ax.set_title(f'XRD Patterns from {self.name}') + plt.show() + + else: + batch_size = 32 + j = 0 + while j < len(patterns): + pattern_batch = patterns[j:j + batch_size] + for k, p in enumerate(pattern_batch): + p.metadata.filename = p.get_name() or f'pattern_{j + k}' + multiplot(patterns=pattern_batch, start_idx=j) + j += batch_size + + user_input = input(f'Press enter to continue or q to quit') + if user_input.lower() == 'q': + break def show_histograms(self, save_fpath : Optional[str] = None, attach_colorbar : bool = True): histograms(patterns=self.patterns, attach_colorbar=attach_colorbar, save_fpath=save_fpath) diff --git a/xrdpattern/pattern/visualization.py b/xrdpattern/pattern/visualization.py index a7d35d5..c576b58 100644 --- a/xrdpattern/pattern/visualization.py +++ b/xrdpattern/pattern/visualization.py @@ -14,32 +14,6 @@ # ----------------------------------------- -def plot_all(patterns : list[XrdPattern], db_name : Optional[str] = None, single_plot : bool = False): - if single_plot: - data = [p.get_pattern_data() for p in patterns] - fig, ax = plt.subplots(dpi=600) - for x, y in data: - ax.plot(x, y, linewidth=0.25, linestyle='--', alpha=0.75) - - ax.set_xlabel(r'$2\theta$ [$^\circ$]') - ax.set_ylabel('Standardized relative intensity (a.u.)') - if db_name: - ax.set_title(f'{len(patterns)} of patterns from {db_name}') - plt.show() - - else: - batch_size = 32 - j = 0 - while j < len(patterns): - pattern_batch = patterns[j:j + batch_size] - for k, p in enumerate(pattern_batch): - p.metadata.filename = p.get_name() or f'pattern_{j + k}' - multiplot(patterns=pattern_batch, start_idx=j) - j += batch_size - - user_input = input(f'Press enter to continue or q to quit') - if user_input.lower() == 'q': - break def multiplot(patterns : list[XrdPattern], start_idx : int): labels = [p.get_name() for p in patterns]