Skip to content

Commit

Permalink
Make it easier to access dumps and matrices from Python (#2042)
Browse files Browse the repository at this point in the history
This adds some convenience accessors for getting dumps, messages, and
matrices as separate entries from an invocation of `qsharp.run` and
`qsharp.eval`. Previously, the only way to get a state dump from a run
was the awkward:

```python
state = qsharp.StateDump(qsharp.run("DumpMachine()", shots=1, save_events=True)[0]["events"][0].state_dump())
```

This change preserves the existings "events" entry in the saved output,
which has everything intermingled in the order from each shot, but also
introduces dumps, messages, and matrices that will keep just the ordered
output of that type. This makes the above pattern slightly better (and
more discoverable):

```python
state = qsharp.run("DumpMachine()", shots=1, save_events=True)[0]["dumps"][0]
```

This adds similar functionality to `qsharp.eval` which now supports
`save_events=True` to capture output, so for single shot execution you
can use:

```python
state = qsharp.eval("DumpMachine()", save_events=True)["dumps"][0]
```
  • Loading branch information
swernli authored Dec 12, 2024
1 parent 609bef5 commit 502ae34
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 91 deletions.
3 changes: 3 additions & 0 deletions pip/qsharp/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ class Output:
def __str__(self) -> str: ...
def _repr_markdown_(self) -> Optional[str]: ...
def state_dump(self) -> Optional[StateDumpData]: ...
def is_state_dump(self) -> bool: ...
def is_matrix(self) -> bool: ...
def is_message(self) -> bool: ...

class StateDumpData:
"""
Expand Down
223 changes: 133 additions & 90 deletions pip/qsharp/_qsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,133 @@ def get_interpreter() -> Interpreter:
return _interpreter


def eval(source: str) -> Any:
class StateDump:
"""
A state dump returned from the Q# interpreter.
"""

"""
The number of allocated qubits at the time of the dump.
"""
qubit_count: int

__inner: dict
__data: StateDumpData

def __init__(self, data: StateDumpData):
self.__data = data
self.__inner = data.get_dict()
self.qubit_count = data.qubit_count

def __getitem__(self, index: int) -> complex:
return self.__inner.__getitem__(index)

def __iter__(self):
return self.__inner.__iter__()

def __len__(self) -> int:
return len(self.__inner)

def __repr__(self) -> str:
return self.__data.__repr__()

def __str__(self) -> str:
return self.__data.__str__()

def _repr_markdown_(self) -> str:
return self.__data._repr_markdown_()

def check_eq(
self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10
) -> bool:
"""
Checks if the state dump is equal to the given state. This is not mathematical equality,
as the check ignores global phase.
:param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes,
or as a list of real amplitudes.
:param tolerance: The tolerance for the check. Defaults to 1e-10.
"""
phase = None
# Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes
if isinstance(state, list):
state = {i: val for i, val in enumerate(state)}
# Filter out zero states from the state dump and the given state based on tolerance
state = {k: v for k, v in state.items() if abs(v) > tolerance}
inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance}
if len(state) != len(inner_state):
return False
for key in state:
if key not in inner_state:
return False
if phase is None:
# Calculate the phase based on the first state pair encountered.
# Every pair of states after this must have the same phase for the states to be equivalent.
phase = inner_state[key] / state[key]
elif abs(phase - inner_state[key] / state[key]) > tolerance:
# This pair of states does not have the same phase,
# within tolerance, so the equivalence check fails.
return False
return True

def as_dense_state(self) -> List[complex]:
"""
Returns the state dump as a dense list of complex amplitudes. This will include zero amplitudes.
"""
return [self.__inner.get(i, complex(0)) for i in range(2**self.qubit_count)]


class ShotResult(TypedDict):
"""
A single result of a shot.
"""

events: List[Output]
result: Any
messages: List[str]
matrices: List[Output]
dumps: List[StateDump]


def eval(
source: str,
*,
save_events: bool = False,
) -> Any:
"""
Evaluates Q# source code.
Output is printed to console.
:param source: The Q# source code to evaluate.
:returns value: The value returned by the last statement in the source code.
:param save_events: If true, all output will be saved and returned. If false, they will be printed.
:returns value: The value returned by the last statement in the source code or the saved output if `save_events` is true.
:raises QSharpError: If there is an error evaluating the source code.
"""
ipython_helper()

results: ShotResult = {
"events": [],
"result": None,
"messages": [],
"matrices": [],
"dumps": [],
}

def on_save_events(output: Output) -> None:
# Append the output to the last shot's output list
if output.is_matrix():
results["events"].append(output)
results["matrices"].append(output)
elif output.is_state_dump():
state_dump = StateDump(output.state_dump())
results["events"].append(state_dump)
results["dumps"].append(state_dump)
elif output.is_message():
stringified = str(output)
results["events"].append(stringified)
results["messages"].append(stringified)

def callback(output: Output) -> None:
if _in_jupyter:
try:
Expand All @@ -244,21 +359,17 @@ def callback(output: Output) -> None:
telemetry_events.on_eval()
start_time = monotonic()

results = get_interpreter().interpret(source, callback)
results["result"] = get_interpreter().interpret(
source, on_save_events if save_events else callback
)

durationMs = (monotonic() - start_time) * 1000
telemetry_events.on_eval_end(durationMs)

return results


class ShotResult(TypedDict):
"""
A single result of a shot.
"""

events: List[Output]
result: Any
if save_events:
return results
else:
return results["result"]


def run(
Expand Down Expand Up @@ -315,9 +426,17 @@ def print_output(output: Output) -> None:
def on_save_events(output: Output) -> None:
# Append the output to the last shot's output list
results[-1]["events"].append(output)
if output.is_matrix():
results[-1]["matrices"].append(output)
elif output.is_state_dump():
results[-1]["dumps"].append(StateDump(output.state_dump()))
elif output.is_message():
results[-1]["messages"].append(str(output))

for shot in range(shots):
results.append({"result": None, "events": []})
results.append(
{"result": None, "events": [], "messages": [], "matrices": [], "dumps": []}
)
run_results = get_interpreter().run(
entry_expr,
on_save_events if save_events else print_output,
Expand Down Expand Up @@ -482,82 +601,6 @@ def set_classical_seed(seed: Optional[int]) -> None:
get_interpreter().set_classical_seed(seed)


class StateDump:
"""
A state dump returned from the Q# interpreter.
"""

"""
The number of allocated qubits at the time of the dump.
"""
qubit_count: int

__inner: dict
__data: StateDumpData

def __init__(self, data: StateDumpData):
self.__data = data
self.__inner = data.get_dict()
self.qubit_count = data.qubit_count

def __getitem__(self, index: int) -> complex:
return self.__inner.__getitem__(index)

def __iter__(self):
return self.__inner.__iter__()

def __len__(self) -> int:
return len(self.__inner)

def __repr__(self) -> str:
return self.__data.__repr__()

def __str__(self) -> str:
return self.__data.__str__()

def _repr_markdown_(self) -> str:
return self.__data._repr_markdown_()

def check_eq(
self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10
) -> bool:
"""
Checks if the state dump is equal to the given state. This is not mathematical equality,
as the check ignores global phase.
:param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes,
or as a list of real amplitudes.
:param tolerance: The tolerance for the check. Defaults to 1e-10.
"""
phase = None
# Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes
if isinstance(state, list):
state = {i: state[i] for i in range(len(state))}
# Filter out zero states from the state dump and the given state based on tolerance
state = {k: v for k, v in state.items() if abs(v) > tolerance}
inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance}
if len(state) != len(inner_state):
return False
for key in state:
if key not in inner_state:
return False
if phase is None:
# Calculate the phase based on the first state pair encountered.
# Every pair of states after this must have the same phase for the states to be equivalent.
phase = inner_state[key] / state[key]
elif abs(phase - inner_state[key] / state[key]) > tolerance:
# This pair of states does not have the same phase,
# within tolerance, so the equivalence check fails.
return False
return True

def as_dense_state(self) -> List[complex]:
"""
Returns the state dump as a dense list of complex amplitudes. This will include zero amplitudes.
"""
return [self.__inner.get(i, complex(0)) for i in range(2**self.qubit_count)]


def dump_machine() -> StateDump:
"""
Returns the sparse state vector of the simulator as a StateDump object.
Expand Down
12 changes: 12 additions & 0 deletions pip/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,18 @@ impl Output {
DisplayableOutput::Matrix(_) | DisplayableOutput::Message(_) => None,
}
}

fn is_state_dump(&self) -> bool {
matches!(&self.0, DisplayableOutput::State(_))
}

fn is_matrix(&self) -> bool {
matches!(&self.0, DisplayableOutput::Matrix(_))
}

fn is_message(&self) -> bool {
matches!(&self.0, DisplayableOutput::Message(_))
}
}

#[pyclass]
Expand Down
33 changes: 32 additions & 1 deletion pip/tests/test_qsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,35 @@ def test_stdout_multiple_lines() -> None:
assert f.getvalue() == "STATE:\n|0⟩: 1.0000+0.0000𝑖\nHello!\n"


def test_captured_stdout() -> None:
qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted)
f = io.StringIO()
with redirect_stdout(f):
result = qsharp.eval(
'{Message("Hello, world!"); Message("Goodbye!")}', save_events=True
)
assert f.getvalue() == ""
assert len(result["messages"]) == 2
assert result["messages"][0] == "Hello, world!"
assert result["messages"][1] == "Goodbye!"


def test_captured_matrix() -> None:
qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted)
f = io.StringIO()
with redirect_stdout(f):
result = qsharp.eval(
"Std.Diagnostics.DumpOperation(1, qs => H(qs[0]))",
save_events=True,
)
assert f.getvalue() == ""
assert len(result["matrices"]) == 1
assert (
str(result["matrices"][0])
== "MATRIX:\n 0.7071+0.0000𝑖 0.7071+0.0000𝑖\n 0.7071+0.0000𝑖 −0.7071+0.0000𝑖"
)


def test_quantum_seed() -> None:
qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted)
qsharp.set_quantum_seed(42)
Expand Down Expand Up @@ -257,6 +286,7 @@ def test_dump_operation() -> None:
else:
assert res[i][j] == complex(0.0, 0.0)


def test_run_with_noise_produces_noisy_results() -> None:
qsharp.init()
qsharp.set_quantum_seed(0)
Expand All @@ -273,6 +303,7 @@ def test_run_with_noise_produces_noisy_results() -> None:
)
assert result[0] > 5


def test_compile_qir_input_data() -> None:
qsharp.init(target_profile=qsharp.TargetProfile.Base)
qsharp.eval("operation Program() : Result { use q = Qubit(); return M(q) }")
Expand Down Expand Up @@ -324,7 +355,7 @@ def on_result(result):
results = qsharp.run("Foo()", 3, on_result=on_result, save_events=True)
assert (
str(results)
== "[{'result': Zero, 'events': [Hello, world!]}, {'result': Zero, 'events': [Hello, world!]}, {'result': Zero, 'events': [Hello, world!]}]"
== "[{'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}, {'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}, {'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}]"
)
stdout = capsys.readouterr().out
assert stdout == ""
Expand Down

0 comments on commit 502ae34

Please sign in to comment.