Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Montecarlo globals (limited scope) #2705

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,22 @@ jobs:
echo "TARDIS_PIP_PATH=$directory_path" >> $GITHUB_ENV

- name: Run tests
run: pytest tardis ${{ env.PYTEST_FLAGS }} -m "not continuum"
run: pytest tardis ${{ env.PYTEST_FLAGS }} -m "not (continuum or rpacket_tracking)"
working-directory: ${{ env.TARDIS_PIP_PATH }}
if: always()

- name: Upload to Codecov
run: bash <(curl -s https://codecov.io/bash)

- name: Run continuum tests
run: pytest tardis ${{ env.PYTEST_FLAGS }} -m continuum
working-directory: ${{ env.TARDIS_PIP_PATH }}
if: always()

- name: Upload to Codecov
run: bash <(curl -s https://codecov.io/bash)
- name: Run rpacket tracking tests
run: pytest tardis ${{ env.PYTEST_FLAGS }} -m rpacket_tracking
working-directory: ${{ env.TARDIS_PIP_PATH }}
if: always()

- name: Refdata Generation tests
run: pytest tardis ${{ env.PYTEST_FLAGS }} --generate-reference
Expand Down
32 changes: 12 additions & 20 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from tardis.simulation import Simulation
from tardis.tests.fixtures.atom_data import DEFAULT_ATOM_DATA_UUID
from tardis.tests.fixtures.regression_data import RegressionData
from tardis.transport.montecarlo import RPacket, montecarlo_configuration
from tardis.transport.montecarlo import RPacket
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.configuration.base import (
MonteCarloConfiguration,
)
from tardis.transport.montecarlo.estimators import radfield_mc_estimators
from tardis.transport.montecarlo.numba_interface import opacity_state_initialize
from tardis.transport.montecarlo.packet_collections import (
Expand Down Expand Up @@ -235,9 +239,7 @@ def packet(self):

@property
def verysimple_packet_collection(self):
return (
self.nb_simulation_verysimple.transport.transport_state.packet_collection
)
return self.nb_simulation_verysimple.transport.transport_state.packet_collection

@property
def nb_simulation_verysimple(self):
Expand All @@ -259,7 +261,6 @@ def verysimple_opacity_state(self):
self.nb_simulation_verysimple.plasma,
line_interaction_type="macroatom",
disable_line_scattering=self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING,
continuum_processes_enabled=self.nb_simulation_verysimple.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)

@property
Expand All @@ -268,27 +269,19 @@ def verysimple_enable_full_relativity(self):

@property
def verysimple_disable_line_scattering(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING
)
return self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING

@property
def verysimple_continuum_processes_enabled(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED
)
return montecarlo_globals.CONTINUUM_PROCESSES_ENABLED

@property
def verysimple_tau_russian(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
)
return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN

@property
def verysimple_survival_probability(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
)
return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY

@property
def static_packet(self):
Expand Down Expand Up @@ -359,10 +352,10 @@ def verysimple_radfield_mc_estimators(self):

@property
def montecarlo_configuration(self):
return montecarlo_configuration.MonteCarloConfiguration()
return MonteCarloConfiguration()

@property
def rpacket_tracker(self):
def rpacket_tracker(self):
return RPacketTracker(0)

@property
Expand Down Expand Up @@ -396,7 +389,6 @@ def geometry(self):
v_outer=np.array([-1, -1], dtype=np.float64),
)


@property
def estimators(self):
return radfield_mc_estimators.RadiationFieldMCEstimators(
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/transport_montecarlo_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from asv_runner.benchmarks.mark import parameterize



class BenchmarkMontecarloMontecarloNumbaInteraction(BenchmarkBase):
"""
Class to benchmark the numba interaction function.
Expand Down Expand Up @@ -52,7 +51,6 @@ def time_line_scatter(self, line_interaction_type):
line_interaction_type,
self.verysimple_opacity_state,
self.verysimple_enable_full_relativity,
self.verysimple_continuum_processes_enabled,
)

@parameterize(
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/transport_montecarlo_numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,4 @@ def time_opacity_state_initialize(self, input_params):
plasma,
line_interaction_type,
self.verysimple_disable_line_scattering,
self.verysimple_continuum_processes_enabled,
)
)
27 changes: 4 additions & 23 deletions benchmarks/transport_montecarlo_vpacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def v_packet(self):
next_line_id=0,
index=0,
)

@property
def r_packet(self):
return RPacket(
Expand Down Expand Up @@ -62,9 +62,6 @@ def time_trace_vpacket_within_shell(self):
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
enable_full_relativity = self.verysimple_enable_full_relativity
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)

# Give the vpacket a reasonable line ID
self.v_packet_initialize_line_id(
Expand All @@ -80,7 +77,6 @@ def time_trace_vpacket_within_shell(self):
verysimple_time_explosion,
verysimple_opacity_state,
enable_full_relativity,
continuum_processes_enabled,
)

def time_trace_vpacket(self):
Expand All @@ -91,9 +87,6 @@ def time_trace_vpacket(self):
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
enable_full_relativity = self.verysimple_enable_full_relativity
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)
tau_russian = self.verysimple_tau_russian
survival_probability = self.verysimple_survival_probability

Expand All @@ -116,7 +109,6 @@ def time_trace_vpacket(self):
tau_russian,
survival_probability,
enable_full_relativity,
continuum_processes_enabled,
)

@property
Expand All @@ -139,9 +131,6 @@ def time_trace_bad_vpacket(self):
enable_full_relativity = self.verysimple_enable_full_relativity
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)
tau_russian = self.verysimple_tau_russian
survival_probability = self.verysimple_survival_probability

Expand All @@ -153,20 +142,13 @@ def time_trace_bad_vpacket(self):
tau_russian,
survival_probability,
enable_full_relativity,
continuum_processes_enabled,
)

@parameterize(
{
"Paramters": [
{
"tau_russian": 10.0,
"survival_possibility": 0.0
},
{
"tau_russian": 15.0,
"survival_possibility": 0.1
},
{"tau_russian": 10.0, "survival_possibility": 0.0},
{"tau_russian": 15.0, "survival_possibility": 0.1},
]
}
)
Expand All @@ -180,6 +162,5 @@ def time_trace_vpacket_volley(self, parameters):
False,
parameters["tau_russian"],
parameters["survival_possibility"],
False
False,
)

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ text_file_format = "rst"
markers = [
# continuum tests
"continuum",
# rpacket tracking tests
"rpacket_tracking"
]

[tool.tardis]
Expand Down
2 changes: 1 addition & 1 deletion tardis/opacities/opacities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
SIGMA_THOMSON,
)

Expand Down
6 changes: 3 additions & 3 deletions tardis/opacities/opacity_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numba.experimental import jitclass

from tardis.opacities.tau_sobolev import calculate_sobolev_line_opacity
from tardis.transport.montecarlo.configuration import montecarlo_globals

opacity_state_spec = [
("electron_density", float64[:]),
Expand Down Expand Up @@ -110,7 +111,6 @@ def opacity_state_initialize(
plasma,
line_interaction_type,
disable_line_scattering,
continuum_processes_enabled,
):
"""
Initialize the OpacityState object and copy over the data over from TARDIS Plasma
Expand Down Expand Up @@ -156,7 +156,7 @@ def opacity_state_initialize(
)
# TODO: Fix setting of block references for non-continuum mode

if continuum_processes_enabled:
if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
macro_block_references = plasma.macro_block_references
else:
macro_block_references = plasma.atomic_data.macro_atom_references[
Expand All @@ -169,7 +169,7 @@ def opacity_state_initialize(
"destination_level_idx"
].values
transition_line_id = plasma.macro_atom_data["lines_idx"].values
if continuum_processes_enabled:
if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
bf_threshold_list_nu = plasma.nu_i.loc[
plasma.level2continuum_idx.index
].values
Expand Down
3 changes: 2 additions & 1 deletion tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tardis.plasma.standard_plasmas import assemble_plasma
from tardis.simulation.convergence import ConvergenceSolver
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots

Expand Down Expand Up @@ -199,7 +200,7 @@ def __init__(
self._callbacks = OrderedDict()
self._cb_next_id = 0

self.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED = (
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED = (
not self.plasma.continuum_interaction_species.empty
)

Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/frame_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.numba_config import C_SPEED_OF_LIGHT
from tardis.transport.montecarlo.configuration.constants import C_SPEED_OF_LIGHT


@njit(**njit_dict_no_parallel)
Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/geometry/calculate_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
C_SPEED_OF_LIGHT,
MISS_DISTANCE,
SIGMA_THOMSON,
Expand Down
18 changes: 9 additions & 9 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
from astropy import units as u
from numba import cuda, set_num_threads

import tardis.transport.montecarlo.configuration.constants as constants
from tardis import constants as const
from tardis.io.logger import montecarlo_tracking as mc_tracker
from tardis.io.util import HDFWriterMixin
from tardis.transport.montecarlo import (
montecarlo_main_loop,
numba_config,
)
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.configuration.base import (
MonteCarloConfiguration,
configuration_initialize,
)
from tardis.transport.montecarlo.estimators.radfield_mc_estimators import (
initialize_estimator_statistics,
)
from tardis.transport.montecarlo.formal_integral import FormalIntegrator
from tardis.transport.montecarlo.montecarlo_configuration import (
MonteCarloConfiguration,
configuration_initialize,
)
from tardis.transport.montecarlo.montecarlo_transport_state import (
MonteCarloTransportState,
)
Expand Down Expand Up @@ -116,7 +117,6 @@ def initialize_transport_state(
plasma,
self.line_interaction_type,
self.montecarlo_configuration.DISABLE_LINE_SCATTERING,
self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)
transport_state = MonteCarloTransportState(
packet_collection,
Expand Down Expand Up @@ -210,7 +210,7 @@ def run(
update_iterations_pbar(1)
refresh_packet_pbar()
# Condition for Checking if RPacket Tracking is enabled
if self.montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if self.enable_rpacket_tracking:
transport_state.rpacket_tracker = rpacket_trackers

if self.transport_state.rpacket_tracker is not None:
Expand Down Expand Up @@ -245,10 +245,10 @@ def from_config(
"Likely bug in formal integral - "
"will not give same results."
)
numba_config.SIGMA_THOMSON = 1e-200
constants.SIGMA_THOMSON = 1e-200
else:
logger.debug("Electron scattering switched on")
numba_config.SIGMA_THOMSON = const.sigma_T.to("cm^2").value
constants.SIGMA_THOMSON = const.sigma_T.to("cm^2").value

spectrum_frequency = quantity_linspace(
config.spectrum.stop.to("Hz", u.spectral()),
Expand Down
Empty file.
Loading
Loading