Skip to content

Commit

Permalink
* Adding partial implementation of runge-kutta integrator
Browse files Browse the repository at this point in the history
* Removing time_stepper implemented specifically only for runge-kutta
  integrator as we now have a general cudm_time_stepper
* Updating CMakelists.txt accordingly
* Removing runge_kutta_test_helpers as we will be using test_mocks
  instead

Signed-off-by: Sachin Pisal <spisal@nvidia.com>
  • Loading branch information
sacpis committed Jan 31, 2025
1 parent b74e2bf commit 97486c1
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 339 deletions.
18 changes: 17 additions & 1 deletion runtime/cudaq/base_integrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,27 @@ class BaseIntegrator {
virtual void post_init() = 0;

public:
/// @brief Default constructor
BaseIntegrator() = default;

/// @brief Constructor to initialize the integrator with a state and time
/// stepper.
/// @param initial_state Initial quantum state.
/// @param t0 Initial time.
/// @param stepper Time stepper instance.
BaseIntegrator(const TState &initial_state, double t0,
std::shared_ptr<BaseTimeStepper<TState>> stepper)
: state(initial_state), t(t0), stepper(std::move(stepper)) {}

virtual ~BaseIntegrator() = default;

/// @brief Set the initial state and time
void set_state(const TState &initial_state, double t0 = 0.0) {
state = initial_state;
t = t0;
}

/// @brief Set the system parameters (dimensions, schedule, and operators)
void set_system(
const std::map<int, int> &dimensions, std::shared_ptr<Schedule> schedule,
std::shared_ptr<operator_sum> hamiltonian,
Expand All @@ -48,8 +62,10 @@ class BaseIntegrator {
this->collapse_operators = collapse_operators;
}

virtual void integrate(double t) = 0;
/// @brief Perform integration to the target time.
virtual void integrate(double target_time) = 0;

/// @brief Get the current time and state.
std::pair<double, TState> get_state() const { return {t, state}; }
};
} // namespace cudaq
13 changes: 12 additions & 1 deletion runtime/cudaq/dynamics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-ctad-maybe-unsupported")
set(INTERFACE_POSITION_INDEPENDENT_CODE ON)

set(CUDAQ_OPS_SRC
scalar_operators.cpp elementary_operators.cpp product_operators.cpp operator_sum.cpp schedule.cpp definition.cpp helpers.cpp rydberg_hamiltonian.cpp cudm_helpers.cpp cudm_state.cpp cudm_time_stepper.cpp
scalar_operators.cpp
elementary_operators.cpp
product_operators.cpp
operator_sum.cpp
schedule.cpp
definition.cpp
helpers.cpp
rydberg_hamiltonian.cpp
cudm_helpers.cpp
cudm_state.cpp
cudm_time_stepper.cpp
runge_kutta_integrator.cpp
)

set(CUQUANTUM_INSTALL_PREFIX "/usr/local/lib/python3.10/dist-packages/cuquantum")
Expand Down
65 changes: 65 additions & 0 deletions runtime/cudaq/dynamics/runge_kutta_integrator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*******************************************************************************
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "cudaq/runge_kutta_integrator.h"
#include <iostream>

using namespace cudaq;

namespace cudaq {
void runge_kutta_integrator::integrate(double target_time) {
if (!stepper) {
throw std::runtime_error("Time stepper is not initialized.");
}

double dt = integrator_options["dt"];
if (dt <= 0) {
throw std::invalid_argument("Invalid time step size for integration.");
}

auto handle = state.get_handle();
auto hilbertSpaceDims = state.get_hilbert_space_dims();

while (t < target_time) {
double step_size = std::min(dt, target_time - 1);

std::cout << "Runge-Kutta step at time " << t << " with step size: " << step_size << std::endl;

// Empty vectors of same size as state.get_raw_data()
std::vector<std::complex<double>> zero_state(state.get_raw_data().size(), {0.0, 0.0});

cudm_state k1(handle, zero_state, hilbertSpaceDims);
cudm_state k2(handle, zero_state, hilbertSpaceDims);
cudm_state k3(handle, zero_state, hilbertSpaceDims);
cudm_state k4(handle, zero_state, hilbertSpaceDims);

if (substeps_ == 1) {
// Euler method (1st order)
k1 = stepper->compute(state, t, step_size);
state = k1;
} else if (substeps_ == 2) {
// Midpoint method (2nd order)
k1 = stepper->compute(state, t, step_size / 2.0);
k2 = stepper->compute(k1, t + step_size / 2.0, step_size);
state = (k1 + k2) * 0.5;
} else if (substeps_ == 4) {
// Runge-Kutta method (4th order)
k1 = stepper->compute(state, t, step_size / 2.0);
k2 = stepper->compute(k1, t + step_size / 2.0, step_size / 2.0);
k3 = stepper->compute(k2, t + step_size / 2.0, step_size);
k4 = stepper->compute(k3, t + step_size, step_size);
state = (k1 + k2 * 2.0 + k3 * 2.0 + k4) * (1.0 / 6.0);
}

// Update time
t += step_size;
}

std::cout << "Integration complete. Final time: " << t << std::endl;
}
}
49 changes: 24 additions & 25 deletions runtime/cudaq/runge_kutta_integrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,38 @@
#pragma once

#include "base_integrator.h"
#include "runge_kutta_time_stepper.h"
#include "cudaq/cudm_state.h"
#include "cudaq/cudm_time_stepper.h"
#include <memory>

namespace cudaq {
template <typename TState>
class RungeKuttaIntegrator : public BaseIntegrator<TState> {
class runge_kutta_integrator : public BaseIntegrator<cudm_state> {
public:
using DerivativeFunction = std::function<TState(const TState &, double)>;

explicit RungeKuttaIntegrator(DerivativeFunction f)
: stepper(std::make_shared<RungeKuttaTimeStepper<TState>>(f)) {}

// Initializes the integrator
void post_init() override {
if (!this->stepper) {
throw std::runtime_error("Time stepper is not set");
/// @brief Constructor to initialize the Runge-Kutta integrator
/// @param initial_state Initial quantum state.
/// @param t0 Initial time.
/// @param stepper Time stepper instance.
/// @param substeps Number of Runge-Kutta substeps (must be 1, 2, or 4)
runge_kutta_integrator(const cudm_state &initial_state, double t0,
std::shared_ptr<cudm_time_stepper> stepper,
int substeps = 4)
: BaseIntegrator(initial_state, t0, stepper), substeps_(substeps) {
if (substeps_ != 1 && substeps_ != 2 && substeps_ != 4) {
throw std::invalid_argument("Runge-Kutta substeps must be 1, 2, or 4.");
}
post_init();
}

// Advances the system's state from current time to `t`
void integrate(double target_t) override {
if (!this->schedule || !this->hamiltonian) {
throw std::runtime_error("System is not properly set!");
}
/// @brief Perform Runge-Kutta integration until the target time.
/// @param target_time The final time to integrate to.
void integrate(double t) override;

while (this->t < target_t) {
stepper->compute(this->state, this->t);
// Time step size
this->t += 0.01;
}
}
protected:
/// @brief Any post-initialization setup
void post_init() override {}

private:
std::shared_ptr<RungeKuttaTimeStepper<TState>> stepper;
// Number of substeps in RK integration (1, 2, or 4)
int substeps_;
};
} // namespace cudaq
} // namespace cudaq
33 changes: 0 additions & 33 deletions runtime/cudaq/runge_kutta_time_stepper.h

This file was deleted.

1 change: 0 additions & 1 deletion unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ set(CUDAQ_RUNTIME_TEST_SOURCES
dynamics/elementary_ops_simple.cpp
dynamics/elementary_ops_arithmetic.cpp
dynamics/product_operators_arithmetic.cpp
dynamics/test_runge_kutta_time_stepper.cpp
dynamics/test_runge_kutta_integrator.cpp
dynamics/test_helpers.cpp
dynamics/rydberg_hamiltonian.cpp
Expand Down
24 changes: 0 additions & 24 deletions unittests/dynamics/runge_kutta_test_helpers.h

This file was deleted.

Loading

0 comments on commit 97486c1

Please sign in to comment.