Forráskód Böngészése

Add plot_surface

It probably hard-codes too many defaults, but it works well enough for
my purposes.
Austin Schuh 6 éve
szülő
commit
69e58fe25b
2 módosított fájl, 138 hozzáadás és 2 törlés
  1. 26 0
      README.md
  2. 112 2
      matplotlibcpp.h

+ 26 - 0
README.md

@@ -159,6 +159,32 @@ int main()
 
 ![quiver example](./examples/quiver.png)
 
+When working with 3d functions, you might be interested in 3d plots:
+```cpp
+#include "../matplotlibcpp.h"
+
+namespace plt = matplotlibcpp;
+
+int main()
+{
+    std::vector<std::vector<double>> x, y, z;
+    for (double i = -5; i <= 5;  i += 0.25) {
+        std::vector<double> x_row, y_row, z_row;
+        for (double j = -5; j <= 5; j += 0.25) {
+            x_row.push_back(i);
+            y_row.push_back(j);
+            z_row.push_back(::std::sin(::std::hypot(x, y)));
+        }
+        x.push_back(x_row);
+        y.push_back(y_row);
+        z.push_back(z_row);
+    }
+
+    plt::plot_surface(x, y, z);
+    plt::show();
+}
+```
+
 Installation
 ------------
 

+ 112 - 2
matplotlibcpp.h

@@ -59,6 +59,7 @@ struct _interpreter {
     PyObject *s_python_function_errorbar;
     PyObject *s_python_function_annotate;
     PyObject *s_python_function_tight_layout;
+    PyObject *s_python_colormap;
     PyObject *s_python_empty_tuple;
     PyObject *s_python_function_stem;
     PyObject *s_python_function_xkcd;
@@ -115,9 +116,13 @@ private:
 
         PyObject* matplotlibname = PyString_FromString("matplotlib");
         PyObject* pyplotname = PyString_FromString("matplotlib.pyplot");
+        PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
+        PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
         PyObject* pylabname  = PyString_FromString("pylab");
-        if (!pyplotname || !pylabname || !matplotlibname) {
-            throw std::runtime_error("couldnt create string");
+        PyObject* cmname  = PyString_FromString("matplotlib.cm");
+        if (!pyplotname || !pylabname || !matplotlibname || !mpl_toolkits ||
+            !axis3d || !cmname) {
+          throw std::runtime_error("couldnt create string");
         }
 
         PyObject* matplotlib = PyImport_Import(matplotlibname);
@@ -134,11 +139,22 @@ private:
         Py_DECREF(pyplotname);
         if (!pymod) { throw std::runtime_error("Error loading module matplotlib.pyplot!"); }
 
+        s_python_colormap = PyImport_Import(cmname);
+        Py_DECREF(cmname);
+        if (!s_python_colormap) { throw std::runtime_error("Error loading module matplotlib.cm!"); }
 
         PyObject* pylabmod = PyImport_Import(pylabname);
         Py_DECREF(pylabname);
         if (!pylabmod) { throw std::runtime_error("Error loading module pylab!"); }
 
+        PyObject* mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
+        Py_DECREF(mpl_toolkitsmod);
+        if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
+
+        PyObject* axis3dmod = PyImport_Import(axis3d);
+        Py_DECREF(axis3dmod);
+        if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
+
         s_python_function_show = PyObject_GetAttrString(pymod, "show");
         s_python_function_close = PyObject_GetAttrString(pymod, "close");
         s_python_function_draw = PyObject_GetAttrString(pymod, "draw");
@@ -325,6 +341,30 @@ PyObject* get_array(const std::vector<Numeric>& v)
     return varray;
 }
 
+template<typename Numeric>
+PyObject* get_2darray(const std::vector<::std::vector<Numeric>>& v)
+{
+    detail::_interpreter::get();    //interpreter needs to be initialized for the numpy commands to work
+    if (v.size() < 1) throw std::runtime_error("get_2d_array v too small");
+
+    npy_intp vsize[2] = {static_cast<npy_intp>(v.size()),
+                         static_cast<npy_intp>(v[0].size())};
+
+    PyArrayObject *varray =
+        (PyArrayObject *)PyArray_SimpleNew(2, vsize, NPY_DOUBLE);
+
+    double *vd_begin = static_cast<double *>(PyArray_DATA(varray));
+
+    for (const ::std::vector<Numeric> &v_row : v) {
+      if (v_row.size() != static_cast<size_t>(vsize[1]))
+        throw std::runtime_error("Missmatched array size");
+      std::copy(v_row.begin(), v_row.end(), vd_begin);
+      vd_begin += vsize[1];
+    }
+
+    return reinterpret_cast<PyObject *>(varray);
+}
+
 #else // fallback if we don't have numpy: copy every element of the given vector
 
 template<typename Numeric>
@@ -369,6 +409,76 @@ bool plot(const std::vector<Numeric> &x, const std::vector<Numeric> &y, const st
     return res;
 }
 
+template <typename Numeric>
+void plot_surface(const std::vector<::std::vector<Numeric>> &x,
+                  const std::vector<::std::vector<Numeric>> &y,
+                  const std::vector<::std::vector<Numeric>> &z,
+                  const std::map<std::string, std::string> &keywords =
+                      std::map<std::string, std::string>()) {
+  assert(x.size() == y.size());
+  assert(y.size() == z.size());
+
+  // using numpy arrays
+  PyObject *xarray = get_2darray(x);
+  PyObject *yarray = get_2darray(y);
+  PyObject *zarray = get_2darray(z);
+
+  // construct positional args
+  PyObject *args = PyTuple_New(3);
+  PyTuple_SetItem(args, 0, xarray);
+  PyTuple_SetItem(args, 1, yarray);
+  PyTuple_SetItem(args, 2, zarray);
+
+  // Build up the kw args.
+  PyObject *kwargs = PyDict_New();
+  PyDict_SetItemString(kwargs, "rstride", PyInt_FromLong(1));
+  PyDict_SetItemString(kwargs, "cstride", PyInt_FromLong(1));
+
+  PyObject *python_colormap_coolwarm = PyObject_GetAttrString(
+      detail::_interpreter::get().s_python_colormap, "coolwarm");
+
+  PyDict_SetItemString(kwargs, "cmap", python_colormap_coolwarm);
+
+  for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
+       it != keywords.end(); ++it) {
+    PyDict_SetItemString(kwargs, it->first.c_str(),
+                         PyString_FromString(it->second.c_str()));
+  }
+
+
+  PyObject *fig =
+      PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
+                          detail::_interpreter::get().s_python_empty_tuple);
+  if (!fig) throw std::runtime_error("Call to figure() failed.");
+
+  PyObject *gca_kwargs = PyDict_New();
+  PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d"));
+
+  PyObject *gca = PyObject_GetAttrString(fig, "gca");
+  if (!gca) throw std::runtime_error("No gca");
+  Py_INCREF(gca);
+  PyObject *axis = PyObject_Call(
+      gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs);
+
+  if (!axis) throw std::runtime_error("No axis");
+  Py_INCREF(axis);
+
+  Py_DECREF(gca);
+  Py_DECREF(gca_kwargs);
+
+  PyObject *plot_surface = PyObject_GetAttrString(axis, "plot_surface");
+  if (!plot_surface) throw std::runtime_error("No surface");
+  Py_INCREF(plot_surface);
+  PyObject *res = PyObject_Call(plot_surface, args, kwargs);
+  if (!res) throw std::runtime_error("failed surface");
+  Py_DECREF(plot_surface);
+
+  Py_DECREF(axis);
+  Py_DECREF(args);
+  Py_DECREF(kwargs);
+  if (res) Py_DECREF(res);
+}
+
 template<typename Numeric>
 bool stem(const std::vector<Numeric> &x, const std::vector<Numeric> &y, const std::map<std::string, std::string>& keywords)
 {