depthpro-wrapper / scripts /image_to_pointcloud.py
bdck's picture
Upload scripts/image_to_pointcloud.py
808837f verified
#!/usr/bin/env python3
"""
CLI script: image β†’ metric depth β†’ 3D point cloud using DepthPro.
Usage
-----
python image_to_pointcloud.py photo.jpg output.ply --device cuda:0 --sample-step 2
python image_to_pointcloud.py photo.jpg output.ply --colored --save-depth depth.png
"""
import argparse
import sys
from pathlib import Path
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent))
from depthpro_wrapper import (
DepthProEstimator,
depth_to_point_cloud,
rgbd_to_point_cloud,
normals_from_depth,
load_image,
save_point_cloud,
)
def main() -> None:
parser = argparse.ArgumentParser(
description="Apple DepthPro: image β†’ metric depth β†’ 3D point cloud"
)
parser.add_argument("image", type=Path, help="Input RGB image")
parser.add_argument("output", type=Path, help="Output point cloud (.ply)")
parser.add_argument("--device", default="cuda:0", help="PyTorch device")
parser.add_argument("--dtype", choices=["float16", "float32"], default="float16", help="Inference dtype")
parser.add_argument("--colored", action="store_true", help="Include per-point RGB colours")
parser.add_argument("--normals", action="store_true", help="Include per-point normals")
parser.add_argument("--sample-step", type=int, default=1, help="Spatial sub-sampling (1 = full res, 2 = 1/4 points)")
parser.add_argument("--save-depth", type=Path, default=None, help="Also save depth map as .npy")
parser.add_argument("--save-confidence", type=Path, default=None, help="Also save confidence map as .npy")
args = parser.parse_args()
if not args.image.exists():
parser.error(f"Input image not found: {args.image}")
# ---- depth estimation -----------------------------------------------
print(f"Loading DepthPro on {args.device} (dtype={args.dtype}) ...")
dtype = {"float16": "float16", "float32": "float32"}[args.dtype]
import torch
torch_dtype = torch.float16 if dtype == "float16" else torch.float32
estimator = DepthProEstimator(device=args.device, dtype=torch_dtype)
print(f"Estimating depth for {args.image} ...")
result = estimator.estimate(
args.image,
return_confidence=args.save_confidence is not None,
)
print(f" Image size: {result.width}Γ—{result.height}")
print(f" Estimated focal length: {result.focal_length:.1f} px")
print(f" Estimated FOV: {result.field_of_view:.1f}Β°")
print(f" Depth range: {result.depth.min():.2f} m – {result.depth.max():.2f} m")
# ---- optional saves -------------------------------------------------
if args.save_depth:
np.save(args.save_depth, result.depth)
print(f" Saved depth map β†’ {args.save_depth}")
if args.save_confidence and result.confidence is not None:
np.save(args.save_confidence, result.confidence)
print(f" Saved confidence map β†’ {args.save_confidence}")
# ---- back-projection ------------------------------------------------
print("\nBack-projecting to 3D point cloud ...")
normals = None
if args.normals:
normals = normals_from_depth(result.depth, result.focal_length)
if args.colored:
points, colors = rgbd_to_point_cloud(
result.depth,
result.image,
result.focal_length,
sample_step=args.sample_step,
)
if args.normals:
# Sample normals at same grid
H, W = result.depth.shape
v_idx = np.arange(0, H, args.sample_step)
u_idx = np.arange(0, W, args.sample_step)
valid = result.depth[v_idx[:, None], u_idx[None, :]] > 0
normals = normals[v_idx[:, None], u_idx[None, :]]
normals = normals[valid]
print(f" Colored point cloud: {len(points):,} points")
save_point_cloud(args.output, points, colors=colors, normals=normals)
else:
points = depth_to_point_cloud(
result.depth,
result.focal_length,
sample_step=args.sample_step,
)
if args.normals:
H, W = result.depth.shape
v_idx = np.arange(0, H, args.sample_step)
u_idx = np.arange(0, W, args.sample_step)
valid = result.depth[v_idx[:, None], u_idx[None, :]] > 0
normals = normals[v_idx[:, None], u_idx[None, :]]
normals = normals[valid]
print(f" Point cloud: {len(points):,} points")
save_point_cloud(args.output, points, normals=normals)
print(f"\nDone β€” saved to {args.output}")
if __name__ == "__main__":
main()