Sfoglia il codice sorgente

Load mplot3d module lazily.

Load the mplot3d module lazily when `plot_surface()` is called,
to avoid introducing accidental dependencies.
Benno Evers 6 anni fa
parent
commit
0e39f8f56b
1 ha cambiato i file con 25 aggiunte e 14 eliminazioni
  1. 25 14
      matplotlibcpp.h

+ 25 - 14
matplotlibcpp.h

@@ -116,12 +116,9 @@ private:
 
         PyObject* matplotlibname = PyString_FromString("matplotlib");
         PyObject* pyplotname = PyString_FromString("matplotlib.pyplot");
-        PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
-        PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
-        PyObject* pylabname  = PyString_FromString("pylab");
         PyObject* cmname  = PyString_FromString("matplotlib.cm");
-        if (!pyplotname || !pylabname || !matplotlibname || !mpl_toolkits ||
-            !axis3d || !cmname) {
+        PyObject* pylabname  = PyString_FromString("pylab");
+        if (!pyplotname || !pylabname || !matplotlibname || !cmname) {
           throw std::runtime_error("couldnt create string");
         }
 
@@ -147,14 +144,6 @@ private:
         Py_DECREF(pylabname);
         if (!pylabmod) { throw std::runtime_error("Error loading module pylab!"); }
 
-        PyObject* mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
-        Py_DECREF(mpl_toolkitsmod);
-        if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
-
-        PyObject* axis3dmod = PyImport_Import(axis3d);
-        Py_DECREF(axis3dmod);
-        if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
-
         s_python_function_show = PyObject_GetAttrString(pymod, "show");
         s_python_function_close = PyObject_GetAttrString(pymod, "close");
         s_python_function_draw = PyObject_GetAttrString(pymod, "draw");
@@ -414,7 +403,29 @@ void plot_surface(const std::vector<::std::vector<Numeric>> &x,
                   const std::vector<::std::vector<Numeric>> &y,
                   const std::vector<::std::vector<Numeric>> &z,
                   const std::map<std::string, std::string> &keywords =
-                      std::map<std::string, std::string>()) {
+                      std::map<std::string, std::string>())
+{
+  // We lazily load the modules here the first time this function is called
+  // because I'm not sure that we can assume "matplotlib installed" implies
+  // "mpl_toolkits installed" on all platforms, and we don't want to require
+  // it for people who don't need 3d plots.
+  static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr;
+  if (!mpl_toolkitsmod) {
+    detail::_interpreter::get();
+
+    PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
+    PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
+    if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); }
+
+    mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
+    Py_DECREF(mpl_toolkits);
+    if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
+
+    axis3dmod = PyImport_Import(axis3d);
+    Py_DECREF(axis3d);
+    if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
+  }
+
   assert(x.size() == y.size());
   assert(y.size() == z.size());