diff --git a/pip/qsharp/_native.pyi b/pip/qsharp/_native.pyi index 0bf47a2d5b..adef263e44 100644 --- a/pip/qsharp/_native.pyi +++ b/pip/qsharp/_native.pyi @@ -262,6 +262,9 @@ class Output: def __str__(self) -> str: ... def _repr_markdown_(self) -> Optional[str]: ... def state_dump(self) -> Optional[StateDumpData]: ... + def is_state_dump(self) -> bool: ... + def is_matrix(self) -> bool: ... + def is_message(self) -> bool: ... class StateDumpData: """ diff --git a/pip/qsharp/_qsharp.py b/pip/qsharp/_qsharp.py index 84a083c770..5bf6ba4f1a 100644 --- a/pip/qsharp/_qsharp.py +++ b/pip/qsharp/_qsharp.py @@ -219,18 +219,133 @@ def get_interpreter() -> Interpreter: return _interpreter -def eval(source: str) -> Any: +class StateDump: + """ + A state dump returned from the Q# interpreter. + """ + + """ + The number of allocated qubits at the time of the dump. + """ + qubit_count: int + + __inner: dict + __data: StateDumpData + + def __init__(self, data: StateDumpData): + self.__data = data + self.__inner = data.get_dict() + self.qubit_count = data.qubit_count + + def __getitem__(self, index: int) -> complex: + return self.__inner.__getitem__(index) + + def __iter__(self): + return self.__inner.__iter__() + + def __len__(self) -> int: + return len(self.__inner) + + def __repr__(self) -> str: + return self.__data.__repr__() + + def __str__(self) -> str: + return self.__data.__str__() + + def _repr_markdown_(self) -> str: + return self.__data._repr_markdown_() + + def check_eq( + self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10 + ) -> bool: + """ + Checks if the state dump is equal to the given state. This is not mathematical equality, + as the check ignores global phase. + + :param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes, + or as a list of real amplitudes. + :param tolerance: The tolerance for the check. Defaults to 1e-10. + """ + phase = None + # Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes + if isinstance(state, list): + state = {i: val for i, val in enumerate(state)} + # Filter out zero states from the state dump and the given state based on tolerance + state = {k: v for k, v in state.items() if abs(v) > tolerance} + inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance} + if len(state) != len(inner_state): + return False + for key in state: + if key not in inner_state: + return False + if phase is None: + # Calculate the phase based on the first state pair encountered. + # Every pair of states after this must have the same phase for the states to be equivalent. + phase = inner_state[key] / state[key] + elif abs(phase - inner_state[key] / state[key]) > tolerance: + # This pair of states does not have the same phase, + # within tolerance, so the equivalence check fails. + return False + return True + + def as_dense_state(self) -> List[complex]: + """ + Returns the state dump as a dense list of complex amplitudes. This will include zero amplitudes. + """ + return [self.__inner.get(i, complex(0)) for i in range(2**self.qubit_count)] + + +class ShotResult(TypedDict): + """ + A single result of a shot. + """ + + events: List[Output] + result: Any + messages: List[str] + matrices: List[Output] + dumps: List[StateDump] + + +def eval( + source: str, + *, + save_events: bool = False, +) -> Any: """ Evaluates Q# source code. Output is printed to console. :param source: The Q# source code to evaluate. - :returns value: The value returned by the last statement in the source code. + :param save_events: If true, all output will be saved and returned. If false, they will be printed. + :returns value: The value returned by the last statement in the source code or the saved output if `save_events` is true. :raises QSharpError: If there is an error evaluating the source code. """ ipython_helper() + results: ShotResult = { + "events": [], + "result": None, + "messages": [], + "matrices": [], + "dumps": [], + } + + def on_save_events(output: Output) -> None: + # Append the output to the last shot's output list + if output.is_matrix(): + results["events"].append(output) + results["matrices"].append(output) + elif output.is_state_dump(): + state_dump = StateDump(output.state_dump()) + results["events"].append(state_dump) + results["dumps"].append(state_dump) + elif output.is_message(): + stringified = str(output) + results["events"].append(stringified) + results["messages"].append(stringified) + def callback(output: Output) -> None: if _in_jupyter: try: @@ -244,21 +359,17 @@ def callback(output: Output) -> None: telemetry_events.on_eval() start_time = monotonic() - results = get_interpreter().interpret(source, callback) + results["result"] = get_interpreter().interpret( + source, on_save_events if save_events else callback + ) durationMs = (monotonic() - start_time) * 1000 telemetry_events.on_eval_end(durationMs) - return results - - -class ShotResult(TypedDict): - """ - A single result of a shot. - """ - - events: List[Output] - result: Any + if save_events: + return results + else: + return results["result"] def run( @@ -315,9 +426,17 @@ def print_output(output: Output) -> None: def on_save_events(output: Output) -> None: # Append the output to the last shot's output list results[-1]["events"].append(output) + if output.is_matrix(): + results[-1]["matrices"].append(output) + elif output.is_state_dump(): + results[-1]["dumps"].append(StateDump(output.state_dump())) + elif output.is_message(): + results[-1]["messages"].append(str(output)) for shot in range(shots): - results.append({"result": None, "events": []}) + results.append( + {"result": None, "events": [], "messages": [], "matrices": [], "dumps": []} + ) run_results = get_interpreter().run( entry_expr, on_save_events if save_events else print_output, @@ -482,82 +601,6 @@ def set_classical_seed(seed: Optional[int]) -> None: get_interpreter().set_classical_seed(seed) -class StateDump: - """ - A state dump returned from the Q# interpreter. - """ - - """ - The number of allocated qubits at the time of the dump. - """ - qubit_count: int - - __inner: dict - __data: StateDumpData - - def __init__(self, data: StateDumpData): - self.__data = data - self.__inner = data.get_dict() - self.qubit_count = data.qubit_count - - def __getitem__(self, index: int) -> complex: - return self.__inner.__getitem__(index) - - def __iter__(self): - return self.__inner.__iter__() - - def __len__(self) -> int: - return len(self.__inner) - - def __repr__(self) -> str: - return self.__data.__repr__() - - def __str__(self) -> str: - return self.__data.__str__() - - def _repr_markdown_(self) -> str: - return self.__data._repr_markdown_() - - def check_eq( - self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10 - ) -> bool: - """ - Checks if the state dump is equal to the given state. This is not mathematical equality, - as the check ignores global phase. - - :param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes, - or as a list of real amplitudes. - :param tolerance: The tolerance for the check. Defaults to 1e-10. - """ - phase = None - # Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes - if isinstance(state, list): - state = {i: state[i] for i in range(len(state))} - # Filter out zero states from the state dump and the given state based on tolerance - state = {k: v for k, v in state.items() if abs(v) > tolerance} - inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance} - if len(state) != len(inner_state): - return False - for key in state: - if key not in inner_state: - return False - if phase is None: - # Calculate the phase based on the first state pair encountered. - # Every pair of states after this must have the same phase for the states to be equivalent. - phase = inner_state[key] / state[key] - elif abs(phase - inner_state[key] / state[key]) > tolerance: - # This pair of states does not have the same phase, - # within tolerance, so the equivalence check fails. - return False - return True - - def as_dense_state(self) -> List[complex]: - """ - Returns the state dump as a dense list of complex amplitudes. This will include zero amplitudes. - """ - return [self.__inner.get(i, complex(0)) for i in range(2**self.qubit_count)] - - def dump_machine() -> StateDump: """ Returns the sparse state vector of the simulator as a StateDump object. diff --git a/pip/src/interpreter.rs b/pip/src/interpreter.rs index 78b4b00a10..f540360f77 100644 --- a/pip/src/interpreter.rs +++ b/pip/src/interpreter.rs @@ -587,6 +587,18 @@ impl Output { DisplayableOutput::Matrix(_) | DisplayableOutput::Message(_) => None, } } + + fn is_state_dump(&self) -> bool { + matches!(&self.0, DisplayableOutput::State(_)) + } + + fn is_matrix(&self) -> bool { + matches!(&self.0, DisplayableOutput::Matrix(_)) + } + + fn is_message(&self) -> bool { + matches!(&self.0, DisplayableOutput::Message(_)) + } } #[pyclass] diff --git a/pip/tests/test_qsharp.py b/pip/tests/test_qsharp.py index 0779972fa5..a1f9892587 100644 --- a/pip/tests/test_qsharp.py +++ b/pip/tests/test_qsharp.py @@ -35,6 +35,35 @@ def test_stdout_multiple_lines() -> None: assert f.getvalue() == "STATE:\n|0⟩: 1.0000+0.0000𝑖\nHello!\n" +def test_captured_stdout() -> None: + qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted) + f = io.StringIO() + with redirect_stdout(f): + result = qsharp.eval( + '{Message("Hello, world!"); Message("Goodbye!")}', save_events=True + ) + assert f.getvalue() == "" + assert len(result["messages"]) == 2 + assert result["messages"][0] == "Hello, world!" + assert result["messages"][1] == "Goodbye!" + + +def test_captured_matrix() -> None: + qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted) + f = io.StringIO() + with redirect_stdout(f): + result = qsharp.eval( + "Std.Diagnostics.DumpOperation(1, qs => H(qs[0]))", + save_events=True, + ) + assert f.getvalue() == "" + assert len(result["matrices"]) == 1 + assert ( + str(result["matrices"][0]) + == "MATRIX:\n 0.7071+0.0000𝑖 0.7071+0.0000𝑖\n 0.7071+0.0000𝑖 −0.7071+0.0000𝑖" + ) + + def test_quantum_seed() -> None: qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted) qsharp.set_quantum_seed(42) @@ -257,6 +286,7 @@ def test_dump_operation() -> None: else: assert res[i][j] == complex(0.0, 0.0) + def test_run_with_noise_produces_noisy_results() -> None: qsharp.init() qsharp.set_quantum_seed(0) @@ -273,6 +303,7 @@ def test_run_with_noise_produces_noisy_results() -> None: ) assert result[0] > 5 + def test_compile_qir_input_data() -> None: qsharp.init(target_profile=qsharp.TargetProfile.Base) qsharp.eval("operation Program() : Result { use q = Qubit(); return M(q) }") @@ -324,7 +355,7 @@ def on_result(result): results = qsharp.run("Foo()", 3, on_result=on_result, save_events=True) assert ( str(results) - == "[{'result': Zero, 'events': [Hello, world!]}, {'result': Zero, 'events': [Hello, world!]}, {'result': Zero, 'events': [Hello, world!]}]" + == "[{'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}, {'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}, {'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}]" ) stdout = capsys.readouterr().out assert stdout == ""