36 #ifndef VIGRA_NUMPY_ARRAY_HXX
37 #define VIGRA_NUMPY_ARRAY_HXX
42 #include <numpy/arrayobject.h>
43 #include "multi_array.hxx"
44 #include "array_vector.hxx"
45 #include "python_utility.hxx"
46 #include "numpy_array_traits.hxx"
47 #include "numpy_array_taggedshape.hxx"
53 inline void import_vigranumpy()
55 if(_import_array() < 0)
56 pythonToCppException(0);
57 python_ptr module(PyImport_ImportModule(
"vigra.vigranumpycore"), python_ptr::keep_count);
58 pythonToCppException(module);
68 class MultibandVectorAccessor
79 typedef Multiband<T> value_type;
83 typedef T component_type;
85 typedef VectorElementAccessor<MultibandVectorAccessor<T> > ElementAccessor;
90 template <
class ITERATOR>
91 component_type
const & getComponent(ITERATOR
const & i,
int idx)
const
93 return *(&*i+idx*stride_);
101 template <
class V,
class ITERATOR>
102 void setComponent(V
const & value, ITERATOR
const & i,
int idx)
const
104 *(&*i+idx*stride_) = detail::RequiresExplicitCast<component_type>::cast(value);
110 template <
class ITERATOR,
class DIFFERENCE>
111 component_type
const & getComponent(ITERATOR
const & i, DIFFERENCE
const & diff,
int idx)
const
113 return *(&i[diff]+idx*stride_);
121 template <
class V,
class ITERATOR,
class DIFFERENCE>
123 setComponent(V
const & value, ITERATOR
const & i, DIFFERENCE
const & diff,
int idx)
const
125 *(&i[diff]+idx*stride_) = detail::RequiresExplicitCast<component_type>::cast(value);
137 template <
class TYPECODE>
140 constructArray(TaggedShape tagged_shape, TYPECODE typeCode,
bool init,
141 python_ptr arraytype = python_ptr());
170 static python_ptr getArrayTypeObject()
172 return detail::getArrayTypeObject();
175 static std::string defaultOrder(std::string defaultValue =
"C")
177 return detail::defaultOrder(defaultValue);
180 static python_ptr defaultAxistags(
int ndim, std::string order =
"")
182 return detail::defaultAxistags(ndim, order);
185 static python_ptr emptyAxistags(
int ndim)
187 return detail::emptyAxistags(ndim);
197 explicit NumpyAnyArray(PyObject * obj = 0,
bool createCopy =
false, PyTypeObject * type = 0)
201 vigra_precondition(type == 0 || PyType_IsSubtype(type, &PyArray_Type),
202 "NumpyAnyArray(obj, createCopy, type): type must be numpy.ndarray or a subclass thereof.");
206 vigra_precondition(
makeReference(obj, type),
"NumpyAnyArray(obj): obj isn't a numpy array.");
218 vigra_precondition(type == 0 || PyType_IsSubtype(type, &PyArray_Type),
219 "NumpyAnyArray(obj, createCopy, type): type must be numpy.ndarray or a subclass thereof.");
242 vigra_precondition(other.
hasData(),
243 "NumpyArray::operator=(): Cannot assign from empty array.");
245 python_ptr arraytype = getArrayTypeObject();
246 python_ptr f(PyString_FromString(
"_copyValuesImpl"), python_ptr::keep_count);
247 if(PyObject_HasAttr(arraytype, f))
249 python_ptr res(PyObject_CallMethodObjArgs(arraytype, f.get(),
250 pyArray_.get(), other.pyArray_.get(), NULL),
251 python_ptr::keep_count);
252 vigra_postcondition(res.get() != 0,
253 "NumpyArray::operator=(): VigraArray._copyValuesImpl() failed.");
257 PyArrayObject * sarray = (PyArrayObject *)pyArray_.get();
258 PyArrayObject * tarray = (PyArrayObject *)other.pyArray_.get();
260 if(PyArray_CopyInto(tarray, sarray) == -1)
261 pythonToCppException(0);
266 pyArray_ = other.pyArray_;
291 return pythonGetAttr(
pyObject(),
"spatialDimensions",
ndim());
294 bool hasChannelAxis()
const
298 return channelIndex() ==
ndim();
305 return pythonGetAttr(
pyObject(),
"channelIndex",
ndim());
312 return pythonGetAttr(
pyObject(),
"innerNonchannelIndex",
ndim());
344 if(stride[j] < stride[smallest])
349 std::swap(stride[k], stride[smallest]);
350 std::swap(permutation[k], permutation[smallest]);
355 ordering[permutation[k]] = k;
388 return PyArray_DESCR(
pyObject())->type_num;
398 static python_ptr key(PyString_FromString(
"axistags"), python_ptr::keep_count);
399 python_ptr
axistags(PyObject_GetAttr(
pyObject(), key), python_ptr::keep_count);
410 return (PyArrayObject *)pyArray_.get();
419 return pyArray_.get();
432 if(obj == 0 || !PyArray_Check(obj))
436 vigra_precondition(PyType_IsSubtype(type, &PyArray_Type) != 0,
437 "NumpyAnyArray::makeReference(obj, type): type must be numpy.ndarray or a subclass thereof.");
438 obj = PyArray_View((PyArrayObject*)obj, 0, type);
439 pythonToCppException(obj);
451 void makeCopy(PyObject * obj, PyTypeObject * type = 0)
453 vigra_precondition(obj && PyArray_Check(obj),
454 "NumpyAnyArray::makeCopy(obj): obj is not an array.");
455 vigra_precondition(type == 0 || PyType_IsSubtype(type, &PyArray_Type),
456 "NumpyAnyArray::makeCopy(obj, type): type must be numpy.ndarray or a subclass thereof.");
457 python_ptr array(PyArray_NewCopy((PyArrayObject*)obj, NPY_ANYORDER), python_ptr::keep_count);
458 pythonToCppException(array);
467 return pyArray_ != 0;
480 nontrivialPermutation(ArrayVector<npy_intp>
const & p)
482 for(
unsigned int k=0; k<p.size(); ++k)
490 template <
class TYPECODE>
493 constructArray(TaggedShape tagged_shape, TYPECODE typeCode,
bool init, python_ptr arraytype)
495 ArrayVector<npy_intp> shape = finalizeTaggedShape(tagged_shape);
496 PyAxisTags axistags(tagged_shape.axistags);
498 int ndim = (int)shape.size();
499 ArrayVector<npy_intp> inverse_permutation;
505 arraytype = NumpyAnyArray::getArrayTypeObject();
507 inverse_permutation = axistags.permutationFromNormalOrder();
508 vigra_precondition(ndim == (
int)inverse_permutation.size(),
509 "axistags.permutationFromNormalOrder(): permutation has wrong size.");
513 arraytype = python_ptr((PyObject*)&PyArray_Type);
519 python_ptr array(PyArray_New((PyTypeObject *)arraytype.get(), ndim, shape.begin(),
520 typeCode, 0, 0, 0, order, 0),
521 python_ptr::keep_count);
522 pythonToCppException(array);
524 if(detail::nontrivialPermutation(inverse_permutation))
526 PyArray_Dims permute = { inverse_permutation.begin(), ndim };
527 array = python_ptr(PyArray_Transpose((PyArrayObject*)array.get(), &permute),
528 python_ptr::keep_count);
529 pythonToCppException(array);
532 if(arraytype != (PyObject*)&PyArray_Type && axistags)
533 pythonToCppException(PyObject_SetAttrString(array,
"axistags", axistags.axistags) != -1);
536 PyArray_FILLWBYTE((PyArrayObject *)array.get(), 0);
538 return array.release();
542 template <
class TINY_VECTOR>
544 python_ptr constructNumpyArrayFromData(
545 TINY_VECTOR
const & shape, npy_intp *strides,
546 NPY_TYPES typeCode,
void *data)
548 ArrayVector<npy_intp> pyShape(shape.begin(), shape.end());
550 python_ptr array(PyArray_New(&PyArray_Type, shape.size(), pyShape.begin(),
551 typeCode, strides, data, 0, NPY_WRITEABLE, 0),
552 python_ptr::keep_count);
553 pythonToCppException(array);
572 template <
unsigned int N,
class T,
class Str
ide = Str
idedArrayTag>
574 :
public MultiArrayView<N, typename NumpyArrayTraits<N, T, Stride>::value_type, Stride>,
578 typedef NumpyArrayTraits<N, T, Stride> ArrayTraits;
579 typedef typename ArrayTraits::dtype
dtype;
580 typedef T pseudo_value_type;
582 static NPY_TYPES
const typeCode = ArrayTraits::typeCode;
588 enum { actual_dimension = view_type::actual_dimension };
649 void setupArrayView();
651 static python_ptr init(
difference_type const & shape,
bool init =
true,
652 std::string
const & order =
"")
654 vigra_precondition(order ==
"" || order ==
"C" || order ==
"F" ||
655 order ==
"V" || order ==
"A",
656 "NumpyArray.init(): order must be in ['C', 'F', 'V', 'A', ''].");
657 return python_ptr(constructArray(ArrayTraits::taggedShape(shape, order), typeCode, init),
658 python_ptr::keep_count);
678 explicit NumpyArray(PyObject *obj = 0,
bool createCopy =
false)
686 "NumpyArray(obj): Cannot construct from incompatible array.");
716 "NumpyArray(view_type): Python constructor did not produce a compatible array.");
730 "NumpyArray(shape): Python constructor did not produce a compatible array.");
742 "NumpyArray(tagged_shape): Python constructor did not produce a compatible array.");
757 "NumpyArray(NumpyAnyArray): Cannot construct from incompatible or empty array.");
785 template <
class U,
class S>
791 "NumpyArray::operator=(): shape mismatch.");
798 "NumpyArray::operator=(): reshape failed unexpectedly.");
825 vigra_precondition(
false,
826 "NumpyArray::operator=(): Cannot assign from incompatible array.");
840 "NumpyArray::permuteLikewise(): array has no data.");
851 template<
class U,
int K>
856 "NumpyArray::permuteLikewise(): array has no data.");
871 #if VIGRA_CONVERTER_DEBUG
872 std::cerr <<
"class " <<
typeid(
NumpyArray).name() <<
" got " << obj->ob_type->tp_name <<
"\n";
873 std::cerr <<
"using traits " <<
typeid(ArrayTraits).name() <<
"\n";
874 std::cerr<<
"isArray: "<< ArrayTraits::isArray(obj)<<std::endl;
875 std::cerr<<
"isShapeCompatible: "<< ArrayTraits::isShapeCompatible((PyArrayObject *)obj)<<std::endl;
878 return ArrayTraits::isArray(obj) &&
879 ArrayTraits::isShapeCompatible((PyArrayObject *)obj);
890 return ArrayTraits::isArray(obj) &&
891 ArrayTraits::isPropertyCompatible((PyArrayObject *)obj);
910 for(
unsigned int k=0; k<N; ++k)
911 strideOrdering[k] = k;
965 "makeUnsafeReference(): cannot replace existing view with given buffer");
968 python_ptr array(ArrayTraits::unsafeConstructorFromData(multiArrayView.
shape(),
969 multiArrayView.
data(), multiArrayView.
stride()));
983 #if VIGRA_CONVERTER_DEBUG
984 int ndim = PyArray_NDIM((PyArrayObject *)obj);
985 npy_intp * s = PyArray_DIMS((PyArrayObject *)obj);
988 std::cerr <<
"for " <<
typeid(*this).name() <<
"\n";
991 "NumpyArray::makeCopy(obj): Cannot copy an incompatible array.");
1009 "NumpyArray.reshape(shape): Python constructor did not produce a compatible array.");
1031 ArrayTraits::finalizeTaggedShape(tagged_shape);
1035 vigra_precondition(tagged_shape.compatible(taggedShape()), message.c_str());
1039 python_ptr array(constructArray(tagged_shape, typeCode,
true),
1040 python_ptr::keep_count);
1042 "NumpyArray.reshapeIfEmpty(): Python constructor did not produce a compatible array.");
1046 TaggedShape taggedShape()
const
1048 return ArrayTraits::taggedShape(this->
shape(), PyAxisTags(this->
axistags(),
true));
1053 template <
unsigned int N,
class T,
class Str
ide>
1054 void NumpyArray<N, T, Stride>::setupArrayView()
1058 permutation_type permute;
1059 ArrayTraits::permutationToSetupOrder(this->pyArray_, permute);
1061 vigra_precondition(
abs((
int)permute.size() - actual_dimension) <= 1,
1062 "NumpyArray::setupArrayView(): got array of incompatible shape (should never happen).");
1065 pyArray()->dimensions, this->m_shape.begin());
1067 pyArray()->strides, this->m_stride.begin());
1069 if((
int)permute.size() == actual_dimension - 1)
1071 this->m_shape[actual_dimension-1] = 1;
1072 this->m_stride[actual_dimension-1] =
sizeof(value_type);
1075 this->m_stride /=
sizeof(value_type);
1076 this->m_ptr =
reinterpret_cast<pointer
>(pyArray()->data);
1077 vigra_precondition(this->checkInnerStride(Stride()),
1078 "NumpyArray<..., UnstridedArrayTag>::setupArrayView(): First dimension of given array is not unstrided (should never happen).");
1088 typedef NumpyArray<2, float > NumpyFArray2;
1089 typedef NumpyArray<3, float > NumpyFArray3;
1090 typedef NumpyArray<4, float > NumpyFArray4;
1091 typedef NumpyArray<2, Singleband<float> > NumpyFImage;
1092 typedef NumpyArray<3, Singleband<float> > NumpyFVolume;
1093 typedef NumpyArray<2, RGBValue<float> > NumpyFRGBImage;
1094 typedef NumpyArray<3, RGBValue<float> > NumpyFRGBVolume;
1095 typedef NumpyArray<3, Multiband<float> > NumpyFMultibandImage;
1096 typedef NumpyArray<4, Multiband<float> > NumpyFMultibandVolume;
1104 template <
class PixelType,
class Str
ide>
1105 inline triple<ConstStridedImageIterator<PixelType>,
1106 ConstStridedImageIterator<PixelType>,
1107 MultibandVectorAccessor<PixelType> >
1108 srcImageRange(NumpyArray<3, Multiband<PixelType>, Stride>
const & img)
1110 ConstStridedImageIterator<PixelType>
1111 ul(img.data(), 1, img.stride(0), img.stride(1));
1112 return triple<ConstStridedImageIterator<PixelType>,
1113 ConstStridedImageIterator<PixelType>,
1114 MultibandVectorAccessor<PixelType> >
1115 (ul, ul + Size2D(img.shape(0), img.shape(1)), MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1118 template <
class PixelType,
class Str
ide>
1119 inline pair< ConstStridedImageIterator<PixelType>,
1120 MultibandVectorAccessor<PixelType> >
1121 srcImage(NumpyArray<3, Multiband<PixelType>, Stride>
const & img)
1123 ConstStridedImageIterator<PixelType>
1124 ul(img.data(), 1, img.stride(0), img.stride(1));
1125 return pair<ConstStridedImageIterator<PixelType>, MultibandVectorAccessor<PixelType> >
1126 (ul, MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1129 template <
class PixelType,
class Str
ide>
1130 inline triple< StridedImageIterator<PixelType>,
1131 StridedImageIterator<PixelType>,
1132 MultibandVectorAccessor<PixelType> >
1133 destImageRange(NumpyArray<3, Multiband<PixelType>, Stride> & img)
1135 StridedImageIterator<PixelType>
1136 ul(img.data(), 1, img.stride(0), img.stride(1));
1137 typedef typename AccessorTraits<PixelType>::default_accessor Accessor;
1138 return triple<StridedImageIterator<PixelType>,
1139 StridedImageIterator<PixelType>,
1140 MultibandVectorAccessor<PixelType> >
1141 (ul, ul + Size2D(img.shape(0), img.shape(1)),
1142 MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1145 template <
class PixelType,
class Str
ide>
1146 inline pair< StridedImageIterator<PixelType>,
1147 MultibandVectorAccessor<PixelType> >
1148 destImage(NumpyArray<3, Multiband<PixelType>, Stride> & img)
1150 StridedImageIterator<PixelType>
1151 ul(img.data(), 1, img.stride(0), img.stride(1));
1152 return pair<StridedImageIterator<PixelType>, MultibandVectorAccessor<PixelType> >
1153 (ul, MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1156 template <
class PixelType,
class Str
ide>
1157 inline pair< ConstStridedImageIterator<PixelType>,
1158 MultibandVectorAccessor<PixelType> >
1159 maskImage(NumpyArray<3, Multiband<PixelType>, Stride>
const & img)
1161 ConstStridedImageIterator<PixelType>
1162 ul(img.data(), 1, img.stride(0), img.stride(1));
1163 typedef typename AccessorTraits<PixelType>::default_accessor Accessor;
1164 return pair<ConstStridedImageIterator<PixelType>, MultibandVectorAccessor<PixelType> >
1165 (ul, MultibandVectorAccessor<PixelType>(img.shape(2), img.stride(2)));
1170 #endif // VIGRA_NUMPY_ARRAY_HXX