Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ def _rasterize_to_pixels(
rolling_shutter_time: Tensor, # [C]
backgrounds: Optional[Tensor] = None, # [C, channels]
batch_per_iter: int = 100,
rolling_shutter_direction: int = 1,
):
"""Pytorch implementation of `gsplat.cuda._wrapper.rasterize_to_pixels()`.

Expand Down
18 changes: 17 additions & 1 deletion gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def rasterize_to_pixels(
isect_offsets: Tensor, # [C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
rolling_shutter_time: Optional[Tensor] = None, # [C]
rolling_shutter_direction: Optional[int] = None, # 1: top2bottom, 2: bottom2top, 3: left2right, 4: right2left, 5: no rolling shutter
backgrounds: Optional[Tensor] = None, # [C, channels]
packed: bool = False,
absgrad: bool = False,
Expand Down Expand Up @@ -869,6 +870,10 @@ def rasterize_to_pixels(
assert rolling_shutter_time.shape == (C,), rolling_shutter_time.shape
else:
rolling_shutter_time = torch.zeros(C, device=device)
if rolling_shutter_direction is not None:
assert rolling_shutter_direction in (1,2,3,4,5), f"rolling_shutter_direction must be one of (1, 2, 3, 4, 5), but got {rolling_shutter_direction}"
else:
rolling_shutter_direction = 1
if backgrounds is not None:
assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape
backgrounds = backgrounds.contiguous()
Expand Down Expand Up @@ -935,6 +940,7 @@ def rasterize_to_pixels(
opacities.contiguous(),
pix_vels.contiguous(),
rolling_shutter_time.contiguous(),
rolling_shutter_direction,
backgrounds,
image_width,
image_height,
Expand Down Expand Up @@ -1126,6 +1132,7 @@ def rasterize_to_indices_in_range(
isect_offsets: Tensor, # [C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
rolling_shutter_time: Optional[Tensor] = None, # [C]
rolling_shutter_direction: int = 1, # 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global
) -> Tuple[Tensor, Tensor, Tensor]:
"""Rasterizes a batch of Gaussians to images but only returns the indices.

Expand Down Expand Up @@ -1188,6 +1195,7 @@ def rasterize_to_indices_in_range(
image_width,
image_height,
tile_size,
rolling_shutter_direction,
isect_offsets.contiguous(),
flatten_ids.contiguous(),
)
Expand Down Expand Up @@ -1805,6 +1813,7 @@ def forward(
opacities: Tensor, # [C, N]
pix_vels: Tensor, # [C, N, 2]
rolling_shutter_time: Tensor, # [C]
rolling_shutter_direction: int, # <- should probably make this per cam
backgrounds: Tensor, # [C, D], Optional
width: int,
height: int,
Expand All @@ -1826,6 +1835,7 @@ def forward(
width,
height,
tile_size,
rolling_shutter_direction,
isect_offsets,
flatten_ids,
)
Expand All @@ -1847,6 +1857,7 @@ def forward(
ctx.height = height
ctx.tile_size = tile_size
ctx.absgrad = absgrad
ctx.rolling_shutter_direction = rolling_shutter_direction

# double to float
render_alphas = render_alphas.float()
Expand Down Expand Up @@ -1875,6 +1886,7 @@ def backward(
height = ctx.height
tile_size = ctx.tile_size
absgrad = ctx.absgrad
rolling_shutter_direction = ctx.rolling_shutter_direction

(
v_means2d_abs,
Expand All @@ -1894,6 +1906,7 @@ def backward(
width,
height,
tile_size,
rolling_shutter_direction,
isect_offsets,
flatten_ids,
render_alphas,
Expand All @@ -1920,6 +1933,7 @@ def backward(
v_opacities,
v_pix_vels,
None,
None,
v_backgrounds,
None,
None,
Expand Down Expand Up @@ -1996,7 +2010,9 @@ def forward(
ctx.tile_height = tile_height
ctx.absgrad = absgrad
ctx.compute_alpha_sum_until_points = compute_alpha_sum_until_points
ctx.compute_alpha_sum_until_points_threshold = compute_alpha_sum_until_points_threshold
ctx.compute_alpha_sum_until_points_threshold = (
compute_alpha_sum_until_points_threshold
)
ctx.depth_channel_idx = depth_channel_idx

# double to float
Expand Down
6 changes: 6 additions & 0 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_
const at::optional<torch::Tensor> &backgrounds, // [C, D]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// rolling shutter direction
const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
Expand Down Expand Up @@ -323,6 +325,8 @@ rasterize_to_pixels_bwd_tensor(
const at::optional<torch::Tensor> &backgrounds, // [C, 3]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// shutter direction
const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids, // [n_isects]
Expand Down Expand Up @@ -380,6 +384,8 @@ std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
const uint32_t image_width,
const uint32_t image_height,
const uint32_t tile_size,
// shutter direction
const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
Expand Down
Loading
Loading