@@ -23,6 +23,87 @@ namespace pybind11
2323{
2424 namespace detail
2525 {
26+ template <typename T, xt::layout_type L>
27+ struct xtensor_get_buffer
28+ {
29+ template <typename H>
30+ static auto get (H src)
31+ {
32+ return array_t <T, array::c_style | array::forcecast>::ensure (src);
33+ }
34+ };
35+
36+ template <typename T>
37+ struct xtensor_get_buffer <T, xt::layout_type::column_major>
38+ {
39+ template <typename H>
40+ static auto get (H src)
41+ {
42+ return array_t <T, array::f_style>::ensure (src);
43+ }
44+ };
45+
46+ template <class T >
47+ struct xtensor_check_buffer
48+ {
49+ };
50+
51+ template <class T , xt::layout_type L>
52+ struct xtensor_check_buffer <xt::xarray<T, L>>
53+ {
54+ template <typename H>
55+ static auto get (H src)
56+ {
57+ auto buf = xtensor_get_buffer<T, L>::get (src);
58+ return buf;
59+ }
60+ };
61+
62+ template <class T , std::size_t N, xt::layout_type L>
63+ struct xtensor_check_buffer <xt::xtensor<T, N, L>>
64+ {
65+ template <typename H>
66+ static auto get (H src)
67+ {
68+ auto buf = xtensor_get_buffer<T, L>::get (src);
69+ if (buf.ndim () != N) {
70+ return false ;
71+ }
72+ return buf;
73+ }
74+ };
75+
76+ template <class CT , class S , xt::layout_type L, class FST >
77+ struct xtensor_check_buffer <xt::xstrided_view<CT, S, L, FST>>
78+ {
79+ template <typename H>
80+ static auto get (H /* src*/ )
81+ {
82+ return false ;
83+ }
84+ };
85+
86+ template <class EC , xt::layout_type L, class SC , class Tag >
87+ struct xtensor_check_buffer <xt::xarray_adaptor<EC, L, SC, Tag>>
88+ {
89+ template <typename H>
90+ static auto get (H /* src*/ )
91+ {
92+ return false ;
93+ }
94+ };
95+
96+ template <class EC , std::size_t N, xt::layout_type L, class Tag >
97+ struct xtensor_check_buffer <xt::xtensor_adaptor<EC, N, L, Tag>>
98+ {
99+ template <typename H>
100+ static auto get (H /* src*/ )
101+ {
102+ return false ;
103+ }
104+ };
105+
106+
26107 // Casts a strided expression type to numpy array.If given a base,
27108 // the numpy array references the src data, otherwise it'll make a copy.
28109 // The writeable attributes lets you specify writeable flag for the array.
@@ -74,10 +155,6 @@ namespace pybind11
74155 template <class Type >
75156 struct xtensor_type_caster_base
76157 {
77- bool load (handle /* src*/ , bool )
78- {
79- return false ;
80- }
81158
82159 private:
83160
@@ -106,6 +183,30 @@ namespace pybind11
106183
107184 public:
108185
186+ PYBIND11_TYPE_CASTER (Type, _(" numpy.ndarray[" ) + npy_format_descriptor<typename Type::value_type>::name + _(" ]" ));
187+
188+ bool load (handle src, bool convert)
189+ {
190+ using T = typename Type::value_type;
191+
192+ if (!convert && !array_t <T>::check_ (src)) {
193+ return false ;
194+ }
195+
196+ auto buf = xtensor_check_buffer<Type>::get (src);
197+
198+ if (!buf) {
199+ return false ;
200+ }
201+
202+ std::vector<size_t > shape (buf.ndim ());
203+ std::copy (buf.shape (), buf.shape () + buf.ndim (), shape.begin ());
204+ value = Type (shape);
205+ std::copy (buf.data (), buf.data () + buf.size (), value.begin ());
206+
207+ return true ;
208+ }
209+
109210 // Normal returned non-reference, non-const value:
110211 static handle cast (Type&& src, return_value_policy /* policy */ , handle parent)
111212 {
@@ -151,18 +252,6 @@ namespace pybind11
151252 {
152253 return cast_impl (src, policy, parent);
153254 }
154-
155- #ifdef PYBIND11_DESCR // The macro is removed from pybind11 since 2.3
156- static PYBIND11_DESCR name ()
157- {
158- return _ (" xt::xtensor" );
159- }
160- #else
161- static constexpr auto name = _(" xt::xtensor" );
162- #endif
163-
164- template <typename T>
165- using cast_op_type = cast_op_type<T>;
166255 };
167256 }
168257}
0 commit comments