Преглед на файлове

Add 3D line plot and zlabel function.

Brian Phung преди 5 години
родител
ревизия
811ebfb2c9
променени са 4 файла, в които са добавени 166 реда и са изтрити 1 реда
  1. 1 1
      Makefile
  2. 30 0
      examples/lines3d.cpp
  3. BIN
      examples/lines3d.png
  4. 135 0
      matplotlibcpp.h

+ 1 - 1
Makefile

@@ -15,7 +15,7 @@ WITHOUT_NUMPY   := $(findstring $(CXXFLAGS), WITHOUT_NUMPY)
 # Examples requiring numpy support to compile
 EXAMPLES_NUMPY  := surface
 EXAMPLES        := minimal basic modern animation nonblock xkcd quiver bar \
-	           fill_inbetween fill update subplot2grid colorbar \
+	           fill_inbetween fill update subplot2grid colorbar lines3d \
                    $(if WITHOUT_NUMPY,,$(EXAMPLES_NUMPY))
 
 # Prefix every example with 'examples/build/'

+ 30 - 0
examples/lines3d.cpp

@@ -0,0 +1,30 @@
+#include "../matplotlibcpp.h"
+
+#include <cmath>
+
+namespace plt = matplotlibcpp;
+
+int main()
+{
+    std::vector<double> x, y, z;
+    double theta, r;
+    double z_inc = 4.0/99.0; double theta_inc = (8.0 * M_PI)/99.0;
+    
+    for (double i = 0; i < 100; i += 1) {
+        theta = -4.0 * M_PI + theta_inc*i;
+        z.push_back(-2.0 + z_inc*i);
+        r = z[i]*z[i] + 1;
+        x.push_back(r * sin(theta));
+        y.push_back(r * cos(theta));
+    }
+
+    std::map<std::string, std::string> keywords;
+    keywords.insert(std::pair<std::string, std::string>("label", "parametric curve") );
+
+    plt::plot3(x, y, z, keywords);
+    plt::xlabel("x label");
+    plt::ylabel("y label");
+    plt::set_zlabel("z label"); // set_zlabel rather than just zlabel, in accordance with the Axes3D method
+    plt::legend();
+    plt::show();
+}

BIN
examples/lines3d.png


+ 135 - 0
matplotlibcpp.h

@@ -74,6 +74,7 @@ struct _interpreter {
     PyObject *s_python_function_axvline;
     PyObject *s_python_function_xlabel;
     PyObject *s_python_function_ylabel;
+    PyObject *s_python_function_gca;
     PyObject *s_python_function_xticks;
     PyObject *s_python_function_yticks;
     PyObject *s_python_function_tick_params;
@@ -208,6 +209,7 @@ private:
         s_python_function_axvline = safe_import(pymod, "axvline");
         s_python_function_xlabel = safe_import(pymod, "xlabel");
         s_python_function_ylabel = safe_import(pymod, "ylabel");
+        s_python_function_gca = safe_import(pymod, "gca");
         s_python_function_xticks = safe_import(pymod, "xticks");
         s_python_function_yticks = safe_import(pymod, "yticks");
         s_python_function_tick_params = safe_import(pymod, "tick_params");
@@ -489,6 +491,88 @@ void plot_surface(const std::vector<::std::vector<Numeric>> &x,
 }
 #endif // WITHOUT_NUMPY
 
+template <typename Numeric>
+void plot3(const std::vector<Numeric> &x,
+                  const std::vector<Numeric> &y,
+                  const std::vector<Numeric> &z,
+                  const std::map<std::string, std::string> &keywords =
+                      std::map<std::string, std::string>())
+{
+  // Same as with plot_surface: We lazily load the modules here the first time 
+  // this function is called because I'm not sure that we can assume "matplotlib 
+  // installed" implies "mpl_toolkits installed" on all platforms, and we don't 
+  // want to require it for people who don't need 3d plots.
+  static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr;
+  if (!mpl_toolkitsmod) {
+    detail::_interpreter::get();
+
+    PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
+    PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
+    if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); }
+
+    mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
+    Py_DECREF(mpl_toolkits);
+    if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
+
+    axis3dmod = PyImport_Import(axis3d);
+    Py_DECREF(axis3d);
+    if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
+  }
+
+  assert(x.size() == y.size());
+  assert(y.size() == z.size());
+
+  PyObject *xarray = get_array(x);
+  PyObject *yarray = get_array(y);
+  PyObject *zarray = get_array(z);
+
+  // construct positional args
+  PyObject *args = PyTuple_New(3);
+  PyTuple_SetItem(args, 0, xarray);
+  PyTuple_SetItem(args, 1, yarray);
+  PyTuple_SetItem(args, 2, zarray);
+
+  // Build up the kw args.
+  PyObject *kwargs = PyDict_New();
+
+  for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
+       it != keywords.end(); ++it) {
+    PyDict_SetItemString(kwargs, it->first.c_str(),
+                         PyString_FromString(it->second.c_str()));
+  }
+
+  PyObject *fig =
+      PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
+                          detail::_interpreter::get().s_python_empty_tuple);
+  if (!fig) throw std::runtime_error("Call to figure() failed.");
+
+  PyObject *gca_kwargs = PyDict_New();
+  PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d"));
+
+  PyObject *gca = PyObject_GetAttrString(fig, "gca");
+  if (!gca) throw std::runtime_error("No gca");
+  Py_INCREF(gca);
+  PyObject *axis = PyObject_Call(
+      gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs);
+
+  if (!axis) throw std::runtime_error("No axis");
+  Py_INCREF(axis);
+
+  Py_DECREF(gca);
+  Py_DECREF(gca_kwargs);
+
+  PyObject *plot3 = PyObject_GetAttrString(axis, "plot");
+  if (!plot3) throw std::runtime_error("No 3D line plot");
+  Py_INCREF(plot3);
+  PyObject *res = PyObject_Call(plot3, args, kwargs);
+  if (!res) throw std::runtime_error("Failed 3D line plot");
+  Py_DECREF(plot3);
+
+  Py_DECREF(axis);
+  Py_DECREF(args);
+  Py_DECREF(kwargs);
+  if (res) Py_DECREF(res);
+}
 
 template<typename Numeric>
 bool stem(const std::vector<Numeric> &x, const std::vector<Numeric> &y, const std::map<std::string, std::string>& keywords)
@@ -1662,6 +1746,57 @@ inline void ylabel(const std::string &str, const std::map<std::string, std::stri
     Py_DECREF(res);
 }
 
+inline void set_zlabel(const std::string &str, const std::map<std::string, std::string>& keywords = {}) {
+    // Same as with plot_surface: We lazily load the modules here the first time 
+    // this function is called because I'm not sure that we can assume "matplotlib 
+    // installed" implies "mpl_toolkits installed" on all platforms, and we don't 
+    // want to require it for people who don't need 3d plots.
+    static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr;
+    if (!mpl_toolkitsmod) {
+        detail::_interpreter::get();
+
+        PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
+        PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
+        if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); }
+
+        mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
+        Py_DECREF(mpl_toolkits);
+        if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
+
+        axis3dmod = PyImport_Import(axis3d);
+        Py_DECREF(axis3d);
+        if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
+    }
+
+    PyObject* pystr = PyString_FromString(str.c_str());
+    PyObject* args = PyTuple_New(1);
+    PyTuple_SetItem(args, 0, pystr);
+
+    PyObject* kwargs = PyDict_New();
+    for (auto it = keywords.begin(); it != keywords.end(); ++it) {
+        PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str()));
+    }
+
+    PyObject *ax =
+    PyObject_CallObject(detail::_interpreter::get().s_python_function_gca,
+      detail::_interpreter::get().s_python_empty_tuple);
+    if (!ax) throw std::runtime_error("Call to gca() failed.");
+    Py_INCREF(ax);
+
+    PyObject *zlabel = PyObject_GetAttrString(ax, "set_zlabel");
+    if (!zlabel) throw std::runtime_error("Attribute set_zlabel not found.");
+    Py_INCREF(zlabel);
+
+    PyObject *res = PyObject_Call(zlabel, args, kwargs);
+    if (!res) throw std::runtime_error("Call to set_zlabel() failed.");
+    Py_DECREF(zlabel);
+
+    Py_DECREF(ax);
+    Py_DECREF(args);
+    Py_DECREF(kwargs);
+    if (res) Py_DECREF(res);
+}
+
 inline void grid(bool flag)
 {
     PyObject* pyflag = flag ? Py_True : Py_False;