diff --git a/CHANGELOG b/CHANGELOG index 7927ade7..03b4b2ec 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,9 @@ +Version 0.1.1: + - Fixed type promotions in scalar arithmetic - left hand types are now promoted when appropriate. + - Added "tensor_functions" module for implementing additional functions on tensors. Currently only Log is implemented. + + + Version 0.1.0: - Added framework for integrating device support and redesigned scalars module to accommodate the changes. - Made changes to type deduction in constructors to avoid exceptions when providing lists of python ints/floats. diff --git a/algebra/include/roughpy/algebra/algebra_fwd.h b/algebra/include/roughpy/algebra/algebra_fwd.h index aaa4bf4c..d2497a3a 100644 --- a/algebra/include/roughpy/algebra/algebra_fwd.h +++ b/algebra/include/roughpy/algebra/algebra_fwd.h @@ -30,6 +30,7 @@ #include #include +#include #include #include "roughpy_algebra_export.h" diff --git a/algebra/include/roughpy/algebra/basis.h b/algebra/include/roughpy/algebra/basis.h index 0f7872ae..887a71c2 100644 --- a/algebra/include/roughpy/algebra/basis.h +++ b/algebra/include/roughpy/algebra/basis.h @@ -33,6 +33,7 @@ #include #include +#include #include @@ -54,15 +55,17 @@ class BasisInterface template using impl_t = dtl::BasisImplementation; + using base_interface_t = BasisInterface; + using mixin_t = void; virtual ~BasisInterface() = default; - RPY_NO_DISCARD - virtual string key_to_string(const key_type& key) const = 0; + RPY_NO_DISCARD virtual string key_to_string(const key_type& key) const = 0; + + RPY_NO_DISCARD virtual dimn_t dimension() const noexcept = 0; - RPY_NO_DISCARD - virtual dimn_t dimension() const noexcept = 0; + RPY_NO_DISCARD virtual bool are_same(const BasisInterface &other) const noexcept = 0; }; namespace dtl { @@ -85,14 +88,14 @@ class OrderedBasisInterface : public void_or_base template using impl_t = dtl::OrderedBasisImplementationMixin< - T, Derived, typename Base::template impl_t>; + T, + Derived, + typename Base::template impl_t>; virtual ~OrderedBasisInterface() = default; - RPY_NO_DISCARD - virtual key_type index_to_key(dimn_t index) const = 0; - RPY_NO_DISCARD - virtual dimn_t key_to_index(const key_type& key) const = 0; + RPY_NO_DISCARD virtual key_type index_to_key(dimn_t index) const = 0; + RPY_NO_DISCARD virtual dimn_t key_to_index(const key_type& key) const = 0; }; namespace dtl { @@ -114,40 +117,39 @@ class WordLikeBasisInterface : public void_or_base template using impl_t = dtl::WordLikeBasisImplementationMixin< - T, Derived, typename Base::template impl_t>; - - RPY_NO_DISCARD - virtual deg_t width() const noexcept = 0; - RPY_NO_DISCARD - virtual deg_t depth() const noexcept = 0; - RPY_NO_DISCARD - virtual deg_t degree(const key_type& key) const noexcept = 0; - RPY_NO_DISCARD - virtual deg_t size(deg_t degree) const noexcept = 0; - RPY_NO_DISCARD - virtual let_t first_letter(const key_type& key) const noexcept = 0; - RPY_NO_DISCARD - virtual let_t to_letter(const key_type& key) const noexcept = 0; - - RPY_NO_DISCARD - virtual dimn_t start_of_degree(deg_t degree) const noexcept = 0; - RPY_NO_DISCARD - virtual pair, optional> + T, + Derived, + typename Base::template impl_t>; + + RPY_NO_DISCARD virtual deg_t width() const noexcept = 0; + RPY_NO_DISCARD virtual deg_t depth() const noexcept = 0; + RPY_NO_DISCARD virtual deg_t degree(const key_type& key) const noexcept = 0; + RPY_NO_DISCARD virtual deg_t size(deg_t degree) const noexcept = 0; + RPY_NO_DISCARD virtual let_t first_letter(const key_type& key + ) const noexcept + = 0; + RPY_NO_DISCARD virtual let_t to_letter(const key_type& key) const noexcept + = 0; + + RPY_NO_DISCARD virtual dimn_t start_of_degree(deg_t degree) const noexcept + = 0; + RPY_NO_DISCARD virtual pair, optional> parents(const key_type& key) const = 0; - RPY_NO_DISCARD - optional lparent(const key_type& key) const + RPY_NO_DISCARD optional lparent(const key_type& key) const { return parents(key).first; } - RPY_NO_DISCARD - optional rparent(const key_type& key) const + RPY_NO_DISCARD optional rparent(const key_type& key) const { return parents(key).second; } - RPY_NO_DISCARD - virtual key_type key_of_letter(let_t letter) const noexcept = 0; - RPY_NO_DISCARD - virtual bool letter(const key_type& key) const = 0; + + RPY_NO_DISCARD virtual optional + child(const key_type& lparent, const key_type& rparent) const = 0; + + RPY_NO_DISCARD virtual key_type key_of_letter(let_t letter) const noexcept + = 0; + RPY_NO_DISCARD virtual bool letter(const key_type& key) const = 0; }; template @@ -182,54 +184,67 @@ class Basis : public PrimaryInterface::mixin_t Basis(const Basis&) = default; Basis(Basis&&) noexcept = default; - RPY_NO_UBSAN - ~Basis() = default; + RPY_NO_UBSAN ~Basis() = default; Basis& operator=(const Basis&) = default; Basis& operator=(Basis&&) noexcept = default; - RPY_NO_DISCARD - const basis_interface& instance() const noexcept { return *p_impl; } + RPY_NO_DISCARD const basis_interface& instance() const noexcept + { + return *p_impl; + } - RPY_NO_DISCARD - constexpr operator bool() const noexcept { return static_cast(p_impl); } + RPY_NO_DISCARD constexpr operator bool() const noexcept + { + return static_cast(p_impl); + } - RPY_NO_DISCARD - string key_to_string(const key_type& key) const noexcept + RPY_NO_DISCARD string key_to_string(const key_type& key) const noexcept { return instance().key_to_string(key); } - RPY_NO_DISCARD - dimn_t dimension() const noexcept { return instance().dimension(); } - + RPY_NO_DISCARD dimn_t dimension() const noexcept + { + return instance().dimension(); + } friend bool operator==(const Basis& left, const Basis& right) noexcept { - return left.p_impl == right.p_impl; + if (left.p_impl == right.p_impl) { + return true; + } + + return left.p_impl->are_same(*right.p_impl); } friend bool operator!=(const Basis& left, const Basis& right) noexcept { - return left.p_impl != right.p_impl; + return !operator==(left, right); } + friend hash_t hash_value(const Basis& basis) noexcept + { + return hash_value(basis.p_impl); + } }; - - namespace dtl { template < - typename Derived, typename KeyType, - template class... Interfaces> + typename Derived, + typename KeyType, + template + class... Interfaces> struct make_basis_interface_impl; } template < - typename Derived, typename KeyType, - template class... Interfaces> -using make_basis_interface = typename dtl::make_basis_interface_impl< - Derived, KeyType, Interfaces...>::type; + typename Derived, + typename KeyType, + template + class... Interfaces> +using make_basis_interface = typename dtl:: + make_basis_interface_impl::type; namespace dtl { @@ -237,8 +252,7 @@ template class OrderedBasisMixin : public Base { - RPY_NO_DISCARD - const Derived& instance() const noexcept + RPY_NO_DISCARD const Derived& instance() const noexcept { return static_cast*>(this)->instance(); } @@ -246,14 +260,12 @@ class OrderedBasisMixin : public Base public: using key_type = typename Derived::key_type; - RPY_NO_DISCARD - key_type index_to_key(dimn_t index) const + RPY_NO_DISCARD key_type index_to_key(dimn_t index) const { return instance().index_to_key(index); } - RPY_NO_DISCARD - dimn_t key_to_index(const key_type& key) const + RPY_NO_DISCARD dimn_t key_to_index(const key_type& key) const { return instance().key_to_index(key); } @@ -263,8 +275,7 @@ template class WordLikeBasisMixin { - RPY_NO_DISCARD - const Derived& instance() const noexcept + RPY_NO_DISCARD const Derived& instance() const noexcept { return static_cast*>(this)->instance(); } @@ -272,61 +283,69 @@ class WordLikeBasisMixin public: using key_type = typename Derived::key_type; - RPY_NO_DISCARD - deg_t width() const noexcept { return instance().width(); } - RPY_NO_DISCARD - deg_t depth() const noexcept { return instance().depth(); } - RPY_NO_DISCARD - deg_t degree(const key_type& key) const noexcept + RPY_NO_DISCARD deg_t width() const noexcept { return instance().width(); } + RPY_NO_DISCARD deg_t depth() const noexcept { return instance().depth(); } + RPY_NO_DISCARD deg_t degree(const key_type& key) const noexcept { return instance().degree(key); } - RPY_NO_DISCARD - deg_t size(deg_t degree) const noexcept { return instance().size(degree); } - RPY_NO_DISCARD - let_t first_letter(const key_type& key) const noexcept + RPY_NO_DISCARD deg_t size(deg_t degree) const noexcept + { + return instance().size(degree); + } + RPY_NO_DISCARD let_t first_letter(const key_type& key) const noexcept { return instance().first_letter(key); } - RPY_NO_DISCARD let_t to_letter(const key_type& key) const noexcept { + RPY_NO_DISCARD let_t to_letter(const key_type& key) const noexcept + { return instance().to_letter(key); } - RPY_NO_DISCARD - dimn_t start_of_degree(deg_t degree) const noexcept + RPY_NO_DISCARD dimn_t start_of_degree(deg_t degree) const noexcept { return instance().start_of_degree(degree); } - RPY_NO_DISCARD - pair, optional> parents(const key_type& key - ) const + RPY_NO_DISCARD pair, optional> + parents(const key_type& key) const { return instance().parents(key); } - RPY_NO_DISCARD - optional lparent(const key_type& key) const + RPY_NO_DISCARD optional lparent(const key_type& key) const { return instance().lparent(key); } - RPY_NO_DISCARD - optional rparent(const key_type& key) const + RPY_NO_DISCARD optional rparent(const key_type& key) const { return instance().rparent(key); } - RPY_NO_DISCARD - key_type key_of_letter(let_t letter) const noexcept + RPY_NO_DISCARD optional + child(const key_type& lparent, const key_type& rparent) const + { + return instance().child(lparent, rparent); + } + + RPY_NO_DISCARD key_type key_of_letter(let_t letter) const noexcept { return instance().key_of_letter(letter); } - RPY_NO_DISCARD - bool letter(const key_type& key) const { return instance().letter(key); } + RPY_NO_DISCARD bool letter(const key_type& key) const + { + return instance().letter(key); + } }; template < - typename Derived, typename KeyType, - template class FirstInterface, - template class... Interfaces> + typename Derived, + typename KeyType, + template + class FirstInterface, + template + class... Interfaces> struct make_basis_interface_impl< - Derived, KeyType, FirstInterface, Interfaces...> { + Derived, + KeyType, + FirstInterface, + Interfaces...> { using next_t = make_basis_interface_impl; using type = FirstInterface; }; diff --git a/algebra/include/roughpy/algebra/basis_impl.h b/algebra/include/roughpy/algebra/basis_impl.h index 0799d866..e83f665a 100644 --- a/algebra/include/roughpy/algebra/basis_impl.h +++ b/algebra/include/roughpy/algebra/basis_impl.h @@ -64,6 +64,8 @@ class BasisImplementation : public PrimaryInterface string key_to_string(const typename PrimaryInterface::key_type& key ) const override; dimn_t dimension() const noexcept override; + + bool are_same(const typename PrimaryInterface::base_interface_t &other) const noexcept override; }; template @@ -101,6 +103,7 @@ class WordLikeBasisImplementationMixin : public Base dimn_t start_of_degree(deg_t degree) const noexcept override; pair, optional> parents(const key_type& key ) const override; + optional child(const key_type& lparent, const key_type& rparent) const override; key_type key_of_letter(let_t letter) const noexcept override; bool letter(const key_type& key) const override; }; @@ -118,6 +121,15 @@ dimn_t BasisImplementation::dimension() const noexcept return basis_traits::dimension(m_impl); } +template +bool BasisImplementation::are_same(const typename PrimaryInterface::base_interface_t &other) const noexcept { + const auto *pother = dynamic_cast(&other); + if (pother == nullptr) { return false; } + + return basis_traits::are_same(m_impl, pother->m_impl); +} + template typename Derived::key_type OrderedBasisImplementationMixin::index_to_key(dimn_t index @@ -182,6 +194,12 @@ WordLikeBasisImplementationMixin::parents(const key_type& key { return basis_traits::parents(m_impl, key); } + +template +optional WordLikeBasisImplementationMixin::child(const key_type& lparent, const key_type& rparent) const { + return basis_traits::child(m_impl, lparent, rparent); +} + template typename Derived::key_type WordLikeBasisImplementationMixin::key_of_letter(let_t letter diff --git a/algebra/include/roughpy/algebra/basis_info.h b/algebra/include/roughpy/algebra/basis_info.h index af6bf4da..f480d54c 100644 --- a/algebra/include/roughpy/algebra/basis_info.h +++ b/algebra/include/roughpy/algebra/basis_info.h @@ -157,6 +157,12 @@ struct BasisInfo { convert_from_impl(basis, parents_pair.second)}; } + + static optional child(storage_t basis, const our_key_type& lparent, const our_key_type& rparent) + { + return {}; + } + /// Get the first letter of the key as a word static let_t first_letter(storage_t basis, const our_key_type& key) { @@ -174,6 +180,9 @@ struct BasisInfo { { return basis->letter(convert_to_impl(basis, key)); } + + /// Determine whether two bases are the same + static bool are_same(storage_t basis1, storage_t basis2) noexcept { return basis1 == basis2; } }; }// namespace algebra diff --git a/algebra/src/libalgebra_lite_internal/lie_basis_info.h b/algebra/src/libalgebra_lite_internal/lie_basis_info.h index fa940326..f135218f 100644 --- a/algebra/src/libalgebra_lite_internal/lie_basis_info.h +++ b/algebra/src/libalgebra_lite_internal/lie_basis_info.h @@ -170,6 +170,25 @@ struct BasisInfo { convert_from_impl(basis, parents_pair.second)}; } + static optional child(storage_t basis, const our_key_type& lparent, const our_key_type& rparent) + { + using parent_type = typename lal::hall_basis::parent_type; + + parent_type parents { + convert_to_impl(basis, lparent), + convert_to_impl(basis, rparent) + }; + + auto result = basis->find(parents); + + optional out {}; + if (result.found) { + out = convert_from_impl(basis, result.it->second); + } + + return out; + } + /// Get the first letter of the key as a word static let_t first_letter(storage_t basis, const our_key_type& key) { @@ -191,6 +210,8 @@ struct BasisInfo { { return basis->letter(convert_to_impl(basis, key)); } + + static bool are_same(storage_t basis1, storage_t basis2) noexcept { return basis1 == basis2; } }; }// namespace algebra }// namespace rpy diff --git a/algebra/src/libalgebra_lite_internal/tensor_basis_info.h b/algebra/src/libalgebra_lite_internal/tensor_basis_info.h index 119db37a..e4e085e9 100644 --- a/algebra/src/libalgebra_lite_internal/tensor_basis_info.h +++ b/algebra/src/libalgebra_lite_internal/tensor_basis_info.h @@ -165,6 +165,27 @@ struct BasisInfo { convert_from_impl(basis, basis->rparent(tmpkey))}; } + static optional child(storage_t basis, const our_key_type& lparent, const our_key_type& rparent) + { + const auto left = convert_to_impl(basis, lparent); + const auto right = convert_to_impl(basis, rparent); + + const auto ldegree = left.degree(); + const auto rdegree = right.degree(); + + optional out {}; + + const auto degree = ldegree + rdegree; + if (degree <= basis->depth()) { + const auto shift = basis->powers()[rdegree]; + const auto idx = left.index() * shift + right.index(); + + out = convert_from_impl(basis, impl_key_type {degree, idx}); + } + + return out; + } + /// Get the first letter of the key as a word static let_t first_letter(storage_t basis, const our_key_type& key) { @@ -188,6 +209,9 @@ struct BasisInfo { { return lal::tensor_basis::letter(convert_to_impl(basis, key)); } + + + static bool are_same(storage_t basis1, storage_t basis2) noexcept { return basis1 == basis2; } }; }// namespace algebra }// namespace rpy diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index 64cefa2e..14146668 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -9,12 +9,13 @@ include(GNUInstallDirs) add_roughpy_component(Core INTERFACE PUBLIC_HEADERS + include/roughpy/core/alloc.h + include/roughpy/core/hash.h include/roughpy/core/types.h include/roughpy/core/traits.h include/roughpy/core/macros.h include/roughpy/core/helpers.h include/roughpy/core/slice.h - include/roughpy/core/alloc.h DEPENDENCIES INTERFACE Boost::boost ) diff --git a/core/include/roughpy/core/hash.h b/core/include/roughpy/core/hash.h new file mode 100644 index 00000000..1ac06ce9 --- /dev/null +++ b/core/include/roughpy/core/hash.h @@ -0,0 +1,25 @@ +// +// Created by sam on 09/01/24. +// + +#ifndef ROUGHPY_CORE_HASH_H +#define ROUGHPY_CORE_HASH_H + +#include "types.h" + +#include + +namespace rpy { + + +using hash_t = std::size_t; + +template +using Hash = boost::hash; + +using boost::hash_combine; + +} + + +#endif// ROUGHPY_CORE_HASH_H diff --git a/core/include/roughpy/core/macros.h b/core/include/roughpy/core/macros.h index a0924962..4a63f086 100644 --- a/core/include/roughpy/core/macros.h +++ b/core/include/roughpy/core/macros.h @@ -34,9 +34,7 @@ #define ROUGHPY_CORE_MACROS_H #include -#include -#include -#include + #ifdef __has_builtin # define RPY_HAS_BUILTIN(x) __has_builtin(x) @@ -334,80 +332,7 @@ # define RPY_FILE_NAME __FILE__ #endif -/* - * Check macro definition. - * - * This macro checks that the given expression evaluates to true (under the - * assumption that it will usually be true), and throws an error if this - * evaluates to false. - * - * Optionally, one can provide a message string literal that will be used - * instead of the default, and an optional error type. The default error type - * is a std::runtime_error. - */ -namespace rpy { -namespace errors { - -template -RPY_NO_RETURN RPY_INLINE_ALWAYS void throw_exception( - std::string msg, const char* filename, int lineno, const char* func -) -{ - std::stringstream ss; - ss << msg << " at lineno " << lineno << " in " << filename - << " in function " << func; - throw E(ss.str()); -} - -template -RPY_NO_RETURN RPY_INLINE_ALWAYS void throw_exception( - const char* msg, const char* filename, int lineno, const char* func -) -{ - std::stringstream ss; - ss << msg << " at lineno " << lineno << " in " << filename - << " in function " << func; - throw E(ss.str()); -} - -}// namespace errors -}// namespace rpy - -// Dispatch the check macro on the number of arguments -// See: https://stackoverflow.com/a/16683147/9225581 -#define RPY_CHECK_3(EXPR, MSG, TYPE) \ - do { \ - if (RPY_UNLIKELY(!(EXPR))) { \ - ::rpy::errors::throw_exception( \ - MSG, RPY_FILE_NAME, __LINE__, RPY_FUNC_NAME \ - ); \ - } \ - } while (0) - -#define RPY_CHECK_2(EXPR, MSG) RPY_CHECK_3(EXPR, MSG, std::runtime_error) - -#define RPY_CHECK_1(EXPR) RPY_CHECK_2(EXPR, "failed check \"" #EXPR "\"") - -#define RPY_CHECK_CNT_IMPL(_1, _2, _3, COUNT, ...) COUNT -// Always pass one more argument than expected, so clang doesn't complain about -// empty parameter packs. -#define RPY_CHECK_CNT(...) RPY_CHECK_CNT_IMPL(__VA_ARGS__, 3, 2, 1, 0) -#define RPY_CHECK_SEL(NUM) RPY_JOIN(RPY_CHECK_, NUM) - -#define RPY_CHECK(...) \ - RPY_INVOKE_VA(RPY_CHECK_SEL(RPY_COUNT_ARGS(__VA_ARGS__)), (__VA_ARGS__)) - -#define RPY_THROW_2(EXC_TYPE, MSG) \ - ::rpy::errors::throw_exception( \ - MSG, __FILE__, __LINE__, RPY_FUNC_NAME \ - ) -#define RPY_THROW_1(MSG) RPY_THROW_2(std::runtime_error, MSG) - -#define RPY_THROW_SEL(NUM) RPY_JOIN(RPY_THROW_, NUM) -#define RPY_THROW_CNT_IMPL(_1, _2, COUNT, ...) COUNT -#define RPY_THROW_CNT(...) RPY_THROW_CNT_IMPL(__VA_ARGS__, 2, 1, 0) -#define RPY_THROW(...) \ - RPY_INVOKE_VA(RPY_THROW_SEL(RPY_COUNT_ARGS(__VA_ARGS__)), (__VA_ARGS__)) + #ifdef RPY_DEBUG # if defined(RPY_GCC) || defined(RPY_CLANG) diff --git a/intervals/src/dyadic_interval.cpp b/intervals/src/dyadic_interval.cpp index fcaff6b8..a3177436 100644 --- a/intervals/src/dyadic_interval.cpp +++ b/intervals/src/dyadic_interval.cpp @@ -34,6 +34,7 @@ #include #include +#include #include diff --git a/intervals/src/dyadic_searcher.cpp b/intervals/src/dyadic_searcher.cpp index 441f5c1d..c607f2e9 100644 --- a/intervals/src/dyadic_searcher.cpp +++ b/intervals/src/dyadic_searcher.cpp @@ -32,6 +32,8 @@ #include "dyadic_searcher.h" +#include + #include using namespace rpy::intervals; diff --git a/intervals/src/interval.cpp b/intervals/src/interval.cpp index a06496a9..7b8d48fb 100644 --- a/intervals/src/interval.cpp +++ b/intervals/src/interval.cpp @@ -31,6 +31,8 @@ // #include + +#include #include using namespace rpy; diff --git a/intervals/src/partition.cpp b/intervals/src/partition.cpp index 9e6c14b5..abdfb89c 100644 --- a/intervals/src/partition.cpp +++ b/intervals/src/partition.cpp @@ -30,6 +30,8 @@ // #include "partition.h" +#include + #include #include diff --git a/platform/CMakeLists.txt b/platform/CMakeLists.txt index 4e4a7915..b67c266b 100644 --- a/platform/CMakeLists.txt +++ b/platform/CMakeLists.txt @@ -3,15 +3,11 @@ cmake_minimum_required(VERSION 3.21) project(Roughpy_Platform VERSION 0.0.1) - - - - - add_roughpy_component(Platform SOURCES src/configuration.cpp # src/threading/openmp_threading.cpp + src/errors.cpp src/fs_path_serialization.cpp src/devices/host/host_buffer.cpp src/devices/host/host_buffer.h @@ -70,6 +66,7 @@ add_roughpy_component(Platform include/roughpy/platform/serialization.h include/roughpy/platform/threads.h include/roughpy/platform/devices.h + include/roughpy/platform/errors.h include/roughpy/platform/devices/buffer.h include/roughpy/platform/devices/core.h include/roughpy/platform/devices/device_handle.h @@ -103,7 +100,6 @@ add_roughpy_component(Platform RoughPy::Core ) - target_precompile_headers(RoughPy_Platform PUBLIC diff --git a/platform/include/roughpy/platform/devices/core.h b/platform/include/roughpy/platform/devices/core.h index a6607488..0063a5e7 100644 --- a/platform/include/roughpy/platform/devices/core.h +++ b/platform/include/roughpy/platform/devices/core.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include diff --git a/platform/include/roughpy/platform/errors.h b/platform/include/roughpy/platform/errors.h new file mode 100644 index 00000000..a4c9a198 --- /dev/null +++ b/platform/include/roughpy/platform/errors.h @@ -0,0 +1,86 @@ +// +// Created by sam on 10/01/24. +// + +#ifndef ROUGHPY_PLATFORM_ERRORS_H +#define ROUGHPY_PLATFORM_ERRORS_H + +#include +#include + +#include + +/* + * Check macro definition. + * + * This macro checks that the given expression evaluates to true (under the + * assumption that it will usually be true), and throws an error if this + * evaluates to false. + * + * Optionally, one can provide a message string literal that will be used + * instead of the default, and an optional error type. The default error type + * is a std::runtime_error. + */ +namespace rpy { +namespace errors { + + +string format_error_message(string_view user_message, const char* filename, int lineno, const char* func, const boost::stacktrace::stacktrace& st); + +template +RPY_NO_RETURN RPY_INLINE_ALWAYS void throw_exception( + const string& msg, const char* filename, int lineno, const char* func +) +{ + throw E(format_error_message(msg, filename, lineno, func, boost::stacktrace::stacktrace())); +} + +template +RPY_NO_RETURN RPY_INLINE_ALWAYS void throw_exception( + const char* msg, const char* filename, int lineno, const char* func +) +{ + throw E(format_error_message(msg, filename, lineno, func, boost::stacktrace::stacktrace())); +} + +}// namespace errors +}// namespace rpy + +// Dispatch the check macro on the number of arguments +// See: https://stackoverflow.com/a/16683147/9225581 +#define RPY_CHECK_3(EXPR, MSG, TYPE) \ + do { \ + if (RPY_UNLIKELY(!(EXPR))) { \ + ::rpy::errors::throw_exception( \ + MSG, RPY_FILE_NAME, __LINE__, RPY_FUNC_NAME \ + ); \ + } \ + } while (0) + +#define RPY_CHECK_2(EXPR, MSG) RPY_CHECK_3(EXPR, MSG, std::runtime_error) + +#define RPY_CHECK_1(EXPR) RPY_CHECK_2(EXPR, "failed check \"" #EXPR "\"") + +#define RPY_CHECK_CNT_IMPL(_1, _2, _3, COUNT, ...) COUNT +// Always pass one more argument than expected, so clang doesn't complain about +// empty parameter packs. +#define RPY_CHECK_CNT(...) RPY_CHECK_CNT_IMPL(__VA_ARGS__, 3, 2, 1, 0) +#define RPY_CHECK_SEL(NUM) RPY_JOIN(RPY_CHECK_, NUM) + +#define RPY_CHECK(...) \ + RPY_INVOKE_VA(RPY_CHECK_SEL(RPY_COUNT_ARGS(__VA_ARGS__)), (__VA_ARGS__)) + +#define RPY_THROW_2(EXC_TYPE, MSG) \ + ::rpy::errors::throw_exception( \ + MSG, __FILE__, __LINE__, RPY_FUNC_NAME \ + ) +#define RPY_THROW_1(MSG) RPY_THROW_2(std::runtime_error, MSG) + +#define RPY_THROW_SEL(NUM) RPY_JOIN(RPY_THROW_, NUM) +#define RPY_THROW_CNT_IMPL(_1, _2, COUNT, ...) COUNT +#define RPY_THROW_CNT(...) RPY_THROW_CNT_IMPL(__VA_ARGS__, 2, 1, 0) +#define RPY_THROW(...) \ + RPY_INVOKE_VA(RPY_THROW_SEL(RPY_COUNT_ARGS(__VA_ARGS__)), (__VA_ARGS__)) + + +#endif// ROUGHPY_PLATFORM_ERRORS_H diff --git a/platform/src/devices/opencl/ocl_handle_errors.cpp b/platform/src/devices/opencl/ocl_handle_errors.cpp index 952d4068..a26b5d1b 100644 --- a/platform/src/devices/opencl/ocl_handle_errors.cpp +++ b/platform/src/devices/opencl/ocl_handle_errors.cpp @@ -30,6 +30,7 @@ // #include "ocl_handle_errors.h" +#include #include void rpy::devices::cl::handle_cl_error( diff --git a/platform/src/devices/opencl/ocl_version.cpp b/platform/src/devices/opencl/ocl_version.cpp index 362e204f..fb33d861 100644 --- a/platform/src/devices/opencl/ocl_version.cpp +++ b/platform/src/devices/opencl/ocl_version.cpp @@ -8,6 +8,7 @@ #include #include #include +#include using namespace rpy; using namespace devices; diff --git a/platform/src/errors.cpp b/platform/src/errors.cpp new file mode 100644 index 00000000..c1549eeb --- /dev/null +++ b/platform/src/errors.cpp @@ -0,0 +1,23 @@ +// +// Created by sam on 10/01/24. +// + +#include "errors.h" +#include + +using namespace rpy; + +string errors::format_error_message( + string_view user_message, + const char* filename, + int lineno, + const char* func, + const boost::stacktrace::stacktrace& st +) +{ + std::stringstream ss; + ss << user_message << " at lineno " << lineno << " in " << filename + << " in function " << func << '\n' + << st << '\n'; + return ss.str(); +} \ No newline at end of file diff --git a/roughpy/CMakeLists.txt b/roughpy/CMakeLists.txt index 4fa34c16..65eb1cf6 100644 --- a/roughpy/CMakeLists.txt +++ b/roughpy/CMakeLists.txt @@ -111,6 +111,8 @@ target_sources(RoughPy_PyModule PRIVATE src/scalars/parse_key_scalar_stream.h src/scalars/r_py_polynomial.cpp src/scalars/r_py_polynomial.h + # src/scalars/scalar.h + # src/scalars/scalar.cpp src/scalars/scalar_type.h src/scalars/scalar_type.cpp src/scalars/scalars.cpp diff --git a/roughpy/__init__.py b/roughpy/__init__.py index 78a8781e..efc4763c 100644 --- a/roughpy/__init__.py +++ b/roughpy/__init__.py @@ -33,3 +33,6 @@ def _add_dynload_location(path: Path): import roughpy._roughpy from roughpy._roughpy import * + +from . import tensor_functions + diff --git a/roughpy/src/algebra/basis.cpp b/roughpy/src/algebra/basis.cpp index 62e02a83..f1abe618 100644 --- a/roughpy/src/algebra/basis.cpp +++ b/roughpy/src/algebra/basis.cpp @@ -47,13 +47,19 @@ static py::class_ wordlike_basis_setup(py::module_& m, const char* name) py::class_ basis(m, name); -// basis.def(py::init([](deg_t width, deg_t depth) { -// -// })); + // basis.def(py::init([](deg_t width, deg_t depth) { + // + // })); - basis.def_property_readonly("width", &T::width); - basis.def_property_readonly("depth", &T::depth); - basis.def_property_readonly("dimension", &T::dimension); + basis.def_property_readonly("width", [](const T& basis) { + return basis.width(); + }); + basis.def_property_readonly("depth", [](const T& basis) { + return basis.depth(); + }); + basis.def_property_readonly("dimension", [](const T& basis) { + return basis.dimension(); + }); basis.def( "index_to_key", @@ -75,12 +81,12 @@ static py::class_ wordlike_basis_setup(py::module_& m, const char* name) [](const T& self, const K& key) { return self.parents(0); }, "key"_a ); - basis.def("size", &T::size); - - basis.def("__iter__", [](const T& self) { - return KIter(self); + basis.def("size", [](const T& basis, deg_t degree) { + return basis.size(degree); }); + basis.def("__iter__", [](const T& self) { return KIter(self); }); + return basis; } @@ -88,10 +94,9 @@ void python::init_basis(py::module_& m) { wordlike_basis_setup( - m, "TensorBasis" + m, + "TensorBasis" ); - wordlike_basis_setup( - m, "LieBasis" - ); + wordlike_basis_setup(m, "LieBasis"); } diff --git a/roughpy/src/algebra/context.cpp b/roughpy/src/algebra/context.cpp index 2db80a4a..f087b8e5 100644 --- a/roughpy/src/algebra/context.cpp +++ b/roughpy/src/algebra/context.cpp @@ -349,6 +349,36 @@ RPyContext_new(PyObject* self, PyObject* args, PyObject* kwargs) RPY_UNREACHABLE_RETURN(nullptr); } +static PyObject * + RPyContext_richcompare(PyObject * o1, PyObject * o2, int +opid ) { + +switch ( opid ) { +case Py_EQ : +if ( ctx_cast(o1) == ctx_cast(o2)) { +Py_RETURN_TRUE ; +} +Py_RETURN_FALSE; +case Py_NE: +if ( +ctx_cast(o1) +== +ctx_cast(o2) +) { +Py_RETURN_FALSE; +} +Py_RETURN_TRUE; +case Py_LT: +case Py_LE: +case Py_GT: +case Py_GE: +break; +} + +return nullptr; +} + + static const char* CONTEXT_DOC = R"rpydoc()rpydoc"; PyTypeObject rpy::python::RPyContext_Type = { PyVarObject_HEAD_INIT(nullptr, 0) // @@ -374,7 +404,7 @@ PyTypeObject rpy::python::RPyContext_Type = { CONTEXT_DOC, /* tp_doc */ nullptr, /* tp_traverse */ nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ + RPyContext_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ (getiterfunc) nullptr, /* tp_iter */ nullptr, /* tp_iternext */ diff --git a/roughpy/src/algebra/free_tensor.cpp b/roughpy/src/algebra/free_tensor.cpp index 10ee433c..7ce1027f 100644 --- a/roughpy/src/algebra/free_tensor.cpp +++ b/roughpy/src/algebra/free_tensor.cpp @@ -160,6 +160,9 @@ void python::init_free_tensor(py::module_& m) klass.def("__getitem__", [](const FreeTensor& self, key_type key) { return self[key]; }); + klass.def("__getitem__", [](const FreeTensor& self, const PyTensorKey& tkey) { + return self[static_cast(tkey)]; + }); klass.def("exp", &FreeTensor::exp); klass.def("log", &FreeTensor::log); diff --git a/roughpy/src/algebra/lie.cpp b/roughpy/src/algebra/lie.cpp index 8dd3f781..4e370672 100644 --- a/roughpy/src/algebra/lie.cpp +++ b/roughpy/src/algebra/lie.cpp @@ -105,6 +105,9 @@ void python::init_lie(py::module_& m) klass.def("__getitem__", [](const Lie& self, key_type key) { return self[key]; }); + klass.def("__getitem__", [](const Lie& self, const PyLieKey& lkey) { + return self[static_cast(lkey)]; + }); klass.def("__repr__", [](const Lie& self) { std::stringstream ss; diff --git a/roughpy/src/algebra/lie_key.cpp b/roughpy/src/algebra/lie_key.cpp index 0826e996..09fbca91 100644 --- a/roughpy/src/algebra/lie_key.cpp +++ b/roughpy/src/algebra/lie_key.cpp @@ -469,3 +469,40 @@ python::parse_lie_key(const algebra::LieBasis& basis, const py::args& args, cons ToLieKeyHelper helper(basis, basis.width()); return PyLieKey(basis, helper(args[0])); } + +namespace { + +optional to_internal_key_recurse(const algebra::LieBasis& basis, const python::PyLieLetter* base) noexcept +{ + if (!base->is_offset()) { + return static_cast(*base); + } + + optional result; + + auto left = to_internal_key_recurse(basis, base + static_cast(*base)); + auto right = to_internal_key_recurse(basis, ++base); + + if (left && right) { + result = basis.child(*left, *right); + } + + return result; +} + +} + + +python::PyLieKey::operator key_type() const noexcept { + if (m_data.size() == 1) { + return static_cast(m_data[0]); + } + + auto combined = to_internal_key_recurse(m_basis, m_data.data()); + + if (combined) { + return *combined; + } + + return 0; +} diff --git a/roughpy/src/algebra/lie_key.h b/roughpy/src/algebra/lie_key.h index 932540c4..1e5d704f 100644 --- a/roughpy/src/algebra/lie_key.h +++ b/roughpy/src/algebra/lie_key.h @@ -75,6 +75,8 @@ class PyLieKey deg_t degree() const; bool equals(const PyLieKey& other) const noexcept; + + explicit operator key_type() const noexcept; }; PyLieKey diff --git a/roughpy/src/algebra/setup_algebra_type.h b/roughpy/src/algebra/setup_algebra_type.h index b63135a8..fc8b8206 100644 --- a/roughpy/src/algebra/setup_algebra_type.h +++ b/roughpy/src/algebra/setup_algebra_type.h @@ -49,6 +49,7 @@ RPY_MSVC_DISABLE_WARNING(4661) #include #include "args/numpy.h" +#include "context.h" #include "scalars/scalar_type.h" namespace rpy { @@ -86,6 +87,9 @@ void setup_algebra_type(py::class_& klass) klass.def_property_readonly("storage_type", [](const Alg& arg) { return arg.storage_type(); }); + klass.def_property_readonly("context", [](const Alg& arg) { + return py::handle(python::RPyContext_FromContext(arg.context())); + }); // setup dynamic properties klass.def("size", &Alg::size); diff --git a/roughpy/src/algebra/shuffle_tensor.cpp b/roughpy/src/algebra/shuffle_tensor.cpp index 028fb333..4817082b 100644 --- a/roughpy/src/algebra/shuffle_tensor.cpp +++ b/roughpy/src/algebra/shuffle_tensor.cpp @@ -123,6 +123,10 @@ void rpy::python::init_shuffle_tensor(py::module_& m) klass.def("__getitem__", [](const ShuffleTensor& self, key_type key) { return self[key]; }); + klass.def("__getitem__", [](const ShuffleTensor& self, const PyTensorKey& tkey) { + return self[static_cast(tkey)]; + }); + klass.def( "__matmul__", diff --git a/roughpy/src/algebra/tensor_key.cpp b/roughpy/src/algebra/tensor_key.cpp index 1f025d93..c34e95a1 100644 --- a/roughpy/src/algebra/tensor_key.cpp +++ b/roughpy/src/algebra/tensor_key.cpp @@ -26,8 +26,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "tensor_key.h" + #include #include +#include #include #include @@ -40,6 +42,16 @@ python::PyTensorKey::PyTensorKey(algebra::TensorBasis basis, key_type key) : m_key(key), m_basis(std::move(basis)) {} +python::PyTensorKey::PyTensorKey(algebra::TensorBasis basis, Slice letters) + : m_key(0), m_basis(std::move(basis)) +{ + const auto width = m_basis.width(); + for (const auto& letter : letters) { + m_key *= static_cast(width); + m_key += letter; + } +} + python::PyTensorKey::operator key_type() const noexcept { return m_key; } string python::PyTensorKey::to_string() const { @@ -118,6 +130,12 @@ bool python::PyTensorKey::less(const python::PyTensorKey& other) const noexcept { return m_key < other.m_key; } +python::PyTensorKey python::PyTensorKey::reverse() const +{ + auto letters = to_letters(); + std::reverse(letters.begin(), letters.end()); + return {m_basis, letters}; +} static python::PyTensorKey construct_key(const py::args& args, const py::kwargs& kwargs) @@ -181,6 +199,30 @@ construct_key(const py::args& args, const py::kwargs& kwargs) return python::PyTensorKey(ctx->get_tensor_basis(), result); } +python::PyTensorKey python::operator*( + const python::PyTensorKey& lhs, + const python::PyTensorKey& rhs +) +{ + const auto basis = lhs.basis(); + RPY_CHECK(basis == rhs.basis()); + const auto left_letters = lhs.to_letters(); + const auto right_letters = rhs.to_letters(); + + std::vector letters; + letters.reserve(left_letters.size() + right_letters.size()); + auto it = letters.insert(letters.end(), left_letters.begin(), left_letters.end()); + letters.insert(it, right_letters.begin(), right_letters.end()); + + return {basis, letters}; +} + +hash_t python::hash_value(const python::PyTensorKey& key) noexcept { + auto seed = hash_value(key.m_basis); + hash_combine(seed, key.m_key); + return seed; +} + void python::init_py_tensor_key(py::module_& m) { py::class_ klass(m, "TensorKey"); @@ -188,6 +230,7 @@ void python::init_py_tensor_key(py::module_& m) klass.def_property_readonly("width", &PyTensorKey::width); klass.def_property_readonly("max_degree", &PyTensorKey::depth); + klass.def("basis", &PyTensorKey::basis); klass.def("to_index", [](const PyTensorKey& key) { return static_cast(key); @@ -195,10 +238,15 @@ void python::init_py_tensor_key(py::module_& m) klass.def("degree", [](const PyTensorKey& key) { return key.to_letters().size(); }); + klass.def("to_letters", &PyTensorKey::to_letters); klass.def("split_n", &PyTensorKey::split_n, "n"_a); + klass.def("reverse", &PyTensorKey::reverse); klass.def("__str__", &PyTensorKey::to_string); klass.def("__repr__", &PyTensorKey::to_string); - klass.def("__eq__", &PyTensorKey::equals); + + klass.def("__hash__", [](const PyTensorKey& key) { return hash_value(key); }); + + klass.def(py::self * py::self); } diff --git a/roughpy/src/algebra/tensor_key.h b/roughpy/src/algebra/tensor_key.h index 002d4983..433171b1 100644 --- a/roughpy/src/algebra/tensor_key.h +++ b/roughpy/src/algebra/tensor_key.h @@ -31,6 +31,8 @@ #include "roughpy_module.h" +#include +#include #include namespace rpy { @@ -61,25 +63,41 @@ class PyTensorKey public: explicit PyTensorKey(algebra::TensorBasis basis, key_type key); + PyTensorKey(algebra::TensorBasis basis, Slice letters); + explicit operator key_type() const noexcept; - string to_string() const; - PyTensorKey lparent() const; - PyTensorKey rparent() const; - bool is_letter() const; + RPY_NO_DISCARD string to_string() const; + RPY_NO_DISCARD PyTensorKey lparent() const; + RPY_NO_DISCARD PyTensorKey rparent() const; + RPY_NO_DISCARD bool is_letter() const; + + RPY_NO_DISCARD deg_t width() const; + RPY_NO_DISCARD deg_t depth() const; + RPY_NO_DISCARD algebra::TensorBasis basis() const { return m_basis; } + RPY_NO_DISCARD pair split_n(deg_t n) const; - deg_t width() const; - deg_t depth() const; - algebra::TensorBasis basis() const { return m_basis; } - pair split_n(deg_t n) const; + RPY_NO_DISCARD deg_t degree() const; + RPY_NO_DISCARD std::vector to_letters() const; - deg_t degree() const; - std::vector to_letters() const; + RPY_NO_DISCARD PyTensorKey reverse() const; + RPY_NO_DISCARD bool equals(const PyTensorKey& other) const noexcept; + RPY_NO_DISCARD bool less(const PyTensorKey& other) const noexcept; + + + + friend hash_t hash_value(const PyTensorKey& key) noexcept; + }; +hash_t hash_value(const PyTensorKey& key) noexcept; + + +RPY_NO_DISCARD PyTensorKey operator*(const PyTensorKey& lhs, const PyTensorKey& rhs); + void init_py_tensor_key(py::module_& m); }// namespace python diff --git a/roughpy/src/roughpy_module.h b/roughpy/src/roughpy_module.h index 4e577643..9627133c 100644 --- a/roughpy/src/roughpy_module.h +++ b/roughpy/src/roughpy_module.h @@ -36,7 +36,9 @@ #include #include #include + #include +#include #if defined(RPY_GCC) # define RPY_NO_EXPORT __attribute__((visibility("hidden"))) diff --git a/roughpy/src/scalars/scalar.cpp b/roughpy/src/scalars/scalar.cpp new file mode 100644 index 00000000..c47addb8 --- /dev/null +++ b/roughpy/src/scalars/scalar.cpp @@ -0,0 +1,437 @@ +// +// Created by sam on 1/17/24. +// + +#include "scalar.h" +#include "r_py_polynomial.h" +#include "scalar_type.h" +#include "scalars.h" + +#include + +using namespace rpy; +using namespace rpy::python; + +scalars::Scalar pyobject_to_scalar(PyObject* obj); + +void PyScalarProxy::do_conversion() +{ + if (!PyScalar_Check(p_object)) { + m_converted = scalars::Scalar(); + assign_py_object_to_scalar(*m_converted, p_object); + } +} + +namespace { + +template +inline PyObject* +do_arithmetic(PyScalarProxy lhs, PyScalarProxy rhs, Op&& op) noexcept +{ + try { + return PyScalar_FromScalar(op(lhs.ref(), rhs.ref())); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_ArithmeticError, exc.what()); + return nullptr; + } +} + +struct ScalarInplaceAdd { + scalars::Scalar& + operator()(scalars::Scalar& lhs, const scalars::Scalar& rhs) const + { + return lhs += rhs; + } + + scalars::Scalar + operator()(const scalars::Scalar& lhs, const scalars::Scalar& rhs) const + { + return lhs + rhs; + } +}; + +struct ScalarInplaceSub { + scalars::Scalar& + operator()(scalars::Scalar& lhs, const scalars::Scalar& rhs) const + { + return lhs -= rhs; + } + scalars::Scalar + operator()(const scalars::Scalar& lhs, const scalars::Scalar& rhs) const + { + return lhs - rhs; + } +}; + +struct ScalarInplaceMul { + scalars::Scalar& + operator()(scalars::Scalar& lhs, const scalars::Scalar& rhs) const + { + return lhs *= rhs; + } + scalars::Scalar + operator()(const scalars::Scalar& lhs, const scalars::Scalar& rhs) const + { + return lhs * rhs; + } +}; + +struct ScalarInplaceDiv { + scalars::Scalar& + operator()(scalars::Scalar& lhs, const scalars::Scalar& rhs) const + { + return lhs /= rhs; + } + scalars::Scalar + operator()(const scalars::Scalar& lhs, const scalars::Scalar& rhs) const + { + return lhs / rhs; + } +}; + +template +inline PyObject* +do_inplace_arithmetic(PyScalarProxy lhs, PyScalarProxy rhs, Op&& op) +{ + if (!lhs.is_scalar()) { + return do_arithmetic( + std::move(lhs), + std::move(rhs), + std::forward(op) + ); + } + + try { + op(lhs.mut_ref(), rhs.ref()); + // Py_INCREF(lhs.object()); + // return lhs.object(); + Py_RETURN_NONE; + } catch (std::exception& exc) { + PyErr_SetString(PyExc_ArithmeticError, exc.what()); + return nullptr; + } +} + +template +PyObject* do_compare(PyScalarProxy lhs, PyScalarProxy rhs, Comp&& comp) noexcept +{ + try { + if (comp(lhs.ref(), rhs.ref())) { Py_RETURN_TRUE; } + Py_RETURN_FALSE; + } catch (...) { + Py_RETURN_NOTIMPLEMENTED; + } +} + +}// namespace + +extern "C" { + +static PyObject* PyScalar_type(PyObject* self) +{ + const auto& scalar = cast_pyscalar(self); + auto tp = scalar.type(); + if (tp) { return PyScalarType_FromScalarType(*tp); } + + Py_INCREF(&PyLong_Type); + return reinterpret_cast(&PyLong_Type); +} + +static PyObject* PyScalar_to_float(PyObject* self) +{ + const auto& scalar = cast_pyscalar(self); + + try { + return PyFloat_FromDouble(scalars::scalar_cast(scalar)); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_RuntimeError, "unable to convert to float"); + return nullptr; + } +} + +static PyMethodDef PyScalar_methods[] = { + {"type", (PyCFunction) PyScalar_type, METH_METHOD | METH_NOARGS}, + {"to_float", (PyCFunction) PyScalar_to_float, METH_METHOD | METH_NOARGS + }, + {nullptr, nullptr, 0, nullptr} +}; + +static PyObject* PyScalar_add(PyObject* self, PyObject* other) +{ + std::plus op; + return do_arithmetic(self, other, op); +} + +static PyObject* PyScalar_sub(PyObject* self, PyObject* other) +{ + std::minus op; + return do_arithmetic(self, other, op); +} + +static PyObject* PyScalar_mul(PyObject* self, PyObject* other) +{ + std::multiplies op; + return do_arithmetic(self, other, op); +} + +static PyObject* +PyScalar_rem(PyObject* RPY_UNUSED_VAR self, PyObject* RPY_UNUSED_VAR other) +{ + Py_RETURN_NOTIMPLEMENTED; +} + +static PyObject* PyScalar_pow( + PyObject* RPY_UNUSED_VAR self, + PyObject* RPY_UNUSED_VAR other, + PyObject* RPY_UNUSED_VAR modulo +) +{ + Py_RETURN_NOTIMPLEMENTED; +} + +static PyObject* PyScalar_neg(PyObject* self) +{ + try { + return PyScalar_FromScalar(-cast_pyscalar(self)); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_ArithmeticError, exc.what()); + return nullptr; + } +} + +static PyObject* PyScalar_pos(PyObject* self) +{ + try { + return PyScalar_FromScalar(scalars::Scalar(cast_pyscalar(self))); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_ArithmeticError, exc.what()); + return nullptr; + } +} + +static int PyScalar_bool(PyObject* self) +{ + return cast_pyscalar(self).is_zero() ? 0 : 1; +} + +static PyObject* PyScalar_inplace_add(PyObject* self, PyObject* other) +{ + ScalarInplaceAdd add; + return do_inplace_arithmetic(self, other, add); +} + +static PyObject* PyScalar_inplace_sub(PyObject* self, PyObject* other) +{ + ScalarInplaceSub sub; + return do_inplace_arithmetic(self, other, sub); +} + +static PyObject* PyScalar_inplace_mul(PyObject* self, PyObject* other) +{ + ScalarInplaceMul mul; + return do_inplace_arithmetic(self, other, mul); +} + +static PyObject* PyScalar_inplace_rem( + PyObject* RPY_UNUSED_VAR self, + PyObject* RPY_UNUSED_VAR other +) +{ + Py_RETURN_NOTIMPLEMENTED; +} + +static PyObject* PyScalar_inplace_pow( + PyObject* RPY_UNUSED_VAR self, + PyObject* RPY_UNUSED_VAR other, + PyObject* RPY_UNUSED_VAR modulo +) +{ + Py_RETURN_NOTIMPLEMENTED; +} + +static PyObject* +PyScalar_floordiv(PyObject* RPY_UNUSED_VAR self, PyObject* RPY_UNUSED_VAR other) +{ + Py_RETURN_NOTIMPLEMENTED; +} + +static PyObject* PyScalar_div(PyObject* self, PyObject* other) +{ + ScalarInplaceDiv div; + return do_inplace_arithmetic(self, other, div); +} + +static PyObject* PyScalar_inplace_floordiv( + PyObject* RPY_UNUSED_VAR self, + PyObject* RPY_UNUSED_VAR other +) +{ + Py_RETURN_NOTIMPLEMENTED; +} + +static PyObject* PyScalar_inplace_div( + PyObject* RPY_UNUSED_VAR self, + PyObject* RPY_UNUSED_VAR other +) +{ + Py_RETURN_NOTIMPLEMENTED; +} + +static PyNumberMethods PyScalar_number{ + (binaryfunc) PyScalar_add, /* nb_add */ + (binaryfunc) PyScalar_sub, /* nb_subtract */ + (binaryfunc) PyScalar_mul, /* nb_multiply */ + (binaryfunc) PyScalar_rem, /* nb_remainder */ + nullptr, /* nb_divmod */ + (ternaryfunc) PyScalar_pow, /* nb_power */ + PyScalar_neg, /* nb_negative */ + PyScalar_pos, /* nb_positive */ + nullptr, /* nb_absolute */ + (inquiry) PyScalar_bool, /* nb_bool */ + nullptr, /* nb_invert */ + nullptr, /* nb_lshift */ + nullptr, /* nb_rshift */ + nullptr, /* nb_and */ + nullptr, /* nb_xor */ + nullptr, /* nb_or */ + nullptr, /* nb_int */ + nullptr, /* nb_reserved */ + nullptr, /* nb_float */ + (binaryfunc) PyScalar_inplace_add, /* nb_inplace_add */ + (binaryfunc) PyScalar_inplace_sub, /* nb_inplace_subtract */ + (binaryfunc) PyScalar_inplace_mul, /* nb_inplace_multiply */ + (binaryfunc) PyScalar_inplace_rem, /* nb_inplace_remainder */ + (ternaryfunc) PyScalar_inplace_pow, /* nb_inplace_power */ + nullptr, /* nb_inplace_lshift */ + nullptr, /* nb_inplace_rshift */ + nullptr, /* nb_inplace_and */ + nullptr, /* nb_inplace_xor */ + nullptr, /* nb_inplace_or */ + (binaryfunc) PyScalar_floordiv, /* nb_floor_divide */ + (binaryfunc) PyScalar_div, /* nb_true_divide */ + (binaryfunc) PyScalar_inplace_floordiv, /* nb_inplace_floor_divide */ + (binaryfunc) PyScalar_inplace_div, /* nb_inplace_true_div */ + nullptr, /* nb_index */ + nullptr, /* nb_matrix_multiply */ + nullptr /* nb_inplace_matrix_multiply */ +}; + +static void PyScalar_finalize(PyObject* self) +{ + cast_pyscalar_mut(self).~Scalar(); +} + +static PyObject* PyScalar_repr(PyObject* self) +{ + try { + std::stringstream ss; + ss << cast_pyscalar(self); + return PyUnicode_FromString(ss.str().c_str()); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_RuntimeError, exc.what()); + return nullptr; + } +} + +static PyObject* PyScalar_str(PyObject* self) +{ + try { + std::stringstream ss; + ss << cast_pyscalar(self); + return PyUnicode_FromString(ss.str().c_str()); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_RuntimeError, exc.what()); + return nullptr; + } +} + +static PyObject* +PyScalar_richcmp(PyObject* self, PyObject* other, const int optype) +{ + switch (optype) { + case Py_EQ: + return do_compare(self, other, std::equal_to()); + case Py_NE: + return do_compare( + self, + other, + std::not_equal_to() + ); + case Py_LT: + case Py_LE: + case Py_GT: + case Py_GE: + default: break; + } + Py_RETURN_NOTIMPLEMENTED; +} + +PyTypeObject rpy::python::PyScalar_Type = { + PyVarObject_HEAD_INIT(nullptr, 0)// + "roughpy.Scalar", /* tp_name */ + sizeof(PyScalar), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_as_async */ + PyScalar_repr, /* tp_repr */ + &PyScalar_number, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + PyScalar_str, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + PyDoc_STR("My objects"), /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + PyScalar_richcmp, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ + nullptr, /* tp_free */ + nullptr, /* tp_is_gc */ + nullptr, /* tp_bases */ + nullptr, /* tp_mro */ + nullptr, /* tp_cache */ + nullptr, /* tp_subclasses */ + nullptr, /* tp_weaklist */ + nullptr, /* tp_del */ + 0, /* tp_version_tag */ + PyScalar_finalize, /* tp_finalize */ + nullptr /* tp_vectorcall */ +}; +} + +scalars::Scalar pyobject_to_scalar(PyObject* obj) +{ + scalars::Scalar result; + + if (RPyPolynomial_Check(obj)) { + result = RPyPolynomial_cast(obj); + } else if (PyFloat_CheckExact(obj)) { + result = PyFloat_AsDouble(obj); + } else if (PyLong_CheckExact(obj)) { + result = PyLong_AsLongLong(obj); + } else { + result = PyFloat_AsDouble(obj); + } + + return result; +} diff --git a/roughpy/src/scalars/scalar.h b/roughpy/src/scalars/scalar.h new file mode 100644 index 00000000..40f72897 --- /dev/null +++ b/roughpy/src/scalars/scalar.h @@ -0,0 +1,80 @@ +// +// Created by sam on 1/17/24. +// + +#ifndef ROUGHPY_SCALAR_H +#define ROUGHPY_SCALAR_H + +#include "roughpy_module.h" + +#include + +namespace rpy { +namespace python { + +extern "C" PyTypeObject PyScalar_Type; + +struct PyScalar { + PyObject_HEAD scalars::Scalar m_content; +}; + +inline bool PyScalar_Check(PyObject* obj) noexcept +{ + return Py_TYPE(obj) == &PyScalar_Type; +} + +inline scalars::Scalar& cast_pyscalar_mut(PyObject* obj) noexcept +{ + return reinterpret_cast(obj)->m_content; +} + +inline const scalars::Scalar& cast_pyscalar(PyObject* obj) noexcept +{ + return reinterpret_cast(obj)->m_content; +} + +class PyScalarProxy +{ + optional m_converted; + PyObject* p_object; + bool m_needs_conversion; + + void do_conversion(); + +public: + PyScalarProxy(PyObject* obj) + : p_object(obj), + m_needs_conversion(PyScalar_Check(obj)) + {} + + bool is_scalar() const noexcept { return !m_needs_conversion; } + + const scalars::Scalar& ref() + { + if (m_needs_conversion) { + if (!m_converted) { do_conversion(); } + RPY_DBG_ASSERT(m_converted); + return *m_converted; + } + return reinterpret_cast(p_object)->m_content; + } + + scalars::Scalar& mut_ref() + { + if (m_needs_conversion) { + if (!m_converted) { do_conversion(); } + RPY_DBG_ASSERT(m_converted); + return *m_converted; + } + return reinterpret_cast(p_object)->m_content; + } + + PyObject* object() const noexcept { return p_object; } +}; + +// PyObject* PyScalar_FromScalar(scalars::Scalar&& obj); + +}// namespace python +}// namespace rpy + +#endif// ROUGHPY_SCALAR_H diff --git a/roughpy/src/scalars/scalar_type.cpp b/roughpy/src/scalars/scalar_type.cpp index 4e7652c5..fcb00640 100644 --- a/roughpy/src/scalars/scalar_type.cpp +++ b/roughpy/src/scalars/scalar_type.cpp @@ -28,6 +28,9 @@ #include "scalar_type.h" +// #include "scalar.h" +#include "scalars.h" + #include @@ -43,11 +46,109 @@ static PyMethodDef PyScalarMetaType_methods[] = { {nullptr, nullptr, 0, nullptr} }; -PyObject* PyScalarMetaType_call(PyObject*, PyObject*, PyObject*) +extern "C" PyObject* +PyScalarMetaType_call(PyObject* self, PyObject* args, PyObject* kwargs) { - // TODO: implement this? - PyErr_SetString(PyExc_AssertionError, "doh"); - return nullptr; + using namespace python; + auto* type = reinterpret_cast(self); + const auto* stype = reinterpret_cast(type)->tp_ctype; + + static const char* kwords[] = {"data", "numerator", "denominator"}; + + PyObject* data = nullptr; + PyObject* numerator = nullptr; + PyObject* denominator = nullptr; + + if (PyArg_ParseTupleAndKeywords( + args, + kwargs, + "|OOO", + const_cast(kwords), + &data, + &numerator, + &denominator + ) + == 0) { + return nullptr; + } + + scalars::Scalar result(stype); + if (data != nullptr) { + if (numerator != nullptr) { + if (denominator != nullptr) { + PyErr_SetString( + PyExc_ValueError, + "cannot specify data and numerator/denominator" + ); + return nullptr; + } + denominator = numerator; + numerator = data; + data = nullptr; + } else if (denominator != nullptr) { + numerator = data; + data = nullptr; + } + } else if (numerator != nullptr) { + if (denominator == nullptr) { + PyErr_SetString(PyExc_ValueError, "no denominator provided"); + return nullptr; + } + } else if (denominator != nullptr) { + PyErr_SetString(PyExc_ValueError, "no numerator provided"); + return nullptr; + } else { + // No arguments means zero + + try { + return py::cast(std::move(result)).release().ptr(); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_RuntimeError, exc.what()); + return nullptr; + } + } + + if (data != nullptr) { + try { + assign_py_object_to_scalar(result, data); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_RuntimeError, exc.what()); + return nullptr; + } + } else { + if (PyLong_Check(numerator) == 0) { + PyErr_SetString(PyExc_TypeError, "numerator should be an int"); + return nullptr; + } + if (PyLong_Check(denominator) == 0) { + PyErr_SetString(PyExc_TypeError, "denominator should be an int"); + return nullptr; + } + + auto num = PyLong_AsLongLong(numerator); + auto denom = PyLong_AsLongLong(denominator); + + try { + result = scalars::Scalar(stype, num, denom); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_RuntimeError, exc.what()); + return nullptr; + } + } + + // auto* obj = reinterpret_cast(type->tp_alloc(type, 0)); + // if (obj == nullptr) { return nullptr; } + // + // construct_inplace(&obj->m_content); + + try { + return py::cast(std::move(result)).release().ptr(); + } catch (std::exception& exc) { + PyErr_SetString(PyExc_RuntimeError, exc.what()); + return nullptr; + } + + // return reinterpret_cast(obj); } static PyTypeObject PyScalarMetaType_type = { PyVarObject_HEAD_INIT( @@ -243,50 +344,50 @@ char python::format_to_type_char(const string& fmt) after_loop: return python_format; } -//string python::py_buffer_to_type_id(const py::buffer_info& info) +// string python::py_buffer_to_type_id(const py::buffer_info& info) //{ -// using scalars::type_id_of; +// using scalars::type_id_of; // -// auto python_format = format_to_type_char(info.format); -// string format; -// switch (python_format) { -// case 'd': format = type_id_of(); break; -// case 'f': format = type_id_of(); break; -// case 'l': { -// if (info.itemsize == sizeof(int)) { -// format = type_id_of(); -// } else { -// format = type_id_of(); -// } -// break; -// } -// case 'q': format = scalars::type_id_of(); break; -// case 'L': -// if (info.itemsize == sizeof(int)) { -// format = type_id_of(); -// } else { -// format = type_id_of(); -// } -// break; -// case 'Q': format = type_id_of(); break; -// case 'i': format = type_id_of(); break; -// case 'I': format = type_id_of(); break; -// case 'n': -// format = type_id_of(); -// break; -// case 'N': -// format = type_id_of(); -// break; -// case 'h': format = type_id_of(); break; -// case 'H': format = type_id_of(); break; -// case 'b': -// case 'c': format = type_id_of(); break; -// case 'B': format = type_id_of(); break; -// default: RPY_THROW(std::runtime_error, "Unrecognised data format"); -// } +// auto python_format = format_to_type_char(info.format); +// string format; +// switch (python_format) { +// case 'd': format = type_id_of(); break; +// case 'f': format = type_id_of(); break; +// case 'l': { +// if (info.itemsize == sizeof(int)) { +// format = type_id_of(); +// } else { +// format = type_id_of(); +// } +// break; +// } +// case 'q': format = scalars::type_id_of(); break; +// case 'L': +// if (info.itemsize == sizeof(int)) { +// format = type_id_of(); +// } else { +// format = type_id_of(); +// } +// break; +// case 'Q': format = type_id_of(); break; +// case 'i': format = type_id_of(); break; +// case 'I': format = type_id_of(); break; +// case 'n': +// format = type_id_of(); +// break; +// case 'N': +// format = type_id_of(); +// break; +// case 'h': format = type_id_of(); break; +// case 'H': format = type_id_of(); break; +// case 'b': +// case 'c': format = type_id_of(); break; +// case 'B': format = type_id_of(); break; +// default: RPY_THROW(std::runtime_error, "Unrecognised data format"); +// } // -// return format; -//} +// return format; +// } devices::TypeInfo python::py_buffer_to_type_info(const py::buffer_info& info) { @@ -367,7 +468,10 @@ py::type python::scalar_type_to_py_type(const scalars::ScalarType* type) ); } - RPY_THROW(py::type_error, "no matching type for type " + string(type->name())); + RPY_THROW( + py::type_error, + "no matching type for type " + string(type->name()) + ); } void python::init_scalar_types(pybind11::module_& m) @@ -409,3 +513,21 @@ void python::init_scalar_types(pybind11::module_& m) *scalars::ScalarType::of() ); } + +PyObject* python::PyScalarType_FromScalarType(const scalars::ScalarType* type) +{ + if (type == nullptr) { + PyErr_SetString(PyExc_RuntimeError, "invalid scalar type"); + return nullptr; + } + + const auto found = ctype_type_cache.find(type); + if (found != ctype_type_cache.end()) { + return py::reinterpret_borrow(found->second) + .release() + .ptr(); + } + + PyErr_SetString(PyExc_RuntimeError, "unregistered scalar type"); + return nullptr; +} diff --git a/roughpy/src/scalars/scalar_type.h b/roughpy/src/scalars/scalar_type.h index 12187eb5..48ea391f 100644 --- a/roughpy/src/scalars/scalar_type.h +++ b/roughpy/src/scalars/scalar_type.h @@ -73,6 +73,10 @@ py::handle get_scalar_baseclass(); extern "C" void PyScalarMetaType_dealloc(PyObject* arg); void register_scalar_type(const scalars::ScalarType* ctype, py::handle py_type); + +extern "C" PyObject* PyScalarType_FromScalarType(const scalars::ScalarType* type +); + py::object to_ctype_type(const scalars::ScalarType* type); inline void make_scalar_type(py::module_& m, const scalars::ScalarType* ctype) diff --git a/roughpy/src/scalars/scalars.cpp b/roughpy/src/scalars/scalars.cpp index d34ee3e8..b83d1b96 100644 --- a/roughpy/src/scalars/scalars.cpp +++ b/roughpy/src/scalars/scalars.cpp @@ -39,6 +39,7 @@ #include "dlpack.h" #include "r_py_polynomial.h" +// #include "scalar.h" #include "scalar_type.h" #include "args/dlpack_helpers.h" @@ -62,10 +63,12 @@ void python::assign_py_object_to_scalar( dst = object.cast(); } else if (RPyPolynomial_Check(object.ptr())) { dst = RPyPolynomial_cast(object.ptr()); + } else if (py::isinstance(object)) { + dst = object.cast(); } else { // TODO: other checks - auto tp = py::type::of(object); + auto tp = py::str(py::type::of(object)); RPY_THROW( py::value_error, "bad conversion from " + tp.cast() + " to " @@ -82,6 +85,11 @@ void python::init_scalars(pybind11::module_& m) python::init_scalar_types(m); python::init_monomial(m); + // + // if (PyType_Ready(&PyScalar_Type) < 0) { + // throw py::error_already_set(); + // } + // m.add_object("Scalar", (PyObject*) &PyScalar_Type); py::class_ klass(m, "Scalar", SCALAR_DOC); @@ -93,25 +101,142 @@ void python::init_scalars(pybind11::module_& m) RPY_CLANG_DISABLE_WARNING(-Wself-assign-overloaded) klass.def(-py::self); - klass.def(py::self + py::self); - klass.def(py::self - py::self); - klass.def(py::self * py::self); - klass.def(py::self / py::self); + // klass.def(py::self + py::self); + // klass.def(py::self - py::self); + // klass.def(py::self * py::self); + // klass.def(py::self / py::self); + // + // klass.def(py::self += py::self); + // klass.def(py::self -= py::self); + // klass.def(py::self *= py::self); + // klass.def(py::self /= py::self); + + // klass.def(py::self == py::self); + // klass.def(py::self != py::self); + + klass.def("__add__", [](const Scalar& self, const Scalar& other) { + return self + other; + }); + klass.def("__add__", [](const Scalar& self, scalar_t other) { + return self + Scalar(other); + }); + klass.def("__add__", [](const Scalar& self, long long other) { + return self + Scalar(other); + }); + klass.def("__radd__", [](const Scalar& self, scalar_t other) { + return Scalar(other) + self; + }); + klass.def("__radd__", [](const Scalar& self, long long other) { + return Scalar(other) + self; + }); + klass.def("__iadd__", [](Scalar& self, const Scalar& other) { + return self += other; + }); + klass.def("__iadd__", [](Scalar& self, scalar_t other) { + return self += Scalar(other); + }); + klass.def("__iadd__", [](Scalar& self, long long other) { + return self += Scalar(other); + }); + + klass.def("__sub__", [](const Scalar& self, scalar_t other) { + return self - Scalar(other); + }); + klass.def("__sub__", [](const Scalar& self, long long other) { + return self - Scalar(other); + }); + klass.def("__rsub__", [](const Scalar& self, scalar_t other) { + return Scalar(other) - self; + }); + klass.def("__rsub__", [](const Scalar& self, long long other) { + return Scalar(other) - self; + }); + klass.def("__isub__", [](Scalar& self, scalar_t other) { + return self -= Scalar(other); + }); + klass.def("__isub__", [](Scalar& self, long long other) { + return self -= Scalar(other); + }); - klass.def(py::self += py::self); - klass.def(py::self -= py::self); - klass.def(py::self *= py::self); - klass.def(py::self /= py::self); + klass.def("__mul__", [](const Scalar& self, const Scalar& other) { + return self * other; + }); + klass.def("__mul__", [](const Scalar& self, scalar_t other) { + return self * Scalar(other); + }); + klass.def("__mul__", [](const Scalar& self, long long other) { + return self * Scalar(other); + }); + klass.def("__rmul__", [](const Scalar& self, scalar_t other) { + return Scalar(other) * self; + }); + klass.def("__rmul__", [](const Scalar& self, long long other) { + return Scalar(other) * self; + }); + klass.def("__imul__", [](Scalar& self, const Scalar& other) { + return self *= other; + }); + klass.def("__imul__", [](Scalar& self, scalar_t other) { + return self *= Scalar(other); + }); + klass.def("__imul__", [](Scalar& self, long long other) { + return self *= Scalar(other); + }); - klass.def(py::self == py::self); - klass.def(py::self != py::self); + klass.def("__div__", [](const Scalar& self, scalar_t other) { + if (other == 0.0) { throw py::value_error("division by zero"); } + return self / Scalar(other); + }); + klass.def("__div__", [](const Scalar& self, long long other) { + if (other == 0) { throw py::value_error("division by zero"); } + return self / Scalar(other); + }); + klass.def("__rdiv__", [](const Scalar& self, scalar_t other) { + return Scalar(other) / self; + }); + klass.def("__rdiv__", [](const Scalar& self, long long other) { + return Scalar(other) / self; + }); + klass.def("__idiv__", [](Scalar& self, scalar_t other) { + if (other == 0.0) { throw py::value_error("division by zero"); } + return self /= Scalar(other); + }); + klass.def("__idiv__", [](Scalar& self, long long other) { + if (other == 0) { throw py::value_error("division by zero"); } + return self /= Scalar(other); + }); + klass.def("__eq__", [](const Scalar& lhs, const Scalar& rhs) { + return lhs == rhs; + }); klass.def("__eq__", [](const Scalar& self, scalar_t other) { return self == Scalar(other); }); klass.def("__eq__", [](scalar_t other, const Scalar& self) { return self == Scalar(other); }); + klass.def("__eq__", [](const Scalar& self, long long other) { + return self == Scalar(other); + }); + klass.def("__eq__", [](long long other, const Scalar& self) { + return self == Scalar(other); + }); + + klass.def("__ne__", [](const Scalar& lhs, const Scalar& rhs) { + return lhs != rhs; + }); + klass.def("__ne__", [](const Scalar& self, scalar_t other) { + return self != Scalar(other); + }); + klass.def("__ne__", [](scalar_t other, const Scalar& self) { + return self != Scalar(other); + }); + klass.def("__ne__", [](const Scalar& self, long long other) { + return self != Scalar(other); + }); + klass.def("__ne__", [](long long other, const Scalar& self) { + return self != Scalar(other); + }); klass.def("__str__", [](const Scalar& self) { std::stringstream ss; @@ -120,8 +245,20 @@ void python::init_scalars(pybind11::module_& m) }); klass.def("__repr__", [](const Scalar& self) { std::stringstream ss; - ss << "Scalar(type=" << string((*self.type())->name()) - << ", value approx " << self << ")"; + ss << "Scalar(type="; + auto tp = self.type(); + if (tp) { + ss << string((*tp)->name()); + } else { + auto info = self.type_info(); + if (info.code == devices::TypeCode::Int) { + ss << "int" << CHAR_BIT * info.bytes; + } else if (info.code == devices::TypeCode::UInt) { + ss << "uint" << CHAR_BIT * info.bytes; + } + } + ss << ", value approx " << self << ")"; + return ss.str(); }); @@ -261,7 +398,9 @@ static bool try_fill_buffer_dlpack( const auto tensor_stype = scalars::scalar_type_of(tensor_stype_info); if (options.type == nullptr) { - if (tensor_stype) { options.type = *tensor_stype; } else { + if (tensor_stype) { + options.type = *tensor_stype; + } else { options.type = scalars::ScalarType::for_info(tensor_stype_info); } } @@ -312,7 +451,16 @@ static void check_and_set_dtype(python::PyToBufferOptions& options, py::handle arg) { if (options.type == nullptr) { - if (options.no_check_imported) { + + if (py::isinstance(arg)) { + const auto& scal = arg.cast(); + auto arg_type = scal.type(); + if (arg_type) { + options.type = *arg_type; + } else { + options.type = scalars::ScalarType::for_info(scal.type_info()); + } + } else if (options.no_check_imported) { options.type = *scalars::ScalarType::of(); } else { options.type = python::py_type_to_scalar_type(py::type::of(arg)); @@ -653,7 +801,8 @@ scalars::KeyScalarArray python::py_to_buffer( } } } - } else if (object.is_none()) {} else { + } else if (object.is_none()) { + } else { RPY_THROW( std::invalid_argument, "could not parse argument to a valid scalar array type" diff --git a/roughpy/tensor_functions.py b/roughpy/tensor_functions.py new file mode 100644 index 00000000..0142e746 --- /dev/null +++ b/roughpy/tensor_functions.py @@ -0,0 +1,281 @@ +""" +This module provides additional tensor functions that are not yet implemented in +the RoughPy core library. + +The implementation of the functions in this module are far from optimal, but +should serve as both a temporary implementation and a demonstration of how one +can build on top of the RoughPy core library. +""" + + +import functools +from collections import defaultdict +from typing import Any, Union, TypeVar + +import roughpy as rp + + +def tensor_word_factory(basis): + """ + Create a factory function that constructs tensor words objects. + + Since the tensor words are specific to their basis, the basis is needed to + construct the words. This function creates a factory from the correct basis + to make tensor words that correspond. The arguments to the factory function + are just a sequence of letters in the same order as they will appear in the + tensor word. + + :param basis: RoughPy tensor basis object. + :return: function from a sequence of letters to tensor words + """ + width = basis.width + depth = basis.depth + + # noinspection PyArgumentList + def factory(*args): + return rp.TensorKey(*args, width=width, depth=depth) + + return factory + + +class TensorTensorProduct: + """ + External tensor product of two free tensors (or shuffle tensors). + + This is an intermediate container that is used to implement some of the + tensor functions such as Log. + """ + + data: dict[tuple[rp.TensorKey], Any] + ctx: rp.Context + + def __init__(self, data, ctx=None): + + if isinstance(data, (tuple, list)): + assert len(data) == 2 + + assert isinstance(data[0], (rp.FreeTensor, rp.ShuffleTensor)) + assert isinstance(data[1], (rp.FreeTensor, rp.ShuffleTensor)) + + if ctx is not None: + self.ctx = ctx + assert data[0].context == ctx + assert data[1].context == ctx + else: + self.ctx = data[0].context + assert self.ctx == data[2].context + + self.data = odata = {} + for lhs in data[0]: + for rhs in data[1]: + odata[(lhs.key(), rhs.key())] = lhs.value() * rhs.value() + + # self.data = {k: v for k, v in odata.items() if v != 0} + + elif isinstance(data, (dict, defaultdict)): + self.data = {k: v for k, v in data.items() if v != 0} + assert ctx is not None + self.ctx = ctx + + def __str__(self): + return " ".join([ + '{', + *(f"{v}{k}" for k, v in sorted(self.data.items(), + key=lambda t: tuple( + map(lambda r: tuple( + r.to_letters()), t[0])))), + '}' + ]) + + def __repr__(self): + return " ".join([ + '{', + *(f"{v}{k}" for k, v in sorted(self.data.items(), + key=lambda t: tuple( + map(lambda r: tuple( + r.to_letters()), t[0])))), + '}' + ]) + + def __mul__(self, scalar): + if isinstance(scalar, (int, float, rp.Scalar)): + if scalar == 0: + data = {} + else: + data = {k: v * scalar for k, v in self.data.items()} + return TensorTensorProduct(data, self.ctx) + + return NotImplemented + + def __add__(self, other): + if isinstance(other, TensorTensorProduct): + assert self.ctx == other.ctx + new_data = defaultdict(lambda: 0) + + # Do this to force a deep copy of the values from self + for k, v in self.data.items(): + new_data[k] += v + + for k, v in other.data.items(): + new_data[k] += v + + return TensorTensorProduct( + {k: v for k, v in new_data.items() if v != 0}, self.ctx) + + return NotImplemented + + def add_scal_prod(self, other, scalar): + my_data = self.data + for k, v in other.data.items(): + + val = v * scalar + if k in self.data: + my_data[k] += val + else: + my_data[k] = val + + # print("asp", other, scalar, self) + self.data = {k: v for k, v in self.data.items() if v != 0} + return self + + def add_scal_div(self, other, scalar): + my_data = self.data + for k, v in other.data.items(): + if k in self.data: + my_data[k] += v / scalar + else: + my_data[k] = v / scalar + + return self + + def sub_scal_div(self, other, scalar): + my_data = self.data + for k, v in other.data.items(): + if k in self.data: + my_data[k] -= v / scalar + else: + my_data[k] = -v / scalar + + return self + + +def _concat_product(a, b): + out = defaultdict(lambda: 0) + + for k1, v1 in a.data.items(): + for k2, v2 in b.data.items(): + out[tuple(i * j for i, j in zip(k1, k2))] += v1 * v2 + + return TensorTensorProduct({k: v for k, v in out.items() if v != 0}, a.ctx) + + +# noinspection PyUnresolvedReferences +def _adjoint_of_word(word: rp.TensorKey, ctx: rp.Context) \ + -> TensorTensorProduct: + word_factory = tensor_word_factory(word.basis()) + letters = word.to_letters() + if not letters: + return TensorTensorProduct({(word_factory(),) * 2: 1}, ctx) + + letters_adj = [ + TensorTensorProduct({(word_factory(letter), word_factory()): 1, + (word_factory(), word_factory(letter)): 1}, ctx) + for letter in word.to_letters()] + return functools.reduce(_concat_product, letters_adj) + + +def _adjoint_of_shuffle( + tensor: Union[rp.FreeTensor, rp.ShuffleTensor]) -> TensorTensorProduct: + # noinspection PyUnresolvedReferences + ctx = tensor.context + out = TensorTensorProduct(defaultdict(lambda: 0), ctx) + + for item in tensor: + out.add_scal_prod(_adjoint_of_word(item.key(), ctx), item.value()) + + return out + + +def _concatenate(a: TensorTensorProduct, otype=rp.FreeTensor): + """ + Perform an elementwise reduction induced on A \\otimes B by the + concatenation of words. + + :param a: External tensor product of tensors + :return: tensor obtained by reducing all pairs of words + """ + + data = defaultdict(lambda: 0) + for (l, r), v in a.data.items(): + data[l * r] += v + + # noinspection PyArgumentList + result = otype(data, ctx=a.ctx) + return result + + +def _tensor_product_functions(f, g): + # noinspection PyArgumentList + def function_product(x: TensorTensorProduct) -> TensorTensorProduct: + ctx = x.ctx + + result = TensorTensorProduct({}, ctx) + for (k1, k2), v in x.data.items(): + tk1 = f(rp.FreeTensor((k1, 1), ctx=ctx)) + tk2 = g(rp.FreeTensor((k2, 1), ctx=ctx)) + result.add_scal_prod(TensorTensorProduct((tk1, tk2), ctx), v) + + return result + + return function_product + + +def _convolve(f, g): + func = _tensor_product_functions(f, g) + + def convolved(x): + return _concatenate(func(_adjoint_of_shuffle(x)), otype=type(x)) + + return convolved + + +def _remove_constant(x): + ctx = x.context + # noinspection PyArgumentList + empty_word = rp.TensorKey(width=ctx.width, depth=ctx.depth) + remover = type(x)((empty_word, x[empty_word]), ctx=ctx) + return x - remover + + +Tensor = TypeVar('Tensor') + + +# noinspection PyPep8Naming +def Log(x: Tensor) -> Tensor: + """ + Linear function on tensors that agrees with log on the group-like elements. + + This function is the linear extension of the log function defined on the + group-like elements of the free tensor algebra (or the corresponding subset + of the shuffle tensor algebra) to the whole algebra. This implementation is + far from optimal. + + :param x: Tensor (either a shuffle tensor or free tensor) + :return: Log(x) with the same type as the input. + """ + ctx = x.context + fn = _remove_constant + + out = fn(x) + sign = False + for i in range(2, ctx.depth + 1): + sign = not sign + fn = _convolve(_remove_constant, fn) + + tmp = fn(x) / i + if sign: + out -= tmp + else: + out += tmp + + return out diff --git a/scalars/CMakeLists.txt b/scalars/CMakeLists.txt index 1da4f339..fe2d3567 100644 --- a/scalars/CMakeLists.txt +++ b/scalars/CMakeLists.txt @@ -40,6 +40,9 @@ add_roughpy_component(Scalars src/scalar/raw_bytes.cpp src/scalar/raw_bytes.h src/scalar/serialization.cpp + src/scalar/serialization.h + src/scalar/type_promotion.cpp + src/scalar/type_promotion.h src/scalar_helpers/standard_scalar_type.h src/types/aprational/ap_rational_type.cpp src/types/aprational/ap_rational_type.h diff --git a/scalars/include/roughpy/scalars/scalar.h b/scalars/include/roughpy/scalars/scalar.h index 41b6e06c..51115aeb 100644 --- a/scalars/include/roughpy/scalars/scalar.h +++ b/scalars/include/roughpy/scalars/scalar.h @@ -136,13 +136,16 @@ content_type_of(PackedScalarTypePointer ptype) noexcept return content_type_of(ptype.get_type_info()); } - template -struct can_be_scalar : conditional_t::value, std::true_type, std::false_type> {}; +struct can_be_scalar : conditional_t< + !is_base_of::value, + std::true_type, + std::false_type> { +}; template -struct can_be_scalar> : std::false_type {}; - +struct can_be_scalar> : std::false_type { +}; }// namespace dtl @@ -206,20 +209,18 @@ class ROUGHPY_SCALARS_EXPORT Scalar template ::value>> explicit Scalar(const T& value) : p_type_and_content_type( - devices::type_info>(), - dtl::ScalarContentType::TrivialBytes - ), + devices::type_info>(), + dtl::ScalarContentType::TrivialBytes + ), integer_for_convenience(0) { - if constexpr (is_standard_layout::value - && is_trivially_copyable::value - && is_trivially_destructible::value - && sizeof(T) <= sizeof(void*)) { + if constexpr (is_standard_layout::value && is_trivially_copyable::value && is_trivially_destructible::value && sizeof(T) <= sizeof(void*)) { std::memcpy(trivial_bytes, &value, sizeof(T)); } else { allocate_data(); construct_inplace(static_cast(opaque_pointer), value); } + RPY_DBG_ASSERT(!p_type_and_content_type.is_null()); } template < @@ -277,12 +278,14 @@ class ROUGHPY_SCALARS_EXPORT Scalar explicit Scalar(devices::TypeInfo info, void* ptr); explicit Scalar(devices::TypeInfo info, const void* ptr); - template ::value>> + template < + typename I, + typename = enable_if_t::value>> explicit Scalar(std::unique_ptr&& iface) : p_type_and_content_type( - nullptr, - dtl::ScalarContentType::OwnedInterface - ), + nullptr, + dtl::ScalarContentType::OwnedInterface + ), interface(std::move(iface)) {} @@ -318,7 +321,7 @@ class ROUGHPY_SCALARS_EXPORT Scalar enable_if_t::value, Scalar&> operator=(const T& value ) { - if (fast_is_zero()) { + if (p_type_and_content_type.is_null()) { construct_inplace(this, value); } else { switch (p_type_and_content_type.get_enumeration()) { @@ -417,6 +420,8 @@ class ROUGHPY_SCALARS_EXPORT Scalar */ void* mut_pointer(); + bool is_mutable() const noexcept; + /** * @brief Get a pointer to the scalar type representing this value. * @@ -430,6 +435,19 @@ class ROUGHPY_SCALARS_EXPORT Scalar */ devices::TypeInfo type_info() const noexcept; + type_pointer packed_type_info() const noexcept + { + return p_type_and_content_type; + } + + void change_type(devices::TypeInfo new_type_info) + { + Scalar tmp(new_type_info); + dtl::scalar_convert_copy(tmp.mut_pointer(), new_type_info, *this); + this->~Scalar(); + construct_inplace(this, std::move(tmp)); + } + friend std::ostream& operator<<(std::ostream& os, const Scalar& value); /** @@ -504,24 +522,28 @@ const T& Scalar::as_type() const noexcept inline Scalar operator+(const Scalar& lhs, const Scalar& rhs) { Scalar result(lhs); + RPY_DBG_ASSERT(result.is_mutable()); result += rhs; return result; } inline Scalar operator-(const Scalar& lhs, const Scalar& rhs) { Scalar result(lhs); + RPY_DBG_ASSERT(result.is_mutable()); result -= rhs; return result; } inline Scalar operator*(const Scalar& lhs, const Scalar& rhs) { Scalar result(lhs); + RPY_DBG_ASSERT(result.is_mutable()); result *= rhs; return result; } inline Scalar operator/(const Scalar& lhs, const Scalar& rhs) { Scalar result(lhs); + RPY_DBG_ASSERT(result.is_mutable()); result /= rhs; return result; } diff --git a/scalars/include/roughpy/scalars/scalars_fwd.h b/scalars/include/roughpy/scalars/scalars_fwd.h index 665bd2f7..d65479b5 100644 --- a/scalars/include/roughpy/scalars/scalars_fwd.h +++ b/scalars/include/roughpy/scalars/scalars_fwd.h @@ -31,6 +31,7 @@ #include #include +#include #include "roughpy_scalars_export.h" diff --git a/scalars/include/roughpy/scalars/traits.h b/scalars/include/roughpy/scalars/traits.h index 104f0c26..e7ffd35f 100644 --- a/scalars/include/roughpy/scalars/traits.h +++ b/scalars/include/roughpy/scalars/traits.h @@ -83,6 +83,7 @@ constexpr bool is_signed(const devices::TypeInfo& info) noexcept case devices::TypeCode::ArbitraryPrecisionFloat: case devices::TypeCode::ArbitraryPrecisionRational: return true; } + RPY_UNREACHABLE_RETURN(false); } constexpr bool is_unsigned(const devices::TypeInfo& info) noexcept diff --git a/scalars/src/scalar.cpp b/scalars/src/scalar.cpp index 284ae903..8bdda3c3 100644 --- a/scalars/src/scalar.cpp +++ b/scalars/src/scalar.cpp @@ -2,8 +2,8 @@ // Copyright (c) 2023 the RoughPy Developers. All rights reserved. // -// Redistribution and use in source and binary forms, with or without modification, -// are permitted provided that the following conditions are met: +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: // // 1. Redistributions of source code must retain the above copyright notice, // this list of conditions and the following disclaimer. @@ -20,43 +20,45 @@ // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. #include #include #include #include -#include #include "scalar/arithmetic.h" #include "scalar/casts.h" #include "scalar/comparison.h" #include "scalar/print.h" #include "scalar/raw_bytes.h" +#include #include - using namespace rpy; using namespace rpy::scalars; -void Scalar::allocate_data() { +void Scalar::allocate_data() +{ if (!p_type_and_content_type.is_pointer()) { auto tp_o = scalar_type_of(p_type_and_content_type.get_type_info()); if (!tp_o) { RPY_THROW(std::runtime_error, "unable to allocate scalar"); } - } RPY_DBG_ASSERT(p_type_and_content_type.is_pointer()); opaque_pointer = p_type_and_content_type->allocate_single(); - p_type_and_content_type.update_enumeration(dtl::ScalarContentType::OwnedPointer); + p_type_and_content_type.update_enumeration( + dtl::ScalarContentType::OwnedPointer + ); } Scalar::Scalar(Scalar::type_pointer type) : integer_for_convenience(0) @@ -71,7 +73,7 @@ Scalar::Scalar(Scalar::type_pointer type) : integer_for_convenience(0) mode = dtl::content_type_of(info); if (mode == dtl::ScalarContentType::OwnedPointer) { p_type_and_content_type - = type_pointer(*scalar_type_of(info), mode); + = type_pointer(*scalar_type_of(info), mode); } else { p_type_and_content_type = type_pointer(info, mode); } @@ -92,6 +94,7 @@ Scalar::Scalar(const ScalarType* type) opaque_pointer = type->allocate_single(); } } + Scalar::Scalar(devices::TypeInfo info) : p_type_and_content_type(info, dtl::content_type_of(info)), integer_for_convenience(0) @@ -102,33 +105,35 @@ Scalar::Scalar(devices::TypeInfo info) opaque_pointer = p_type_and_content_type->allocate_single(); } } + Scalar::Scalar() : p_type_and_content_type(), integer_for_convenience(0) {} Scalar::Scalar(const ScalarType* type, void* ptr) : p_type_and_content_type(type, dtl::ScalarContentType::OpaquePointer), - opaque_pointer(ptr) -{} + opaque_pointer(ptr) {} + Scalar::Scalar(const ScalarType* type, const void* ptr) : p_type_and_content_type(type, dtl::ScalarContentType::ConstOpaquePointer), - opaque_pointer(const_cast(ptr)) -{} + opaque_pointer(const_cast(ptr)) {} Scalar::Scalar(devices::TypeInfo info, void* ptr) : p_type_and_content_type(info, dtl::ScalarContentType::OpaquePointer), - opaque_pointer(ptr) -{} + opaque_pointer(ptr) {} + Scalar::Scalar(devices::TypeInfo info, const void* ptr) : p_type_and_content_type(info, dtl::ScalarContentType::ConstOpaquePointer), - opaque_pointer(const_cast(ptr)) -{} + opaque_pointer(const_cast(ptr)) {} + Scalar::Scalar(const ScalarType* type, int64_t num, int64_t denom) { auto info = type->type_info(); if (traits::is_arithmetic(info) && info.bytes <= sizeof(void*)) { - p_type_and_content_type = type_pointer(type, dtl::ScalarContentType::TrivialBytes); + p_type_and_content_type + = type_pointer(type, dtl::ScalarContentType::TrivialBytes); dtl::scalar_assign_rational(trivial_bytes, info, num, denom); } else { - p_type_and_content_type = type_pointer(type, dtl::ScalarContentType::OwnedPointer); + p_type_and_content_type + = type_pointer(type, dtl::ScalarContentType::OwnedPointer); opaque_pointer = type->allocate_single(); dtl::scalar_assign_rational(opaque_pointer, info, num, denom); } @@ -136,43 +141,69 @@ Scalar::Scalar(const ScalarType* type, int64_t num, int64_t denom) void Scalar::copy_from_opaque_pointer(devices::TypeInfo info, const void* src) { - if (traits::is_fundamental(info) && info.bytes <= sizeof(void*)) { - std::memcpy(trivial_bytes, src, info.bytes); - p_type_and_content_type.update_enumeration(dtl::ScalarContentType::TrivialBytes); + if (p_type_and_content_type.is_null()) { + p_type_and_content_type = + type_pointer(info, dtl::ScalarContentType::TrivialBytes); + } + + const auto dst_info = type_info(); + if (traits::is_fundamental(dst_info) && dst_info.bytes <= sizeof(void*)) { + if (info == dst_info) { + std::memcpy(trivial_bytes, src, info.bytes); + } else { + auto successful = + dtl::scalar_convert_copy(trivial_bytes, dst_info, src, info); + RPY_DBG_ASSERT(successful); + } + p_type_and_content_type.update_enumeration( + dtl::ScalarContentType::TrivialBytes + ); } else { allocate_data(); - dtl::scalar_convert_copy(opaque_pointer, info, src, info); + auto successful + = dtl::scalar_convert_copy(opaque_pointer, dst_info, src, info); + RPY_DBG_ASSERT(successful); } } - -Scalar::Scalar(const Scalar& other) +Scalar::Scalar(const Scalar& other) : integer_for_convenience(0) { + p_type_and_content_type = other.p_type_and_content_type; if (!other.fast_is_zero()) { - p_type_and_content_type = other.p_type_and_content_type; const auto info = type_info(); - switch (p_type_and_content_type.get_enumeration()) { - case dtl::ScalarContentType::TrivialBytes: + switch (p_type_and_content_type.get_enumeration()) { + case dtl::ScalarContentType::TrivialBytes: case dtl::ScalarContentType::ConstTrivialBytes: - std::memcpy(trivial_bytes, other.trivial_bytes, sizeof(void*)); - break; - case dtl::ScalarContentType::OpaquePointer: - case dtl::ScalarContentType::ConstOpaquePointer: - case dtl::ScalarContentType::Interface: - case dtl::ScalarContentType::OwnedInterface: - copy_from_opaque_pointer(info, other.pointer()); - break; - case dtl::ScalarContentType::OwnedPointer: - // This should only happen if the data type is too large to fit in - // inline storage or has a non-trivial constructor/destructor. - allocate_data(); - auto successful = dtl::scalar_convert_copy(opaque_pointer, info, other.opaque_pointer, info, 1); - RPY_DBG_ASSERT(successful); - break; + std::memcpy(trivial_bytes, + other.trivial_bytes, + sizeof(void*)); + break; + case dtl::ScalarContentType::OpaquePointer: + case dtl::ScalarContentType::ConstOpaquePointer: + copy_from_opaque_pointer(info, other.opaque_pointer); + break; + case dtl::ScalarContentType::Interface: + case dtl::ScalarContentType::OwnedInterface: + copy_from_opaque_pointer(info, other.interface->pointer()); + break; + case dtl::ScalarContentType::OwnedPointer: + // This should only happen if the data type is too large to fit + // in inline storage or has a non-trivial + // constructor/destructor. + allocate_data(); + auto successful = dtl::scalar_convert_copy( + opaque_pointer, + info, + other.opaque_pointer, + info, + 1 + ); + RPY_DBG_ASSERT(successful); + break; } } - } + Scalar::Scalar(Scalar&& other) noexcept : p_type_and_content_type(other.p_type_and_content_type), integer_for_convenience(0) @@ -208,11 +239,13 @@ Scalar::~Scalar() } break; case dtl::ScalarContentType::Interface: - case dtl::ScalarContentType::OwnedInterface: interface.~unique_ptr(); + case dtl::ScalarContentType::OwnedInterface: + interface.~unique_ptr(); case dtl::ScalarContentType::TrivialBytes: case dtl::ScalarContentType::OpaquePointer: case dtl::ScalarContentType::ConstTrivialBytes: - case dtl::ScalarContentType::ConstOpaquePointer: break; + case dtl::ScalarContentType::ConstOpaquePointer: + break; } }; @@ -222,36 +255,47 @@ Scalar& Scalar::operator=(const Scalar& other) if (this->fast_is_zero()) { p_type_and_content_type = other.p_type_and_content_type; + const auto info = type_info(); switch (p_type_and_content_type.get_enumeration()) { case dtl::ScalarContentType::TrivialBytes: case dtl::ScalarContentType::ConstTrivialBytes: std::memcpy( - trivial_bytes, - other.trivial_bytes, - sizeof(interface_pointer_t) + trivial_bytes, + other.trivial_bytes, + sizeof(interface_pointer_t) ); break; case dtl::ScalarContentType::OpaquePointer: case dtl::ScalarContentType::ConstOpaquePointer: - opaque_pointer = other.opaque_pointer; + copy_from_opaque_pointer(info, other.opaque_pointer); break; - case dtl::ScalarContentType::OwnedPointer: - // TODO: implement copy for owned pointers + case dtl::ScalarContentType::OwnedPointer: { + allocate_data(); + auto successful = dtl::scalar_convert_copy( + opaque_pointer, + p_type_and_content_type->type_info(), + other.opaque_pointer, + p_type_and_content_type->type_info(), + 1 + ); + RPY_DBG_ASSERT(successful); + } break; case dtl::ScalarContentType::Interface: case dtl::ScalarContentType::OwnedInterface: + copy_from_opaque_pointer(info, other.pointer()); // Copying of interface pointers is disallowed. - RPY_THROW( - std::runtime_error, - "copying of interface pointers " - "is not allowed" - ); +// RPY_THROW( +// std::runtime_error, +// "copying of interface pointers " +// "is not allowed" +// ); } } else { dtl::scalar_convert_copy( - mut_pointer(), - p_type_and_content_type.get_type_info(), - other + mut_pointer(), + type_info(), + other ); } } @@ -261,44 +305,16 @@ Scalar& Scalar::operator=(const Scalar& other) Scalar& Scalar::operator=(Scalar&& other) noexcept { if (&other != this) { - this->~Scalar(); - p_type_and_content_type = other.p_type_and_content_type; - switch (p_type_and_content_type.get_enumeration()) { - case dtl::ScalarContentType::TrivialBytes: - case dtl::ScalarContentType::ConstTrivialBytes: - std::memcpy( - trivial_bytes, - other.trivial_bytes, - sizeof(interface_pointer_t) - ); - break; - case dtl::ScalarContentType::OpaquePointer: - case dtl::ScalarContentType::ConstOpaquePointer: - case dtl::ScalarContentType::OwnedPointer: - opaque_pointer = other.opaque_pointer; - other.opaque_pointer = nullptr; - break; - case dtl::ScalarContentType::Interface: - case dtl::ScalarContentType::OwnedInterface: - interface = std::move(other.interface); - other.interface = nullptr; - break; + if (p_type_and_content_type.is_null()) { + construct_inplace(this, std::move(other)); + } else { + dtl::scalar_convert_copy(mut_pointer(), type_info(), other); } + } return *this; } - -#define DO_OP_FOR_ALL_TYPES(INFO) \ - switch(INFO.code) {\ - casse devices::TypeCode::Int: \ - switch (INFO.bytes) { \ - case \ - - - - - bool Scalar::is_zero() const noexcept { if (fast_is_zero()) { return true; } @@ -314,8 +330,8 @@ bool Scalar::is_zero() const noexcept case dtl::ScalarContentType::Interface: case dtl::ScalarContentType::OwnedInterface: return is_pointer_zero( - interface->pointer(), - p_type_and_content_type + interface->pointer(), + p_type_and_content_type ); } @@ -328,11 +344,13 @@ bool Scalar::is_reference() const noexcept switch (p_type_and_content_type.get_enumeration()) { case dtl::ScalarContentType::TrivialBytes: case dtl::ScalarContentType::ConstTrivialBytes: - case dtl::ScalarContentType::OwnedPointer: return false; + case dtl::ScalarContentType::OwnedPointer: + return false; case dtl::ScalarContentType::OpaquePointer: case dtl::ScalarContentType::ConstOpaquePointer: case dtl::ScalarContentType::Interface: - case dtl::ScalarContentType::OwnedInterface: return true; + case dtl::ScalarContentType::OwnedInterface: + return true; } RPY_UNREACHABLE_RETURN(false); } @@ -341,13 +359,19 @@ bool Scalar::is_const() const noexcept { if (fast_is_zero()) { return true; } switch (p_type_and_content_type.get_enumeration()) { - case dtl::ScalarContentType::TrivialBytes: return false; - case dtl::ScalarContentType::ConstTrivialBytes: return true; - case dtl::ScalarContentType::OpaquePointer: return false; - case dtl::ScalarContentType::ConstOpaquePointer: return true; + case dtl::ScalarContentType::TrivialBytes: + return false; + case dtl::ScalarContentType::ConstTrivialBytes: + return true; + case dtl::ScalarContentType::OpaquePointer: + return false; + case dtl::ScalarContentType::ConstOpaquePointer: + return true; case dtl::ScalarContentType::Interface: - case dtl::ScalarContentType::OwnedInterface: return false; - case dtl::ScalarContentType::OwnedPointer: return false; + case dtl::ScalarContentType::OwnedInterface: + return false; + case dtl::ScalarContentType::OwnedPointer: + return false; } RPY_UNREACHABLE_RETURN(false); } @@ -356,45 +380,64 @@ const void* Scalar::pointer() const noexcept { switch (p_type_and_content_type.get_enumeration()) { case dtl::ScalarContentType::TrivialBytes: - case dtl::ScalarContentType::ConstTrivialBytes: return &trivial_bytes; + case dtl::ScalarContentType::ConstTrivialBytes: + return trivial_bytes; case dtl::ScalarContentType::OpaquePointer: case dtl::ScalarContentType::ConstOpaquePointer: - case dtl::ScalarContentType::OwnedPointer: return opaque_pointer; + case dtl::ScalarContentType::OwnedPointer: + return opaque_pointer; case dtl::ScalarContentType::Interface: case dtl::ScalarContentType::OwnedInterface: return interface->pointer(); } RPY_UNREACHABLE_RETURN(nullptr); } + void* Scalar::mut_pointer() { - if (fast_is_zero()) { + if (p_type_and_content_type.is_null()) { RPY_THROW( - std::runtime_error, - "cannot get mutable pointer to constant" - " value zero" + std::runtime_error, + "cannot get mutable pointer to constant value zero" ); } switch (p_type_and_content_type.get_enumeration()) { - case dtl::ScalarContentType::TrivialBytes: return &trivial_bytes; + case dtl::ScalarContentType::TrivialBytes: + return trivial_bytes; case dtl::ScalarContentType::OpaquePointer: - case dtl::ScalarContentType::OwnedPointer: return opaque_pointer; + case dtl::ScalarContentType::OwnedPointer: + if (opaque_pointer == nullptr) { + RPY_THROW( + std::runtime_error, + "cannot get mutable pointer to constant value zero" + ); + } + return opaque_pointer; case dtl::ScalarContentType::ConstTrivialBytes: case dtl::ScalarContentType::ConstOpaquePointer: RPY_THROW( - std::runtime_error, - "cannot get mutable pointer to constant value" + std::runtime_error, + "cannot get mutable pointer to constant value" ); case dtl::ScalarContentType::Interface: case dtl::ScalarContentType::OwnedInterface: RPY_THROW( - std::runtime_error, - "cannot get mutable pointer to special scalar references" + std::runtime_error, + "cannot get mutable pointer to special scalar references" ); } RPY_UNREACHABLE_RETURN(nullptr); } + +bool Scalar::is_mutable() const noexcept +{ + return !static_cast( + static_cast(p_type_and_content_type.get_enumeration()) + & static_cast(dtl::ScalarContentType::IsConst) + ); +} + optional Scalar::type() const noexcept { if (p_type_and_content_type.is_pointer()) { @@ -404,6 +447,7 @@ optional Scalar::type() const noexcept auto info = p_type_and_content_type.get_type_info(); return scalar_type_of(info); } + devices::TypeInfo Scalar::type_info() const noexcept { if (p_type_and_content_type.is_pointer()) { @@ -413,18 +457,16 @@ devices::TypeInfo Scalar::type_info() const noexcept } static void print_trivial_bytes( - std::ostream& os, - const byte* bytes, - PackedScalarTypePointer p_type -) -{} + std::ostream& os, + const byte* bytes, + PackedScalarTypePointer p_type +) {} static void print_opaque_pointer( - std::ostream& os, - const void* opaque, - PackedScalarTypePointer p_type -) -{} + std::ostream& os, + const void* opaque, + PackedScalarTypePointer p_type +) {} std::ostream& rpy::scalars::operator<<(std::ostream& os, const Scalar& value) { @@ -435,7 +477,6 @@ std::ostream& rpy::scalars::operator<<(std::ostream& os, const Scalar& value) dtl::print_scalar_val(os, value.pointer(), value.type_info()); - return os; } @@ -443,15 +484,18 @@ std::vector Scalar::to_raw_bytes() const { return dtl::to_raw_bytes(pointer(), 1, type_info()); } -void Scalar::from_raw_bytes(devices::TypeInfo info, Slice bytes) + +void Scalar::from_raw_bytes(devices::TypeInfo info, Slice bytes) { // Make sure it is clear this->~Scalar(); auto tp_o = scalar_type_of(info); if (tp_o) { - p_type_and_content_type = type_pointer(*tp_o, dtl::ScalarContentType::TrivialBytes); + p_type_and_content_type + = type_pointer(*tp_o, dtl::ScalarContentType::TrivialBytes); } else { - p_type_and_content_type = type_pointer(info, dtl::ScalarContentType::TrivialBytes); + p_type_and_content_type + = type_pointer(info, dtl::ScalarContentType::TrivialBytes); } void* ptr = nullptr; diff --git a/scalars/src/scalar/arithmetic.cpp b/scalars/src/scalar/arithmetic.cpp index e8c6e571..8f456127 100644 --- a/scalars/src/scalar/arithmetic.cpp +++ b/scalars/src/scalar/arithmetic.cpp @@ -33,13 +33,14 @@ #include "arithmetic.h" #include "casts.h" #include "do_macro.h" +#include "type_promotion.h" +#include "traits.h" #include #include #include -#include #include -#include "traits.h" +#include using namespace rpy; using namespace rpy::scalars; @@ -47,7 +48,8 @@ using namespace rpy::scalars; using rpy::scalars::dtl::ScalarContentType; static inline void -uminus_impl(void* dst, const void* ptr, const devices::TypeInfo& info) {} +uminus_impl(void* dst, const void* ptr, const devices::TypeInfo& info) +{} Scalar Scalar::operator-() const { @@ -57,30 +59,31 @@ Scalar Scalar::operator-() const if (!fast_is_zero()) { switch (p_type_and_content_type.get_enumeration()) { case ScalarContentType::TrivialBytes: - case ScalarContentType::ConstTrivialBytes: result. - p_type_and_content_type = type_pointer( - info, - dtl::ScalarContentType::TrivialBytes - ); + case ScalarContentType::ConstTrivialBytes: + result.p_type_and_content_type = type_pointer( + info, + dtl::ScalarContentType::TrivialBytes + ); uminus_impl(result.trivial_bytes, trivial_bytes, info); break; case ScalarContentType::OpaquePointer: case ScalarContentType::ConstOpaquePointer: - case ScalarContentType::OwnedPointer: stype = type(); + case ScalarContentType::OwnedPointer: + stype = type(); RPY_CHECK(stype); result.p_type_and_content_type = type_pointer( - *stype, - dtl::ScalarContentType::OwnedPointer + *stype, + dtl::ScalarContentType::OwnedPointer ); result.allocate_data(); uminus_impl(result.opaque_pointer, opaque_pointer, info); break; case ScalarContentType::Interface: - case ScalarContentType::OwnedInterface: result. - p_type_and_content_type = type_pointer( - *stype, - dtl::ScalarContentType::OwnedPointer - ); + case ScalarContentType::OwnedInterface: + result.p_type_and_content_type = type_pointer( + *stype, + dtl::ScalarContentType::OwnedPointer + ); result.allocate_data(); uminus_impl(result.opaque_pointer, interface->pointer(), info); break; @@ -91,37 +94,45 @@ Scalar Scalar::operator-() const namespace { -struct RPY_LOCAL AddInplace -{ +struct RPY_LOCAL AddInplace { template - constexpr void operator()(T& lhs, const T& rhs) noexcept { lhs += rhs; } + constexpr void operator()(T& lhs, const T& rhs) noexcept + { + lhs += rhs; + } }; -struct RPY_LOCAL SubInplace -{ +struct RPY_LOCAL SubInplace { template - constexpr void operator()(T& lhs, const T& rhs) noexcept { lhs -= rhs; } + constexpr void operator()(T& lhs, const T& rhs) noexcept + { + lhs -= rhs; + } }; -struct RPY_LOCAL MulInplace -{ +struct RPY_LOCAL MulInplace { template - constexpr void operator()(T& lhs, const T& rhs) noexcept { lhs *= rhs; } + constexpr void operator()(T& lhs, const T& rhs) noexcept + { + lhs *= rhs; + } }; -struct RPY_LOCAL DivInplace -{ +struct RPY_LOCAL DivInplace { template constexpr enable_if_t< - !is_same::value && - is_same< - decltype(std::declval() /= std::declval()), T&>::value> - operator()(T& lhs, const R& rhs) noexcept { lhs /= rhs; } + !is_same::value + && is_same< + decltype(std::declval() /= std::declval()), + T&>::value> + operator()(T& lhs, const R& rhs) noexcept + { + lhs /= rhs; + } void operator()(...) { RPY_THROW(std::domain_error, "invalid division"); } }; - template static inline void do_op(void* dst, const void* src, devices::TypeInfo type, Op&& op) @@ -134,11 +145,11 @@ do_op(void* dst, const void* src, devices::TypeInfo type, Op&& op) template inline void scalar_inplace_arithmetic( - void* dst, - PackedScalarTypePointer dst_type, - const void* src, - PackedScalarTypePointer src_type, - Op&& op + void* dst, + PackedScalarTypePointer dst_type, + const void* src, + PackedScalarTypePointer src_type, + Op&& op ) { auto src_info = type_info_from(src_type); @@ -147,54 +158,65 @@ inline void scalar_inplace_arithmetic( if (src_info == dst_info) { do_op(dst, src, dst_info, std::forward(op)); } else { - Scalar tmp(dst_type); + Scalar tmp(dst_info); void* tmp_ptr = tmp.mut_pointer(); scalars::dtl::scalar_convert_copy(tmp_ptr, dst_info, src, src_info, 1); do_op(dst, tmp_ptr, dst_info, std::forward(op)); } } +template +inline void scalar_inplace_arithmetic(Scalar& dst, const Scalar& src, Op&& op) +{ + const auto dst_info = dst.type_info(); + const auto src_info = src.type_info(); + + if (src_info == dst_info) { + do_op(dst.mut_pointer(), src.pointer(), dst_info, std::forward(op)); + } else { + auto out_type = scalars::dtl::compute_dest_type( + dst.packed_type_info(), + src.packed_type_info() + ); + if (out_type != dst_info) { dst.change_type(out_type); } + scalar_inplace_arithmetic( + dst.mut_pointer(), + dst.packed_type_info(), + src.pointer(), + src.packed_type_info(), + std::forward(op) + ); + } +} + }// namespace Scalar& Scalar::operator+=(const Scalar& other) { + RPY_DBG_ASSERT(!p_type_and_content_type.is_null()); if (!other.fast_is_zero()) { - scalar_inplace_arithmetic( - mut_pointer(), - p_type_and_content_type, - other.pointer(), - other.p_type_and_content_type, - AddInplace() - ); + scalar_inplace_arithmetic(*this, other, AddInplace()); } return *this; } Scalar& Scalar::operator-=(const Scalar& other) { + RPY_DBG_ASSERT(!p_type_and_content_type.is_null()); if (!other.fast_is_zero()) { - scalar_inplace_arithmetic( - mut_pointer(), - p_type_and_content_type, - other.pointer(), - other.p_type_and_content_type, - SubInplace() - ); + scalar_inplace_arithmetic(*this, other, SubInplace()); } return *this; } Scalar& Scalar::operator*=(const Scalar& other) { - if (!other.fast_is_zero()) { - scalar_inplace_arithmetic( - mut_pointer(), - p_type_and_content_type, - other.pointer(), - other.p_type_and_content_type, - MulInplace() - ); - } else { operator=(0); } + RPY_DBG_ASSERT(!p_type_and_content_type.is_null()); + if (!fast_is_zero() && !other.fast_is_zero()) { + scalar_inplace_arithmetic(*this, other, MulInplace()); + } else { + operator=(0); + } return *this; } @@ -209,14 +231,16 @@ Scalar& Scalar::operator*=(const Scalar& other) * to eliminate these as problems. */ template -static inline -enable_if_t() /= std::declval()), T&>::value> +static inline enable_if_t< + is_same() /= std::declval()), + T&>::value> do_divide_impl(T* dst, const R* divisor) { *dst /= *divisor; } -static inline void do_divide_impl(rational_poly_scalar* dst, const rational_scalar_type* divisor) +static inline void +do_divide_impl(rational_poly_scalar* dst, const rational_scalar_type* divisor) { *dst /= *divisor; } @@ -232,18 +256,14 @@ static inline void do_divide_impl(...) RPY_THROW(std::domain_error, "invalid division"); } - template -static inline void do_divide(T* dst, - const void* divisor, - const devices::TypeInfo& info) -{ +static inline void +do_divide(T* dst, const void* divisor, const devices::TypeInfo& info){ #define X(TP) return do_divide_impl(dst, (const TP*) divisor) - DO_FOR_EACH_X(info) + DO_FOR_EACH_X(info) #undef X } - Scalar& Scalar::operator/=(const Scalar& other) { if (other.fast_is_zero()) { @@ -258,41 +278,32 @@ Scalar& Scalar::operator/=(const Scalar& other) * type - as is the case for polynomials. * We need to handle both of these cases gracefully. */ - - type_pointer tmp_type; - if (other.p_type_and_content_type.is_pointer()) { - tmp_type = type_pointer(other.p_type_and_content_type->rational_type(), - other.p_type_and_content_type. - get_enumeration()); - } else { - tmp_type = type_pointer( - traits::rational_type_of( - other.p_type_and_content_type.get_type_info()), - other.p_type_and_content_type.get_enumeration()); + const auto num_info = type_info_from(p_type_and_content_type); + const auto true_denom_info = type_info_from(other.p_type_and_content_type); + const auto rat_denom_info = traits::rational_type_of(true_denom_info); + + // Set up the actual denominator that we will divide by + Scalar rational_denom(true_denom_info, other.pointer()); + if (rat_denom_info != true_denom_info) { + rational_denom.change_type(rat_denom_info); } - const auto num_type = type_info(); - if (tmp_type == other.p_type_and_content_type) { -#define X(TP) do_divide((TP*) mut_pointer(),\ - other.pointer(),\ - tmp_type.get_type_info());\ - break + /* + * The type resolution needs to be done with the true denominator type, + * rather than whatever the rational type is. + */ + if (num_info != true_denom_info) { + change_type(dtl::compute_dest_type( + p_type_and_content_type, + other.p_type_and_content_type + )); + } - DO_FOR_EACH_X(num_type) -#undef X - } else { - Scalar tmp(tmp_type); - // hopefully converts - tmp = other; -#define X(TP) do_divide((TP*) mut_pointer(),\ - tmp.pointer(),\ - tmp_type.get_type_info());\ - break - - DO_FOR_EACH_X(num_type) + // Now we should be ready +#define X(tp) do_divide((tp*) mut_pointer(), rational_denom.pointer(), rat_denom_info);\ + break + DO_FOR_EACH_X(num_info) #undef X - } - return *this; } diff --git a/scalars/src/scalar/casts.cpp b/scalars/src/scalar/casts.cpp index 362627c4..c1d5f957 100644 --- a/scalars/src/scalar/casts.cpp +++ b/scalars/src/scalar/casts.cpp @@ -92,8 +92,6 @@ write_result(D*, const T*, dimn_t) noexcept static inline bool write_single_poly(rational_poly_scalar* dst, const rational_poly_scalar& value) { - // std::cout << "writing " << value << " to dst " << *dst << '\n'; - *dst = value; return true; } diff --git a/scalars/src/scalar/comparison.cpp b/scalars/src/scalar/comparison.cpp index 57ead28b..f87f2a34 100644 --- a/scalars/src/scalar/comparison.cpp +++ b/scalars/src/scalar/comparison.cpp @@ -296,8 +296,10 @@ bool scalars::dtl::is_pointer_zero( const PackedScalarTypePointer& p_type ) noexcept { + const auto info = (p_type.is_pointer() ? p_type.get_pointer()->type_info() : p_type.get_type_info()); + #define X(TYP) return is_zero(*((const TYP*) ptr)) - DO_FOR_EACH_X(p_type.get_type_info()) + DO_FOR_EACH_X(info) #undef X return true; } diff --git a/scalars/src/scalar/type_promotion.cpp b/scalars/src/scalar/type_promotion.cpp new file mode 100644 index 00000000..3ce5e9f9 --- /dev/null +++ b/scalars/src/scalar/type_promotion.cpp @@ -0,0 +1,170 @@ +// +// Created by sam on 1/19/24. +// + +#include "type_promotion.h" +#include + +using namespace rpy; +using namespace rpy::scalars; + +using scalars::dtl::ScalarContentType; + +namespace { + +devices::TypeInfo compute_int_promotion( + const devices::TypeInfo& dst_info, + const devices::TypeInfo& src_info +) +{ + if (traits::is_integral(src_info)) { + devices::TypeInfo info{ + dst_info.code, + std::max(dst_info.bytes, src_info.bytes), + std::max(dst_info.alignment, src_info.alignment), + 1 + }; + + if (src_info.code != info.code) { + // One is signed, the other is unsigned + if (traits::is_signed(src_info)) { + // If src is signed the we must make sure that the dst is signed + info.code = src_info.code; + } else { + // If src is unsigned then dst might not be large enough + if (info.bytes <= src_info.bytes + && info.bytes < sizeof(int64_t)) { + info.bytes *= 2; + info.alignment = info.bytes; + } + } + } + + return info; + } + + // In almost all other circumstances, the correct promotion is to simply use + // the src type. + return src_info; +} + +devices::TypeInfo compute_float_promotion( + const devices::TypeInfo& dst_info, + const devices::TypeInfo& src_info +) +{ + // Floats should promote to anything that isn't an int + if (traits::is_integral(src_info)) { + // TODO: Make sure dst is large enough + return dst_info; + } + + if (traits::is_floating_point(src_info)) { + if (src_info.bytes > dst_info.bytes) { + return {devices::TypeCode::Float, + src_info.bytes, + src_info.alignment, + 1}; + } + + return dst_info; + } + + return src_info; +} + +devices::TypeInfo compute_bfloat_promotion( + const devices::TypeInfo& dst_info, + const devices::TypeInfo& src_info +) +{ + if (traits::is_integral(src_info)) { + /* + * bfloats cannot hold anything larger than int8_t faithfully, so for + * anything larger than int8_t we'll promote to a float or double. + */ + if (src_info.bytes == 1) { return dst_info; } + if (src_info.bytes <= 4) { return devices::type_info(); } + + return devices::type_info(); + } + + return src_info; +} + +devices::TypeInfo compute_aprational_promotion( + const devices::TypeInfo& dst_info, + const devices::TypeInfo& src_info +) +{ + // Rationals only promote to rational polys + if (src_info.code == devices::TypeCode::APRationalPolynomial) { + return src_info; + } + + return dst_info; +} + +devices::TypeInfo compute_aprpoly_promotion( + const devices::TypeInfo& dst_info, + const devices::TypeInfo& RPY_UNUSED_VAR src_info +) +{ + return dst_info; +} + +devices::TypeInfo compute_promotion( + PackedScalarTypePointer dst_type, + PackedScalarTypePointer src_type +) +{ + const auto dst_info = type_info_from(dst_type); + const auto src_info = type_info_from(src_type); + + switch (dst_info.code) { + case devices::TypeCode::Int: + case devices::TypeCode::UInt: + return compute_int_promotion(dst_info, src_info); + case devices::TypeCode::Float: + return compute_float_promotion(dst_info, src_info); + case devices::TypeCode::BFloat: + return compute_bfloat_promotion(dst_info, src_info); + case devices::TypeCode::ArbitraryPrecisionRational: + return compute_aprational_promotion(dst_info, src_info); + case devices::TypeCode::APRationalPolynomial: + return compute_aprpoly_promotion(dst_info, src_info); + case devices::TypeCode::Bool: + case devices::TypeCode::OpaqueHandle: + case devices::TypeCode::Complex: + case devices::TypeCode::Rational: + case devices::TypeCode::ArbitraryPrecisionInt: + case devices::TypeCode::ArbitraryPrecisionUInt: + case devices::TypeCode::ArbitraryPrecisionFloat: + case devices::TypeCode::ArbitraryPrecisionComplex: break; + } + + RPY_THROW(std::runtime_error, "cannot find suitable promotion"); +} + +}// namespace + +devices::TypeInfo scalars::dtl::compute_dest_type( + PackedScalarTypePointer dst_type, + PackedScalarTypePointer src_type +) +{ + RPY_DBG_ASSERT(!(dst_type == src_type)); + + switch (dst_type.get_enumeration()) { + case ScalarContentType::TrivialBytes: + case ScalarContentType::ConstTrivialBytes: + case ScalarContentType::OwnedPointer: + return compute_promotion(dst_type, src_type); + case ScalarContentType::OpaquePointer: + case ScalarContentType::ConstOpaquePointer: + case ScalarContentType::Interface: + case ScalarContentType::OwnedInterface: return dst_type; + } + + RPY_UNREACHABLE_RETURN(dst_type); +} diff --git a/scalars/src/scalar/type_promotion.h b/scalars/src/scalar/type_promotion.h new file mode 100644 index 00000000..0cf3da61 --- /dev/null +++ b/scalars/src/scalar/type_promotion.h @@ -0,0 +1,23 @@ +// +// Created by sam on 1/19/24. +// + +#ifndef ROUGHPY_TYPE_PROMOTION_H +#define ROUGHPY_TYPE_PROMOTION_H + +#include + +namespace rpy { +namespace scalars { +namespace dtl { + +devices::TypeInfo compute_dest_type( + PackedScalarTypePointer dst_type, + PackedScalarTypePointer src_type +); + +} +}// namespace scalars +}// namespace rpy + +#endif// ROUGHPY_TYPE_PROMOTION_H diff --git a/streams/include/roughpy/streams/stream_base.h b/streams/include/roughpy/streams/stream_base.h index 332bbc10..dd11eabc 100644 --- a/streams/include/roughpy/streams/stream_base.h +++ b/streams/include/roughpy/streams/stream_base.h @@ -36,6 +36,7 @@ #include #include #include +#include #include "schema.h" diff --git a/streams/src/dynamically_constructed_stream.cpp b/streams/src/dynamically_constructed_stream.cpp index 0070fa08..16162396 100644 --- a/streams/src/dynamically_constructed_stream.cpp +++ b/streams/src/dynamically_constructed_stream.cpp @@ -194,11 +194,9 @@ void streams::DynamicallyConstructedStream::update_parents( { // data_increment below(current); // data_increment top = update_parent_accuracy(below); - // std::cout << top->first << ' ' << below->first << '\n'; // while (top != below) { // below = top; // top = update_parent_accuracy(below); - // std::cout << top->first << ' ' << below->first << '\n'; // } const auto& md = metadata(); auto root = m_data_tree.begin(); diff --git a/tests/scalars/test_scalars.py b/tests/scalars/test_scalars.py new file mode 100644 index 00000000..d5874d6c --- /dev/null +++ b/tests/scalars/test_scalars.py @@ -0,0 +1,124 @@ + + + +import pytest + +import roughpy as rp + + +def test_construct_rational_from_float(): + x = rp.Rational(1.5232) + assert x.to_float() == pytest.approx(1.5232) + + +def test_construct_rational_num_denom_positional(): + x = rp.Rational(513452, 21536344) + assert x.to_float() == pytest.approx(513452 / 21536344) + + +def test_scalar_equals_zero(): + x = rp.Rational(0) + assert x == 0 + assert x == 0.0 + assert x == rp.Rational(0) + + +def test_scalar_not_equals_zero(): + x = rp.Rational(1) + assert not x == 0 + assert not x == 0.0 + assert not x == rp.Rational(0) + + +def test_scalar_equals_nonzero(): + x = rp.Rational(3.141527) + assert x == 3.141527 + assert x == rp.Rational(3.141527) + + +def test_inplace_scalar_add_int(): + x = rp.Rational(1) + x += 1 + assert x == 2 + + +def test_inplace_scalar_sub_int(): + x = rp.Rational(1) + x -= 1 + assert x == 0 + + +def test_inplace_scalar_mul_int(): + x = rp.Rational(1) + x *= 2 + assert x == 2 + +@pytest.fixture +def borrowed_scalar(): + tensor = rp.FreeTensor([1.0, 2.0, 3.0], width=2, depth=1) + yield tensor[rp.TensorKey(width=2,depth=1)] + + +def test_add_borrowed_scalar_to_borrowed_scalar(borrowed_scalar): + result = borrowed_scalar + borrowed_scalar + assert result == 2.0 + + +def test_add_borrowed_scalar_to_int(borrowed_scalar): + result = borrowed_scalar + 1 + assert result == 2.0 + + +def test_add_int_to_borrowed_scalar(borrowed_scalar): + result = 1 + borrowed_scalar + assert result == 2.0 + + +def test_add_borrowed_scalar_to_float(borrowed_scalar): + result = borrowed_scalar + 1. + assert result == 2.0 + + +def test_add_borrowed_to_int_0(borrowed_scalar): + result = borrowed_scalar + 0 + assert result == borrowed_scalar + + +def test_add_int_0_to_borrowed_scalar(borrowed_scalar): + result = 0 + borrowed_scalar + assert result == borrowed_scalar + + +def test_add_float_to_borrowed_scalar(borrowed_scalar): + result = 1. + borrowed_scalar + assert result == 2.0 + + +def test_multiply_borrowed_scalar_by_float(borrowed_scalar): + result = borrowed_scalar * 2. + assert result == 2.0 + + +def test_multiply_borrowed_scalar_by_borrowed_scalar(borrowed_scalar): + result = borrowed_scalar * borrowed_scalar + assert result == 1.0 + + +def test_multiply_borrowed_scalar_by_int(borrowed_scalar): + result = borrowed_scalar * 2 + assert result == 2.0 + + +def test_multiply_borrowed_scalar_by_float(borrowed_scalar): + result = borrowed_scalar * 2.0 + assert result == 2.0 + + +def test_multiply_int_by_borrowed_scalar(borrowed_scalar): + result = 2 * borrowed_scalar + assert result == 2.0 + + +def test_mulitiply_float_by_borrowed_scalar(borrowed_scalar): + result = 2. * borrowed_scalar + assert result == 2.0 diff --git a/tests/streams/test_lie_increment_path.py b/tests/streams/test_lie_increment_path.py index 6bd785dd..d9598646 100644 --- a/tests/streams/test_lie_increment_path.py +++ b/tests/streams/test_lie_increment_path.py @@ -335,7 +335,6 @@ def test_log_sigature_interval_alignment_set_resolution(): stream = LieIncrementStream.from_increments(np.array([[0., 1., 2.],[3., 4., 5.]]), depth=2, resolution=4) zero = stream.ctx.zero_lie() - print(f"{stream.resolution=}") result = stream.log_signature(rp.RealInterval(0.0, 0.1), resolution=4) expected = rp.Lie([0., 1., 2.], ctx=stream.ctx) diff --git a/tests/test_tensor_functions.py b/tests/test_tensor_functions.py new file mode 100644 index 00000000..df5b2f06 --- /dev/null +++ b/tests/test_tensor_functions.py @@ -0,0 +1,51 @@ + +import roughpy as rp +from roughpy import tensor_functions + + +def test_adjoint_of_free_tensor(): + + tensor = rp.FreeTensor([1., 2., 3., 4., 5., 6., 7.], width=2, depth=2) + ad_tensor = tensor_functions._adjoint_of_shuffle(tensor) + + key_ = tensor_functions.tensor_word_factory(tensor.context.tensor_basis) + data = { + (key_(), key_()): 1.0, + (key_(), key_(1)): 2.0, + (key_(), key_(2)): 3.0, + (key_(), key_(1, 1)): 4.0, + (key_(), key_(1, 2)): 6.0, + (key_(), key_(2, 1)): 5.0, + (key_(), key_(2, 2)): 7.0, + (key_(1), key_()): 2.0, + (key_(1), key_(1)): 8.0, + (key_(1), key_(2)): 11.0, + (key_(2), key_()): 3.0, + (key_(2), key_(1)): 11.0, + (key_(2), key_(2)): 14.0, + (key_(1, 1), key_()): 4.0, + (key_(1, 2), key_()): 6.0, + (key_(2, 1), key_()): 5.0, + (key_(2, 2), key_()): 7.0 + } + + for k, v in data.items(): + assert k in ad_tensor.data + assert v == ad_tensor.data[k] + del ad_tensor.data[k] + + assert not ad_tensor.data + +def test_Log_on_grouplike(): + tensor_data = [ + 0, + 1 * rp.Monomial("x1"), + 1 * rp.Monomial("x2"), + 1 * rp.Monomial("x3") + ] + x = rp.FreeTensor(tensor_data, width=3, depth=2, + dtype=rp.RationalPoly).exp() + + result = tensor_functions.Log(x) + expected = x.log() + assert result == expected, f"{result} != {expected}" diff --git a/vcpkg.json b/vcpkg.json index 3167870e..932d0a71 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,104 +1,83 @@ { - "name": "roughpy", - "version-string": "1.0.0", - "dependencies": [ - { - "name": "libsndfile", - "version>=": "1.2.0" - }, - { - "name": "eigen3", - "version>=": "3.4.0#2" - }, - { - "name": "boost-align", - "version>=": "1.83.0" - }, - { - "name": "boost-container", - "version>=": "1.83.0" - }, - { - "name": "boost-core", - "version>=": "1.83.0" - }, - { - "name": "boost-url", - "version>=": "1.83.0" - }, - { - "name": "boost-multiprecision", - "version>=": "1.83.0" - }, - { - "name": "boost-type-traits", - "version>=": "1.83.0" - }, - { - "name": "tomlplusplus", - "version>=": "3.1.0", - "$comment": " # this is heuristically generated, and may not be correct\n\n find_package(tomlplusplus CONFIG REQUIRED)\n\n target_link_libraries(main PRIVATE tomlplusplus::tomlplusplus)\n" - }, - { - "name": "boost-thread", - "version>=": "1.83.0" - }, - { - "name": "boost-endian", - "version>=": "1.83.0" - }, - { - "name": "boost-dll", - "version>=": "1.83.0" - }, - { - "name": "gmp", - "version>=": "6.2.1#21", - "platform": "linux" - }, - { - "name": "mpir", - "version>=": "2022-03-02", - "platform": "windows" - }, - { - "name": "pcg", - "version>=": "2021-04-06" - }, - { - "name": "cereal", - "version>=": "1.3.2#1" - }, - { - "name": "boost-smart-ptr", - "version>=": "1.83.0" - }, - { - "name": "boost-uuid", - "version>=": "1.83.0" - }, - { - "name": "boost-interprocess", - "version>=": "1.83.0" - }, - { - "name": "opencl", - "version>=": "v2023.02.06#2" - }, - { - "name": "libalgebra-lite", - "version>=": "1.0.0" - } - ], - "features": { - "tests": { - "description": "dependencies for the unit tests", - "dependencies": [ - { - "name": "gtest", - "version>=": "1.13.0" - } - ] + "name" : "roughpy", + "version-string" : "1.0.0", + "dependencies" : [ { + "name" : "libsndfile", + "version>=" : "1.2.0" + }, { + "name" : "eigen3", + "version>=" : "3.4.0#2" + }, { + "name" : "boost-align", + "version>=" : "1.83.0" + }, { + "name" : "boost-container", + "version>=" : "1.83.0" + }, { + "name" : "boost-core", + "version>=" : "1.83.0" + }, { + "name" : "boost-url", + "version>=" : "1.83.0" + }, { + "name" : "boost-multiprecision", + "version>=" : "1.83.0" + }, { + "name" : "boost-type-traits", + "version>=" : "1.83.0" + }, { + "name" : "tomlplusplus", + "version>=" : "3.1.0", + "$comment" : " # this is heuristically generated, and may not be correct\n\n find_package(tomlplusplus CONFIG REQUIRED)\n\n target_link_libraries(main PRIVATE tomlplusplus::tomlplusplus)\n" + }, { + "name" : "boost-thread", + "version>=" : "1.83.0" + }, { + "name" : "boost-endian", + "version>=" : "1.83.0" + }, { + "name" : "boost-dll", + "version>=" : "1.83.0" + }, { + "name" : "gmp", + "version>=" : "6.2.1#21", + "platform" : "linux" + }, { + "name" : "mpir", + "version>=" : "2022-03-02", + "platform" : "windows" + }, { + "name" : "pcg", + "version>=" : "2021-04-06" + }, { + "name" : "cereal", + "version>=" : "1.3.2#1" + }, { + "name" : "boost-smart-ptr", + "version>=" : "1.83.0" + }, { + "name" : "boost-uuid", + "version>=" : "1.83.0" + }, { + "name" : "boost-interprocess", + "version>=" : "1.83.0" + }, { + "name" : "opencl", + "version>=" : "v2023.02.06#2" + }, { + "name" : "libalgebra-lite", + "version>=" : "1.0.0" + }, { + "name" : "boost-stacktrace", + "version>=" : "1.83.0" + } ], + "features" : { + "tests" : { + "description" : "dependencies for the unit tests", + "dependencies" : [ { + "name" : "gtest", + "version>=" : "1.13.0" + } ] } } -} +} \ No newline at end of file