Cryoris hace 5 años
padre
commit
9ff7a4b29d
Se han modificado 2 ficheros con 67 adiciones y 7 borrados
  1. 29 0
      examples/spy.cpp
  2. 38 7
      matplotlibcpp.h

+ 29 - 0
examples/spy.cpp

@@ -0,0 +1,29 @@
+#import <iostream>
+#import <vector>
+#import "../matplotlibcpp.h"
+
+namespace plt = matplotlibcpp;
+
+int main()
+{
+    const int n = 20;
+    std::vector<std::vector<double>> matrix;
+
+    for (int i = 0; i < n; ++i) {
+        std::vector<double> row;
+        for (int j = 0; j < n; ++j) {
+            if (i == j)
+                row.push_back(-2);
+            else if (j == i - 1 || j == i + 1)
+                row.push_back(1);
+            else
+                row.push_back(0);
+        }
+        matrix.push_back(row);
+    }
+
+    plt::spy(matrix, 5, {{"marker", "o"}});
+    plt::show();
+
+    return 0;
+}

+ 38 - 7
matplotlibcpp.h

@@ -102,6 +102,7 @@ struct _interpreter {
     PyObject *s_python_function_colorbar;
     PyObject *s_python_function_subplots_adjust;
     PyObject *s_python_function_rcparams;
+    PyObject *s_python_function_spy;
 
     /* For now, _interpreter is implemented as a singleton since its currently not possible to have
        multiple independent embedded python interpreters without patching the python source code
@@ -276,6 +277,7 @@ private:
         s_python_function_colorbar = PyObject_GetAttrString(pymod, "colorbar");
         s_python_function_subplots_adjust = safe_import(pymod,"subplots_adjust");
         s_python_function_rcparams = PyObject_GetAttrString(pymod, "rcParams");
+	s_python_function_spy = PyObject_GetAttrString(pymod, "spy");
 #ifndef WITHOUT_NUMPY
         s_python_function_imshow = safe_import(pymod, "imshow");
 #endif
@@ -348,11 +350,11 @@ template <> struct select_npy_type<uint64_t> { const static NPY_TYPES type = NPY
 
 // Sanity checks; comment them out or change the numpy type below if you're compiling on
 // a platform where they don't apply
-// static_assert(sizeof(long long) == 8);
-// template <> struct select_npy_type<long long> { const static NPY_TYPES type = NPY_INT64; };
-// static_assert(sizeof(unsigned long long) == 8);
-// template <> struct select_npy_type<unsigned long long> { const static NPY_TYPES type = NPY_UINT64; };
-// TODO: add int, long, etc.
+static_assert(sizeof(long long) == 8);
+template <> struct select_npy_type<long long> { const static NPY_TYPES type = NPY_INT64; };
+static_assert(sizeof(unsigned long long) == 8);
+template <> struct select_npy_type<unsigned long long> { const static NPY_TYPES type = NPY_UINT64; };
+TODO: add int, long, etc.
 
 template<typename Numeric>
 PyObject* get_array(const std::vector<Numeric>& v)
@@ -621,8 +623,37 @@ void contour(const std::vector<::std::vector<Numeric>> &x,
 
   Py_DECREF(args);
   Py_DECREF(kwargs);
-  if (res)
-    Py_DECREF(res);
+  if (res) Py_DECREF(res);
+}
+
+template <typename Numeric>
+void spy(const std::vector<::std::vector<Numeric>> &x,
+         const double markersize = -1,  // -1 for default matplotlib size
+         const std::map<std::string, std::string> &keywords = {})
+{
+  detail::_interpreter::get();
+
+  PyObject *xarray = detail::get_2darray(x);
+
+  PyObject *kwargs = PyDict_New();
+  if (markersize != -1) {
+    PyDict_SetItemString(kwargs, "markersize", PyFloat_FromDouble(markersize));
+  }
+  for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
+       it != keywords.end(); ++it) {
+    PyDict_SetItemString(kwargs, it->first.c_str(),
+                         PyString_FromString(it->second.c_str()));
+  }
+
+  PyObject *plot_args = PyTuple_New(1);
+  PyTuple_SetItem(plot_args, 0, xarray);
+
+  PyObject *res = PyObject_Call(
+      detail::_interpreter::get().s_python_function_spy, plot_args, kwargs);
+
+  Py_DECREF(plot_args);
+  Py_DECREF(kwargs);
+  if (res) Py_DECREF(res);
 }
 #endif // WITHOUT_NUMPY