Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hooks for remote launch to use argument synthesis. #2101

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions runtime/common/BaseRemoteSimulatorQPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,20 @@ class BaseRemoteSimulatorQPU : public cudaq::QPU {

void launchKernel(const std::string &name,
const std::vector<void *> &rawArgs) override {
throw std::runtime_error("launch kernel on raw args not implemented");
launchKernelImpl(name, nullptr, nullptr, 0, 0, &rawArgs);
}

void launchKernel(const std::string &name, void (*kernelFunc)(void *),
void *args, std::uint64_t voidStarSize,
std::uint64_t resultOffset) override {
launchKernelImpl(name, kernelFunc, args, voidStarSize, resultOffset,
nullptr);
}

void launchKernelImpl(const std::string &name, void (*kernelFunc)(void *),
void *args, std::uint64_t voidStarSize,
std::uint64_t resultOffset,
const std::vector<void *> *rawArgs) {
cudaq::info(
"BaseRemoteSimulatorQPU: Launch kernel named '{}' remote QPU {} "
"(simulator = {})",
Expand Down Expand Up @@ -145,7 +153,7 @@ class BaseRemoteSimulatorQPU : public cudaq::QPU {
const bool requestOkay = m_client->sendRequest(
*m_mlirContext, executionContext, /*serializedCodeContext=*/nullptr,
/*vqe_gradient=*/nullptr, /*vqe_optimizer=*/nullptr, /*vqe_n_params=*/0,
m_simName, name, kernelFunc, args, voidStarSize, &errorMsg);
m_simName, name, kernelFunc, args, voidStarSize, &errorMsg, rawArgs);
if (!requestOkay)
throw std::runtime_error("Failed to launch kernel. Error: " + errorMsg);
if (isDirectInvocation &&
Expand Down
68 changes: 42 additions & 26 deletions runtime/common/BaseRestRemoteClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include "common/ArgumentConversion.h"
#include "common/Environment.h"
#include "common/JsonConvert.h"
#include "common/Logger.h"
Expand All @@ -18,6 +19,7 @@
#include "common/UnzipUtils.h"
#include "cudaq.h"
#include "cudaq/Frontend/nvqpp/AttributeNames.h"
#include "cudaq/Optimizer/Builder/Runtime.h"
#include "cudaq/Optimizer/CodeGen/OpenQASMEmitter.h"
#include "cudaq/Optimizer/CodeGen/Passes.h"
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
Expand Down Expand Up @@ -116,12 +118,10 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
return cudaq::RestRequest::REST_PAYLOAD_VERSION;
}

std::string constructKernelPayload(mlir::MLIRContext &mlirContext,
const std::string &name,
void (*kernelFunc)(void *),
const void *args,
std::uint64_t voidStarSize,
std::size_t startingArgIdx) {
std::string constructKernelPayload(
mlir::MLIRContext &mlirContext, const std::string &name,
void (*kernelFunc)(void *), const void *args, std::uint64_t voidStarSize,
std::size_t startingArgIdx, const std::vector<void *> *rawArgs) {
enablePrintMLIREachPass =
getEnvBool("CUDAQ_MLIR_PRINT_EACH_PASS", enablePrintMLIREachPass);

Expand Down Expand Up @@ -175,15 +175,28 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
moduleOp.push_back(funcOp.clone());
}
// Add globals defined in the module.
if (auto globalOp = dyn_cast<cudaq::cc::GlobalOp>(op))
if (auto globalOp = dyn_cast<cc::GlobalOp>(op))
moduleOp.push_back(globalOp.clone());
}

if (args) {
cudaq::info("Run Quake Synth.\n");
if (rawArgs || args) {
mlir::PassManager pm(&mlirContext);
pm.addPass(
cudaq::opt::createQuakeSynthesizer(name, args, startingArgIdx));
if (rawArgs && !rawArgs->empty()) {
cudaq::info("Run Argument Synth.\n");
opt::ArgumentConverter argCon(name, moduleOp);
argCon.gen_drop_front(*rawArgs, startingArgIdx);
std::string kernName = runtime::cudaqGenPrefixName + name;
mlir::SmallVector<mlir::StringRef> kernels = {kernName};
std::string substBuff;
llvm::raw_string_ostream ss(substBuff);
ss << argCon.getSubstitutionModule();
mlir::SmallVector<mlir::StringRef> substs = {substBuff};
pm.addNestedPass<mlir::func::FuncOp>(
opt::createArgumentSynthesisPass(kernels, substs));
} else if (args) {
cudaq::info("Run Quake Synth.\n");
pm.addPass(opt::createQuakeSynthesizer(name, args, startingArgIdx));
}
pm.addPass(mlir::createCanonicalizerPass());
if (enablePrintMLIREachPass) {
moduleOp.getContext()->disableMultithreading();
Expand Down Expand Up @@ -215,7 +228,7 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
"Remote rest platform failed to add passes to pipeline (" + errMsg +
").");

cudaq::opt::addPipelineConvertToQIR(pm);
opt::addPipelineConvertToQIR(pm);

if (failed(pm.run(moduleOp)))
throw std::runtime_error(
Expand All @@ -234,7 +247,8 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
mlir::MLIRContext &mlirContext, cudaq::ExecutionContext &io_context,
const std::string &backendSimName, const std::string &kernelName,
const void *kernelArgs, cudaq::gradient *gradient,
cudaq::optimizer &optimizer, const int n_params) {
cudaq::optimizer &optimizer, const int n_params,
const std::vector<void *> *rawArgs) {
cudaq::RestRequest request(io_context, version());

request.opt = RestRequestOptFields();
Expand All @@ -251,7 +265,7 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
request.code =
constructKernelPayload(mlirContext, kernelName, /*kernelFunc=*/nullptr,
/*kernelArgs=*/kernelArgs,
/*argsSize=*/0, /*startingArgIdx=*/1);
/*argsSize=*/0, /*startingArgIdx=*/1, rawArgs);
request.simulator = backendSimName;
// Remote server seed
// Note: unlike local executions whereby a static instance of the simulator
Expand All @@ -276,7 +290,7 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
cudaq::SerializedCodeExecutionContext *serializedCodeContext,
const std::string &backendSimName, const std::string &kernelName,
void (*kernelFunc)(void *), const void *kernelArgs,
std::uint64_t argsSize) {
std::uint64_t argsSize, const std::vector<void *> *rawArgs) {

cudaq::RestRequest request(io_context, version());
if (serializedCodeContext)
Expand Down Expand Up @@ -320,20 +334,20 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
stateIrPayload1.entryPoint = kernelName1;
stateIrPayload1.ir =
constructKernelPayload(mlirContext, kernelName1, nullptr, args1,
argsSize1, /*startingArgIdx=*/0);
argsSize1, /*startingArgIdx=*/0, nullptr);
schweitzpgi marked this conversation as resolved.
Show resolved Hide resolved
stateIrPayload2.entryPoint = kernelName2;
stateIrPayload2.ir =
constructKernelPayload(mlirContext, kernelName2, nullptr, args2,
argsSize2, /*startingArgIdx=*/0);
argsSize2, /*startingArgIdx=*/0, nullptr);
// First kernel of the overlap calculation
request.code = stateIrPayload1.ir;
request.entryPoint = stateIrPayload1.entryPoint;
// Second kernel of the overlap calculation
request.overlapKernel = stateIrPayload2;
} else if (serializedCodeContext == nullptr) {
request.code =
constructKernelPayload(mlirContext, kernelName, kernelFunc,
kernelArgs, argsSize, /*startingArgIdx=*/0);
request.code = constructKernelPayload(mlirContext, kernelName, kernelFunc,
kernelArgs, argsSize,
/*startingArgIdx=*/0, rawArgs);
}
request.simulator = backendSimName;
// Remote server seed
Expand Down Expand Up @@ -362,7 +376,8 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
const int vqe_n_params, const std::string &backendSimName,
const std::string &kernelName, void (*kernelFunc)(void *),
const void *kernelArgs, std::uint64_t argsSize,
std::string *optionalErrorMsg) override {
std::string *optionalErrorMsg,
const std::vector<void *> *rawArgs) override {
if (isDisallowed(io_context.name))
throw std::runtime_error(
io_context.name +
Expand All @@ -372,10 +387,10 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
if (vqe_n_params > 0)
return constructVQEJobRequest(mlirContext, io_context, backendSimName,
kernelName, kernelArgs, vqe_gradient,
*vqe_optimizer, vqe_n_params);
*vqe_optimizer, vqe_n_params, rawArgs);
return constructJobRequest(mlirContext, io_context, serializedCodeContext,
backendSimName, kernelName, kernelFunc,
kernelArgs, argsSize);
kernelArgs, argsSize, rawArgs);
}();

if (request.code.empty() && (serializedCodeContext == nullptr ||
Expand Down Expand Up @@ -909,7 +924,8 @@ class BaseNvcfRuntimeClient : public cudaq::BaseRemoteRestRuntimeClient {
const int vqe_n_params, const std::string &backendSimName,
const std::string &kernelName, void (*kernelFunc)(void *),
const void *kernelArgs, std::uint64_t argsSize,
std::string *optionalErrorMsg) override {
std::string *optionalErrorMsg,
const std::vector<void *> *rawArgs) override {
if (isDisallowed(io_context.name))
throw std::runtime_error(
io_context.name +
Expand Down Expand Up @@ -941,10 +957,10 @@ class BaseNvcfRuntimeClient : public cudaq::BaseRemoteRestRuntimeClient {
if (vqe_n_params > 0)
return constructVQEJobRequest(mlirContext, io_context, backendSimName,
kernelName, kernelArgs, vqe_gradient,
*vqe_optimizer, vqe_n_params);
*vqe_optimizer, vqe_n_params, rawArgs);
return constructJobRequest(mlirContext, io_context, serializedCodeContext,
backendSimName, kernelName, kernelFunc,
kernelArgs, argsSize);
kernelArgs, argsSize, rawArgs);
}();

if (request.code.empty() && (serializedCodeContext == nullptr ||
Expand Down
3 changes: 2 additions & 1 deletion runtime/common/RemoteKernelExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class RemoteRuntimeClient
const int vqe_n_params, const std::string &backendSimName,
const std::string &kernelName, void (*kernelFunc)(void *),
const void *kernelArgs, std::uint64_t argsSize,
std::string *optionalErrorMsg = nullptr) = 0;
std::string *optionalErrorMsg = nullptr,
const std::vector<void *> *rawArgs = nullptr) = 0;
// Destructor
virtual ~RemoteRuntimeClient() = default;
};
Expand Down
Loading