Forráskód Böngészése

replace python lists with numpy arrays and avoid copy of input data

Christos 8 éve
szülő
commit
ac61b3fabe
1 módosított fájl, 78 hozzáadás és 74 törlés
  1. 78 74
      matplotlibcpp.h

+ 78 - 74
matplotlibcpp.h

@@ -3,6 +3,7 @@
 #include <vector>
 #include <map>
 #include <numeric>
+#include <algorithm>
 #include <stdexcept>
 #include <iostream>
 
@@ -23,6 +24,8 @@
 #define PyString_FromString PyUnicode_FromString
 #endif
 
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+#include <numpy/arrayobject.h>
 
 namespace matplotlibcpp {
 
@@ -72,6 +75,8 @@ namespace matplotlibcpp {
 				Py_SetProgramName(name);
 				Py_Initialize();
 
+				import_array(); // initialize numpy C-API
+
 				PyObject* pyplotname = PyString_FromString("matplotlib.pyplot");
 				PyObject* pylabname  = PyString_FromString("pylab");
 				if(!pyplotname || !pylabname) { throw std::runtime_error("couldnt create string"); }
@@ -177,25 +182,52 @@ namespace matplotlibcpp {
 
 		return res;
 	}
+	// Type selector for numpy array conversion
+	template <typename T> struct select_npy_type { const static NPY_TYPES type = NPY_NOTYPE; }; //Default
+	template <> struct select_npy_type<double> { const static NPY_TYPES type = NPY_DOUBLE; };
+	template <> struct select_npy_type<float> { const static NPY_TYPES type = NPY_FLOAT; };
+	template <> struct select_npy_type<bool> { const static NPY_TYPES type = NPY_BOOL; };
+	template <> struct select_npy_type<std::int8_t> { const static NPY_TYPES type = NPY_INT8; };
+	template <> struct select_npy_type<std::int16_t> { const static NPY_TYPES type = NPY_SHORT; };
+	template <> struct select_npy_type<std::int32_t> { const static NPY_TYPES type = NPY_INT; };
+	template <> struct select_npy_type<std::int64_t> { const static NPY_TYPES type = NPY_INT64; };
+	template <> struct select_npy_type<std::uint8_t> { const static NPY_TYPES type = NPY_UINT8; };
+	template <> struct select_npy_type<std::uint16_t> { const static NPY_TYPES type = NPY_USHORT; };
+	template <> struct select_npy_type<std::uint32_t> { const static NPY_TYPES type = NPY_ULONG; };
+	template <> struct select_npy_type<std::uint64_t> { const static NPY_TYPES type = NPY_UINT64; };
+
+	template<typename Numeric>
+	PyObject* get_array(const std::vector<Numeric>& v)
+	{
+		detail::_interpreter::get();	//interpreter needs to be initialized for the numpy commands to work
+		NPY_TYPES type = select_npy_type<Numeric>::type; 
+		if (type == NPY_NOTYPE)
+		{
+				std::vector<double> vd(v.size());
+				npy_intp vsize = v.size();
+				std::copy(v.begin(),v.end(),vd.begin());
+				PyObject* varray = PyArray_SimpleNewFromData(1, &vsize, NPY_DOUBLE, (void*)(vd.data()));
+				return varray;
+		}
+
+		npy_intp vsize = v.size();
+		PyObject* varray = PyArray_SimpleNewFromData(1, &vsize, type, (void*)(v.data()));
+		return varray;
+	}
 
 	template<typename Numeric>
 	bool plot(const std::vector<Numeric> &x, const std::vector<Numeric> &y, const std::map<std::string, std::string>& keywords)
 	{
 		assert(x.size() == y.size());
 
-		// using python lists
-		PyObject* xlist = PyList_New(x.size());
-		PyObject* ylist = PyList_New(y.size());
-
-		for(size_t i = 0; i < x.size(); ++i) {
-			PyList_SetItem(xlist, i, PyFloat_FromDouble(x.at(i)));
-			PyList_SetItem(ylist, i, PyFloat_FromDouble(y.at(i)));
-		}
+		// using numpy arrays
+		PyObject* xarray = get_array(x);
+		PyObject* yarray = get_array(y);
 
 		// construct positional args
 		PyObject* args = PyTuple_New(2);
-		PyTuple_SetItem(args, 0, xlist);
-		PyTuple_SetItem(args, 1, ylist);
+		PyTuple_SetItem(args, 0, xarray);
+		PyTuple_SetItem(args, 1, yarray);
 
 		// construct keyword args
 		PyObject* kwargs = PyDict_New();
@@ -219,22 +251,16 @@ namespace matplotlibcpp {
 		assert(x.size() == y1.size());
 		assert(x.size() == y2.size());
 
-		// using python lists
-		PyObject* xlist = PyList_New(x.size());
-		PyObject* y1list = PyList_New(y1.size());
-		PyObject* y2list = PyList_New(y2.size());
-
-		for(size_t i = 0; i < x.size(); ++i) {
-			PyList_SetItem(xlist, i, PyFloat_FromDouble(x.at(i)));
-			PyList_SetItem(y1list, i, PyFloat_FromDouble(y1.at(i)));
-			PyList_SetItem(y2list, i, PyFloat_FromDouble(y2.at(i)));
-		}
+		// using numpy arrays
+		PyObject* xarray = get_array(x);
+		PyObject* y1array = get_array(y1);
+		PyObject* y2array = get_array(y2);
 
 		// construct positional args
 		PyObject* args = PyTuple_New(3);
-		PyTuple_SetItem(args, 0, xlist);
-		PyTuple_SetItem(args, 1, y1list);
-		PyTuple_SetItem(args, 2, y2list);
+		PyTuple_SetItem(args, 0, xarray);
+		PyTuple_SetItem(args, 1, y1array);
+		PyTuple_SetItem(args, 2, y2array);
 
 		// construct keyword args
 		PyObject* kwargs = PyDict_New();
@@ -255,20 +281,18 @@ namespace matplotlibcpp {
 	template< typename Numeric>
 	bool hist(const std::vector<Numeric>& y, long bins=10,std::string color="b", double alpha=1.0)
 	{
-		PyObject* ylist = PyList_New(y.size());
-		
+
+		PyObject* yarray = get_array(y);
+
 		PyObject* kwargs = PyDict_New();
 		PyDict_SetItemString(kwargs, "bins", PyLong_FromLong(bins));
 		PyDict_SetItemString(kwargs, "color", PyString_FromString(color.c_str()));
 		PyDict_SetItemString(kwargs, "alpha", PyFloat_FromDouble(alpha));
 		
-		for(size_t i = 0; i < y.size(); ++i) {
-			PyList_SetItem(ylist, i, PyFloat_FromDouble(y.at(i)));
-		}
 
 		PyObject* plot_args = PyTuple_New(1);
 
-		PyTuple_SetItem(plot_args, 0, ylist);
+		PyTuple_SetItem(plot_args, 0, yarray);
 
 
 		PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_hist, plot_args, kwargs);
@@ -284,19 +308,17 @@ namespace matplotlibcpp {
 	template< typename Numeric>
 	bool named_hist(std::string label,const std::vector<Numeric>& y, long bins=10, std::string color="b", double alpha=1.0)
 	{
-		PyObject* ylist = PyList_New(y.size());
+		PyObject* yarray = get_array(y);
+
 		PyObject* kwargs = PyDict_New();
 		PyDict_SetItemString(kwargs, "label", PyString_FromString(label.c_str()));
 		PyDict_SetItemString(kwargs, "bins", PyLong_FromLong(bins));
-		PyDict_SetItemString(kwargs, "color", PyString_FromString(color.c_str()));  
+		PyDict_SetItemString(kwargs, "color", PyString_FromString(color.c_str()));
 		PyDict_SetItemString(kwargs, "alpha", PyFloat_FromDouble(alpha));
-		
-		for(size_t i = 0; i < y.size(); ++i) {
-			PyList_SetItem(ylist, i, PyFloat_FromDouble(y.at(i)));
-		}
+
 
 		PyObject* plot_args = PyTuple_New(1);
-		PyTuple_SetItem(plot_args, 0, ylist);
+		PyTuple_SetItem(plot_args, 0, yarray);
 
 		PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_hist, plot_args, kwargs);
 
@@ -312,18 +334,14 @@ namespace matplotlibcpp {
 	{
 		assert(x.size() == y.size());
 
-		PyObject* xlist = PyList_New(x.size());
-		PyObject* ylist = PyList_New(y.size());
-		PyObject* pystring = PyString_FromString(s.c_str());
+		PyObject* xarray = get_array(x);
+		PyObject* yarray = get_array(y);
 
-		for(size_t i = 0; i < x.size(); ++i) {
-			PyList_SetItem(xlist, i, PyFloat_FromDouble(x.at(i)));
-			PyList_SetItem(ylist, i, PyFloat_FromDouble(y.at(i)));
-		}
+		PyObject* pystring = PyString_FromString(s.c_str());
 
 		PyObject* plot_args = PyTuple_New(3);
-		PyTuple_SetItem(plot_args, 0, xlist);
-		PyTuple_SetItem(plot_args, 1, ylist);
+		PyTuple_SetItem(plot_args, 0, xarray);
+		PyTuple_SetItem(plot_args, 1, yarray);
 		PyTuple_SetItem(plot_args, 2, pystring);
 
 		PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_plot, plot_args);
@@ -339,26 +357,19 @@ namespace matplotlibcpp {
 	{
 		assert(x.size() == y.size());
 
-		PyObject *kwargs = PyDict_New();
-		PyObject *xlist = PyList_New(x.size());
-		PyObject *ylist = PyList_New(y.size());
-		PyObject *yerrlist = PyList_New(yerr.size());
+		PyObject* xarray = get_array(x);
+		PyObject* yarray = get_array(y);
+		PyObject* yerrarray = get_array(yerr);
 
-		for (size_t i = 0; i < yerr.size(); ++i)
-			PyList_SetItem(yerrlist, i, PyFloat_FromDouble(yerr.at(i)));
+		PyObject *kwargs = PyDict_New();
 
-		PyDict_SetItemString(kwargs, "yerr", yerrlist);
+		PyDict_SetItemString(kwargs, "yerr", yerrarray);
 
 		PyObject *pystring = PyString_FromString(s.c_str());
 
-		for (size_t i = 0; i < x.size(); ++i) {
-			PyList_SetItem(xlist, i, PyFloat_FromDouble(x.at(i)));
-			PyList_SetItem(ylist, i, PyFloat_FromDouble(y.at(i)));
-		}
-
 		PyObject *plot_args = PyTuple_New(2);
-		PyTuple_SetItem(plot_args, 0, xlist);
-		PyTuple_SetItem(plot_args, 1, ylist);
+		PyTuple_SetItem(plot_args, 0, xarray);
+		PyTuple_SetItem(plot_args, 1, yarray);
 
 		PyObject *res = PyObject_Call(detail::_interpreter::get().s_python_function_errorbar, plot_args, kwargs);
 
@@ -379,16 +390,13 @@ namespace matplotlibcpp {
 		PyObject* kwargs = PyDict_New();
 		PyDict_SetItemString(kwargs, "label", PyString_FromString(name.c_str()));
 
-		PyObject* ylist = PyList_New(y.size());
-		PyObject* pystring = PyString_FromString(format.c_str());
+		PyObject* yarray = get_array(y);
 
-		for(size_t i = 0; i < y.size(); ++i) {
-			PyList_SetItem(ylist, i, PyFloat_FromDouble(y.at(i)));
-		}
+		PyObject* pystring = PyString_FromString(format.c_str());
 
 		PyObject* plot_args = PyTuple_New(2);
 
-		PyTuple_SetItem(plot_args, 0, ylist);
+		PyTuple_SetItem(plot_args, 0, yarray);
 		PyTuple_SetItem(plot_args, 1, pystring);
 
 		PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_plot, plot_args, kwargs);
@@ -406,18 +414,14 @@ namespace matplotlibcpp {
 		PyObject* kwargs = PyDict_New();
 		PyDict_SetItemString(kwargs, "label", PyString_FromString(name.c_str()));
 
-		PyObject* xlist = PyList_New(x.size());
-		PyObject* ylist = PyList_New(y.size());
-		PyObject* pystring = PyString_FromString(format.c_str());
+		PyObject* xarray = get_array(x);
+		PyObject* yarray = get_array(y);
 
-		for(size_t i = 0; i < x.size(); ++i) {
-			PyList_SetItem(xlist, i, PyFloat_FromDouble(x.at(i)));
-			PyList_SetItem(ylist, i, PyFloat_FromDouble(y.at(i)));
-		}
+		PyObject* pystring = PyString_FromString(format.c_str());
 
 		PyObject* plot_args = PyTuple_New(3);
-		PyTuple_SetItem(plot_args, 0, xlist);
-		PyTuple_SetItem(plot_args, 1, ylist);
+		PyTuple_SetItem(plot_args, 0, xarray);
+		PyTuple_SetItem(plot_args, 1, yarray);
 		PyTuple_SetItem(plot_args, 2, pystring);
 
 		PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_plot, plot_args, kwargs);