ed_sensor_integration
sam_seg_module.cpp
Go to the documentation of this file.
2 
3 #include <cv_bridge/cv_bridge.h>
4 #include <filesystem>
5 #include <pcl/point_cloud.h>
6 #include <pcl/point_types.h>
7 #include <pcl_conversions/pcl_conversions.h>
8 #include <sam_onnx_ros/segmentation.hpp>
9 #include <sensor_msgs/Image.h>
10 #include <sensor_msgs/PointCloud2.h>
11 #include <yolo_onnx_ros/detection.hpp>
12 
13 
15 {
17  std::unique_ptr<YOLO_V8> yoloDetector;
18  DL_INIT_PARAM params;
19  std::string yolo_model, sam_encoder, sam_decoder;
20  config.value("yolo_model", yolo_model);
21  std::tie(yoloDetector, params) = Initialize(yolo_model);
22 
24  std::vector<std::unique_ptr<SAM>> samSegmentors;
25  SEG::DL_INIT_PARAM params_encoder;
26  SEG::DL_INIT_PARAM params_decoder;
28  SEG::DL_RESULT res;
29  config.value("sam_encoder", sam_encoder);
30  config.value("sam_decoder", sam_decoder);
31  std::tie(samSegmentors, params_encoder, params_decoder, res, resSam) = Initialize(sam_encoder, sam_decoder);
32 
34  std::vector<DL_RESULT> resYolo;
35  resYolo = Detector(yoloDetector, img);
36 
38  for (const auto& result : resYolo)
39  {
40  // (here we are skipping the table object but it should happen only on the rosservice scenario: on_top_of dinner_table )
41  // int table_classification = 60;
42  // if (result.classId == table_classification)
43  // {
44  // ROS_DEBUG_STREAM("Class ID is: " << yoloDetector->classes[result.classId] << " So we don't append");
45  // continue;
46  // }
47  res.boxes.push_back(result.box);
48  ROS_DEBUG_STREAM("Confidence is OKOK: " << result.confidence);
49  ROS_DEBUG_STREAM("Class is: " << yoloDetector->classes[result.classId]);
50  ROS_DEBUG_STREAM("Class ID is: " << result.classId);
51  }
52 
53  SegmentAnything(samSegmentors, params_encoder, params_decoder, img, resSam, res);
54 
55  return std::move(res.masks);
56 }
57 
58 
59 void overlayMasksOnImage_(cv::Mat& rgb, const std::vector<cv::Mat>& masks)
60 {
61  // Define colors in BGR format for OpenCV (high contrast)
62  std::vector<cv::Scalar> colors = {
63  cv::Scalar(0, 0, 255), // Red
64  cv::Scalar(0, 255, 0), // Green
65  cv::Scalar(255, 0, 0), // Blue
66  cv::Scalar(0, 255, 255), // Yellow
67  cv::Scalar(255, 0, 255), // Magenta
68  cv::Scalar(255, 255, 0), // Cyan
69  cv::Scalar(128, 0, 128), // Purple
70  cv::Scalar(0, 128, 128) // Brown
71  };
72 
73  // Create a copy for the overlay (preserves original for contours)
74  cv::Mat overlay = rgb.clone();
75 
76  for (size_t i = 0; i < masks.size(); i++)
77  {
78  // Get a working copy of the mask
79  cv::Mat working_mask = masks[i].clone();
80 
81  // Check if mask needs resizing
82  if (working_mask.rows != rgb.rows || working_mask.cols != rgb.cols)
83  cv::resize(working_mask, working_mask, rgb.size(), 0, 0, cv::INTER_NEAREST);
84 
85  // Ensure the mask is binary (values 0 or 255)
86  if (cv::countNonZero((working_mask > 0) & (working_mask < 255)) > 0)
87  cv::threshold(working_mask, working_mask, 127, 255, cv::THRESH_BINARY);
88 
89  // Use a different color for each mask
90  cv::Scalar color = colors[i % colors.size()];
91 
92  // Create the colored overlay with this mask's specific color
93  cv::Mat colorMask = cv::Mat::zeros(rgb.size(), CV_8UC3);
94  colorMask.setTo(color, working_mask);
95 
96  // Add this mask's overlay to the combined overlay
97  cv::addWeighted(overlay, 1.0, colorMask, 0.2, 0, overlay);
98 
99  // Find contours of the mask - use CHAIN_APPROX_NONE for most accurate contours
101  cv::findContours(working_mask, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_NONE);
102 
103  // Draw double contours for better visibility (outer black, inner colored)
104  cv::drawContours(rgb, contours, -1, cv::Scalar(0, 0, 0), 2); // Outer black border
105  cv::drawContours(rgb, contours, -1, color, 1); // Inner colored line
106  }
107 
108  // Apply the semi-transparent overlay with all masks
109  cv::addWeighted(rgb, 0.7, overlay, 0.3, 0, rgb);
110 }
111 
112 void publishSegmentationResults(const cv::Mat& filtered_depth_image, const cv::Mat& rgb,
113  const geo::Pose3D& sensor_pose, std::vector<cv::Mat>& clustered_images,
114  ros::Publisher& mask_pub_, ros::Publisher& cloud_pub_, std::vector<EntityUpdate>& res_updates)
115 {
116  // Overlay masks on the RGB image
117  cv::Mat visualization = rgb.clone();
118 
119  // Create a path to save the image using platform-independent temp directory
121  cv::imwrite((temp_dir / "visualization.png").string(), visualization);
122 
123  // Create a properly normalized depth visualization
124  cv::Mat depth_vis;
125  double min_val, max_val;
126  cv::minMaxLoc(filtered_depth_image, &min_val, &max_val);
127 
128  // Handle empty depth image case
129  if (max_val == 0)
130  {
131  depth_vis = cv::Mat::zeros(filtered_depth_image.size(), CV_8UC1);
132  }
133  else
134  {
135  // Scale to full 8-bit range and convert to 8-bit
136  filtered_depth_image.convertTo(depth_vis, CV_8UC1, 255.0 / max_val);
137 
138  // Apply a colormap for better visibility
139  cv::Mat depth_color;
140  cv::applyColorMap(depth_vis, depth_color, cv::COLORMAP_JET);
141  cv::imwrite((temp_dir / "visualization_depth_color.png").string(), depth_color);
142  }
143 
144  // Save both grayscale and color versions
145  cv::imwrite((temp_dir / "visualization_depth.png").string(), depth_vis);
146  overlayMasksOnImage_(visualization, clustered_images);
147  // save after overlaying masks
148  cv::imwrite((temp_dir / "visualization_with_masks.png").string(), visualization);
149 
150  // Convert to ROS message
151  sensor_msgs::ImagePtr msg = cv_bridge::CvImage(std_msgs::Header(), "bgr8", visualization).toImageMsg();
152  msg->header.stamp = ros::Time::now();
153 
154  typedef pcl::PointCloud<pcl::PointXYZRGB> PointCloud;
155  PointCloud::Ptr combined_cloud (new PointCloud);
156 
157  combined_cloud->header.frame_id = "map";
158 
159  // Add points from all entity updates
160  for (const EntityUpdate& update : res_updates)
161  {
162  for (const geo::Vec3& point : update.points)
163  {
164  // Transform from camera to map frame
165  geo::Vec3 p_map = sensor_pose * point;
166  pcl::PointXYZRGB pcl_point;
167  pcl_point.x = p_map.x;
168  pcl_point.y = p_map.y;
169  pcl_point.z = p_map.z;
170  pcl_point.r = 255; // White
171  pcl_point.g = 255;
172  pcl_point.b = 255;
173  combined_cloud->push_back(pcl_point);
174  }
175 
176  // Add outlier points (red)
177  for (const geo::Vec3& point : update.outlier_points) {
178  geo::Vec3 p_map = sensor_pose * point;
179  pcl::PointXYZRGB pcl_point;
180  pcl_point.x = p_map.x;
181  pcl_point.y = p_map.y;
182  pcl_point.z = p_map.z;
183  pcl_point.r = 255; // Red
184  pcl_point.g = 0;
185  pcl_point.b = 0;
186  combined_cloud->push_back(pcl_point);
187  }
188  }
189 
190  sensor_msgs::PointCloud2 cloud_msg;
191  pcl::toROSMsg(*combined_cloud, cloud_msg);
192  cloud_msg.header.stamp = ros::Time::now();
193  cloud_msg.header.frame_id = "map"; // Use appropriate frame ID
194 
195  // Publish
196  mask_pub_.publish(msg);
197  cloud_pub_.publish(cloud_msg);
198 }
std::string
EntityUpdate
collection structure for laser entities
Definition: kinect/entity_update.h:9
sam_seg_module.h
std::filesystem::temp_directory_path
T temp_directory_path(T... args)
std::vector
std::vector::size
T size(T... args)
geo::Vec3T
update
update
geo::Transform3T
filesystem
std::filesystem::path
publishSegmentationResults
void publishSegmentationResults(const cv::Mat &filtered_depth_image, const cv::Mat &rgb, const geo::Pose3D &sensor_pose, std::vector< cv::Mat > &clustered_images, ros::Publisher &mask_pub_, ros::Publisher &cloud_pub_, std::vector< EntityUpdate > &res_updates)
Publish segmentation results and pointcloud estimation as ROS messages.
Definition: sam_seg_module.cpp:112
std::tie
T tie(T... args)
geo::Vec3T::y
T y
tue::config::ReaderWriter
overlayMasksOnImage_
void overlayMasksOnImage_(cv::Mat &rgb, const std::vector< cv::Mat > &masks)
Overlay segmentation masks on the RGB image for visualization purposes.
Definition: sam_seg_module.cpp:59
SegmentationPipeline
std::vector< cv::Mat > SegmentationPipeline(const cv::Mat &img, tue::Configuration &config)
Segmentation pipeline that processes the input image and generates segmentation masks.
Definition: sam_seg_module.cpp:14
tue::config::ReaderWriter::value
bool value(const std::string &name, T &value, RequiredOrOptional opt=REQUIRED)
geo::Vec3T::z
T z
std::unique_ptr
geo::Vec3T::x
T x
config
tue::config::ReaderWriter config