Skip to content
Merged
36 changes: 36 additions & 0 deletions configs/motus/motus_i2v.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"wan_path": "/path/to/Wan2.2-TI2V-5B",
"vlm_path": "/path/to/Qwen3-VL-2B-Instruct",
"infer_steps": 10,
"text_len": 512,
"target_video_length": 9,
"target_height": 384,
"target_width": 320,
"num_channels_latents": 48,
"sample_guide_scale": 1.0,
"patch_size": [1, 2, 2],
"vae_stride": [4, 16, 16],
"sample_shift": 5.0,
"feature_caching": "NoCaching",
"use_image_encoder": false,
"enable_cfg": false,
"attention_type": "flash_attn2",
"self_joint_attn_type": "flash_attn2",
"cross_attn_type": "flash_attn2",
"global_downsample_rate": 3,
"video_action_freq_ratio": 2,
"num_video_frames": 8,
"video_height": 384,
"video_width": 320,
"fps": 4,
"motus_quantized": false,
"motus_quant_scheme": "Default",
"load_pretrained_backbones": false,
"training_mode": "finetune",
"action_state_dim": 14,
"action_dim": 14,
"action_expert_dim": 1024,
"action_expert_ffn_dim_multiplier": 4,
"und_expert_hidden_size": 512,
"und_expert_ffn_dim_multiplier": 4
}
4 changes: 4 additions & 0 deletions lightx2v/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from lightx2v.common.ops import *
from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401
from lightx2v.models.runners.motus.motus_runner import MotusRunner # noqa: F401
from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
Expand Down Expand Up @@ -82,6 +83,7 @@ def main():
"bagel",
"seedvr2",
"neopp",
"motus",
"lingbot_world_fast",
"worldmirror",
],
Expand All @@ -104,6 +106,7 @@ def main():
default="",
help="The path to input image file(s) for image-to-video (i2v) or image-to-audio-video (i2av) task. Multiple paths should be comma-separated. Example: 'path1.jpg,path2.jpg'",
)
parser.add_argument("--state_path", type=str, default="", help="The path to input robot state file for Motus i2v inference.")
parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
parser.add_argument(
"--audio_path",
Expand Down Expand Up @@ -191,6 +194,7 @@ def main():
parser.add_argument("--wm_ckpt_path", type=str, default=None, help="(worldmirror/recon) Optional .ckpt/.safetensors (pair with --wm_config_path).")

parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
parser.add_argument("--save_action_path", type=str, default=None, help="The path to save action predictions for Motus.")
parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
parser.add_argument("--target_shape", type=int, nargs="+", default=[], help="Set return video or image shape")
parser.add_argument("--target_video_length", type=int, default=81, help="The target video length for each generated clip")
Expand Down
7 changes: 7 additions & 0 deletions lightx2v/models/networks/motus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .model import MotusModel
from .primitives import sinusoidal_embedding_1d

__all__ = [
"MotusModel",
"sinusoidal_embedding_1d",
]
19 changes: 19 additions & 0 deletions lightx2v/models/networks/motus/image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import cv2
import numpy as np


def resize_with_padding(frame: np.ndarray, target_size: tuple[int, int]) -> np.ndarray:
target_height, target_width = target_size
original_height, original_width = frame.shape[:2]

scale = min(target_height / original_height, target_width / original_width)
new_height = int(original_height * scale)
new_width = int(original_width * scale)

resized_frame = cv2.resize(frame, (new_width, new_height))
padded_frame = np.zeros((target_height, target_width, frame.shape[2]), dtype=frame.dtype)

y_offset = (target_height - new_height) // 2
x_offset = (target_width - new_width) // 2
padded_frame[y_offset : y_offset + new_height, x_offset : x_offset + new_width] = resized_frame
return padded_frame
5 changes: 5 additions & 0 deletions lightx2v/models/networks/motus/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .post_infer import MotusPostInfer
from .pre_infer import MotusPreInfer
from .transformer_infer import MotusTransformerInfer

__all__ = ["MotusPreInfer", "MotusTransformerInfer", "MotusPostInfer"]
20 changes: 20 additions & 0 deletions lightx2v/models/networks/motus/infer/module_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass, field
from typing import Any

import torch

from lightx2v.models.networks.wan.infer.module_io import WanPreInferModuleOutput


@dataclass(kw_only=True)
class MotusPreInferModuleOutput(WanPreInferModuleOutput):
state: torch.Tensor
first_frame: torch.Tensor
instruction: str
t5_embeddings: list[torch.Tensor]
vlm_inputs: list[dict[str, Any]]
image_context: torch.Tensor | None
und_tokens: torch.Tensor
condition_frame_latent: torch.Tensor
adapter_args: dict[str, Any] = field(default_factory=dict)
conditional_dict: dict[str, Any] = field(default_factory=dict)
16 changes: 16 additions & 0 deletions lightx2v/models/networks/motus/infer/post_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


class MotusPostInfer:
def __init__(self, model, config):
self.model = model
self.config = config
self.scheduler = None

def set_scheduler(self, scheduler):
self.scheduler = scheduler

@torch.no_grad()
def infer(self, action_latents: torch.Tensor, pre_infer_out):
del pre_infer_out
return self.model.denormalize_actions(action_latents.float()).squeeze(0)
70 changes: 70 additions & 0 deletions lightx2v/models/networks/motus/infer/pre_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch

from lightx2v.models.networks.wan.infer.module_io import GridOutput
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer

from .module_io import MotusPreInferModuleOutput


class MotusPreInfer(WanPreInfer):
def __init__(self, model, config):
super().__init__(config)
self.model = model
self.scheduler = None

def set_scheduler(self, scheduler):
self.scheduler = scheduler

@torch.no_grad()
def infer(self, weights, inputs, kv_start=0, kv_end=0):
del weights, kv_start, kv_end
if self.scheduler is None:
raise RuntimeError("MotusPreInfer requires a scheduler before infer().")

first_frame = inputs["motus_first_frame"]
state = inputs["motus_state"]
instruction = inputs["motus_instruction"]
t5_context = inputs["motus_t5_embeddings"]
processed_t5_context = inputs["motus_processed_t5_context"]
vlm_inputs = inputs["motus_vlm_inputs"]
image_context = inputs["motus_image_context"]
und_tokens = inputs["motus_und_tokens"]

video_latents = self.scheduler.video_latents
if video_latents.dim() != 5:
raise RuntimeError(f"Expected video latents with shape [B, C, T, H, W], got {tuple(video_latents.shape)}")
batch_size = state.shape[0]
_, _, latent_t, latent_h, latent_w = video_latents.shape
grid_sizes = torch.tensor(
[[latent_t, latent_h // self.model.video_backbone.patch_size[1], latent_w // self.model.video_backbone.patch_size[2]]],
dtype=torch.long,
device=state.device,
).expand(batch_size, -1)
grid_output = GridOutput(
tensor=grid_sizes,
tuple=tuple(int(v) for v in grid_sizes[0].tolist()),
)

if self.cos_sin is None or self.grid_sizes != grid_output.tuple:
self.grid_sizes = grid_output.tuple
self.cos_sin = self.prepare_cos_sin(grid_output.tuple, self.freqs.clone())

dummy_embed = torch.empty(0, device=state.device, dtype=processed_t5_context.dtype)

return MotusPreInferModuleOutput(
embed=dummy_embed,
grid_sizes=grid_output,
x=self.scheduler.video_latents,
embed0=dummy_embed,
context=processed_t5_context,
cos_sin=self.cos_sin,
first_frame=first_frame,
state=state,
instruction=instruction,
t5_embeddings=t5_context,
vlm_inputs=vlm_inputs,
image_context=image_context,
und_tokens=und_tokens,
condition_frame_latent=self.scheduler.condition_frame_latent,
adapter_args={"instruction": instruction},
)
Loading