36 #ifndef VIGRA_NUMPY_ARRAY_HXX
37 #define VIGRA_NUMPY_ARRAY_HXX
42 #include <numpy/arrayobject.h>
43 #include "multi_array.hxx"
44 #include "array_vector.hxx"
45 #include "python_utility.hxx"
46 #include "numpy_array_traits.hxx"
47 #include "numpy_array_taggedshape.hxx"
54 static inline void import_vigranumpy()
57 if(_import_array() < 0)
58 pythonToCppException(0);
61 python_ptr module(PyImport_ImportModule(
"vigra.vigranumpycore"), python_ptr::keep_count);
62 pythonToCppException(module);
72 class MultibandVectorAccessor
83 typedef Multiband<T> value_type;
87 typedef T component_type;
89 typedef VectorElementAccessor<MultibandVectorAccessor<T> > ElementAccessor;
94 template <
class ITERATOR>
95 component_type
const & getComponent(ITERATOR
const & i,
int idx)
const
97 return *(&*i+idx*stride_);
105 template <
class V,
class ITERATOR>
106 void setComponent(V
const & value, ITERATOR
const & i,
int idx)
const
108 *(&*i+idx*stride_) = detail::RequiresExplicitCast<component_type>::cast(value);
114 template <
class ITERATOR,
class DIFFERENCE>
115 component_type
const & getComponent(ITERATOR
const & i, DIFFERENCE
const & diff,
int idx)
const
117 return *(&i[diff]+idx*stride_);
125 template <
class V,
class ITERATOR,
class DIFFERENCE>
127 setComponent(V
const & value, ITERATOR
const & i, DIFFERENCE
const & diff,
int idx)
const
129 *(&i[diff]+idx*stride_) = detail::RequiresExplicitCast<component_type>::cast(value);
141 template <
class TYPECODE>
144 constructArray(TaggedShape tagged_shape, TYPECODE typeCode,
bool init,
145 python_ptr arraytype = python_ptr());
174 static python_ptr getArrayTypeObject()
176 return detail::getArrayTypeObject();
179 static std::string defaultOrder(std::string defaultValue =
"C")
181 return detail::defaultOrder(defaultValue);
184 static python_ptr defaultAxistags(
int ndim, std::string order =
"")
186 return detail::defaultAxistags(ndim, order);
189 static python_ptr emptyAxistags(
int ndim)
191 return detail::emptyAxistags(ndim);
201 explicit NumpyAnyArray(PyObject * obj = 0,
bool createCopy =
false, PyTypeObject * type = 0)
205 vigra_precondition(type == 0 || PyType_IsSubtype(type, &PyArray_Type),
206 "NumpyAnyArray(obj, createCopy, type): type must be numpy.ndarray or a subclass thereof.");
210 vigra_precondition(
makeReference(obj, type),
"NumpyAnyArray(obj): obj isn't a numpy array.");
222 vigra_precondition(type == 0 || PyType_IsSubtype(type, &PyArray_Type),
223 "NumpyAnyArray(obj, createCopy, type): type must be numpy.ndarray or a subclass thereof.");
246 vigra_precondition(other.
hasData(),
247 "NumpyArray::operator=(): Cannot assign from empty array.");
249 python_ptr arraytype = getArrayTypeObject();
250 python_ptr f(PyString_FromString(
"_copyValuesImpl"), python_ptr::keep_count);
251 if(PyObject_HasAttr(arraytype, f))
253 python_ptr res(PyObject_CallMethodObjArgs(arraytype, f.get(),
254 pyArray_.get(), other.pyArray_.get(), NULL),
255 python_ptr::keep_count);
256 vigra_postcondition(res.get() != 0,
257 "NumpyArray::operator=(): VigraArray._copyValuesImpl() failed.");
261 PyArrayObject * sarray = (PyArrayObject *)pyArray_.get();
262 PyArrayObject * tarray = (PyArrayObject *)other.pyArray_.get();
264 if(PyArray_CopyInto(tarray, sarray) == -1)
265 pythonToCppException(0);
270 pyArray_ = other.pyArray_;
295 return pythonGetAttr(
pyObject(),
"spatialDimensions",
ndim());
298 bool hasChannelAxis()
const
302 return channelIndex() ==
ndim();
309 return pythonGetAttr(
pyObject(),
"channelIndex",
ndim());
316 return pythonGetAttr(
pyObject(),
"innerNonchannelIndex",
ndim());
348 if(stride[j] < stride[smallest])
353 std::swap(stride[k], stride[smallest]);
354 std::swap(permutation[k], permutation[smallest]);
359 ordering[permutation[k]] = k;
392 return PyArray_DESCR(
pyObject())->type_num;
402 static python_ptr key(PyString_FromString(
"axistags"), python_ptr::keep_count);
407 axistags.reset(PyObject_GetAttr(
pyObject(), key), python_ptr::keep_count);
419 return (PyArrayObject *)pyArray_.get();
428 return pyArray_.get();
441 if(obj == 0 || !PyArray_Check(obj))
445 vigra_precondition(PyType_IsSubtype(type, &PyArray_Type) != 0,
446 "NumpyAnyArray::makeReference(obj, type): type must be numpy.ndarray or a subclass thereof.");
447 obj = PyArray_View((PyArrayObject*)obj, 0, type);
448 pythonToCppException(obj);
460 void makeCopy(PyObject * obj, PyTypeObject * type = 0)
462 vigra_precondition(obj && PyArray_Check(obj),
463 "NumpyAnyArray::makeCopy(obj): obj is not an array.");
464 vigra_precondition(type == 0 || PyType_IsSubtype(type, &PyArray_Type),
465 "NumpyAnyArray::makeCopy(obj, type): type must be numpy.ndarray or a subclass thereof.");
466 python_ptr array(PyArray_NewCopy((PyArrayObject*)obj, NPY_ANYORDER), python_ptr::keep_count);
467 pythonToCppException(array);
476 return pyArray_ != 0;
489 nontrivialPermutation(ArrayVector<npy_intp>
const & p)
491 for(
unsigned int k=0; k<p.size(); ++k)
499 template <
class TYPECODE>
502 constructArray(TaggedShape tagged_shape, TYPECODE typeCode,
bool init, python_ptr arraytype)
504 ArrayVector<npy_intp> shape = finalizeTaggedShape(tagged_shape);
505 PyAxisTags axistags(tagged_shape.axistags);
507 int ndim = (int)shape.size();
508 ArrayVector<npy_intp> inverse_permutation;
514 arraytype = NumpyAnyArray::getArrayTypeObject();
516 inverse_permutation = axistags.permutationFromNormalOrder();
517 vigra_precondition(ndim == (
int)inverse_permutation.size(),
518 "axistags.permutationFromNormalOrder(): permutation has wrong size.");
522 arraytype = python_ptr((PyObject*)&PyArray_Type);
528 python_ptr array(PyArray_New((PyTypeObject *)arraytype.get(), ndim, shape.begin(),
529 typeCode, 0, 0, 0, order, 0),
530 python_ptr::keep_count);
531 pythonToCppException(array);
533 if(detail::nontrivialPermutation(inverse_permutation))
535 PyArray_Dims permute = { inverse_permutation.begin(), ndim };
536 array = python_ptr(PyArray_Transpose((PyArrayObject*)array.get(), &permute),
537 python_ptr::keep_count);
538 pythonToCppException(array);
541 if(arraytype != (PyObject*)&PyArray_Type && axistags)
542 pythonToCppException(PyObject_SetAttrString(array,
"axistags", axistags.axistags) != -1);
545 PyArray_FILLWBYTE((PyArrayObject *)array.get(), 0);
547 return array.release();
551 template <
class TINY_VECTOR>
553 python_ptr constructNumpyArrayFromData(
554 TINY_VECTOR
const & shape, npy_intp *strides,
555 NPY_TYPES typeCode,
void *data)
557 ArrayVector<npy_intp> pyShape(shape.begin(), shape.end());
559 python_ptr array(PyArray_New(&PyArray_Type, shape.size(), pyShape.begin(),
560 typeCode, strides, data, 0, NPY_WRITEABLE, 0),
561 python_ptr::keep_count);
562 pythonToCppException(array);
581 template <
unsigned int N,
class T,
class Str
ide = Str
idedArrayTag>
583 :
public MultiArrayView<N, typename NumpyArrayTraits<N, T, Stride>::value_type, Stride>,
587 typedef NumpyArrayTraits<N, T, Stride> ArrayTraits;
588 typedef typename ArrayTraits::dtype
dtype;
589 typedef T pseudo_value_type;
591 static NPY_TYPES
const typeCode = ArrayTraits::typeCode;
597 enum { actual_dimension = view_type::actual_dimension };
658 void setupArrayView();
660 static python_ptr init(
difference_type const & shape,
bool init =
true,
661 std::string
const & order =
"")
663 vigra_precondition(order ==
"" || order ==
"C" || order ==
"F" ||
664 order ==
"V" || order ==
"A",
665 "NumpyArray.init(): order must be in ['C', 'F', 'V', 'A', ''].");
666 return python_ptr(constructArray(ArrayTraits::taggedShape(shape, order), typeCode, init),
667 python_ptr::keep_count);
687 explicit NumpyArray(PyObject *obj = 0,
bool createCopy =
false)
695 "NumpyArray(obj): Cannot construct from incompatible array.");
720 template <
class U,
class S>
726 "NumpyArray(MultiArrayView): Python constructor did not produce a compatible array.");
740 "NumpyArray(shape): Python constructor did not produce a compatible array.");
752 "NumpyArray(tagged_shape): Python constructor did not produce a compatible array.");
767 "NumpyArray(NumpyAnyArray): Cannot construct from incompatible or empty array.");
795 template <
class U,
class S>
801 "NumpyArray::operator=(): shape mismatch.");
808 "NumpyArray::operator=(): reshape failed unexpectedly.");
821 template <
class U,
class S>
827 "NumpyArray::operator=(): shape mismatch.");
834 "NumpyArray::operator=(): reshape failed unexpectedly.");
861 vigra_precondition(
false,
862 "NumpyArray::operator=(): Cannot assign from incompatible array.");
876 "NumpyArray::permuteLikewise(): array has no data.");
879 ArrayTraits::permuteLikewise(this->pyArray_, data, res);
887 template <
class U,
int K>
892 "NumpyArray::permuteLikewise(): array has no data.");
895 ArrayTraits::permuteLikewise(this->pyArray_, data, res);
908 "NumpyArray::permuteLikewise(): array has no data.");
912 ArrayTraits::permuteLikewise(this->pyArray_, data, res);
924 #if VIGRA_CONVERTER_DEBUG
925 std::cerr <<
"class " <<
typeid(
NumpyArray).name() <<
" got " << obj->ob_type->tp_name <<
"\n";
926 std::cerr <<
"using traits " <<
typeid(ArrayTraits).name() <<
"\n";
927 std::cerr<<
"isArray: "<< ArrayTraits::isArray(obj)<<std::endl;
928 std::cerr<<
"isShapeCompatible: "<< ArrayTraits::isShapeCompatible((PyArrayObject *)obj)<<std::endl;
931 return ArrayTraits::isArray(obj) &&
932 ArrayTraits::isShapeCompatible((PyArrayObject *)obj);
943 return ArrayTraits::isArray(obj) &&
944 ArrayTraits::isPropertyCompatible((PyArrayObject *)obj);
963 for(
unsigned int k=0; k<N; ++k)
964 strideOrdering[k] = k;
1017 vigra_precondition(!
hasData(),
1018 "makeUnsafeReference(): cannot replace existing view with given buffer");
1021 python_ptr array(ArrayTraits::unsafeConstructorFromData(multiArrayView.
shape(),
1022 multiArrayView.
data(), multiArrayView.
stride()));
1036 #if VIGRA_CONVERTER_DEBUG
1037 int ndim = PyArray_NDIM((PyArrayObject *)obj);
1038 npy_intp * s = PyArray_DIMS((PyArrayObject *)obj);
1041 std::cerr <<
"for " <<
typeid(*this).name() <<
"\n";
1044 "NumpyArray::makeCopy(obj): Cannot copy an incompatible array.");
1062 "NumpyArray.reshape(shape): Python constructor did not produce a compatible array.");
1084 ArrayTraits::finalizeTaggedShape(tagged_shape);
1088 vigra_precondition(tagged_shape.compatible(taggedShape()), message.c_str());
1092 python_ptr array(constructArray(tagged_shape, typeCode,
true),
1093 python_ptr::keep_count);
1095 "NumpyArray.reshapeIfEmpty(): Python constructor did not produce a compatible array.");
1099 TaggedShape taggedShape()
const
1101 return ArrayTraits::taggedShape(this->
shape(), PyAxisTags(this->
axistags(),
true));
1106 template <
unsigned int N,
class T,
class Str
ide>
1107 void NumpyArray<N, T, Stride>::setupArrayView()
1111 permutation_type permute;
1112 ArrayTraits::permutationToSetupOrder(this->pyArray_, permute);
1114 vigra_precondition(
abs((
int)permute.size() - actual_dimension) <= 1,
1115 "NumpyArray::setupArrayView(): got array of incompatible shape (should never happen).");
1118 pyArray()->dimensions, this->m_shape.begin());
1120 pyArray()->strides, this->m_stride.begin());
1122 if((
int)permute.size() == actual_dimension - 1)
1124 this->m_shape[actual_dimension-1] = 1;
1125 this->m_stride[actual_dimension-1] =
sizeof(value_type);
1128 this->m_stride /=
sizeof(value_type);
1129 this->m_ptr =
reinterpret_cast<pointer
>(pyArray()->data);
1130 vigra_precondition(this->checkInnerStride(Stride()),
1131 "NumpyArray<..., UnstridedArrayTag>::setupArrayView(): First dimension of given array is not unstrided (should never happen).");
1141 typedef NumpyArray<2, float > NumpyFArray2;
1142 typedef NumpyArray<3, float > NumpyFArray3;
1143 typedef NumpyArray<4, float > NumpyFArray4;
1144 typedef NumpyArray<2, Singleband<float> > NumpyFImage;
1145 typedef NumpyArray<3, Singleband<float> > NumpyFVolume;
1146 typedef NumpyArray<2, RGBValue<float> > NumpyFRGBImage;
1147 typedef NumpyArray<3, RGBValue<float> > NumpyFRGBVolume;
1148 typedef NumpyArray<3, Multiband<float> > NumpyFMultibandImage;
1149 typedef NumpyArray<4, Multiband<float> > NumpyFMultibandVolume;
1157 template <
class PixelType,
class Str
ide>
1158 inline triple<ConstStridedImageIterator<PixelType>,
1159 ConstStridedImageIterator<PixelType>,
1160 MultibandVectorAccessor<PixelType> >
1161 srcImageRange(NumpyArray<3, Multiband<PixelType>, Stride>
const & img)
1163 ConstStridedImageIterator<PixelType>
1164 ul(img.data(), 1, img.stride(0), img.stride(1));
1165 return triple<ConstStridedImageIterator<PixelType>,
1166 ConstStridedImageIterator<PixelType>,
1167 MultibandVectorAccessor<PixelType> >
1168 (ul, ul + Size2D(img.shape(0), img.shape(1)), MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1171 template <
class PixelType,
class Str
ide>
1172 inline pair< ConstStridedImageIterator<PixelType>,
1173 MultibandVectorAccessor<PixelType> >
1174 srcImage(NumpyArray<3, Multiband<PixelType>, Stride>
const & img)
1176 ConstStridedImageIterator<PixelType>
1177 ul(img.data(), 1, img.stride(0), img.stride(1));
1178 return pair<ConstStridedImageIterator<PixelType>, MultibandVectorAccessor<PixelType> >
1179 (ul, MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1182 template <
class PixelType,
class Str
ide>
1183 inline triple< StridedImageIterator<PixelType>,
1184 StridedImageIterator<PixelType>,
1185 MultibandVectorAccessor<PixelType> >
1186 destImageRange(NumpyArray<3, Multiband<PixelType>, Stride> & img)
1188 StridedImageIterator<PixelType>
1189 ul(img.data(), 1, img.stride(0), img.stride(1));
1190 typedef typename AccessorTraits<PixelType>::default_accessor Accessor;
1191 return triple<StridedImageIterator<PixelType>,
1192 StridedImageIterator<PixelType>,
1193 MultibandVectorAccessor<PixelType> >
1194 (ul, ul + Size2D(img.shape(0), img.shape(1)),
1195 MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1198 template <
class PixelType,
class Str
ide>
1199 inline pair< StridedImageIterator<PixelType>,
1200 MultibandVectorAccessor<PixelType> >
1201 destImage(NumpyArray<3, Multiband<PixelType>, Stride> & img)
1203 StridedImageIterator<PixelType>
1204 ul(img.data(), 1, img.stride(0), img.stride(1));
1205 return pair<StridedImageIterator<PixelType>, MultibandVectorAccessor<PixelType> >
1206 (ul, MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1209 template <
class PixelType,
class Str
ide>
1210 inline pair< ConstStridedImageIterator<PixelType>,
1211 MultibandVectorAccessor<PixelType> >
1212 maskImage(NumpyArray<3, Multiband<PixelType>, Stride>
const & img)
1214 ConstStridedImageIterator<PixelType>
1215 ul(img.data(), 1, img.stride(0), img.stride(1));
1216 typedef typename AccessorTraits<PixelType>::default_accessor Accessor;
1217 return pair<ConstStridedImageIterator<PixelType>, MultibandVectorAccessor<PixelType> >
1218 (ul, MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1223 #endif // VIGRA_NUMPY_ARRAY_HXX