Jelajahi Sumber

Add 3D scatter plots, allow more than one 3d plot on the same figure and make rcparams changeable.

Ruan Luies 4 tahun lalu
induk
melakukan
9d19657a36
1 mengubah file dengan 167 tambahan dan 18 penghapusan
  1. 167 18
      matplotlibcpp.h

+ 167 - 18
matplotlibcpp.h

@@ -99,6 +99,7 @@ struct _interpreter {
     PyObject *s_python_function_barh;
     PyObject *s_python_function_colorbar;
     PyObject *s_python_function_subplots_adjust;
+    PyObject *s_python_function_rcparams;
 
 
     /* For now, _interpreter is implemented as a singleton since its currently not possible to have
@@ -189,6 +190,7 @@ private:
         }
 
         PyObject* matplotlib = PyImport_Import(matplotlibname);
+
         Py_DECREF(matplotlibname);
         if (!matplotlib) {
             PyErr_Print();
@@ -201,6 +203,8 @@ private:
             PyObject_CallMethod(matplotlib, const_cast<char*>("use"), const_cast<char*>("s"), s_backend.c_str());
         }
 
+
+
         PyObject* pymod = PyImport_Import(pyplotname);
         Py_DECREF(pyplotname);
         if (!pymod) { throw std::runtime_error("Error loading module matplotlib.pyplot!"); }
@@ -264,6 +268,7 @@ private:
         s_python_function_barh = safe_import(pymod, "barh");
         s_python_function_colorbar = PyObject_GetAttrString(pymod, "colorbar");
         s_python_function_subplots_adjust = safe_import(pymod,"subplots_adjust");
+        s_python_function_rcparams = PyObject_GetAttrString(pymod, "rcParams");
 #ifndef WITHOUT_NUMPY
         s_python_function_imshow = safe_import(pymod, "imshow");
 #endif
@@ -464,6 +469,7 @@ template <typename Numeric>
 void plot_surface(const std::vector<::std::vector<Numeric>> &x,
                   const std::vector<::std::vector<Numeric>> &y,
                   const std::vector<::std::vector<Numeric>> &z,
+                  const long fig_number=0,
                   const std::map<std::string, std::string> &keywords =
                       std::map<std::string, std::string>())
 {
@@ -516,14 +522,29 @@ void plot_surface(const std::vector<::std::vector<Numeric>> &x,
 
   for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
        it != keywords.end(); ++it) {
-    PyDict_SetItemString(kwargs, it->first.c_str(),
-                         PyString_FromString(it->second.c_str()));
+    if (it->first == "linewidth" || it->first == "alpha") {
+      PyDict_SetItemString(kwargs, it->first.c_str(),
+        PyFloat_FromDouble(std::stod(it->second)));
+    } else {
+      PyDict_SetItemString(kwargs, it->first.c_str(),
+        PyString_FromString(it->second.c_str()));
+    }
   }
 
-
-  PyObject *fig =
-      PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
-                          detail::_interpreter::get().s_python_empty_tuple);
+  PyObject *fig_args = PyTuple_New(1);
+  PyObject* fig = nullptr;
+  PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
+  PyObject *fig_exists =
+    PyObject_CallObject(
+    detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
+  if (!PyObject_IsTrue(fig_exists)) {
+    fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
+      detail::_interpreter::get().s_python_empty_tuple);
+  } else {
+    fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
+      fig_args);
+  }
+  Py_DECREF(fig_exists);
   if (!fig) throw std::runtime_error("Call to figure() failed.");
 
   PyObject *gca_kwargs = PyDict_New();
@@ -559,6 +580,7 @@ template <typename Numeric>
 void plot3(const std::vector<Numeric> &x,
                   const std::vector<Numeric> &y,
                   const std::vector<Numeric> &z,
+                  const long fig_number=0,
                   const std::map<std::string, std::string> &keywords =
                       std::map<std::string, std::string>())
 {
@@ -607,9 +629,18 @@ void plot3(const std::vector<Numeric> &x,
                          PyString_FromString(it->second.c_str()));
   }
 
-  PyObject *fig =
-      PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
-                          detail::_interpreter::get().s_python_empty_tuple);
+  PyObject *fig_args = PyTuple_New(1);
+  PyObject* fig = nullptr;
+  PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
+  PyObject *fig_exists =
+    PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
+  if (!PyObject_IsTrue(fig_exists)) {
+    fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
+      detail::_interpreter::get().s_python_empty_tuple);
+  } else {
+    fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
+      fig_args);
+  }
   if (!fig) throw std::runtime_error("Call to figure() failed.");
 
   PyObject *gca_kwargs = PyDict_New();
@@ -911,6 +942,103 @@ bool scatter(const std::vector<NumericX>& x,
     return res;
 }
 
+template<typename NumericX, typename NumericY, typename NumericZ>
+bool scatter(const std::vector<NumericX>& x,
+             const std::vector<NumericY>& y,
+             const std::vector<NumericZ>& z,
+             const double s=1.0, // The marker size in points**2
+             const long fig_number=0,
+             const std::map<std::string, std::string> & keywords = {}) {
+  detail::_interpreter::get();
+
+  // Same as with plot_surface: We lazily load the modules here the first time 
+  // this function is called because I'm not sure that we can assume "matplotlib 
+  // installed" implies "mpl_toolkits installed" on all platforms, and we don't 
+  // want to require it for people who don't need 3d plots.
+  static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr;
+  if (!mpl_toolkitsmod) {
+    detail::_interpreter::get();
+
+    PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
+    PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
+    if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); }
+
+    mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
+    Py_DECREF(mpl_toolkits);
+    if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
+
+    axis3dmod = PyImport_Import(axis3d);
+    Py_DECREF(axis3d);
+    if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
+  }
+
+  assert(x.size() == y.size());
+  assert(y.size() == z.size());
+
+  PyObject *xarray = detail::get_array(x);
+  PyObject *yarray = detail::get_array(y);
+  PyObject *zarray = detail::get_array(z);
+
+  // construct positional args
+  PyObject *args = PyTuple_New(3);
+  PyTuple_SetItem(args, 0, xarray);
+  PyTuple_SetItem(args, 1, yarray);
+  PyTuple_SetItem(args, 2, zarray);
+
+  // Build up the kw args.
+  PyObject *kwargs = PyDict_New();
+
+  for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
+       it != keywords.end(); ++it) {
+    PyDict_SetItemString(kwargs, it->first.c_str(),
+                         PyString_FromString(it->second.c_str()));
+  }
+  PyObject *fig_args = PyTuple_New(1);
+  PyObject* fig = nullptr;
+  PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
+  PyObject *fig_exists =
+    PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
+  if (!PyObject_IsTrue(fig_exists)) {
+    fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
+      detail::_interpreter::get().s_python_empty_tuple);
+  } else {
+    fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
+      fig_args);
+  }
+  Py_DECREF(fig_exists);
+  if (!fig) throw std::runtime_error("Call to figure() failed.");
+
+  PyObject *gca_kwargs = PyDict_New();
+  PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d"));
+
+  PyObject *gca = PyObject_GetAttrString(fig, "gca");
+  if (!gca) throw std::runtime_error("No gca");
+  Py_INCREF(gca);
+  PyObject *axis = PyObject_Call(
+      gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs);
+
+  if (!axis) throw std::runtime_error("No axis");
+  Py_INCREF(axis);
+
+  Py_DECREF(gca);
+  Py_DECREF(gca_kwargs);
+
+  PyObject *plot3 = PyObject_GetAttrString(axis, "scatter");
+  if (!plot3) throw std::runtime_error("No 3D line plot");
+  Py_INCREF(plot3);
+  PyObject *res = PyObject_Call(plot3, args, kwargs);
+  if (!res) throw std::runtime_error("Failed 3D line plot");
+  Py_DECREF(plot3);
+
+  Py_DECREF(axis);
+  Py_DECREF(args);
+  Py_DECREF(kwargs);
+  Py_DECREF(fig);
+  if (res) Py_DECREF(res);
+  return res;
+
+}
+
 template<typename Numeric>
 bool boxplot(const std::vector<std::vector<Numeric>>& data,
              const std::vector<std::string>& labels = {},
@@ -1139,9 +1267,9 @@ bool contour(const std::vector<NumericX>& x, const std::vector<NumericY>& y,
              const std::map<std::string, std::string>& keywords = {}) {
     assert(x.size() == y.size() && x.size() == z.size());
 
-    PyObject* xarray = get_array(x);
-    PyObject* yarray = get_array(y);
-    PyObject* zarray = get_array(z);
+    PyObject* xarray = detail::get_array(x);
+    PyObject* yarray = detail::get_array(y);
+    PyObject* zarray = detail::get_array(z);
 
     PyObject* plot_args = PyTuple_New(3);
     PyTuple_SetItem(plot_args, 0, xarray);
@@ -2094,12 +2222,14 @@ inline void axvspan(double xmin, double xmax, double ymin = 0., double ymax = 1.
 
     // construct keyword args
     PyObject* kwargs = PyDict_New();
-    for(std::map<std::string, std::string>::const_iterator it = keywords.begin(); it != keywords.end(); ++it)
-    {
-    if (it->first == "linewidth" || it->first == "alpha")
-            PyDict_SetItemString(kwargs, it->first.c_str(), PyFloat_FromDouble(std::stod(it->second)));
-    else
-            PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str()));
+    for (auto it = keywords.begin(); it != keywords.end(); ++it) {
+      if (it->first == "linewidth" || it->first == "alpha") {
+        PyDict_SetItemString(kwargs, it->first.c_str(),
+          PyFloat_FromDouble(std::stod(it->second)));
+      } else {
+        PyDict_SetItemString(kwargs, it->first.c_str(),
+          PyString_FromString(it->second.c_str()));
+      }
     }
 
     PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_axvspan, args, kwargs);
@@ -2319,6 +2449,25 @@ inline void save(const std::string& filename)
     Py_DECREF(res);
 }
 
+inline void rcparams(const std::map<std::string, std::string>& keywords = {}) {
+    detail::_interpreter::get();
+    PyObject* args = PyTuple_New(0);
+    PyObject* kwargs = PyDict_New();
+    for (auto it = keywords.begin(); it != keywords.end(); ++it) {
+        if ("text.usetex" == it->first)
+          PyDict_SetItemString(kwargs, it->first.c_str(), PyLong_FromLong(std::stoi(it->second.c_str())));
+        else PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str()));
+    }
+    
+    PyObject * update = PyObject_GetAttrString(detail::_interpreter::get().s_python_function_rcparams, "update");
+    PyObject * res = PyObject_Call(update, args, kwargs);
+    if(!res) throw std::runtime_error("Call to rcParams.update() failed.");
+    Py_DECREF(args);
+    Py_DECREF(kwargs);
+    Py_DECREF(update);
+    Py_DECREF(res);
+}
+
 inline void clf() {
     detail::_interpreter::get();