diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index f11276a..9695dd8 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -834,6 +834,7 @@ def _rasterize_to_pixels( rolling_shutter_time: Tensor, # [C] backgrounds: Optional[Tensor] = None, # [C, channels] batch_per_iter: int = 100, + rolling_shutter_direction: int = 1, ): """Pytorch implementation of `gsplat.cuda._wrapper.rasterize_to_pixels()`. diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 1f375d0..a37641d 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -820,6 +820,7 @@ def rasterize_to_pixels( isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] rolling_shutter_time: Optional[Tensor] = None, # [C] + rolling_shutter_direction: Optional[int] = None, # 1: top2bottom, 2: bottom2top, 3: left2right, 4: right2left, 5: no rolling shutter backgrounds: Optional[Tensor] = None, # [C, channels] packed: bool = False, absgrad: bool = False, @@ -869,6 +870,10 @@ def rasterize_to_pixels( assert rolling_shutter_time.shape == (C,), rolling_shutter_time.shape else: rolling_shutter_time = torch.zeros(C, device=device) + if rolling_shutter_direction is not None: + assert rolling_shutter_direction in (1,2,3,4,5), f"rolling_shutter_direction must be one of (1, 2, 3, 4, 5), but got {rolling_shutter_direction}" + else: + rolling_shutter_direction = 1 if backgrounds is not None: assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape backgrounds = backgrounds.contiguous() @@ -935,6 +940,7 @@ def rasterize_to_pixels( opacities.contiguous(), pix_vels.contiguous(), rolling_shutter_time.contiguous(), + rolling_shutter_direction, backgrounds, image_width, image_height, @@ -1126,6 +1132,7 @@ def rasterize_to_indices_in_range( isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] rolling_shutter_time: Optional[Tensor] = None, # [C] + rolling_shutter_direction: int = 1, # 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global ) -> Tuple[Tensor, Tensor, Tensor]: """Rasterizes a batch of Gaussians to images but only returns the indices. @@ -1188,6 +1195,7 @@ def rasterize_to_indices_in_range( image_width, image_height, tile_size, + rolling_shutter_direction, isect_offsets.contiguous(), flatten_ids.contiguous(), ) @@ -1805,6 +1813,7 @@ def forward( opacities: Tensor, # [C, N] pix_vels: Tensor, # [C, N, 2] rolling_shutter_time: Tensor, # [C] + rolling_shutter_direction: int, # <- should probably make this per cam backgrounds: Tensor, # [C, D], Optional width: int, height: int, @@ -1826,6 +1835,7 @@ def forward( width, height, tile_size, + rolling_shutter_direction, isect_offsets, flatten_ids, ) @@ -1847,6 +1857,7 @@ def forward( ctx.height = height ctx.tile_size = tile_size ctx.absgrad = absgrad + ctx.rolling_shutter_direction = rolling_shutter_direction # double to float render_alphas = render_alphas.float() @@ -1875,6 +1886,7 @@ def backward( height = ctx.height tile_size = ctx.tile_size absgrad = ctx.absgrad + rolling_shutter_direction = ctx.rolling_shutter_direction ( v_means2d_abs, @@ -1894,6 +1906,7 @@ def backward( width, height, tile_size, + rolling_shutter_direction, isect_offsets, flatten_ids, render_alphas, @@ -1920,6 +1933,7 @@ def backward( v_opacities, v_pix_vels, None, + None, v_backgrounds, None, None, @@ -1996,7 +2010,9 @@ def forward( ctx.tile_height = tile_height ctx.absgrad = absgrad ctx.compute_alpha_sum_until_points = compute_alpha_sum_until_points - ctx.compute_alpha_sum_until_points_threshold = compute_alpha_sum_until_points_threshold + ctx.compute_alpha_sum_until_points_threshold = ( + compute_alpha_sum_until_points_threshold + ) ctx.depth_channel_idx = depth_channel_idx # double to float diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 59058bc..0649d97 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -271,6 +271,8 @@ std::tuple rasterize_to_pixels_fwd_ const at::optional &backgrounds, // [C, D] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // rolling shutter direction + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids // [n_isects] @@ -323,6 +325,8 @@ rasterize_to_pixels_bwd_tensor( const at::optional &backgrounds, // [C, 3] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // shutter direction + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids, // [n_isects] @@ -380,6 +384,8 @@ std::tuple rasterize_to_indices_in_range_tensor( const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // shutter direction + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids // [n_isects] diff --git a/gsplat/cuda/csrc/rasterization.cu b/gsplat/cuda/csrc/rasterization.cu index bff205d..fec681c 100644 --- a/gsplat/cuda/csrc/rasterization.cu +++ b/gsplat/cuda/csrc/rasterization.cu @@ -849,6 +849,7 @@ __global__ void rasterize_to_indices_in_range_kernel( const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] const int32_t *__restrict__ flatten_ids, // [n_isects] const float *__restrict__ transmittances, // [C, image_height, image_width] @@ -873,7 +874,28 @@ __global__ void rasterize_to_indices_in_range_kernel( float px = (float)j + 0.5f; float py = (float)i + 0.5f; int32_t pix_id = i * image_width + j; - float roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + // Calculate rolling shutter time based on shutter direction + float roll_time = 0.0f; + switch (rolling_shutter_direction) { + case 1: // top2bot + roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + break; + case 2: // left2right + roll_time = rolling_shutter_time[camera_id] * ((px - 0.5f) / (image_width - 1) - 0.5f); + break; + case 3: // bot2top + roll_time = rolling_shutter_time[camera_id] * ((-py + image_height - 0.5f) / (image_height - 1) - 0.5f); + break; + case 4: // right2left + roll_time = rolling_shutter_time[camera_id] * ((-px + image_width - 0.5f) / (image_width - 1) - 0.5f); + break; + case 5: // global + roll_time = 0.0f; // No rolling shutter effect + break; + default: + // Default to top2bot behavior if an invalid value is provided + roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + } // return if out of bounds // keep not rasterizing threads around for reading data @@ -1013,6 +1035,8 @@ std::tuple rasterize_to_indices_in_range_tensor( const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // shutter direction + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids // [n_isects] @@ -1058,7 +1082,7 @@ std::tuple rasterize_to_indices_in_range_tensor( (float3 *)conics.data_ptr(), opacities.data_ptr(), (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), transmittances.data_ptr(), nullptr, chunk_cnts.data_ptr(), @@ -1082,7 +1106,7 @@ std::tuple rasterize_to_indices_in_range_tensor( (float3 *)conics.data_ptr(), opacities.data_ptr(), (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), transmittances.data_ptr(), chunk_starts.data_ptr(), nullptr, @@ -1372,6 +1396,7 @@ __global__ void rasterize_to_pixels_fwd_kernel( const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] const int32_t *__restrict__ flatten_ids, // [n_isects] float *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] @@ -1398,7 +1423,28 @@ __global__ void rasterize_to_pixels_fwd_kernel( float px = (float)j + 0.5f; float py = (float)i + 0.5f; int32_t pix_id = i * image_width + j; - float roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + // Calculate rolling shutter time based on shutter direction + float roll_time = 0.0f; + switch (rolling_shutter_direction) { + case 1: // top2bot + roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + break; + case 2: // left2right + roll_time = rolling_shutter_time[camera_id] * ((px - 0.5f) / (image_width - 1) - 0.5f); + break; + case 3: // bot2top + roll_time = rolling_shutter_time[camera_id] * ((-py + image_height - 0.5f) / (image_height - 1) - 0.5f); + break; + case 4: // right2left + roll_time = rolling_shutter_time[camera_id] * ((-px + image_width - 0.5f) / (image_width - 1) - 0.5f); + break; + case 5: // global + roll_time = 0.0f; // No rolling shutter effect + break; + default: + // Default to top2bot behavior if an invalid value is provided + roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + } // return if out of bounds // keep not rasterizing threads around for reading data @@ -1530,6 +1576,7 @@ std::tuple rasterize_to_pixels_fwd_ const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids // [n_isects] @@ -1591,7 +1638,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1614,7 +1661,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1637,7 +1684,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1660,7 +1707,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1683,7 +1730,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1706,7 +1753,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1729,7 +1776,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1752,7 +1799,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1775,7 +1822,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1798,7 +1845,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1821,7 +1868,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1844,7 +1891,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1867,7 +1914,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1890,7 +1937,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1913,7 +1960,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1936,7 +1983,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1959,7 +2006,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -1982,7 +2029,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -2005,7 +2052,7 @@ std::tuple rasterize_to_pixels_fwd_ (float2 *)pix_vels.data_ptr(), rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), @@ -2846,6 +2893,7 @@ __global__ void rasterize_to_pixels_bwd_kernel( const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] const int32_t *__restrict__ flatten_ids, // [n_isects] // fwd outputs @@ -2882,7 +2930,28 @@ __global__ void rasterize_to_pixels_bwd_kernel( const float py = (float)i + 0.5f; // clamp this value to the last pixel const int32_t pix_id = min(i * image_width + j, image_width * image_height - 1); - float roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + // Calculate rolling shutter time based on shutter direction + float roll_time = 0.0f; + switch (rolling_shutter_direction) { + case 1: // top2bot + roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + break; + case 2: // left2right + roll_time = rolling_shutter_time[camera_id] * ((px - 0.5f) / (image_width - 1) - 0.5f); + break; + case 3: // bot2top + roll_time = rolling_shutter_time[camera_id] * ((-py + image_height - 0.5f) / (image_height - 1) - 0.5f); + break; + case 4: // right2left + roll_time = rolling_shutter_time[camera_id] * ((-px + image_width - 0.5f) / (image_width - 1) - 0.5f); + break; + case 5: // global + roll_time = 0.0f; // No rolling shutter effect + break; + default: + // Default to top2bot behavior if an invalid value is provided + roll_time = rolling_shutter_time[camera_id] * ((py - 0.5f) / (image_height - 1) - 0.5f); + } // keep not rasterizing threads around for reading data bool inside = (i < image_height && j < image_width); @@ -3102,6 +3171,7 @@ rasterize_to_pixels_bwd_tensor( const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const uint32_t rolling_shutter_direction, // 1: top2bot, 2: left2right, 3: bot2top, 4: right2left, 5: global // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids, // [n_isects] @@ -3177,7 +3247,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3206,7 +3276,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3235,7 +3305,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3264,7 +3334,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3293,7 +3363,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3322,7 +3392,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3351,7 +3421,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3380,7 +3450,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3409,7 +3479,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3438,7 +3508,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3467,7 +3537,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3496,7 +3566,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3525,7 +3595,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3555,7 +3625,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3585,7 +3655,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3615,7 +3685,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3645,7 +3715,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3675,7 +3745,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -3705,7 +3775,7 @@ rasterize_to_pixels_bwd_tensor( rolling_shutter_time.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, - image_width, image_height, tile_size, tile_width, tile_height, + image_width, image_height, tile_size, tile_width, tile_height, rolling_shutter_direction, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), diff --git a/gsplat/rendering.py b/gsplat/rendering.py index dc7e1b0..69bf95c 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -33,6 +33,7 @@ def rasterization( linear_velocity: Optional[Tensor] = None, # [C, 3] angular_velocity: Optional[Tensor] = None, # [C, 3] rolling_shutter_time: Optional[Tensor] = None, # [C] + rolling_shutter_direction: Optional[int] = 1, # 1: top2bottom, 2: bottom2top, 3: left2right, 4: right2left, 5: no rolling shutter near_plane: float = 0.01, far_plane: float = 1e10, radius_clip: float = 0.0, @@ -224,6 +225,10 @@ def rasterization( assert rolling_shutter_time.shape == (C,), rolling_shutter_time.shape else: rolling_shutter_time = torch.zeros(C, device=means.device) + if rolling_shutter_direction is not None: + assert rolling_shutter_direction in (1,2,3,4,5), rolling_shutter_direction + else: + rolling_shutter_direction = 1 assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode if sh_degree is None: @@ -382,6 +387,7 @@ def rasterization( isect_offsets, flatten_ids, rolling_shutter_time, + rolling_shutter_direction=rolling_shutter_direction, backgrounds=backgrounds_chunk, packed=packed, absgrad=absgrad, @@ -403,6 +409,7 @@ def rasterization( isect_offsets, flatten_ids, rolling_shutter_time, + rolling_shutter_direction=rolling_shutter_direction, backgrounds=backgrounds, packed=packed, absgrad=absgrad,