Skip to content

Commit

Permalink
Add hooks for remote launch to use argument synthesis. (#2101)
Browse files Browse the repository at this point in the history
* Add hooks for remote launch to use argument synthesis.

* Call the appropriate generate function.
  • Loading branch information
schweitzpgi authored Aug 19, 2024
1 parent ab2c546 commit 37474b2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 29 deletions.
12 changes: 10 additions & 2 deletions runtime/common/BaseRemoteSimulatorQPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,20 @@ class BaseRemoteSimulatorQPU : public cudaq::QPU {

void launchKernel(const std::string &name,
const std::vector<void *> &rawArgs) override {
throw std::runtime_error("launch kernel on raw args not implemented");
launchKernelImpl(name, nullptr, nullptr, 0, 0, &rawArgs);
}

void launchKernel(const std::string &name, void (*kernelFunc)(void *),
void *args, std::uint64_t voidStarSize,
std::uint64_t resultOffset) override {
launchKernelImpl(name, kernelFunc, args, voidStarSize, resultOffset,
nullptr);
}

void launchKernelImpl(const std::string &name, void (*kernelFunc)(void *),
void *args, std::uint64_t voidStarSize,
std::uint64_t resultOffset,
const std::vector<void *> *rawArgs) {
cudaq::info(
"BaseRemoteSimulatorQPU: Launch kernel named '{}' remote QPU {} "
"(simulator = {})",
Expand Down Expand Up @@ -145,7 +153,7 @@ class BaseRemoteSimulatorQPU : public cudaq::QPU {
const bool requestOkay = m_client->sendRequest(
*m_mlirContext, executionContext, /*serializedCodeContext=*/nullptr,
/*vqe_gradient=*/nullptr, /*vqe_optimizer=*/nullptr, /*vqe_n_params=*/0,
m_simName, name, kernelFunc, args, voidStarSize, &errorMsg);
m_simName, name, kernelFunc, args, voidStarSize, &errorMsg, rawArgs);
if (!requestOkay)
throw std::runtime_error("Failed to launch kernel. Error: " + errorMsg);
if (isDirectInvocation &&
Expand Down
68 changes: 42 additions & 26 deletions runtime/common/BaseRestRemoteClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include "common/ArgumentConversion.h"
#include "common/Environment.h"
#include "common/JsonConvert.h"
#include "common/Logger.h"
Expand All @@ -18,6 +19,7 @@
#include "common/UnzipUtils.h"
#include "cudaq.h"
#include "cudaq/Frontend/nvqpp/AttributeNames.h"
#include "cudaq/Optimizer/Builder/Runtime.h"
#include "cudaq/Optimizer/CodeGen/OpenQASMEmitter.h"
#include "cudaq/Optimizer/CodeGen/Passes.h"
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
Expand Down Expand Up @@ -116,12 +118,10 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
return cudaq::RestRequest::REST_PAYLOAD_VERSION;
}

std::string constructKernelPayload(mlir::MLIRContext &mlirContext,
const std::string &name,
void (*kernelFunc)(void *),
const void *args,
std::uint64_t voidStarSize,
std::size_t startingArgIdx) {
std::string constructKernelPayload(
mlir::MLIRContext &mlirContext, const std::string &name,
void (*kernelFunc)(void *), const void *args, std::uint64_t voidStarSize,
std::size_t startingArgIdx, const std::vector<void *> *rawArgs) {
enablePrintMLIREachPass =
getEnvBool("CUDAQ_MLIR_PRINT_EACH_PASS", enablePrintMLIREachPass);

Expand Down Expand Up @@ -175,15 +175,28 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
moduleOp.push_back(funcOp.clone());
}
// Add globals defined in the module.
if (auto globalOp = dyn_cast<cudaq::cc::GlobalOp>(op))
if (auto globalOp = dyn_cast<cc::GlobalOp>(op))
moduleOp.push_back(globalOp.clone());
}

if (args) {
cudaq::info("Run Quake Synth.\n");
if (rawArgs || args) {
mlir::PassManager pm(&mlirContext);
pm.addPass(
cudaq::opt::createQuakeSynthesizer(name, args, startingArgIdx));
if (rawArgs && !rawArgs->empty()) {
cudaq::info("Run Argument Synth.\n");
opt::ArgumentConverter argCon(name, moduleOp);
argCon.gen_drop_front(*rawArgs, startingArgIdx);
std::string kernName = runtime::cudaqGenPrefixName + name;
mlir::SmallVector<mlir::StringRef> kernels = {kernName};
std::string substBuff;
llvm::raw_string_ostream ss(substBuff);
ss << argCon.getSubstitutionModule();
mlir::SmallVector<mlir::StringRef> substs = {substBuff};
pm.addNestedPass<mlir::func::FuncOp>(
opt::createArgumentSynthesisPass(kernels, substs));
} else if (args) {
cudaq::info("Run Quake Synth.\n");
pm.addPass(opt::createQuakeSynthesizer(name, args, startingArgIdx));
}
pm.addPass(mlir::createCanonicalizerPass());
if (enablePrintMLIREachPass) {
moduleOp.getContext()->disableMultithreading();
Expand Down Expand Up @@ -215,7 +228,7 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
"Remote rest platform failed to add passes to pipeline (" + errMsg +
").");

cudaq::opt::addPipelineConvertToQIR(pm);
opt::addPipelineConvertToQIR(pm);

if (failed(pm.run(moduleOp)))
throw std::runtime_error(
Expand All @@ -234,7 +247,8 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
mlir::MLIRContext &mlirContext, cudaq::ExecutionContext &io_context,
const std::string &backendSimName, const std::string &kernelName,
const void *kernelArgs, cudaq::gradient *gradient,
cudaq::optimizer &optimizer, const int n_params) {
cudaq::optimizer &optimizer, const int n_params,
const std::vector<void *> *rawArgs) {
cudaq::RestRequest request(io_context, version());

request.opt = RestRequestOptFields();
Expand All @@ -251,7 +265,7 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
request.code =
constructKernelPayload(mlirContext, kernelName, /*kernelFunc=*/nullptr,
/*kernelArgs=*/kernelArgs,
/*argsSize=*/0, /*startingArgIdx=*/1);
/*argsSize=*/0, /*startingArgIdx=*/1, rawArgs);
request.simulator = backendSimName;
// Remote server seed
// Note: unlike local executions whereby a static instance of the simulator
Expand All @@ -276,7 +290,7 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
cudaq::SerializedCodeExecutionContext *serializedCodeContext,
const std::string &backendSimName, const std::string &kernelName,
void (*kernelFunc)(void *), const void *kernelArgs,
std::uint64_t argsSize) {
std::uint64_t argsSize, const std::vector<void *> *rawArgs) {

cudaq::RestRequest request(io_context, version());
if (serializedCodeContext)
Expand Down Expand Up @@ -320,20 +334,20 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
stateIrPayload1.entryPoint = kernelName1;
stateIrPayload1.ir =
constructKernelPayload(mlirContext, kernelName1, nullptr, args1,
argsSize1, /*startingArgIdx=*/0);
argsSize1, /*startingArgIdx=*/0, nullptr);
stateIrPayload2.entryPoint = kernelName2;
stateIrPayload2.ir =
constructKernelPayload(mlirContext, kernelName2, nullptr, args2,
argsSize2, /*startingArgIdx=*/0);
argsSize2, /*startingArgIdx=*/0, nullptr);
// First kernel of the overlap calculation
request.code = stateIrPayload1.ir;
request.entryPoint = stateIrPayload1.entryPoint;
// Second kernel of the overlap calculation
request.overlapKernel = stateIrPayload2;
} else if (serializedCodeContext == nullptr) {
request.code =
constructKernelPayload(mlirContext, kernelName, kernelFunc,
kernelArgs, argsSize, /*startingArgIdx=*/0);
request.code = constructKernelPayload(mlirContext, kernelName, kernelFunc,
kernelArgs, argsSize,
/*startingArgIdx=*/0, rawArgs);
}
request.simulator = backendSimName;
// Remote server seed
Expand Down Expand Up @@ -362,7 +376,8 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
const int vqe_n_params, const std::string &backendSimName,
const std::string &kernelName, void (*kernelFunc)(void *),
const void *kernelArgs, std::uint64_t argsSize,
std::string *optionalErrorMsg) override {
std::string *optionalErrorMsg,
const std::vector<void *> *rawArgs) override {
if (isDisallowed(io_context.name))
throw std::runtime_error(
io_context.name +
Expand All @@ -372,10 +387,10 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
if (vqe_n_params > 0)
return constructVQEJobRequest(mlirContext, io_context, backendSimName,
kernelName, kernelArgs, vqe_gradient,
*vqe_optimizer, vqe_n_params);
*vqe_optimizer, vqe_n_params, rawArgs);
return constructJobRequest(mlirContext, io_context, serializedCodeContext,
backendSimName, kernelName, kernelFunc,
kernelArgs, argsSize);
kernelArgs, argsSize, rawArgs);
}();

if (request.code.empty() && (serializedCodeContext == nullptr ||
Expand Down Expand Up @@ -909,7 +924,8 @@ class BaseNvcfRuntimeClient : public cudaq::BaseRemoteRestRuntimeClient {
const int vqe_n_params, const std::string &backendSimName,
const std::string &kernelName, void (*kernelFunc)(void *),
const void *kernelArgs, std::uint64_t argsSize,
std::string *optionalErrorMsg) override {
std::string *optionalErrorMsg,
const std::vector<void *> *rawArgs) override {
if (isDisallowed(io_context.name))
throw std::runtime_error(
io_context.name +
Expand Down Expand Up @@ -941,10 +957,10 @@ class BaseNvcfRuntimeClient : public cudaq::BaseRemoteRestRuntimeClient {
if (vqe_n_params > 0)
return constructVQEJobRequest(mlirContext, io_context, backendSimName,
kernelName, kernelArgs, vqe_gradient,
*vqe_optimizer, vqe_n_params);
*vqe_optimizer, vqe_n_params, rawArgs);
return constructJobRequest(mlirContext, io_context, serializedCodeContext,
backendSimName, kernelName, kernelFunc,
kernelArgs, argsSize);
kernelArgs, argsSize, rawArgs);
}();

if (request.code.empty() && (serializedCodeContext == nullptr ||
Expand Down
3 changes: 2 additions & 1 deletion runtime/common/RemoteKernelExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class RemoteRuntimeClient
const int vqe_n_params, const std::string &backendSimName,
const std::string &kernelName, void (*kernelFunc)(void *),
const void *kernelArgs, std::uint64_t argsSize,
std::string *optionalErrorMsg = nullptr) = 0;
std::string *optionalErrorMsg = nullptr,
const std::vector<void *> *rawArgs = nullptr) = 0;
// Destructor
virtual ~RemoteRuntimeClient() = default;
};
Expand Down

0 comments on commit 37474b2

Please sign in to comment.