From 28ac02fd58ebefd9947f987cb2567c005386d7c2 Mon Sep 17 00:00:00 2001
From: HenrikAsmuth <henrik.asmuth@geo.uu.se>
Date: Mon, 14 Nov 2022 16:53:31 +0100
Subject: [PATCH] add bindings for multi-GPU setup

---
 .../src/gpu/submodules/grid_generator.cpp       | 17 +++++++++++++++--
 1 file changed, 15 insertions(+), 2 deletions(-)

diff --git a/pythonbindings/src/gpu/submodules/grid_generator.cpp b/pythonbindings/src/gpu/submodules/grid_generator.cpp
index a62247aa9..a16b84935 100644
--- a/pythonbindings/src/gpu/submodules/grid_generator.cpp
+++ b/pythonbindings/src/gpu/submodules/grid_generator.cpp
@@ -1,4 +1,5 @@
 #include <pybind11/pybind11.h>
+#include "gpu/GridGenerator/utilities/communication.h"
 #include "gpu/GridGenerator/geometries/Object.h"
 #include "gpu/GridGenerator/geometries/BoundingBox/BoundingBox.h"
 #include "gpu/GridGenerator/geometries/Conglomerate/Conglomerate.h"
@@ -17,10 +18,19 @@ namespace grid_generator
     {  
         py::module gridGeneratorModule = parentModule.def_submodule("grid_generator");
 
+        //TODO:
+        // py::enum_<CommunicationDirections>(gridGeneratorModule, "CommunicationDirections")
+        // .value("MX", CommunicationDirections::MX)
+        // .value("PX", CommunicationDirections::PX)
+        // .value("MY", CommunicationDirections::MY)
+        // .value("PY", CommunicationDirections::PY)
+        // .value("MZ", CommunicationDirections::MZ)
+        // .value("PZ", CommunicationDirections::PZ);
+
         py::class_<GridFactory, std::shared_ptr<GridFactory>>(gridGeneratorModule, "GridFactory")
         .def("make", &GridFactory::make, py::return_value_policy::reference);
 
-        py::class_<BoundingBox>(gridGeneratorModule, "BoundingBox")
+        py::class_<BoundingBox, std::shared_ptr<BoundingBox>>(gridGeneratorModule, "BoundingBox")
         .def(py::init<real, real, real, real, real, real>(),"min_x","max_x","min_y","max_y","min_z","max_z");
 
         py::class_<Object, std::shared_ptr<Object>>(gridGeneratorModule, "Object");
@@ -62,7 +72,10 @@ namespace grid_generator
         .def("add_geometry", py::overload_cast<Object*>(&MultipleGridBuilder::addGeometry))
         .def("add_geometry", py::overload_cast<Object*, uint>(&MultipleGridBuilder::addGeometry))
         .def("get_number_of_levels", &MultipleGridBuilder::getNumberOfLevels)
-        .def("build_grids", &MultipleGridBuilder::buildGrids);
+        .def("build_grids", &MultipleGridBuilder::buildGrids)
+        .def("set_subdomain_box", &MultipleGridBuilder::setSubDomainBox)
+        .def("find_communication_indices", &MultipleGridBuilder::findCommunicationIndices)
+        .def("set_communication_process", &MultipleGridBuilder::setCommunicationProcess);
 
         return gridGeneratorModule;
     }
-- 
GitLab