From a38e6647a633d3d3e848d6924d676274e2fc8ca4 Mon Sep 17 00:00:00 2001 From: shun suzuki Date: Thu, 3 Oct 2024 16:40:30 +0900 Subject: [PATCH] add missing impl --- .coveragerc | 2 - pyautd3/__init__.py | 6 + pyautd3/autd_error.py | 5 - pyautd3/controller/controller.py | 18 +- pyautd3/driver/datagram/__init__.py | 6 +- pyautd3/driver/datagram/modulation/fir.py | 42 ++++ .../driver/datagram/modulation/modulation.py | 2 + pyautd3/driver/datagram/phase_corr.py | 39 ++++ pyautd3/driver/datagram/silencer.py | 172 +++++++-------- pyautd3/driver/utils.py | 8 + pyautd3/modulation/__init__.py | 5 + pyautd3/modulation/audio_file/csv.py | 54 ++++- pyautd3/modulation/audio_file/raw_pcm.py | 51 ++++- pyautd3/modulation/audio_file/wav.py | 43 +++- pyautd3/modulation/custom.py | 43 +++- pyautd3/modulation/fourier.py | 2 + pyautd3/modulation/mixer.py | 2 + pyautd3/modulation/resample.py | 47 +++++ pyautd3/modulation/static.py | 2 + pyautd3/native_methods/autd3capi_driver.py | 2 +- tests/driver/datagram/test_debug.py | 19 ++ tests/driver/datagram/test_modulation.py | 198 ++++++++++++++++++ tests/driver/datagram/test_phase_corr.py | 21 ++ tests/driver/datagram/test_silencer.py | 43 ++-- tests/driver/firmware/fpga/test_fpga_state.py | 2 +- tests/driver/test_utils.py | 16 +- tests/link/test_soem.py | 6 +- tests/modulation/audio_file/.gitignore | 1 + tests/modulation/audio_file/test_csv.py | 22 +- tests/modulation/audio_file/test_rawpcm.py | 21 +- tests/modulation/audio_file/test_wav.py | 25 ++- tests/modulation/test_custom.py | 26 ++- tests/test_autd.py | 14 +- tools/wrapper-generator/src/python.rs | 2 +- 34 files changed, 801 insertions(+), 166 deletions(-) create mode 100644 pyautd3/driver/datagram/modulation/fir.py create mode 100644 pyautd3/driver/datagram/phase_corr.py create mode 100644 pyautd3/modulation/resample.py create mode 100644 tests/driver/datagram/test_phase_corr.py create mode 100644 tests/modulation/audio_file/.gitignore diff --git a/.coveragerc b/.coveragerc index ec7aee0..0d3c303 100644 --- a/.coveragerc +++ b/.coveragerc @@ -8,8 +8,6 @@ exclude_lines = pragma: no cover - def update_geometry - async def update_geometry_async [run] omit = diff --git a/pyautd3/__init__.py b/pyautd3/__init__.py index 95bac9e..761ea95 100644 --- a/pyautd3/__init__.py +++ b/pyautd3/__init__.py @@ -6,10 +6,13 @@ Clear, DebugSettings, DebugType, + FixedCompletionTime, + FixedUpdateRate, FociSTM, ForceFan, GainSTM, GainSTMMode, + PhaseCorrection, PulseWidthEncoder, ReadsFPGAState, Silencer, @@ -53,6 +56,7 @@ def tracing_init() -> None: "tracing_init", "Controller", "AUTD3", + "PhaseCorrection", "Drive", "EmitIntensity", "Phase", @@ -60,6 +64,8 @@ def tracing_init() -> None: "SamplingConfig", "Clear", "Silencer", + "FixedCompletionTime", + "FixedUpdateRate", "DebugSettings", "DebugType", "ReadsFPGAState", diff --git a/pyautd3/autd_error.py b/pyautd3/autd_error.py index c1fa2dd..94cdf33 100644 --- a/pyautd3/autd_error.py +++ b/pyautd3/autd_error.py @@ -29,11 +29,6 @@ def __init__(self: "InvalidDatagramTypeError") -> None: super().__init__("Invalid datagram type") -class InvalidPlotConfigError(AUTDError): - def __init__(self: "InvalidPlotConfigError") -> None: - super().__init__("Invalid plot config type") - - class CantBeZeroError(AUTDError): def __init__(self: "CantBeZeroError", v: int) -> None: super().__init__(f"Value must be greater than 0: {v}") diff --git a/pyautd3/controller/controller.py b/pyautd3/controller/controller.py index fae8248..b8b39c6 100644 --- a/pyautd3/controller/controller.py +++ b/pyautd3/controller/controller.py @@ -14,6 +14,7 @@ from pyautd3.driver.firmware_version import FirmwareInfo from pyautd3.driver.geometry import Device, Geometry from pyautd3.driver.link import Link, LinkBuilder +from pyautd3.driver.utils import _validate_nonzero_u32 from pyautd3.native_methods.autd3capi import ControllerBuilderPtr, ControllerPtr, RuntimePtr from pyautd3.native_methods.autd3capi import NativeMethods as Base from pyautd3.native_methods.autd3capi_driver import DatagramPtr, GeometryPtr, HandlePtr, ResultI32 @@ -45,7 +46,12 @@ def with_send_interval(self: "_Builder", interval: timedelta) -> "_Builder": self._ptr = Base().controller_builder_with_send_interval(self._ptr, int(interval.total_seconds() * 1000 * 1000 * 1000)) return self - def with_timer_resolution(self: "_Builder", resolution: int) -> "_Builder": + def with_receive_interval(self: "_Builder", interval: timedelta) -> "_Builder": + self._ptr = Base().controller_builder_with_receive_interval(self._ptr, int(interval.total_seconds() * 1000 * 1000 * 1000)) + return self + + def with_timer_resolution(self: "_Builder", resolution: int | None) -> "_Builder": + resolution = 0 if resolution is None else _validate_nonzero_u32(resolution) self._ptr = Base().controller_builder_with_timer_resolution(self._ptr, resolution) return self @@ -265,7 +271,7 @@ def get_firmware_info(i: int) -> FirmwareInfo: async def close_async(self: "Controller") -> None: r: ResultI32 | None = None - if self._handle._0 is not None and self._ptr._0 is not None: + if self._handle._0: future: asyncio.Future = asyncio.Future() loop = asyncio.get_event_loop() ffi_future = Base().controller_close(self._ptr) @@ -275,9 +281,8 @@ async def close_async(self: "Controller") -> None: ), ) r = await future - self._ptr._0 = None - if self._handle._0 is not None: Base().delete_runtime(self._runtime) + self._ptr._0 = None self._runtime._0 = None self._handle._0 = None if r is not None: @@ -285,11 +290,10 @@ async def close_async(self: "Controller") -> None: def close(self: "Controller") -> None: r: ResultI32 | None = None - if self._handle._0 is not None and self._ptr._0 is not None: - r = Base().wait_result_i_32(self._handle, Base().controller_close(self._ptr)) - self._ptr._0 = None if self._handle._0 is not None: + r = Base().wait_result_i_32(self._handle, Base().controller_close(self._ptr)) Base().delete_runtime(self._runtime) + self._ptr._0 = None self._runtime._0 = None self._handle._0 = None if r is not None: diff --git a/pyautd3/driver/datagram/__init__.py b/pyautd3/driver/datagram/__init__.py index a004c98..b2a1873 100644 --- a/pyautd3/driver/datagram/__init__.py +++ b/pyautd3/driver/datagram/__init__.py @@ -10,16 +10,19 @@ from .datagram import Datagram from .debug import DebugSettings, DebugType from .force_fan import ForceFan +from .phase_corr import PhaseCorrection from .pulse_width_encoder import PulseWidthEncoder from .reads_fpga_state import ReadsFPGAState from .segment import SwapSegment -from .silencer import Silencer +from .silencer import FixedCompletionTime, FixedUpdateRate, Silencer from .stm import FociSTM, GainSTM, GainSTMMode from .synchronize import Synchronize __all__ = [ "Clear", "Silencer", + "FixedCompletionTime", + "FixedUpdateRate", "DebugSettings", "DebugType", "ReadsFPGAState", @@ -31,6 +34,7 @@ "GainSTMMode", "SwapSegment", "PulseWidthEncoder", + "PhaseCorrection", ] diff --git a/pyautd3/driver/datagram/modulation/fir.py b/pyautd3/driver/datagram/modulation/fir.py new file mode 100644 index 0000000..cc0b2fd --- /dev/null +++ b/pyautd3/driver/datagram/modulation/fir.py @@ -0,0 +1,42 @@ +import ctypes +from collections.abc import Iterable +from typing import Generic, TypeVar + +import numpy as np + +from pyautd3.driver.datagram.modulation.modulation import ModulationBase +from pyautd3.driver.datagram.modulation.radiation_pressure import IntoModulationRadiationPressure +from pyautd3.native_methods.autd3capi import NativeMethods as Base +from pyautd3.native_methods.autd3capi_driver import ModulationPtr + +from .cache import IntoModulationCache + +M = TypeVar("M", bound=ModulationBase) + + +class Fir( + IntoModulationCache["Fir[M]"], + IntoModulationRadiationPressure["Fir[M]"], + ModulationBase["Fir[M]"], + Generic[M], +): + _m: M + _coef: np.ndarray + + def __init__(self: "Fir[M]", m: M, iterable: Iterable[float]) -> None: + self._m = m + self._loop_behavior = m._loop_behavior + self._coef = np.fromiter(iterable, dtype=ctypes.c_float) + + def _modulation_ptr(self: "Fir[M]") -> ModulationPtr: + return Base().modulation_with_fir( + self._m._modulation_ptr(), + self._loop_behavior, + self._coef.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), # type: ignore[arg-type] + len(self._coef), + ) + + +class IntoModulationFir(ModulationBase[M], Generic[M]): + def with_fir(self: M, iterable: Iterable[float]) -> "Fir[M]": + return Fir(self, iterable) diff --git a/pyautd3/driver/datagram/modulation/modulation.py b/pyautd3/driver/datagram/modulation/modulation.py index c45be6f..41cf74e 100644 --- a/pyautd3/driver/datagram/modulation/modulation.py +++ b/pyautd3/driver/datagram/modulation/modulation.py @@ -4,6 +4,7 @@ from pyautd3.driver.datagram.modulation.base import ModulationBase from pyautd3.driver.datagram.modulation.cache import IntoModulationCache +from pyautd3.driver.datagram.modulation.fir import IntoModulationFir from pyautd3.driver.datagram.modulation.radiation_pressure import IntoModulationRadiationPressure from pyautd3.driver.defined.freq import Freq from pyautd3.driver.firmware.fpga.sampling_config import SamplingConfig @@ -16,6 +17,7 @@ class Modulation( IntoModulationCache[M], IntoModulationRadiationPressure[M], + IntoModulationFir[M], ModulationBase[M], Generic[M], metaclass=ABCMeta, diff --git a/pyautd3/driver/datagram/phase_corr.py b/pyautd3/driver/datagram/phase_corr.py new file mode 100644 index 0000000..1fccd9b --- /dev/null +++ b/pyautd3/driver/datagram/phase_corr.py @@ -0,0 +1,39 @@ +import ctypes +import threading +from collections.abc import Callable + +from pyautd3.driver.datagram.with_parallel_threshold import IntoDatagramWithParallelThreshold +from pyautd3.driver.datagram.with_timeout import IntoDatagramWithTimeout +from pyautd3.driver.firmware.fpga.phase import Phase +from pyautd3.driver.geometry import Geometry +from pyautd3.driver.geometry.device import Device +from pyautd3.driver.geometry.transducer import Transducer +from pyautd3.native_methods.autd3capi import NativeMethods as Base +from pyautd3.native_methods.autd3capi_driver import DatagramPtr, GeometryPtr + +from .datagram import Datagram + + +class PhaseCorrection( + IntoDatagramWithTimeout["PhaseCorrection"], + IntoDatagramWithParallelThreshold["PhaseCorrection"], + Datagram, +): + _cache: dict[int, Callable[[Transducer], Phase]] + _lock: threading.Lock + + def __init__(self: "PhaseCorrection", f: Callable[[Device], Callable[[Transducer], Phase]]) -> None: + super().__init__() + self._cache = {} + self._lock = threading.Lock() + + def f_native(_context: ctypes.c_void_p, geometry_ptr: GeometryPtr, dev_idx: int, tr_idx: int) -> int: + if dev_idx not in self._cache: + with self._lock: + self._cache[dev_idx] = f(Device(dev_idx, geometry_ptr)) + return self._cache[dev_idx](Transducer(tr_idx, Base().device(geometry_ptr, dev_idx))).value + + self._f_native = ctypes.CFUNCTYPE(ctypes.c_uint8, ctypes.c_void_p, GeometryPtr, ctypes.c_uint16, ctypes.c_uint8)(f_native) + + def _datagram_ptr(self: "PhaseCorrection", geometry: Geometry) -> DatagramPtr: + return Base().datagram_phase_corr(self._f_native, None, geometry._ptr) # type: ignore[arg-type] diff --git a/pyautd3/driver/datagram/silencer.py b/pyautd3/driver/datagram/silencer.py index f13cbe4..7d9435d 100644 --- a/pyautd3/driver/datagram/silencer.py +++ b/pyautd3/driver/datagram/silencer.py @@ -1,4 +1,5 @@ from datetime import timedelta +from typing import Generic, TypeVar from pyautd3.driver.datagram.modulation.base import ModulationBase from pyautd3.driver.datagram.stm.foci import FociSTM @@ -6,102 +7,109 @@ from pyautd3.driver.datagram.with_parallel_threshold import IntoDatagramWithParallelThreshold from pyautd3.driver.datagram.with_timeout import IntoDatagramWithTimeout from pyautd3.driver.geometry import Geometry +from pyautd3.driver.utils import _validate_nonzero_u16 from pyautd3.native_methods.autd3capi import NativeMethods as Base from pyautd3.native_methods.autd3capi_driver import DatagramPtr, SilencerTarget from .datagram import Datagram +class FixedCompletionTime: + intensity: timedelta + phase: timedelta + + def __init__(self: "FixedCompletionTime", *, intensity: timedelta, phase: timedelta) -> None: + self.intensity = intensity + self.phase = phase + + def _is_valid( + self: "FixedCompletionTime", + v: ModulationBase | FociSTM | GainSTM, + strict_mode: bool, # noqa: FBT001 + target: SilencerTarget, + ) -> bool: + return bool( + Base().datagram_silencer_fixed_completion_time_is_valid( + self._datagram_ptr(strict_mode, target), + v._sampling_config_intensity()._inner, + v._sampling_config_phase()._inner, + ), + ) + + def _datagram_ptr(self: "FixedCompletionTime", strict_mode: bool, target: SilencerTarget) -> DatagramPtr: # noqa: FBT001 + return Base().datagram_silencer_from_completion_time( + int(self.intensity.total_seconds() * 1000 * 1000 * 1000), + int(self.phase.total_seconds() * 1000 * 1000 * 1000), + strict_mode, + target, + ) + + +class FixedUpdateRate: + intensity: int + phase: int + + def __init__(self: "FixedUpdateRate", *, intensity: int, phase: int) -> None: + self.intensity = _validate_nonzero_u16(intensity) + self.phase = _validate_nonzero_u16(phase) + + def _is_valid(self: "FixedUpdateRate", v: ModulationBase | FociSTM | GainSTM, strict_mode: bool, target: SilencerTarget) -> bool: # noqa: FBT001 + return bool( + Base().datagram_silencer_fixed_update_rate_is_valid( + self._datagram_ptr(strict_mode, target), + v._sampling_config_intensity()._inner, + v._sampling_config_phase()._inner, + ), + ) + + def _datagram_ptr(self: "FixedUpdateRate", _strict_mode: bool, target: SilencerTarget) -> DatagramPtr: # noqa: FBT001 + return Base().datagram_silencer_from_update_rate( + self.intensity, + self.phase, + target, + ) + + +T = TypeVar("T", FixedCompletionTime, FixedUpdateRate) + + class Silencer( IntoDatagramWithTimeout["Silencer"], IntoDatagramWithParallelThreshold["Silencer"], Datagram, + Generic[T], ): - class FixedUpdateRate( - IntoDatagramWithTimeout["Silencer.FixedUpdateRate"], - IntoDatagramWithParallelThreshold["Silencer.FixedUpdateRate"], - Datagram, - ): - _value_intensity: int - _value_phase: int - _target: SilencerTarget - - def __init__(self: "Silencer.FixedUpdateRate", value_intensity: int, value_phase: int) -> None: - super().__init__() - self._value_intensity = value_intensity - self._value_phase = value_phase - self._target = SilencerTarget.Intensity - - def with_target(self: "Silencer.FixedUpdateRate", target: SilencerTarget) -> "Silencer.FixedUpdateRate": - self._target = target - return self - - def _datagram_ptr(self: "Silencer.FixedUpdateRate", _: Geometry) -> DatagramPtr: - return Base().datagram_silencer_from_update_rate( - self._value_intensity, - self._value_phase, - self._target, - ) - - def is_valid(self: "Silencer.FixedUpdateRate", target: ModulationBase | FociSTM | GainSTM) -> bool: - return bool( - Base().datagram_silencer_fixed_update_rate_is_valid( - self._datagram_ptr(None), # type: ignore[arg-type] - target._sampling_config_intensity()._inner, - target._sampling_config_phase()._inner, - ), - ) - - class FixedCompletionTime(Datagram): - _value_intensity: timedelta - _value_phase: timedelta - _strict_mode: bool - _target: SilencerTarget - - def __init__(self: "Silencer.FixedCompletionTime", value_intensity: timedelta, value_phase: timedelta) -> None: - super().__init__() - self._value_intensity = value_intensity - self._value_phase = value_phase - self._strict_mode = True - self._target = SilencerTarget.Intensity - - def with_target(self: "Silencer.FixedCompletionTime", target: SilencerTarget) -> "Silencer.FixedCompletionTime": - self._target = target - return self - - def with_strict_mode(self: "Silencer.FixedCompletionTime", mode: bool) -> "Silencer.FixedCompletionTime": # noqa: FBT001 - self._strict_mode = mode - return self - - def _datagram_ptr(self: "Silencer.FixedCompletionTime", _: Geometry) -> DatagramPtr: - return Base().datagram_silencer_from_completion_time( - int(self._value_intensity.total_seconds() * 1000 * 1000 * 1000), - int(self._value_phase.total_seconds() * 1000 * 1000 * 1000), - self._strict_mode, - self._target, - ) - - def is_valid(self: "Silencer.FixedCompletionTime", target: ModulationBase | FociSTM | GainSTM) -> bool: - return bool( - Base().datagram_silencer_fixed_completion_time_is_valid( - self._datagram_ptr(None), # type: ignore[arg-type] - target._sampling_config_intensity()._inner, - target._sampling_config_phase()._inner, - ), - ) + _inner: T + _strict_mode: bool + _target: SilencerTarget - @staticmethod - def from_update_rate(value_intensity: int, value_phase: int) -> "FixedUpdateRate": - return Silencer.FixedUpdateRate(value_intensity, value_phase) + def __init__(self: "Silencer[T]", config: T) -> None: + super().__init__() + self._inner = config + self._strict_mode = True + self._target = SilencerTarget.Intensity - @staticmethod - def from_completion_time(value_intensity: timedelta, value_phase: timedelta) -> "FixedCompletionTime": - return Silencer.FixedCompletionTime(value_intensity, value_phase) + def with_target(self: "Silencer[T]", target: SilencerTarget) -> "Silencer[T]": + self._target = target + return self + + def with_strict_mode(self: "Silencer[FixedCompletionTime]", mode: bool) -> "Silencer[FixedCompletionTime]": # noqa: FBT001 + if not isinstance(self._inner, FixedCompletionTime): # pragma: no cover + msg = "Strict mode is only available for Silencer[FixedCompletionTime]" # pragma: no cover + raise TypeError(msg) # pragma: no cover + self._strict_mode = mode + return self + + def is_valid(self: "Silencer[T]", target: ModulationBase | FociSTM | GainSTM) -> bool: + return self._inner._is_valid(target, self._strict_mode, self._target) + + def _datagram_ptr(self: "Silencer[T]", _: Geometry) -> DatagramPtr: + return self._inner._datagram_ptr(self._strict_mode, self._target) @staticmethod - def disable() -> "FixedCompletionTime": - return Silencer.from_completion_time(timedelta(microseconds=25), timedelta(microseconds=25)) + def disable() -> "Silencer[FixedCompletionTime]": + return Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25), phase=timedelta(microseconds=25))) @staticmethod - def default() -> "FixedCompletionTime": - return Silencer.from_completion_time(timedelta(microseconds=250), timedelta(microseconds=1000)) + def default() -> "Silencer[FixedCompletionTime]": + return Silencer(FixedCompletionTime(intensity=timedelta(microseconds=250), phase=timedelta(microseconds=1000))) diff --git a/pyautd3/driver/utils.py b/pyautd3/driver/utils.py index a488d7e..57b8e69 100644 --- a/pyautd3/driver/utils.py +++ b/pyautd3/driver/utils.py @@ -12,3 +12,11 @@ def _validate_nonzero_u16(value: int) -> int: if value <= 0 or value > 0xFFFF: # noqa: PLR2004 raise ValueError return value + + +def _validate_nonzero_u32(value: int) -> int: + if not isinstance(value, int): + raise TypeError + if value <= 0 or value > 0xFFFFFFFF: # noqa: PLR2004 + raise ValueError + return value diff --git a/pyautd3/modulation/__init__.py b/pyautd3/modulation/__init__.py index 87fc26f..4a183d9 100644 --- a/pyautd3/modulation/__init__.py +++ b/pyautd3/modulation/__init__.py @@ -2,6 +2,7 @@ from .fourier import Fourier from .mixer import Mixer from .modulation import Modulation +from .resample import BlackMan, Rectangular, Resampler, SincInterpolation from .sine import Sine from .square import Square from .static import Static @@ -14,4 +15,8 @@ "Mixer", "Square", "Custom", + "Resampler", + "BlackMan", + "Rectangular", + "SincInterpolation", ] diff --git a/pyautd3/modulation/audio_file/csv.py b/pyautd3/modulation/audio_file/csv.py index 0643b93..fe6d43b 100644 --- a/pyautd3/modulation/audio_file/csv.py +++ b/pyautd3/modulation/audio_file/csv.py @@ -3,9 +3,11 @@ from pyautd3.driver.datagram.modulation.base import ModulationBase from pyautd3.driver.datagram.modulation.cache import IntoModulationCache +from pyautd3.driver.datagram.modulation.fir import IntoModulationFir from pyautd3.driver.datagram.modulation.radiation_pressure import IntoModulationRadiationPressure from pyautd3.driver.defined.freq import Freq from pyautd3.driver.firmware.fpga.sampling_config import SamplingConfig +from pyautd3.modulation.resample import Resampler from pyautd3.native_methods.autd3capi_driver import ModulationPtr from pyautd3.native_methods.autd3capi_modulation_audio_file import ( NativeMethods as ModulationAudioFile, @@ -15,30 +17,60 @@ class Csv( IntoModulationCache["Csv"], + IntoModulationFir["Csv"], IntoModulationRadiationPressure["Csv"], ModulationBase["Csv"], ): _path: Path - _config: SamplingConfig + _config: SamplingConfig | tuple[Freq[float], SamplingConfig, Resampler] _deliminator: str - def __init__(self: "Csv", path: Path, config: SamplingConfig | Freq[int] | Freq[float] | timedelta) -> None: + def __private_init__(self: "Csv", path: Path, config: SamplingConfig | tuple[Freq[float], SamplingConfig, Resampler]) -> None: super().__init__() self._path = path - self._config = SamplingConfig(config) + self._config = config self._deliminator = "," + def __init__(self: "Csv", path: Path, config: SamplingConfig | Freq[int] | Freq[float] | timedelta) -> None: + self.__private_init__(path, SamplingConfig(config)) + + @staticmethod + def new_with_resampler( + path: Path, + source: Freq[float], + target: SamplingConfig | Freq[int] | Freq[float] | timedelta, + resampler: Resampler, + ) -> "Csv": + instance = super(Csv, Csv).__new__(Csv) + instance.__private_init__(path, (source, SamplingConfig(target), resampler)) + return instance + def with_deliminator(self: "Csv", deliminator: str) -> "Csv": self._deliminator = deliminator return self def _modulation_ptr(self: "Csv") -> ModulationPtr: delim = self._deliminator.encode("utf-8") - return _validate_ptr( - ModulationAudioFile().modulation_audio_file_csv( - str(self._path).encode("utf-8"), - self._config._inner, - delim[0], - self._loop_behavior, - ), - ) + path = str(self._path).encode("utf-8") + match self._config: + case (Freq(), SamplingConfig(), Resampler()): + (source, target, resampler) = self._config # type: ignore[misc] + return _validate_ptr( + ModulationAudioFile().modulation_audio_file_csv_with_resample( + path, + delim[0], + self._loop_behavior, + source.hz, + target._inner, + resampler._dyn_resampler(), + ), + ) + case _: + return _validate_ptr( + ModulationAudioFile().modulation_audio_file_csv( + str(self._path).encode("utf-8"), + self._config._inner, # type: ignore[union-attr] + delim[0], + self._loop_behavior, + ), + ) diff --git a/pyautd3/modulation/audio_file/raw_pcm.py b/pyautd3/modulation/audio_file/raw_pcm.py index 4f9ace4..3b754bf 100644 --- a/pyautd3/modulation/audio_file/raw_pcm.py +++ b/pyautd3/modulation/audio_file/raw_pcm.py @@ -3,9 +3,11 @@ from pyautd3.driver.datagram.modulation.base import ModulationBase from pyautd3.driver.datagram.modulation.cache import IntoModulationCache +from pyautd3.driver.datagram.modulation.fir import IntoModulationFir from pyautd3.driver.datagram.modulation.radiation_pressure import IntoModulationRadiationPressure from pyautd3.driver.defined.freq import Freq from pyautd3.driver.firmware.fpga.sampling_config import SamplingConfig +from pyautd3.modulation.resample import Resampler from pyautd3.native_methods.autd3capi_driver import ModulationPtr from pyautd3.native_methods.autd3capi_modulation_audio_file import ( NativeMethods as ModulationAudioFile, @@ -15,23 +17,52 @@ class RawPCM( IntoModulationCache["RawPCM"], + IntoModulationFir["RawPCM"], IntoModulationRadiationPressure["RawPCM"], ModulationBase["RawPCM"], ): _path: Path - _config: SamplingConfig + _config: SamplingConfig | tuple[Freq[float], SamplingConfig, Resampler] _sample_rate: Freq[int] - def __init__(self: "RawPCM", path: Path, config: SamplingConfig | Freq[int] | Freq[float] | timedelta) -> None: + def __private_init__(self: "RawPCM", path: Path, config: SamplingConfig | tuple[Freq[float], SamplingConfig, Resampler]) -> None: super().__init__() self._path = path - self._config = SamplingConfig(config) + self._config = config + + def __init__(self: "RawPCM", path: Path, config: SamplingConfig | Freq[int] | Freq[float] | timedelta) -> None: + self.__private_init__(path, SamplingConfig(config)) + + @staticmethod + def new_with_resampler( + path: Path, + source: Freq[float], + target: SamplingConfig | Freq[int] | Freq[float] | timedelta, + resampler: Resampler, + ) -> "RawPCM": + instance = super(RawPCM, RawPCM).__new__(RawPCM) + instance.__private_init__(path, (source, SamplingConfig(target), resampler)) + return instance def _modulation_ptr(self: "RawPCM") -> ModulationPtr: - return _validate_ptr( - ModulationAudioFile().modulation_audio_file_raw_pcm( - str(self._path).encode("utf-8"), - self._config._inner, - self._loop_behavior, - ), - ) + path = str(self._path).encode("utf-8") + match self._config: + case (Freq(), SamplingConfig(), Resampler()): + (source, target, resampler) = self._config # type: ignore[misc] + return _validate_ptr( + ModulationAudioFile().modulation_audio_file_raw_pcm_with_resample( + path, + self._loop_behavior, + source.hz, + target._inner, + resampler._dyn_resampler(), + ), + ) + case _: + return _validate_ptr( + ModulationAudioFile().modulation_audio_file_raw_pcm( + str(self._path).encode("utf-8"), + self._config._inner, # type: ignore[union-attr] + self._loop_behavior, + ), + ) diff --git a/pyautd3/modulation/audio_file/wav.py b/pyautd3/modulation/audio_file/wav.py index 5311326..656b72d 100644 --- a/pyautd3/modulation/audio_file/wav.py +++ b/pyautd3/modulation/audio_file/wav.py @@ -1,8 +1,13 @@ +from datetime import timedelta from pathlib import Path +from pyautd3 import SamplingConfig from pyautd3.driver.datagram.modulation.base import ModulationBase from pyautd3.driver.datagram.modulation.cache import IntoModulationCache +from pyautd3.driver.datagram.modulation.fir import IntoModulationFir from pyautd3.driver.datagram.modulation.radiation_pressure import IntoModulationRadiationPressure +from pyautd3.driver.defined.freq import Freq +from pyautd3.modulation.resample import Resampler from pyautd3.native_methods.autd3capi_driver import ModulationPtr from pyautd3.native_methods.autd3capi_modulation_audio_file import ( NativeMethods as ModulationAudioFile, @@ -12,19 +17,45 @@ class Wav( IntoModulationCache["Wav"], + IntoModulationFir["Wav"], IntoModulationRadiationPressure["Wav"], ModulationBase["Wav"], ): _path: Path + _resampler: tuple[SamplingConfig, Resampler] | None def __init__(self: "Wav", path: Path) -> None: super().__init__() self._path = path + self._resampler = None + + @staticmethod + def new_with_resampler( + path: Path, + target: SamplingConfig | Freq[int] | Freq[float] | timedelta, + resampler: Resampler, + ) -> "Wav": + instance = Wav(path) + instance._resampler = (SamplingConfig(target), resampler) + return instance def _modulation_ptr(self: "Wav") -> ModulationPtr: - return _validate_ptr( - ModulationAudioFile().modulation_audio_file_wav( - str(self._path).encode("utf-8"), - self._loop_behavior, - ), - ) + path = str(self._path).encode("utf-8") + match self._resampler: + case (SamplingConfig(), Resampler()): + (target, resampler) = self._resampler # type: ignore[misc] + return _validate_ptr( + ModulationAudioFile().modulation_audio_file_wav_with_resample( + path, + self._loop_behavior, + target._inner, + resampler._dyn_resampler(), + ), + ) + case _: + return _validate_ptr( + ModulationAudioFile().modulation_audio_file_wav( + path, + self._loop_behavior, + ), + ) diff --git a/pyautd3/modulation/custom.py b/pyautd3/modulation/custom.py index e819a1f..0bf1f30 100644 --- a/pyautd3/modulation/custom.py +++ b/pyautd3/modulation/custom.py @@ -1,4 +1,5 @@ import ctypes +from collections.abc import Iterable from datetime import timedelta import numpy as np @@ -6,21 +7,47 @@ from pyautd3.driver.datagram.modulation import Modulation from pyautd3.driver.defined.freq import Freq from pyautd3.driver.firmware.fpga.sampling_config import SamplingConfig +from pyautd3.modulation.resample import Resampler from pyautd3.native_methods.autd3capi import NativeMethods as Base from pyautd3.native_methods.autd3capi_driver import ModulationPtr class Custom(Modulation["Custom"]): _buf: np.ndarray + _resampler: tuple[Freq[float], SamplingConfig, Resampler] | None - def __init__(self: "Custom", buf: np.ndarray, config: SamplingConfig | Freq[int] | timedelta) -> None: + def __init__(self: "Custom", buf: Iterable[int], config: SamplingConfig | Freq[int] | Freq[float] | timedelta) -> None: super().__init__(config) - self._buf = buf + self._buf = np.fromiter(buf, dtype=np.uint8) + self._resampler = None + + @staticmethod + def new_with_resample( + buf: Iterable[int], + source: Freq[float], + target: SamplingConfig | Freq[int] | Freq[float] | timedelta, + resampler: Resampler, + ) -> "Custom": + instance = Custom(buf, SamplingConfig(target)) + instance._resampler = (source, SamplingConfig(target), resampler) + return instance def _modulation_ptr(self: "Custom") -> ModulationPtr: - return Base().modulation_custom( - self._config._inner, - self._loop_behavior, - self._buf.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), # type: ignore[arg-type] - len(self._buf), - ) + match self._resampler: + case (Freq(), SamplingConfig(), Resampler()): + (source, target, resampler) = self._resampler # type: ignore[misc] + return Base().modulation_custom_with_resample( + self._loop_behavior, + self._buf.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), # type: ignore[arg-type] + len(self._buf), + source.hz, + target._inner, + resampler._dyn_resampler(), + ) + case _: + return Base().modulation_custom( + self._config._inner, + self._loop_behavior, + self._buf.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), # type: ignore[arg-type] + len(self._buf), + ) diff --git a/pyautd3/modulation/fourier.py b/pyautd3/modulation/fourier.py index ccbd7c8..ec980e1 100644 --- a/pyautd3/modulation/fourier.py +++ b/pyautd3/modulation/fourier.py @@ -4,6 +4,7 @@ from pyautd3.driver.datagram.modulation.base import ModulationBase from pyautd3.driver.datagram.modulation.cache import IntoModulationCache +from pyautd3.driver.datagram.modulation.fir import IntoModulationFir from pyautd3.driver.datagram.modulation.radiation_pressure import IntoModulationRadiationPressure from pyautd3.native_methods.autd3capi_driver import ModulationPtr @@ -12,6 +13,7 @@ class Fourier( IntoModulationCache["Fourier"], + IntoModulationFir["Fourier"], IntoModulationRadiationPressure["Fourier"], ModulationBase["Fourier"], ): diff --git a/pyautd3/modulation/mixer.py b/pyautd3/modulation/mixer.py index 2655c96..b30f02b 100644 --- a/pyautd3/modulation/mixer.py +++ b/pyautd3/modulation/mixer.py @@ -4,6 +4,7 @@ from pyautd3.driver.datagram.modulation.base import ModulationBase from pyautd3.driver.datagram.modulation.cache import IntoModulationCache +from pyautd3.driver.datagram.modulation.fir import IntoModulationFir from pyautd3.driver.datagram.modulation.radiation_pressure import IntoModulationRadiationPressure from pyautd3.native_methods.autd3capi_driver import ModulationPtr @@ -12,6 +13,7 @@ class Mixer( IntoModulationCache["Mixer"], + IntoModulationFir["Mixer"], IntoModulationRadiationPressure["Mixer"], ModulationBase["Mixer"], ): diff --git a/pyautd3/modulation/resample.py b/pyautd3/modulation/resample.py new file mode 100644 index 0000000..5ae7eed --- /dev/null +++ b/pyautd3/modulation/resample.py @@ -0,0 +1,47 @@ +import ctypes +from abc import ABCMeta, abstractmethod +from typing import Generic, TypeVar + +from pyautd3.native_methods.autd3capi_driver import DynSincInterpolator, DynWindow + + +class Resampler(metaclass=ABCMeta): + @abstractmethod + def _dyn_resampler(self) -> DynSincInterpolator: + pass + + +class BlackMan: + _window_size: int + + def __init__(self, window_size: int) -> None: + self._window_size = window_size + + def _window(self) -> ctypes.c_uint32: + return ctypes.c_uint32(int(DynWindow.Blackman)) + + +class Rectangular: + _window_size: int + + def __init__(self, window_size: int) -> None: + self._window_size = window_size + + def _window(self) -> ctypes.c_uint32: + return ctypes.c_uint32(int(DynWindow.Rectangular)) + + +T = TypeVar("T", BlackMan, Rectangular) + + +class SincInterpolation( + Resampler, + Generic[T], +): + _window: T + + def __init__(self: "SincInterpolation", window: T | None = None) -> None: + self._window = window if window is not None else BlackMan(32) + + def _dyn_resampler(self) -> DynSincInterpolator: + return DynSincInterpolator(self._window._window(), ctypes.c_uint32(self._window._window_size)) diff --git a/pyautd3/modulation/static.py b/pyautd3/modulation/static.py index cd16589..0b06dc4 100644 --- a/pyautd3/modulation/static.py +++ b/pyautd3/modulation/static.py @@ -1,5 +1,6 @@ from pyautd3.driver.datagram.modulation.base import ModulationBase from pyautd3.driver.datagram.modulation.cache import IntoModulationCache +from pyautd3.driver.datagram.modulation.fir import IntoModulationFir from pyautd3.driver.datagram.modulation.radiation_pressure import IntoModulationRadiationPressure from pyautd3.driver.utils import _validate_u8 from pyautd3.native_methods.autd3capi import NativeMethods as Base @@ -8,6 +9,7 @@ class Static( IntoModulationCache["Static"], + IntoModulationFir["Static"], IntoModulationRadiationPressure["Static"], ModulationBase["Static"], ): diff --git a/pyautd3/native_methods/autd3capi_driver.py b/pyautd3/native_methods/autd3capi_driver.py index 69ed45d..10c4747 100644 --- a/pyautd3/native_methods/autd3capi_driver.py +++ b/pyautd3/native_methods/autd3capi_driver.py @@ -227,7 +227,7 @@ def __eq__(self, other: object) -> bool: class DynSincInterpolator(ctypes.Structure): - _fields_ = [("window", ctypes.c_int32), ("window_size", ctypes.c_uint32)] + _fields_ = [("window", ctypes.c_uint32), ("window_size", ctypes.c_uint32)] def __eq__(self, other: object) -> bool: diff --git a/tests/driver/datagram/test_debug.py b/tests/driver/datagram/test_debug.py index f091333..cd580b1 100644 --- a/tests/driver/datagram/test_debug.py +++ b/tests/driver/datagram/test_debug.py @@ -4,6 +4,7 @@ from pyautd3 import ( Controller, + DcSysTime, DebugSettings, DebugType, ) @@ -69,3 +70,21 @@ def f2(dev: Device, gpio: GPIOOut) -> DebugTypeWrap: for dev in autd.geometry: assert np.array_equal([0x51, 0x52, 0xE0, 0xF0], autd.link.debug_types(dev.idx)) assert np.array_equal([0x0002, 0x0000, 0x0003, 0x0001], autd.link.debug_values(dev.idx)) + + sys_time = DcSysTime.now() + + def f3(_dev: Device, gpio: GPIOOut) -> DebugTypeWrap: + match gpio: + case GPIOOut.O0: + return DebugType.SysTimeEq(sys_time) + case GPIOOut.O1: + return DebugType.NONE + case GPIOOut.O2: + return DebugType.NONE + case GPIOOut.O3: + return DebugType.NONE + + autd.send(DebugSettings(f3)) + for dev in autd.geometry: + assert np.array_equal([0x60, 0x00, 0x00, 0x00], autd.link.debug_types(dev.idx)) + assert np.array_equal([(sys_time.sys_time // 50000) << 9, 0x00, 0x00, 0x00], autd.link.debug_values(dev.idx)) diff --git a/tests/driver/datagram/test_modulation.py b/tests/driver/datagram/test_modulation.py index 9adf708..047fa2f 100644 --- a/tests/driver/datagram/test_modulation.py +++ b/tests/driver/datagram/test_modulation.py @@ -7,6 +7,7 @@ from pyautd3.driver.defined.freq import Hz from pyautd3.driver.firmware.fpga.transition_mode import TransitionMode from pyautd3.modulation import Modulation, Sine +from pyautd3.modulation.fourier import Fourier from tests.test_autd import create_controller if TYPE_CHECKING: @@ -161,6 +162,203 @@ def test_radiation_pressure(): assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 +def test_radiation_fir(): + autd: Controller[Audit] + with create_controller() as autd: + m = Fourier([Sine(50 * Hz), Sine(1000 * Hz)]).with_fir( + [ + 0.0, + 2.336_732_5e-6, + 8.982_681e-6, + 1.888_706_2e-5, + 3.030_097e-5, + 4.075_849e-5, + 4.708_182e-5, + 4.542_212e-5, + 3.134_882_4e-5, + 0.0, + -5.369_572_3e-5, + -0.000_134_718_74, + -0.000_247_578_05, + -0.000_395_855_98, + -0.000_581_690_7, + -0.000_805_217_2, + -0.001_063_996, + -0.001_352_463_7, + -0.001_661_447_3, + -0.001_977_784_6, + -0.002_284_095_4, + -0.002_558_745, + -0.002_776_031, + -0.002_906_624_2, + -0.002_918_272_5, + -0.002_776_767_4, + -0.002_447_156_7, + -0.001_895_169_7, + -0.001_088_802_4, + 0.0, + 0.001_393_638_8, + 0.003_107_224_6, + 0.005_147_092_5, + 0.007_509_561, + 0.010_180_013, + 0.013_132_379, + 0.016_329_063, + 0.019_721_36, + 0.023_250_382, + 0.026_848_452, + 0.030_440_966, + 0.033_948_626, + 0.037_290_003, + 0.040_384_263, + 0.043_154_005, + 0.045_528_06, + 0.047_444_11, + 0.048_851_013, + 0.049_710_777, + 0.05, + 0.049_710_777, + 0.048_851_013, + 0.047_444_11, + 0.045_528_06, + 0.043_154_005, + 0.040_384_263, + 0.037_290_003, + 0.033_948_626, + 0.030_440_966, + 0.026_848_452, + 0.023_250_382, + 0.019_721_36, + 0.016_329_063, + 0.013_132_379, + 0.010_180_013, + 0.007_509_561, + 0.005_147_092_5, + 0.003_107_224_6, + 0.001_393_638_8, + 0.0, + -0.001_088_802_4, + -0.001_895_169_7, + -0.002_447_156_7, + -0.002_776_767_4, + -0.002_918_272_5, + -0.002_906_624_2, + -0.002_776_031, + -0.002_558_745, + -0.002_284_095_4, + -0.001_977_784_6, + -0.001_661_447_3, + -0.001_352_463_7, + -0.001_063_996, + -0.000_805_217_2, + -0.000_581_690_7, + -0.000_395_855_98, + -0.000_247_578_05, + -0.000_134_718_74, + -5.369_572_3e-5, + 0.0, + 3.134_882_4e-5, + 4.542_212e-5, + 4.708_182e-5, + 4.075_849e-5, + 3.030_097e-5, + 1.888_706_2e-5, + 8.982_681e-6, + 2.336_732_5e-6, + 0.0, + ], + ) + + autd.send(m) + + for dev in autd.geometry: + mod = autd.link.modulation_buffer(dev.idx, Segment.S0) + mod_expect = [ + 126, + 131, + 135, + 140, + 144, + 148, + 152, + 156, + 160, + 164, + 167, + 170, + 173, + 175, + 178, + 180, + 181, + 182, + 183, + 184, + 184, + 184, + 183, + 182, + 181, + 180, + 178, + 175, + 173, + 170, + 167, + 164, + 160, + 156, + 152, + 148, + 144, + 140, + 135, + 131, + 126, + 122, + 117, + 113, + 108, + 104, + 100, + 96, + 92, + 88, + 85, + 82, + 79, + 77, + 74, + 73, + 71, + 70, + 69, + 68, + 68, + 68, + 69, + 70, + 71, + 73, + 74, + 77, + 79, + 82, + 85, + 88, + 92, + 96, + 100, + 104, + 108, + 113, + 117, + 122, + ] + assert np.array_equal(mod, mod_expect) + assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 + + def test_mod_segment(): autd: Controller[Audit] with create_controller() as autd: diff --git a/tests/driver/datagram/test_phase_corr.py b/tests/driver/datagram/test_phase_corr.py new file mode 100644 index 0000000..b1fd16f --- /dev/null +++ b/tests/driver/datagram/test_phase_corr.py @@ -0,0 +1,21 @@ +from typing import TYPE_CHECKING + +import numpy as np + +from pyautd3 import Controller, Phase, PhaseCorrection, Segment +from tests.test_autd import create_controller + +if TYPE_CHECKING: + from pyautd3.link.audit import Audit + + +def test_phase_corr(): + autd: Controller[Audit] + with create_controller() as autd: + autd.send(PhaseCorrection(lambda dev: lambda tr: Phase(dev.idx + tr.idx))) + + for dev in autd.geometry: + intensities, phases = autd.link.drives_at(dev.idx, Segment.S0, 0) + assert np.all(intensities == 0x00) + for i, phase in enumerate(phases): + assert phase == dev.idx + i diff --git a/tests/driver/datagram/test_silencer.py b/tests/driver/datagram/test_silencer.py index 9e3d4af..e4a380f 100644 --- a/tests/driver/datagram/test_silencer.py +++ b/tests/driver/datagram/test_silencer.py @@ -4,12 +4,7 @@ import numpy as np import pytest -from pyautd3 import ( - Controller, - GainSTM, - SamplingConfig, - Silencer, -) +from pyautd3 import Controller, FixedCompletionTime, FixedUpdateRate, GainSTM, SamplingConfig, Silencer from pyautd3.autd_error import AUTDError from pyautd3.driver.datagram.stm.foci import FociSTM from pyautd3.driver.defined.freq import Hz @@ -32,7 +27,7 @@ def test_silencer_from_completion_time(): assert autd.link.silencer_fixed_completion_steps_mode(dev.idx) assert autd.link.silencer_strict_mode(dev.idx) - autd.send(Silencer.from_completion_time(timedelta(microseconds=25 * 2), timedelta(microseconds=25 * 3))) + autd.send(Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 2), phase=timedelta(microseconds=25 * 3)))) for dev in autd.geometry: assert autd.link.silencer_completion_steps_intensity(dev.idx) == 2 @@ -42,7 +37,7 @@ def test_silencer_from_completion_time(): assert autd.link.silencer_target(dev.idx) == SilencerTarget.Intensity autd.send( - Silencer.from_completion_time(timedelta(microseconds=25 * 2), timedelta(microseconds=25 * 3)) + Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 2), phase=timedelta(microseconds=25 * 3))) .with_strict_mode(mode=False) .with_target(SilencerTarget.PulseWidth), ) @@ -74,8 +69,8 @@ def test_silencer_from_update_rate(): assert autd.link.silencer_fixed_completion_steps_mode(dev.idx) assert autd.link.silencer_target(dev.idx) == SilencerTarget.Intensity - assert Silencer.from_update_rate(2, 3).is_valid(Sine(150 * Hz).with_sampling_config(SamplingConfig(1))) - autd.send(Silencer.from_update_rate(2, 3).with_target(SilencerTarget.PulseWidth)) + assert Silencer(FixedUpdateRate(intensity=2, phase=3)).is_valid(Sine(150 * Hz).with_sampling_config(SamplingConfig(1))) + autd.send(Silencer(FixedUpdateRate(intensity=2, phase=3)).with_target(SilencerTarget.PulseWidth)) for dev in autd.geometry: assert autd.link.silencer_update_rate_intensity(dev.idx) == 2 @@ -95,11 +90,11 @@ def test_silencer_large_steps(): assert Silencer.disable().is_valid(Sine(150 * Hz).with_sampling_config(SamplingConfig(1))) autd.send(Sine(150 * Hz).with_sampling_config(SamplingConfig(1))) - assert not Silencer.from_completion_time(timedelta(microseconds=25 * 10), timedelta(microseconds=25 * 40)).is_valid( + assert not Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 10), phase=timedelta(microseconds=25 * 40))).is_valid( Sine(150 * Hz).with_sampling_config(SamplingConfig(1)), ) with pytest.raises(AUTDError) as e: - autd.send(Silencer.from_completion_time(timedelta(microseconds=25 * 10), timedelta(microseconds=25 * 40))) + autd.send(Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 10), phase=timedelta(microseconds=25 * 40)))) assert ( str(e.value) == "Silencer cannot complete phase/intensity completion in the specified sampling period. Please lower the sampling frequency or make the completion time of Silencer longer than the sampling period." # noqa: E501 @@ -122,13 +117,17 @@ def test_silencer_small_freq_div_mod(): == "Silencer cannot complete phase/intensity completion in the specified sampling period. Please lower the sampling frequency or make the completion time of Silencer longer than the sampling period." # noqa: E501 ) - autd.send(Silencer.from_completion_time(timedelta(microseconds=25 * 10), timedelta(microseconds=25 * 40)).with_strict_mode(mode=False)) + autd.send( + Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 10), phase=timedelta(microseconds=25 * 40))).with_strict_mode( + mode=False, + ), + ) for dev in autd.geometry: assert autd.link.silencer_completion_steps_intensity(dev.idx) == 10 assert autd.link.silencer_completion_steps_phase(dev.idx) == 40 assert autd.link.silencer_fixed_completion_steps_mode(dev.idx) assert ( - Silencer.from_completion_time(timedelta(microseconds=25 * 10), timedelta(microseconds=25 * 40)) + Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 10), phase=timedelta(microseconds=25 * 40))) .with_strict_mode(mode=False) .is_valid(Sine(150 * Hz).with_sampling_config(SamplingConfig(1))) ) @@ -151,13 +150,17 @@ def test_silencer_small_freq_div_gain_stm(): == "Silencer cannot complete phase/intensity completion in the specified sampling period. Please lower the sampling frequency or make the completion time of Silencer longer than the sampling period." # noqa: E501 ) - autd.send(Silencer.from_completion_time(timedelta(microseconds=25 * 10), timedelta(microseconds=25 * 40)).with_strict_mode(mode=False)) + autd.send( + Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 10), phase=timedelta(microseconds=25 * 40))).with_strict_mode( + mode=False, + ), + ) for dev in autd.geometry: assert autd.link.silencer_completion_steps_intensity(dev.idx) == 10 assert autd.link.silencer_completion_steps_phase(dev.idx) == 40 assert autd.link.silencer_fixed_completion_steps_mode(dev.idx) assert ( - Silencer.from_completion_time(timedelta(microseconds=25 * 10), timedelta(microseconds=25 * 40)) + Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 10), phase=timedelta(microseconds=25 * 40))) .with_strict_mode(mode=False) .is_valid(GainSTM(SamplingConfig(1), [Null(), Null()])) ) @@ -180,13 +183,17 @@ def test_silencer_small_freq_div_foci_stm(): == "Silencer cannot complete phase/intensity completion in the specified sampling period. Please lower the sampling frequency or make the completion time of Silencer longer than the sampling period." # noqa: E501 ) - autd.send(Silencer.from_completion_time(timedelta(microseconds=25 * 10), timedelta(microseconds=25 * 40)).with_strict_mode(mode=False)) + autd.send( + Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 10), phase=timedelta(microseconds=25 * 40))).with_strict_mode( + mode=False, + ), + ) for dev in autd.geometry: assert autd.link.silencer_completion_steps_intensity(dev.idx) == 10 assert autd.link.silencer_completion_steps_phase(dev.idx) == 40 assert autd.link.silencer_fixed_completion_steps_mode(dev.idx) assert ( - Silencer.from_completion_time(timedelta(microseconds=25 * 10), timedelta(microseconds=25 * 40)) + Silencer(FixedCompletionTime(intensity=timedelta(microseconds=25 * 10), phase=timedelta(microseconds=25 * 40))) .with_strict_mode(mode=False) .is_valid(FociSTM(SamplingConfig(1), [np.zeros(3), np.zeros(3)])) ) diff --git a/tests/driver/firmware/fpga/test_fpga_state.py b/tests/driver/firmware/fpga/test_fpga_state.py index 6cfa043..d97c293 100644 --- a/tests/driver/firmware/fpga/test_fpga_state.py +++ b/tests/driver/firmware/fpga/test_fpga_state.py @@ -81,7 +81,7 @@ def test_fpga_state(): autd.link.repair() -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_fpga_state_async(): autd: Controller[Audit] with await create_controller_async() as autd: diff --git a/tests/driver/test_utils.py b/tests/driver/test_utils.py index b8573f4..64e688b 100644 --- a/tests/driver/test_utils.py +++ b/tests/driver/test_utils.py @@ -1,6 +1,6 @@ import pytest -from pyautd3.driver.utils import _validate_nonzero_u16 +from pyautd3.driver.utils import _validate_nonzero_u16, _validate_nonzero_u32 def test_validate_nonzero_u16(): @@ -15,3 +15,17 @@ def test_validate_nonzero_u16(): with pytest.raises(ValueError): # noqa: PT011 _ = _validate_nonzero_u16(0xFFFF + 1) + + +def test_validate_nonzero_u32(): + assert _validate_nonzero_u32(1) == 1 + assert _validate_nonzero_u32(0xFFFFFFFF) == 0xFFFFFFFF + + with pytest.raises(TypeError): + _ = _validate_nonzero_u32(0.1) # type: ignore[arg-type] + + with pytest.raises(ValueError): # noqa: PT011 + _ = _validate_nonzero_u32(0) + + with pytest.raises(ValueError): # noqa: PT011 + _ = _validate_nonzero_u32(0xFFFFFFFF + 1) diff --git a/tests/link/test_soem.py b/tests/link/test_soem.py index 1ec2afe..b34a023 100644 --- a/tests/link/test_soem.py +++ b/tests/link/test_soem.py @@ -7,7 +7,7 @@ from pyautd3.link.soem import SOEM, RemoteSOEM, Status, SyncMode, TimerStrategy -@pytest.mark.soem() +@pytest.mark.soem def test_soem_adapers(): adapters = SOEM.enumerate_adapters() for adapter in adapters: @@ -18,7 +18,7 @@ def err_handler(slave: int, status: Status, msg: str) -> None: print(f"slave: {slave}, status: {status}, msg: {msg}") -@pytest.mark.soem() +@pytest.mark.soem def test_soem(): with ( pytest.raises(AUTDError) as _, @@ -42,6 +42,6 @@ def test_soem(): pass -@pytest.mark.soem() +@pytest.mark.soem def test_remote_soem(): _ = RemoteSOEM.builder("127.0.0.1:8080").with_timeout(timedelta(milliseconds=200)) diff --git a/tests/modulation/audio_file/.gitignore b/tests/modulation/audio_file/.gitignore new file mode 100644 index 0000000..20250e0 --- /dev/null +++ b/tests/modulation/audio_file/.gitignore @@ -0,0 +1 @@ +custom* \ No newline at end of file diff --git a/tests/modulation/audio_file/test_csv.py b/tests/modulation/audio_file/test_csv.py index c98bfc0..d69e676 100644 --- a/tests/modulation/audio_file/test_csv.py +++ b/tests/modulation/audio_file/test_csv.py @@ -1,11 +1,13 @@ +import csv from pathlib import Path from typing import TYPE_CHECKING import numpy as np -from pyautd3 import Controller, Segment +from pyautd3 import Controller, Segment, kHz from pyautd3.driver.defined.freq import Hz from pyautd3.modulation.audio_file import Csv +from pyautd3.modulation.resample import SincInterpolation from tests.test_autd import create_controller if TYPE_CHECKING: @@ -107,3 +109,21 @@ def test_csv(): autd.send(Csv(Path(__file__).parent / "sin150.csv", 2000 * Hz)) for dev in autd.geometry: assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 20 + + +def test_csv_with_resample(): + autd: Controller[Audit] + with create_controller() as autd: + expect = [127, 217, 255, 217, 127, 37, 0, 37] + buf = [127, 255, 127, 0] + + with Path.open(Path(__file__).parent / "custom.csv", "w") as f: + writer = csv.writer(f) + writer.writerow(buf) + + autd.send(Csv.new_with_resampler(Path(__file__).parent / "custom.csv", 2.0 * kHz, 4 * kHz, SincInterpolation())) + + for dev in autd.geometry: + mod = autd.link.modulation_buffer(dev.idx, Segment.S0) + assert np.array_equal(expect, mod) + assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 diff --git a/tests/modulation/audio_file/test_rawpcm.py b/tests/modulation/audio_file/test_rawpcm.py index 699ee60..b0fd6cb 100644 --- a/tests/modulation/audio_file/test_rawpcm.py +++ b/tests/modulation/audio_file/test_rawpcm.py @@ -3,9 +3,10 @@ import numpy as np -from pyautd3 import Controller, Segment +from pyautd3 import Controller, Segment, kHz from pyautd3.driver.defined.freq import Hz from pyautd3.modulation.audio_file import RawPCM +from pyautd3.modulation.resample import SincInterpolation from tests.test_autd import create_controller if TYPE_CHECKING: @@ -107,3 +108,21 @@ def test_rawpcm(): autd.send(RawPCM(Path(__file__).parent / "sin150.dat", 2000 * Hz)) for dev in autd.geometry: assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 20 + + +def test_rawpcm_with_resample(): + autd: Controller[Audit] + with create_controller() as autd: + expect = [127, 217, 255, 217, 127, 37, 0, 37] + buf = [127, 255, 127, 0] + + with Path.open(Path(__file__).parent / "custom.dat", "wb") as f: + for b in buf: + f.write(b.to_bytes(1, byteorder="little")) + + autd.send(RawPCM.new_with_resampler(Path(__file__).parent / "custom.dat", 2.0 * kHz, 4 * kHz, SincInterpolation())) + + for dev in autd.geometry: + mod = autd.link.modulation_buffer(dev.idx, Segment.S0) + assert np.array_equal(expect, mod) + assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 diff --git a/tests/modulation/audio_file/test_wav.py b/tests/modulation/audio_file/test_wav.py index 80a089e..ccb7a04 100644 --- a/tests/modulation/audio_file/test_wav.py +++ b/tests/modulation/audio_file/test_wav.py @@ -1,10 +1,12 @@ +import wave from pathlib import Path from typing import TYPE_CHECKING import numpy as np -from pyautd3 import Controller, Segment +from pyautd3 import Controller, Segment, kHz from pyautd3.modulation.audio_file import Wav +from pyautd3.modulation.resample import SincInterpolation from tests.test_autd import create_controller if TYPE_CHECKING: @@ -102,3 +104,24 @@ def test_wav(): ] assert np.array_equal(mod, mod_expect) assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 + + +def test_wav_with_resample(): + autd: Controller[Audit] + with create_controller() as autd: + expect = [127, 217, 255, 217, 127, 37, 0, 37] + buf = [127, 255, 127, 0] + + with wave.open(str(Path(__file__).parent / "custom.wav"), "wb") as f: + f.setnchannels(1) + f.setsampwidth(1) + f.setframerate(2000) + f.setnframes(len(buf)) + f.writeframes(np.array(buf, dtype=np.uint8).tobytes()) + + autd.send(Wav.new_with_resampler(Path(__file__).parent / "custom.wav", 4 * kHz, SincInterpolation())) + + for dev in autd.geometry: + mod = autd.link.modulation_buffer(dev.idx, Segment.S0) + assert np.array_equal(expect, mod) + assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 diff --git a/tests/modulation/test_custom.py b/tests/modulation/test_custom.py index 4ab4788..2a4237e 100644 --- a/tests/modulation/test_custom.py +++ b/tests/modulation/test_custom.py @@ -2,8 +2,9 @@ import numpy as np -from pyautd3 import Controller, SamplingConfig, Segment +from pyautd3 import Controller, SamplingConfig, Segment, kHz from pyautd3.modulation import Custom +from pyautd3.modulation.resample import Rectangular, SincInterpolation from tests.test_autd import create_controller if TYPE_CHECKING: @@ -25,3 +26,26 @@ def test_modulation_custom(): assert mod[0] == 0xFF assert np.all(mod[1:] == 0) assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 + + +def test_modulation_custom_with_resample(): + autd: Controller[Audit] + with create_controller() as autd: + m = Custom.new_with_resample([127, 255, 127, 0], 2.0 * kHz, 4.0 * kHz, SincInterpolation()) + + autd.send(m) + + for dev in autd.geometry: + mod = autd.link.modulation_buffer(dev.idx, Segment.S0) + assert np.array_equal([127, 217, 255, 217, 127, 37, 0, 37], mod) + assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 + + with create_controller() as autd: + m = Custom.new_with_resample([127, 255, 127, 0], 2.0 * kHz, 4.0 * kHz, SincInterpolation(Rectangular(32))) + + autd.send(m) + + for dev in autd.geometry: + mod = autd.link.modulation_buffer(dev.idx, Segment.S0) + assert np.array_equal([127, 217, 255, 223, 127, 42, 0, 37], mod) + assert autd.link.modulation_frequency_division(dev.idx, Segment.S0) == 10 diff --git a/tests/test_autd.py b/tests/test_autd.py index 3a57c29..04e006f 100644 --- a/tests/test_autd.py +++ b/tests/test_autd.py @@ -19,6 +19,7 @@ async def create_controller_async() -> Controller[Audit]: return ( await Controller.builder([AUTD3([0.0, 0.0, 0.0]), AUTD3([0.0, 0.0, 0.0])]) .with_send_interval(timedelta(milliseconds=1)) + .with_receive_interval(timedelta(milliseconds=1)) .with_timer_resolution(1) .open_async( Audit.builder(), @@ -30,6 +31,7 @@ def create_controller() -> Controller[Audit]: return ( Controller.builder([AUTD3([0.0, 0.0, 0.0]), AUTD3([0.0, 0.0, 0.0])]) .with_send_interval(timedelta(milliseconds=1)) + .with_receive_interval(timedelta(milliseconds=1)) .with_timer_resolution(1) .open( Audit.builder(), @@ -51,7 +53,7 @@ def test_firmware_info(): autd.link.up() -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_firmware_info_async(): autd: Controller[Audit] with await create_controller_async() as autd: @@ -74,6 +76,7 @@ def test_close(): assert autd.link.is_open() autd.close() + autd.close() with create_controller() as autd: autd.link.break_down() @@ -83,13 +86,14 @@ def test_close(): autd.link.repair() -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_close_async(): autd: Controller[Audit] with await create_controller_async() as autd: assert autd.link.is_open() await autd.close_async() + await autd.close_async() with create_controller() as autd: autd.link.break_down() @@ -123,7 +127,7 @@ def test_send_single(): autd.link.repair() -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_send_async_single(): autd: Controller[Audit] with await create_controller_async() as autd: @@ -148,7 +152,7 @@ async def test_send_async_single(): autd.link.repair() -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_send_async_tuple(): autd: Controller[Audit] with await create_controller_async() as autd: @@ -213,7 +217,7 @@ def test_send_tuple(): autd.link.repair() -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_group_async(): autd: Controller[Audit] with await create_controller_async() as autd: diff --git a/tools/wrapper-generator/src/python.rs b/tools/wrapper-generator/src/python.rs index 8612fb7..2bbe566 100644 --- a/tools/wrapper-generator/src/python.rs +++ b/tools/wrapper-generator/src/python.rs @@ -57,7 +57,7 @@ impl PythonGenerator { Type::Custom(ref s) => match s.as_str() { "* mut c_char" => "ctypes.c_char_p".to_string(), "[u8 ; 2]" => "ctypes.c_uint8 * 2".to_string(), - "DynWindow" => "ctypes.c_int32".to_string(), + "DynWindow" => "ctypes.c_uint32".to_string(), s if s.ends_with("Tag") => "ctypes.c_uint8".to_string(), s => s.to_owned(), },