diff --git a/frontend/src/components/DownloadDialog.tsx b/frontend/src/components/DownloadDialog.tsx index d94362e8..556b7c26 100644 --- a/frontend/src/components/DownloadDialog.tsx +++ b/frontend/src/components/DownloadDialog.tsx @@ -9,11 +9,12 @@ import { DialogTitle, } from "./ui/dialog"; import { Progress } from "./ui/progress"; -import { PIPELINES } from "../data/pipelines"; import type { PipelineId, DownloadProgress } from "../types"; +import type { PipelineInfo } from "../hooks/usePipelines"; interface DownloadDialogProps { open: boolean; + pipelines: Record | null; pipelineId: PipelineId; onClose: () => void; onDownload: () => void; @@ -23,13 +24,14 @@ interface DownloadDialogProps { export function DownloadDialog({ open, + pipelines, pipelineId, onClose, onDownload, isDownloading = false, progress = null, }: DownloadDialogProps) { - const pipelineInfo = PIPELINES[pipelineId]; + const pipelineInfo = pipelines?.[pipelineId]; if (!pipelineInfo) return null; return ( diff --git a/frontend/src/components/InputAndControlsPanel.tsx b/frontend/src/components/InputAndControlsPanel.tsx index ca3a0e06..8afc57c4 100644 --- a/frontend/src/components/InputAndControlsPanel.tsx +++ b/frontend/src/components/InputAndControlsPanel.tsx @@ -12,18 +12,32 @@ import { Badge } from "./ui/badge"; import { Input } from "./ui/input"; import { Upload, ArrowUp } from "lucide-react"; import { LabelWithTooltip } from "./ui/label-with-tooltip"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "./ui/tooltip"; import type { VideoSourceMode } from "../hooks/useVideoSource"; import type { PromptItem, PromptTransition } from "../lib/api"; import type { InputMode } from "../types"; -import { pipelineIsMultiMode } from "../data/pipelines"; +import type { PipelineInfo } from "../hooks/usePipelines"; +import { + pipelineRequiresReferenceImage, + pipelineShowsPromptInput, + pipelineCanChangeReferenceWhileStreaming, + getPipelineReferenceImageDescription, +} from "../data/pipelines"; import { PromptInput } from "./PromptInput"; import { TimelinePromptEditor } from "./TimelinePromptEditor"; import type { TimelinePrompt } from "./PromptTimeline"; import { ImageManager } from "./ImageManager"; import { Button } from "./ui/button"; +import { ImageIcon } from "lucide-react"; interface InputAndControlsPanelProps { className?: string; + pipelines: Record | null; localStream: MediaStream | null; isInitializing: boolean; error: string | null; @@ -63,16 +77,21 @@ interface InputAndControlsPanelProps { onInputModeChange: (mode: InputMode) => void; // Whether Spout is available (server-side detection for native Windows, not WSL) spoutAvailable?: boolean; + // PersonaLive reference image + referenceImageUrl?: string | null; + onReferenceImageUpload?: (file: File) => void; + isUploadingReference?: boolean; // VACE reference images (only shown when VACE is enabled) vaceEnabled?: boolean; refImages?: string[]; onRefImagesChange?: (images: string[]) => void; onSendHints?: (imagePaths: string[]) => void; - isDownloading?: boolean; + isLoading?: boolean; } export function InputAndControlsPanel({ className = "", + pipelines, localStream, isInitializing, error, @@ -109,11 +128,14 @@ export function InputAndControlsPanel({ inputMode, onInputModeChange, spoutAvailable = false, + referenceImageUrl = null, + onReferenceImageUpload, + isUploadingReference = false, vaceEnabled = true, refImages = [], onRefImagesChange, onSendHints, - isDownloading = false, + isLoading = false, }: InputAndControlsPanelProps) { // Helper function to determine if playhead is at the end of timeline const isAtEndOfTimeline = () => { @@ -128,7 +150,22 @@ export function InputAndControlsPanel({ const videoRef = useRef(null); // Check if this pipeline supports multiple input modes - const isMultiMode = pipelineIsMultiMode(pipelineId); + const pipeline = pipelines?.[pipelineId]; + const isMultiMode = (pipeline?.supportedModes?.length ?? 0) > 1; + + // Check if this pipeline requires a reference image (PersonaLive) + const needsReferenceImage = pipelineRequiresReferenceImage(pipelineId); + + const handleReferenceImageUpload = ( + event: React.ChangeEvent + ) => { + const file = event.target.files?.[0]; + if (file && onReferenceImageUpload) { + onReferenceImageUpload(file); + } + // Reset the input value so the same file can be selected again + event.target.value = ""; + }; useEffect(() => { if (videoRef.current && localStream) { @@ -183,6 +220,80 @@ export function InputAndControlsPanel({ )} + {/* Reference Image upload - only show for pipelines that require it */} + {needsReferenceImage && ( +
+

Reference Portrait

+
+ {referenceImageUrl ? ( + Reference portrait + ) : ( +
+ + Upload a portrait image +
+ )} + + {/* Upload button with tooltip when disabled during streaming */} + {(isStreaming || isConnecting) && + !pipelineCanChangeReferenceWhileStreaming(pipelineId) ? ( + + + +
+ +
+
+ +

+ Reference image is processed when the pipeline loads. + Stop the stream to change it. +

+
+
+
+ ) : ( + + )} +
+ {isUploadingReference && ( +

+ Uploading reference image... +

+ )} + {!referenceImageUrl && ( +

+ {getPipelineReferenceImageDescription(pipelineId) || + "This pipeline requires a reference image."} +

+ )} +
+ )} + {/* Video Source toggle - only show when in video input mode */} {inputMode === "video" && (
@@ -299,7 +410,7 @@ export function InputAndControlsPanel({ {})} - disabled={isDownloading} + disabled={isLoading} /> {onSendHints && refImages && refImages.length > 0 && (
@@ -308,7 +419,7 @@ export function InputAndControlsPanel({ e.preventDefault(); onSendHints(refImages.filter(img => img)); }} - disabled={isDownloading || !isStreaming} + disabled={isLoading || !isStreaming} size="sm" className="rounded-full w-8 h-8 p-0 bg-black hover:bg-gray-800 text-white disabled:opacity-50 disabled:cursor-not-allowed" title={ @@ -324,67 +435,76 @@ export function InputAndControlsPanel({
)} -
- {(() => { - // The Input can have two states: Append (default) and Edit (when a prompt is selected and the video is paused) - const isEditMode = selectedTimelinePrompt && isVideoPaused; + {/* Prompts section - only show for pipelines that support text prompts */} + {pipelineShowsPromptInput(pipelineId) && ( +
+ {(() => { + // The Input can have two states: Append (default) and Edit (when a prompt is selected and the video is paused) + const isEditMode = selectedTimelinePrompt && isVideoPaused; - return ( -
-
-

Prompts

- {isEditMode && ( - - Editing - - )} -
+ // Hide prompts section if pipeline doesn't support prompts + if (pipeline?.supportsPrompts === false) { + return null; + } - {selectedTimelinePrompt ? ( - p.id === selectedTimelinePrompt.id + return ( +
+
+

Prompts

+ {isEditMode && ( + + Editing + )} - /> - ) : ( - - )} -
- ); - })()} -
+
+ + {selectedTimelinePrompt ? ( + p.id === selectedTimelinePrompt.id + )} + /> + ) : ( + + )} +
+ ); + })()} +
+ )} ); diff --git a/frontend/src/components/SettingsPanel.tsx b/frontend/src/components/SettingsPanel.tsx index d72e6768..2ec8188e 100644 --- a/frontend/src/components/SettingsPanel.tsx +++ b/frontend/src/components/SettingsPanel.tsx @@ -19,11 +19,15 @@ import { Input } from "./ui/input"; import { Button } from "./ui/button"; import { Toggle } from "./ui/toggle"; import { SliderWithInput } from "./ui/slider-with-input"; -import { Hammer, Info, Minus, Plus, RotateCcw } from "lucide-react"; +import { Info, Minus, Plus, RotateCcw } from "lucide-react"; import { - PIPELINES, - pipelineSupportsLoRA, - pipelineSupportsVACE, + pipelineShowsResolutionControl, + pipelineShowsSeedControl, + pipelineShowsDenoisingSteps, + pipelineShowsCacheManagement, + pipelineShowsQuantization, + pipelineShowsKvCacheAttentionBias, + pipelineShowsNoiseControls, } from "../data/pipelines"; import { PARAMETER_METADATA } from "../data/parameterMetadata"; import { DenoisingStepsSlider } from "./DenoisingStepsSlider"; @@ -35,12 +39,14 @@ import type { SettingsState, InputMode, } from "../types"; +import type { PipelineInfo } from "../hooks/usePipelines"; import { LoRAManager } from "./LoRAManager"; const MIN_DIMENSION = 16; interface SettingsPanelProps { className?: string; + pipelines: Record | null; pipelineId: PipelineId; onPipelineIdChange?: (pipelineId: PipelineId) => void; isStreaming?: boolean; @@ -73,8 +79,6 @@ interface SettingsPanelProps { loraMergeStrategy?: LoraMergeStrategy; // Input mode for conditional rendering of noise controls inputMode?: InputMode; - // Whether this pipeline supports noise controls in video mode (schema-derived) - supportsNoiseControls?: boolean; // Spout settings spoutSender?: SettingsState["spoutSender"]; onSpoutSenderChange?: (spoutSender: SettingsState["spoutSender"]) => void; @@ -89,6 +93,7 @@ interface SettingsPanelProps { export function SettingsPanel({ className = "", + pipelines, pipelineId, onPipelineIdChange, isStreaming = false, @@ -115,7 +120,6 @@ export function SettingsPanel({ onLorasChange, loraMergeStrategy = "permanent_merge", inputMode, - supportsNoiseControls = false, spoutSender, onSpoutSenderChange, spoutAvailable = false, @@ -141,7 +145,7 @@ export function SettingsPanel({ const [seedError, setSeedError] = useState(null); const handlePipelineIdChange = (value: string) => { - if (value in PIPELINES) { + if (pipelines && value in pipelines) { onPipelineIdChange?.(value as PipelineId); } }; @@ -235,7 +239,7 @@ export function SettingsPanel({ handleSeedChange(newValue); }; - const currentPipeline = PIPELINES[pipelineId]; + const currentPipeline = pipelines?.[pipelineId]; return ( @@ -254,11 +258,12 @@ export function SettingsPanel({ - {Object.keys(PIPELINES).map(id => ( - - {id} - - ))} + {pipelines && + Object.keys(pipelines).map(id => ( + + {id} + + ))}
@@ -273,9 +278,7 @@ export function SettingsPanel({
- {(currentPipeline.about || - currentPipeline.docsUrl || - currentPipeline.modified) && ( + {(currentPipeline.about || currentPipeline.docsUrl) && (
{currentPipeline.about && ( @@ -294,26 +297,6 @@ export function SettingsPanel({ )} - {currentPipeline.modified && ( - - - - - - - - -

- This pipeline contains modifications based on the - original project. -

-
-
-
- )} {currentPipeline.docsUrl && (
)} - {pipelineSupportsLoRA(pipelineId) && ( + {currentPipeline?.supportsLoRA && (
)} - {(pipelineId === "longlive" || - pipelineId === "streamdiffusionv2" || - pipelineId === "krea-realtime-video" || - pipelineId === "reward-forcing") && ( -
-
-
-
-
- -
- - { - const value = parseInt(e.target.value); - if (!isNaN(value)) { - handleResolutionChange("height", value); - } - }} - disabled={isStreaming} - className="text-center border-0 focus-visible:ring-0 focus-visible:ring-offset-0 h-8 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none" - min={MIN_DIMENSION} - max={2048} - /> - -
-
- {heightError && ( -

{heightError}

- )} -
- -
-
- -
- - { - const value = parseInt(e.target.value); - if (!isNaN(value)) { - handleResolutionChange("width", value); - } - }} - disabled={isStreaming} - className="text-center border-0 focus-visible:ring-0 focus-visible:ring-offset-0 h-8 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none" - min={MIN_DIMENSION} - max={2048} - /> - -
-
- {widthError && ( -

{widthError}

- )} -
- -
-
- -
- - { - const value = parseInt(e.target.value); - if (!isNaN(value)) { - handleSeedChange(value); - } - }} - disabled={isStreaming} - className="text-center border-0 focus-visible:ring-0 focus-visible:ring-offset-0 h-8 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none" - min={0} - max={2147483647} - /> - -
-
- {seedError && ( -

{seedError}

- )} + {/* Resolution controls */} + {pipelineShowsResolutionControl(pipelineId) && ( +
+
+
+ +
+ + { + const value = parseInt(e.target.value); + if (!isNaN(value)) { + handleResolutionChange("height", value); + } + }} + disabled={isStreaming} + className="text-center border-0 focus-visible:ring-0 focus-visible:ring-offset-0 h-8 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none" + min={MIN_DIMENSION} + max={2048} + /> +
+ {heightError && ( +

{heightError}

+ )}
-
- )} - - {(pipelineId === "longlive" || - pipelineId === "streamdiffusionv2" || - pipelineId === "krea-realtime-video" || - pipelineId === "reward-forcing") && ( -
-
-
- {pipelineId === "krea-realtime-video" && ( - parseFloat(v) || 1.0} - /> - )} -
- - {})} - variant="outline" - size="sm" - className="h-7" +
+
+ +
+
- -
- + + { + const value = parseInt(e.target.value); + if (!isNaN(value)) { + handleResolutionChange("width", value); + } + }} + disabled={isStreaming} + className="text-center border-0 focus-visible:ring-0 focus-visible:ring-offset-0 h-8 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none" + min={MIN_DIMENSION} + max={2048} />
+ {widthError && ( +

{widthError}

+ )} +
+
+ )} + + {/* Seed control */} + {pipelineShowsSeedControl(pipelineId) && ( +
+
+ +
+ + { + const value = parseInt(e.target.value); + if (!isNaN(value)) { + handleSeedChange(value); + } + }} + disabled={isStreaming} + className="text-center border-0 focus-visible:ring-0 focus-visible:ring-offset-0 h-8 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none" + min={0} + max={2147483647} + /> + +
+
+ {seedError && ( +

{seedError}

+ )} +
+ )} + + {/* KV Cache Attention Bias (krea-realtime-video specific) */} + {pipelineShowsKvCacheAttentionBias(pipelineId) && ( + parseFloat(v) || 1.0} + /> + )} + + {/* Cache management controls */} + {pipelineShowsCacheManagement(pipelineId) && ( +
+
+ + {})} + variant="outline" + size="sm" + className="h-7" + > + {manageCache ? "ON" : "OFF"} + +
+ +
+ +
)} - {(pipelineId === "longlive" || - pipelineId === "streamdiffusionv2" || - pipelineId === "krea-realtime-video" || - pipelineId === "reward-forcing") && ( + {/* Denoising steps slider */} + {pipelineShowsDenoisingSteps(pipelineId) && ( {})} @@ -629,84 +602,70 @@ export function SettingsPanel({ /> )} - {/* Noise controls - show for video mode on supported pipelines (schema-derived) */} - {inputMode === "video" && supportsNoiseControls && ( -
-
-
-
- - {})} - disabled={isStreaming} - variant="outline" - size="sm" - className="h-7" - > - {noiseController ? "ON" : "OFF"} - -
-
- - parseFloat(v) || 0.0} + {/* Noise controls - show for video mode on pipelines that support it */} + {inputMode === "video" && pipelineShowsNoiseControls(pipelineId) && ( +
+
+ + {})} + disabled={isStreaming} + variant="outline" + size="sm" + className="h-7" + > + {noiseController ? "ON" : "OFF"} +
+ + parseFloat(v) || 0.0} + />
)} - {(pipelineId === "longlive" || - pipelineId === "streamdiffusionv2" || - pipelineId === "krea-realtime-video" || - pipelineId === "reward-forcing") && ( -
-
-
-
- - -
-
-
+ {/* Quantization selector */} + {pipelineShowsQuantization(pipelineId) && ( +
+ +
)} diff --git a/frontend/src/data/pipelines.ts b/frontend/src/data/pipelines.ts index fae64ad7..c2737d9f 100644 --- a/frontend/src/data/pipelines.ts +++ b/frontend/src/data/pipelines.ts @@ -1,13 +1,27 @@ import type { InputMode } from "../types"; -// Unified default prompts by mode (not per-pipeline) -// These are used across all pipelines for consistency +// Default prompts by mode - used across all pipelines for consistency export const DEFAULT_PROMPTS: Record = { text: "A 3D animated scene. A **panda** walks along a path towards the camera in a park on a spring day.", video: "A 3D animated scene. A **panda** sitting in the grass, looking around.", }; +// UI capability flags - controls which UI elements are shown for each pipeline +// NOTE: These are used as fallbacks when pipeline info is not available from backend +export interface PipelineUICapabilities { + showTimeline?: boolean; // Show prompt timeline (default: true) + showPromptInput?: boolean; // Show prompt input controls (default: true) + showResolutionControl?: boolean; // Show resolution width/height controls + showSeedControl?: boolean; // Show seed control + showDenoisingSteps?: boolean; // Show denoising steps slider + showCacheManagement?: boolean; // Show cache management toggle and reset + showQuantization?: boolean; // Show quantization selector + showKvCacheAttentionBias?: boolean; // Show KV cache attention bias slider (krea-realtime-video) + showNoiseControls?: boolean; // Show noise scale and controller (video mode) + canChangeReferenceWhileStreaming?: boolean; // Allow changing reference image during stream +} + export interface PipelineInfo { name: string; about: string; @@ -18,20 +32,26 @@ export interface PipelineInfo { defaultTemporalInterpolationMethod?: "linear" | "slerp"; // Default method for temporal interpolation defaultTemporalInterpolationSteps?: number; // Default number of steps for temporal interpolation supportsLoRA?: boolean; // Whether this pipeline supports LoRA adapters - supportsVACE?: boolean; // Whether this pipeline supports VACE (Video All-In-One Creation and Editing) + supportsVACE?: boolean; // Whether this pipeline supports VACE + requiresReferenceImage?: boolean; // Whether this pipeline requires a reference image (e.g. PersonaLive) + referenceImageDescription?: string; // Description of what the reference image is for // Multi-mode support supportedModes: InputMode[]; defaultMode: InputMode; + + // UI capabilities - controls which settings/controls are shown + ui?: PipelineUICapabilities; } +// Static fallback pipeline info - used when backend is not available export const PIPELINES: Record = { streamdiffusionv2: { name: "StreamDiffusionV2", docsUrl: "https://github.com/daydreamlive/scope/blob/main/src/scope/core/pipelines/streamdiffusionv2/docs/usage.md", about: - "A streaming pipeline and autoregressive video diffusion model from the creators of the original StreamDiffusion project. The model is trained using Self-Forcing on Wan2.1 1.3b with modifications to support streaming. Includes VACE (All-In-One Video Creation and Editing) for reference image conditioning and structural guidance (depth, flow, pose).", + "A streaming pipeline and autoregressive video diffusion model from the creators of the original StreamDiffusion project. The model is trained using Self-Forcing on Wan2.1 1.3b with modifications to support streaming.", modified: true, estimatedVram: 20, requiresModels: true, @@ -39,16 +59,25 @@ export const PIPELINES: Record = { defaultTemporalInterpolationSteps: 0, supportsLoRA: true, supportsVACE: true, - // Multi-mode support supportedModes: ["text", "video"], defaultMode: "video", + ui: { + showTimeline: true, + showPromptInput: true, + showResolutionControl: true, + showSeedControl: true, + showDenoisingSteps: true, + showCacheManagement: true, + showQuantization: true, + showNoiseControls: true, + }, }, longlive: { name: "LongLive", docsUrl: "https://github.com/daydreamlive/scope/blob/main/src/scope/core/pipelines/longlive/docs/usage.md", about: - "A streaming pipeline and autoregressive video diffusion model from Nvidia, MIT, HKUST, HKU and THU. The model is trained using Self-Forcing on Wan2.1 1.3b with modifications to support smoother prompt switching and improved quality over longer time periods while maintaining fast generation. Includes VACE (All-In-One Video Creation and Editing) for reference image conditioning and structural guidance (depth, flow, pose).", + "A streaming pipeline and autoregressive video diffusion model from Nvidia, MIT, HKUST, HKU and THU. The model is trained using Self-Forcing on Wan2.1 1.3b with modifications to support smoother prompt switching and improved quality over longer time periods while maintaining fast generation.", modified: true, estimatedVram: 20, requiresModels: true, @@ -56,9 +85,18 @@ export const PIPELINES: Record = { defaultTemporalInterpolationSteps: 0, supportsLoRA: true, supportsVACE: true, - // Multi-mode support supportedModes: ["text", "video"], defaultMode: "text", + ui: { + showTimeline: true, + showPromptInput: true, + showResolutionControl: true, + showSeedControl: true, + showDenoisingSteps: true, + showCacheManagement: true, + showQuantization: true, + showNoiseControls: true, + }, }, "krea-realtime-video": { name: "Krea Realtime Video", @@ -72,16 +110,26 @@ export const PIPELINES: Record = { defaultTemporalInterpolationMethod: "linear", defaultTemporalInterpolationSteps: 4, supportsLoRA: true, - // Multi-mode support supportedModes: ["text", "video"], defaultMode: "text", + ui: { + showTimeline: true, + showPromptInput: true, + showResolutionControl: true, + showSeedControl: true, + showDenoisingSteps: true, + showCacheManagement: true, + showQuantization: true, + showKvCacheAttentionBias: true, + showNoiseControls: true, + }, }, "reward-forcing": { name: "RewardForcing", docsUrl: "https://github.com/daydreamlive/scope/blob/main/src/scope/core/pipelines/reward_forcing/docs/usage.md", about: - "A streaming pipeline and autoregressive video diffusion model from ZJU, Ant Group, SIAS-ZJU, HUST and SJTU. The model is trained with Rewarded Distribution Matching Distillation using Wan2.1 1.3b as the base model. Includes VACE (All-In-One Video Creation and Editing) for reference image conditioning and structural guidance (depth, flow, pose).", + "A streaming pipeline and autoregressive video diffusion model from ZJU, Ant Group, SIAS-ZJU, HUST and SJTU. The model is trained with Rewarded Distribution Matching Distillation using Wan2.1 1.3b as the base model.", modified: true, estimatedVram: 20, requiresModels: true, @@ -89,45 +137,117 @@ export const PIPELINES: Record = { defaultTemporalInterpolationSteps: 0, supportsLoRA: true, supportsVACE: true, - // Multi-mode support supportedModes: ["text", "video"], defaultMode: "text", + ui: { + showTimeline: true, + showPromptInput: true, + showResolutionControl: true, + showSeedControl: true, + showDenoisingSteps: true, + showCacheManagement: true, + showQuantization: true, + showNoiseControls: true, + }, }, passthrough: { name: "Passthrough", about: "A pipeline that returns the input video without any processing that is useful for testing and debugging.", requiresModels: false, - // Video-only pipeline supportedModes: ["video"], defaultMode: "video", + ui: { + showTimeline: false, + showPromptInput: false, + showResolutionControl: false, + showSeedControl: false, + showDenoisingSteps: false, + showCacheManagement: false, + showQuantization: false, + showNoiseControls: false, + }, + }, + personalive: { + name: "PersonaLive", + docsUrl: "https://github.com/GVCLab/PersonaLive", + about: + "Real-time portrait animation pipeline from GVCLab. Animates a reference portrait image using driving video frames to transfer expressions and head movements.", + estimatedVram: 12, + requiresModels: true, + supportedModes: ["video"], + defaultMode: "video", + requiresReferenceImage: true, + referenceImageDescription: + "Portrait image to animate. Expressions and head movements from the driving video will be transferred to this image.", + ui: { + showTimeline: false, + showPromptInput: false, + showResolutionControl: true, + showSeedControl: true, + showDenoisingSteps: false, + showCacheManagement: false, + showQuantization: false, + showNoiseControls: false, + canChangeReferenceWhileStreaming: false, + }, }, }; -export function pipelineSupportsLoRA(pipelineId: string): boolean { - return PIPELINES[pipelineId]?.supportsLoRA === true; +export function getDefaultPromptForMode(mode: InputMode): string { + return DEFAULT_PROMPTS[mode]; } -export function pipelineSupportsVACE(pipelineId: string): boolean { - return PIPELINES[pipelineId]?.supportsVACE === true; +export function pipelineRequiresReferenceImage(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.requiresReferenceImage === true; } -export function pipelineSupportsMode( - pipelineId: string, - mode: InputMode -): boolean { - return PIPELINES[pipelineId]?.supportedModes?.includes(mode) ?? false; +export function getPipelineReferenceImageDescription( + pipelineId: string +): string | undefined { + return PIPELINES[pipelineId]?.referenceImageDescription; } -export function pipelineIsMultiMode(pipelineId: string): boolean { - const modes = PIPELINES[pipelineId]?.supportedModes ?? []; - return modes.length > 1; +// UI capability helper functions + +export function pipelineShowsTimeline(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showTimeline !== false; } -export function getPipelineDefaultMode(pipelineId: string): InputMode { - return PIPELINES[pipelineId]?.defaultMode ?? "text"; +export function pipelineShowsPromptInput(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showPromptInput !== false; } -export function getDefaultPromptForMode(mode: InputMode): string { - return DEFAULT_PROMPTS[mode]; +export function pipelineShowsResolutionControl(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showResolutionControl === true; +} + +export function pipelineShowsSeedControl(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showSeedControl === true; +} + +export function pipelineShowsDenoisingSteps(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showDenoisingSteps === true; +} + +export function pipelineShowsCacheManagement(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showCacheManagement === true; +} + +export function pipelineShowsQuantization(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showQuantization === true; +} + +export function pipelineShowsKvCacheAttentionBias(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showKvCacheAttentionBias === true; +} + +export function pipelineShowsNoiseControls(pipelineId: string): boolean { + return PIPELINES[pipelineId]?.ui?.showNoiseControls === true; +} + +export function pipelineCanChangeReferenceWhileStreaming( + pipelineId: string +): boolean { + return PIPELINES[pipelineId]?.ui?.canChangeReferenceWhileStreaming !== false; } diff --git a/frontend/src/hooks/usePipelines.ts b/frontend/src/hooks/usePipelines.ts new file mode 100644 index 00000000..c07d24f6 --- /dev/null +++ b/frontend/src/hooks/usePipelines.ts @@ -0,0 +1,82 @@ +import { useState, useEffect } from "react"; +import { getPipelineSchemas } from "../lib/api"; +import type { InputMode } from "../types"; + +export interface PipelineInfo { + name: string; + about: string; + docsUrl?: string | null; + estimatedVram?: number | null; + requiresModels?: boolean; + supportsPrompts?: boolean; + defaultTemporalInterpolationMethod?: "linear" | "slerp"; + defaultTemporalInterpolationSteps?: number; + supportsLoRA?: boolean; + supportsVACE?: boolean; + supportedModes: InputMode[]; + defaultMode: InputMode; +} + +export function usePipelines() { + const [pipelines, setPipelines] = useState | null>(null); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState(null); + + useEffect(() => { + let mounted = true; + + async function fetchPipelines() { + try { + setIsLoading(true); + const schemas = await getPipelineSchemas(); + + if (!mounted) return; + + // Transform to camelCase for TypeScript conventions + const transformed: Record = {}; + for (const [id, schema] of Object.entries(schemas.pipelines)) { + transformed[id] = { + name: schema.name, + about: schema.description, + supportedModes: schema.supported_modes as InputMode[], + defaultMode: schema.default_mode as InputMode, + supportsPrompts: schema.supports_prompts, + defaultTemporalInterpolationMethod: + schema.default_temporal_interpolation_method, + defaultTemporalInterpolationSteps: + schema.default_temporal_interpolation_steps, + docsUrl: schema.docs_url, + estimatedVram: schema.estimated_vram_gb, + requiresModels: schema.requires_models, + supportsLoRA: schema.supports_lora, + supportsVACE: schema.supports_vace, + }; + } + + setPipelines(transformed); + setError(null); + } catch (err) { + if (!mounted) return; + const errorMessage = + err instanceof Error ? err.message : "Failed to fetch pipelines"; + setError(errorMessage); + console.error("Failed to fetch pipelines:", err); + } finally { + if (mounted) { + setIsLoading(false); + } + } + } + + fetchPipelines(); + + return () => { + mounted = false; + }; + }, []); + + return { pipelines, isLoading, error }; +} diff --git a/frontend/src/hooks/useStreamState.ts b/frontend/src/hooks/useStreamState.ts index 95c7e60c..c8c0fe31 100644 --- a/frontend/src/hooks/useStreamState.ts +++ b/frontend/src/hooks/useStreamState.ts @@ -13,11 +13,9 @@ import { type HardwareInfoResponse, type PipelineSchemasResponse, } from "../lib/api"; -import { getPipelineDefaultMode } from "../data/pipelines"; // Generic fallback defaults used before schemas are loaded. -// Resolution and denoising steps use conservative values; mode-specific -// values are derived from pipelines.ts when possible. +// Resolution and denoising steps use conservative values. const BASE_FALLBACK = { height: 512, width: 512, @@ -26,9 +24,9 @@ const BASE_FALLBACK = { }; // Get fallback defaults for a pipeline before schemas are loaded -// Derives mode from pipelines.ts to stay in sync with frontend definitions -function getFallbackDefaults(pipelineId: PipelineId, mode?: InputMode) { - const effectiveMode = mode ?? getPipelineDefaultMode(pipelineId); +function getFallbackDefaults(mode?: InputMode) { + // Default to text mode if no mode specified (will be corrected when schemas load) + const effectiveMode = mode ?? "text"; const isVideoMode = effectiveMode === "video"; // Video mode gets noise controls, text mode doesn't @@ -107,8 +105,7 @@ export function useStreamState() { }; } // Fallback to derived defaults if schemas not loaded - // Mode is derived from pipelines.ts to stay in sync - return getFallbackDefaults(pipelineId, mode); + return getFallbackDefaults(mode); }, [pipelineSchemas] ); @@ -137,7 +134,7 @@ export function useStreamState() { ); // Get initial defaults (use fallback since schemas haven't loaded yet) - const initialDefaults = getFallbackDefaults("streamdiffusionv2"); + const initialDefaults = getFallbackDefaults(); const [settings, setSettings] = useState({ pipelineId: "streamdiffusionv2", @@ -177,7 +174,27 @@ export function useStreamState() { ]); if (schemasResult.status === "fulfilled") { - setPipelineSchemas(schemasResult.value); + const schemas = schemasResult.value; + setPipelineSchemas(schemas); + + // Check if the default pipeline (streamdiffusionv2) is available + // If not, switch to the first available pipeline + const availablePipelines = Object.keys(schemas.pipelines); + const preferredPipeline = "streamdiffusionv2"; + + if ( + !availablePipelines.includes(preferredPipeline) && + availablePipelines.length > 0 + ) { + const firstPipelineId = availablePipelines[0] as PipelineId; + const firstPipelineSchema = schemas.pipelines[firstPipelineId]; + + setSettings(prev => ({ + ...prev, + pipelineId: firstPipelineId, + inputMode: firstPipelineSchema.default_mode, + })); + } } else { console.error( "useStreamState: Failed to fetch pipeline schemas:", @@ -201,6 +218,21 @@ export function useStreamState() { fetchInitialData(); }, []); + // Update inputMode when schemas load or pipeline changes + // This sets the correct default mode for the pipeline + useEffect(() => { + if (pipelineSchemas) { + const schema = pipelineSchemas.pipelines[settings.pipelineId]; + if (schema?.default_mode) { + setSettings(prev => ({ + ...prev, + inputMode: schema.default_mode, + })); + } + } + // Only run when schemas load or pipeline changes, NOT when inputMode changes + }, [pipelineSchemas, settings.pipelineId]); + // Set recommended quantization when krea-realtime-video is selected // Reset to null when switching to other pipelines useEffect(() => { diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index bdd19f7f..9bcf0807 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -75,6 +75,12 @@ export interface KreaRealtimeVideoLoadParams extends PipelineLoadParams { lora_merge_mode?: "permanent_merge" | "runtime_peft"; } +export interface PersonaLiveLoadParams extends PipelineLoadParams { + height?: number; + width?: number; + seed?: number; +} + export interface PipelineLoadRequest { pipeline_id?: string; load_params?: @@ -82,6 +88,7 @@ export interface PipelineLoadRequest { | StreamDiffusionV2LoadParams | LongLiveLoadParams | KreaRealtimeVideoLoadParams + | PersonaLiveLoadParams | null; } @@ -310,6 +317,37 @@ export const listLoRAFiles = async (): Promise => { return result; }; +// PersonaLive reference image upload +export interface PersonaLiveReferenceResponse { + success: boolean; + message: string; +} + +export const uploadPersonaLiveReference = async ( + imageFile: File | Blob +): Promise => { + // Read file as array buffer + const arrayBuffer = await imageFile.arrayBuffer(); + + const response = await fetch("/api/v1/personalive/reference", { + method: "POST", + headers: { + "Content-Type": imageFile.type || "image/jpeg", + }, + body: arrayBuffer, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `PersonaLive reference upload failed: ${response.status} ${response.statusText}: ${errorText}` + ); + } + + const result = await response.json(); + return result; +}; + export interface AssetFileInfo { name: string; path: string; @@ -423,10 +461,20 @@ export interface PipelineSchemaInfo { name: string; description: string; version: string; + docs_url: string | null; + estimated_vram_gb: number | null; + requires_models: boolean; + supports_lora: boolean; + supports_vace: boolean; + // Pipeline config schema config_schema: PipelineConfigSchema; // Mode support - comes from config class supported_modes: ("text" | "video")[]; default_mode: "text" | "video"; + // Prompt and temporal interpolation support + supports_prompts: boolean; + default_temporal_interpolation_method: "linear" | "slerp"; + default_temporal_interpolation_steps: number; // Mode-specific default overrides (optional) mode_defaults?: Record<"text" | "video", ModeDefaults>; } diff --git a/frontend/src/pages/StreamPage.tsx b/frontend/src/pages/StreamPage.tsx index a583ec5c..24207f4f 100644 --- a/frontend/src/pages/StreamPage.tsx +++ b/frontend/src/pages/StreamPage.tsx @@ -12,11 +12,11 @@ import { useVideoSource } from "../hooks/useVideoSource"; import { useWebRTCStats } from "../hooks/useWebRTCStats"; import { usePipeline } from "../hooks/usePipeline"; import { useStreamState } from "../hooks/useStreamState"; +import { usePipelines } from "../hooks/usePipelines"; import { - PIPELINES, - getPipelineDefaultMode, getDefaultPromptForMode, - pipelineSupportsVACE, + pipelineRequiresReferenceImage, + pipelineShowsTimeline, } from "../data/pipelines"; import type { InputMode, @@ -26,8 +26,13 @@ import type { DownloadProgress, } from "../types"; import type { PromptItem, PromptTransition } from "../lib/api"; -import { checkModelStatus, downloadPipelineModels } from "../lib/api"; +import { + checkModelStatus, + downloadPipelineModels, + uploadPersonaLiveReference, +} from "../lib/api"; import { sendLoRAScaleUpdates } from "../utils/loraHelpers"; +import { toast } from "sonner"; // Delay before resetting video reinitialization flag (ms) // This allows useVideoSource to detect the flag change and trigger reinitialization @@ -66,14 +71,17 @@ function getVaceParams( } export function StreamPage() { + // Fetch available pipelines dynamically + const { pipelines } = usePipelines(); + + // Helper to get default mode for a pipeline + const getPipelineDefaultMode = (pipelineId: string): InputMode => { + return pipelines?.[pipelineId]?.defaultMode ?? "text"; + }; + // Use the stream state hook for settings management - const { - settings, - updateSettings, - getDefaults, - supportsNoiseControls, - spoutAvailable, - } = useStreamState(); + const { settings, updateSettings, getDefaults, spoutAvailable } = + useStreamState(); // Prompt state - use unified default prompts based on mode const initialMode = @@ -121,6 +129,13 @@ export function StreamPage() { null ); + // PersonaLive reference image state + const [referenceImage, setReferenceImage] = useState(null); + const [referenceImageUrl, setReferenceImageUrl] = useState( + null + ); + const [isUploadingReference, setIsUploadingReference] = useState(false); + // Ref to access timeline functions const timelineRef = useRef<{ getCurrentTimelinePrompt: () => string; @@ -248,12 +263,19 @@ export function StreamPage() { }; const handlePipelineIdChange = (pipelineId: PipelineId) => { + console.log( + `[handlePipelineIdChange] Switching to pipeline: ${pipelineId}` + ); + // Stop the stream if it's currently running if (isStreaming) { stopStream(); } - const newPipeline = PIPELINES[pipelineId]; + // Clear reference image when switching pipelines + clearReferenceImage(); + + const newPipeline = pipelines?.[pipelineId]; const modeToUse = newPipeline?.defaultMode || "text"; const currentMode = settings.inputMode || "text"; @@ -339,7 +361,7 @@ export function StreamPage() { // Preserve the current input mode that the user selected before download // Only fall back to pipeline's default mode if no mode is currently set - const newPipeline = PIPELINES[pipelineId]; + const newPipeline = pipelines?.[pipelineId]; const currentMode = settings.inputMode || newPipeline?.defaultMode || "text"; const defaults = getDefaults(pipelineId, currentMode); @@ -532,6 +554,25 @@ export function StreamPage() { }); }; + // Handle reference image upload for PersonaLive + const handleReferenceImageUpload = (file: File) => { + // Clean up old URL + if (referenceImageUrl) { + URL.revokeObjectURL(referenceImageUrl); + } + setReferenceImage(file); + setReferenceImageUrl(URL.createObjectURL(file)); + }; + + // Clear reference image (used when switching pipelines) + const clearReferenceImage = () => { + if (referenceImageUrl) { + URL.revokeObjectURL(referenceImageUrl); + } + setReferenceImage(null); + setReferenceImageUrl(null); + }; + // Sync spoutReceiver.enabled with mode changes const handleModeChange = (newMode: typeof mode) => { // When switching to spout mode, enable spout input @@ -610,9 +651,9 @@ export function StreamPage() { return () => document.removeEventListener("keydown", handleKeyDown); }, [selectedTimelinePrompt]); - // Update temporal interpolation defaults when pipeline changes + // Update temporal interpolation defaults and clear prompts when pipeline changes useEffect(() => { - const pipeline = PIPELINES[settings.pipelineId]; + const pipeline = pipelines?.[settings.pipelineId]; if (pipeline) { const defaultMethod = pipeline.defaultTemporalInterpolationMethod || "slerp"; @@ -620,8 +661,13 @@ export function StreamPage() { setTemporalInterpolationMethod(defaultMethod); setTransitionSteps(defaultSteps); + + // Clear prompts if pipeline doesn't support them + if (pipeline.supportsPrompts === false) { + setPromptItems([{ text: "", weight: 1.0 }]); + } } - }, [settings.pipelineId]); + }, [settings.pipelineId, pipelines]); const handlePlayPauseToggle = () => { const newPausedState = !settings.paused; @@ -659,14 +705,22 @@ export function StreamPage() { // Use override pipeline ID if provided, otherwise use current settings const pipelineIdToUse = overridePipelineId || settings.pipelineId; + // Debug: Log which pipeline is being started + console.log( + `[handleStartStream] Starting pipeline: ${pipelineIdToUse} (settings.pipelineId: ${settings.pipelineId}, override: ${overridePipelineId})` + ); + try { // Check if models are needed but not downloaded - const pipelineInfo = PIPELINES[pipelineIdToUse]; + const pipelineInfo = pipelines?.[pipelineIdToUse]; if (pipelineInfo?.requiresModels) { try { const status = await checkModelStatus(pipelineIdToUse); if (!status.downloaded) { // Show download dialog + console.log( + `[handleStartStream] Models not downloaded for: ${pipelineIdToUse}, showing download dialog` + ); setPipelineNeedsModels(pipelineIdToUse); setShowDownloadDialog(true); return false; // Stream did not start @@ -677,9 +731,6 @@ export function StreamPage() { } } - // Always load pipeline with current parameters - backend will handle the rest - console.log(`Loading ${pipelineIdToUse} pipeline...`); - // Determine current input mode const currentMode = settings.inputMode || getPipelineDefaultMode(pipelineIdToUse) || "text"; @@ -693,7 +744,7 @@ export function StreamPage() { // Compute VACE enabled state once - enabled by default for text mode on VACE-supporting pipelines const vaceEnabled = settings.vaceEnabled ?? - (pipelineSupportsVACE(pipelineIdToUse) && currentMode !== "video"); + (pipelines?.[pipelineIdToUse]?.supportsVACE && currentMode !== "video"); if (pipelineIdToUse === "streamdiffusionv2" && resolution) { loadParams = { @@ -756,6 +807,24 @@ export function StreamPage() { console.log( `Loading with resolution: ${resolution.width}x${resolution.height}, seed: ${loadParams.seed}, quantization: ${loadParams.quantization}, lora_merge_mode: ${loadParams.lora_merge_mode}` ); + } else if (pipelineIdToUse === "personalive" && resolution) { + // PersonaLive requires a reference image + if (!referenceImage) { + toast.error("Reference Image Required", { + description: + "Please upload a reference portrait image before starting PersonaLive.", + duration: 5000, + }); + return false; + } + loadParams = { + height: resolution.height, + width: resolution.width, + seed: settings.seed ?? 42, + }; + console.log( + `Loading PersonaLive with resolution: ${resolution.width}x${resolution.height}, seed: ${loadParams.seed}` + ); } const loadSuccess = await loadPipeline( @@ -767,6 +836,22 @@ export function StreamPage() { return false; } + // For PersonaLive, upload reference image after pipeline is loaded + if (pipelineIdToUse === "personalive" && referenceImage) { + try { + setIsUploadingReference(true); + console.log("Uploading PersonaLive reference image..."); + const result = await uploadPersonaLiveReference(referenceImage); + console.log("Reference image uploaded:", result.message); + } catch (error) { + console.error("Failed to upload reference image:", error); + setIsUploadingReference(false); + return false; + } finally { + setIsUploadingReference(false); + } + } + // Check video requirements based on input mode const needsVideoInput = currentMode === "video"; const isSpoutMode = mode === "spout" && settings.spoutReceiver?.enabled; @@ -801,7 +886,12 @@ export function StreamPage() { }; // Common parameters for pipelines that support prompts - if (pipelineIdToUse !== "passthrough") { + // PersonaLive and Passthrough don't use text prompts + if ( + pipelineIdToUse !== "passthrough" && + pipelineIdToUse !== "personalive" && + pipelineInfo?.supportsPrompts !== false + ) { initialParameters.prompts = promptItems; initialParameters.prompt_interpolation_method = interpolationMethod; initialParameters.denoising_step_list = settings.denoisingSteps || [ @@ -872,6 +962,7 @@ export function StreamPage() {
@@ -941,15 +1041,27 @@ export function StreamPage() { isPlaying={!settings.paused} isDownloading={isDownloading} onPlayPauseToggle={() => { - // Use timeline's play/pause handler instead of direct video toggle - if (timelinePlayPauseRef.current) { + // For pipelines with timeline, use timeline's play/pause handler + // For pipelines without timeline (PersonaLive, Passthrough), toggle directly + if ( + pipelineShowsTimeline(settings.pipelineId) && + timelinePlayPauseRef.current + ) { timelinePlayPauseRef.current(); + } else { + handlePlayPauseToggle(); } }} onStartStream={() => { - // Use timeline's play/pause handler to start stream - if (timelinePlayPauseRef.current) { + // For pipelines with timeline, use timeline's play/pause handler + // For pipelines without timeline (PersonaLive, Passthrough), start directly + if ( + pipelineShowsTimeline(settings.pipelineId) && + timelinePlayPauseRef.current + ) { timelinePlayPauseRef.current(); + } else { + handleStartStream(); } }} onVideoPlaying={() => { @@ -961,116 +1073,120 @@ export function StreamPage() { }} />
- {/* Timeline area - compact, always visible */} -
- { - // Update the left panel's prompt state to reflect current timeline prompt - const prompts = [{ text, weight: 100 }]; - setPromptItems(prompts); - - // Send to backend - use transition if streaming and transition steps > 0 - if (isStreaming && transitionSteps > 0) { - sendParameterUpdate({ - transition: { - target_prompts: prompts, - num_steps: transitionSteps, - temporal_interpolation_method: - temporalInterpolationMethod, - }, - }); - } else { - // Send direct prompts without transition - sendParameterUpdate({ - prompts, - prompt_interpolation_method: interpolationMethod, - denoising_step_list: settings.denoisingSteps || [700, 500], - }); - } - }} - onPromptItemsSubmit={( - prompts, - blockTransitionSteps, - blockTemporalInterpolationMethod - ) => { - // Update the left panel's prompt state to reflect current timeline prompt blend - setPromptItems(prompts); - - // Use transition params from block if provided, otherwise use global settings - const effectiveTransitionSteps = - blockTransitionSteps ?? transitionSteps; - const effectiveTemporalInterpolationMethod = - blockTemporalInterpolationMethod ?? - temporalInterpolationMethod; - - // Update the left panel's transition settings to reflect current block's values - if (blockTransitionSteps !== undefined) { - setTransitionSteps(blockTransitionSteps); - } - if (blockTemporalInterpolationMethod !== undefined) { - setTemporalInterpolationMethod( - blockTemporalInterpolationMethod - ); - } - - // Send to backend - use transition if streaming and transition steps > 0 - if (isStreaming && effectiveTransitionSteps > 0) { - sendParameterUpdate({ - transition: { - target_prompts: prompts, - num_steps: effectiveTransitionSteps, - temporal_interpolation_method: - effectiveTemporalInterpolationMethod, - }, - }); - } else { - // Send direct prompts without transition - sendParameterUpdate({ - prompts, - prompt_interpolation_method: interpolationMethod, - denoising_step_list: settings.denoisingSteps || [700, 500], - }); + {/* Timeline area - only show for pipelines that support text prompts */} + {pipelineShowsTimeline(settings.pipelineId) && ( +
+ { + // Update the left panel's prompt state to reflect current timeline prompt + const prompts = [{ text, weight: 100 }]; + setPromptItems(prompts); + + // Send to backend - use transition if streaming and transition steps > 0 + if (isStreaming && transitionSteps > 0) { + sendParameterUpdate({ + transition: { + target_prompts: prompts, + num_steps: transitionSteps, + temporal_interpolation_method: + temporalInterpolationMethod, + }, + }); + } else { + // Send direct prompts without transition + sendParameterUpdate({ + prompts, + prompt_interpolation_method: interpolationMethod, + denoising_step_list: settings.denoisingSteps || [ + 700, 500, + ], + }); + } + }} + onPromptItemsSubmit={( + prompts, + blockTransitionSteps, + blockTemporalInterpolationMethod + ) => { + // Update the left panel's prompt state to reflect current timeline prompt blend + setPromptItems(prompts); + + // Use transition params from block if provided, otherwise use global settings + const effectiveTransitionSteps = + blockTransitionSteps ?? transitionSteps; + const effectiveTemporalInterpolationMethod = + blockTemporalInterpolationMethod ?? + temporalInterpolationMethod; + + // Update the left panel's transition settings to reflect current block's values + if (blockTransitionSteps !== undefined) { + setTransitionSteps(blockTransitionSteps); + } + if (blockTemporalInterpolationMethod !== undefined) { + setTemporalInterpolationMethod( + blockTemporalInterpolationMethod + ); + } + + // Send to backend - use transition if streaming and transition steps > 0 + if (isStreaming && effectiveTransitionSteps > 0) { + sendParameterUpdate({ + transition: { + target_prompts: prompts, + num_steps: effectiveTransitionSteps, + temporal_interpolation_method: + effectiveTemporalInterpolationMethod, + }, + }); + } else { + // Send direct prompts without transition + sendParameterUpdate({ + prompts, + prompt_interpolation_method: interpolationMethod, + denoising_step_list: settings.denoisingSteps || [ + 700, 500, + ], + }); + } + }} + disabled={ + isPipelineLoading || isConnecting || showDownloadDialog } - }} - disabled={ - settings.pipelineId === "passthrough" || - isPipelineLoading || - isConnecting || - showDownloadDialog - } - isStreaming={isStreaming} - isVideoPaused={settings.paused} - timelineRef={timelineRef} - onLiveStateChange={setIsLive} - onLivePromptSubmit={handleLivePromptSubmit} - onDisconnect={stopStream} - onStartStream={handleStartStream} - onVideoPlayPauseToggle={handlePlayPauseToggle} - onPromptEdit={handleTimelinePromptEdit} - isCollapsed={isTimelineCollapsed} - onCollapseToggle={setIsTimelineCollapsed} - externalSelectedPromptId={externalSelectedPromptId} - settings={settings} - onSettingsImport={updateSettings} - onPlayPauseRef={timelinePlayPauseRef} - onVideoPlayingCallbackRef={onVideoPlayingCallbackRef} - onResetCache={handleResetCache} - onTimelinePromptsChange={handleTimelinePromptsChange} - onTimelineCurrentTimeChange={handleTimelineCurrentTimeChange} - onTimelinePlayingChange={handleTimelinePlayingChange} - isLoading={isLoading} - /> -
+ isStreaming={isStreaming} + isVideoPaused={settings.paused} + timelineRef={timelineRef} + onLiveStateChange={setIsLive} + onLivePromptSubmit={handleLivePromptSubmit} + onDisconnect={stopStream} + onStartStream={handleStartStream} + onVideoPlayPauseToggle={handlePlayPauseToggle} + onPromptEdit={handleTimelinePromptEdit} + isCollapsed={isTimelineCollapsed} + onCollapseToggle={setIsTimelineCollapsed} + externalSelectedPromptId={externalSelectedPromptId} + settings={settings} + onSettingsImport={updateSettings} + onPlayPauseRef={timelinePlayPauseRef} + onVideoPlayingCallbackRef={onVideoPlayingCallbackRef} + onResetCache={handleResetCache} + onTimelinePromptsChange={handleTimelinePromptsChange} + onTimelineCurrentTimeChange={handleTimelineCurrentTimeChange} + onTimelinePlayingChange={handleTimelinePlayingChange} + isLoading={isLoading} + /> +
+ )}
{/* Right Panel - Settings */}
=0.6.2", "huggingface_hub>=0.25.0", + "pluggy>=1.5.0", + "click>=8.3.1", "peft>=0.17.1", "torchao==0.13.0", "kernels>=0.10.4", @@ -66,6 +68,9 @@ Homepage = "https://github.com/daydreamlive/scope" Repository = "https://github.com/daydreamlive/scope" Issues = "https://github.com/daydreamlive/scope/issues" +[tool.uv] +preview = true + [tool.uv.extra-build-dependencies] flash-attn = [{ requirement = "torch", match-runtime = true, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }] diff --git a/src/scope/core/__init__.py b/src/scope/core/__init__.py index e69de29b..38baa921 100644 --- a/src/scope/core/__init__.py +++ b/src/scope/core/__init__.py @@ -0,0 +1,5 @@ +"""Core functionality for Scope.""" + +from scope.core.plugins import hookimpl + +__all__ = ["hookimpl"] diff --git a/src/scope/core/pipelines/krea_realtime_video/pipeline.py b/src/scope/core/pipelines/krea_realtime_video/pipeline.py index 0db8065c..5ab26cca 100644 --- a/src/scope/core/pipelines/krea_realtime_video/pipeline.py +++ b/src/scope/core/pipelines/krea_realtime_video/pipeline.py @@ -21,7 +21,6 @@ from ..wan2_1.lora.mixin import LoRAEnabledPipeline from ..wan2_1.vae import WanVAEWrapper from .modular_blocks import KreaRealtimeVideoBlocks -from .modules.causal_model import CausalWanModel if TYPE_CHECKING: from ..schema import BasePipelineConfig @@ -49,6 +48,8 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype = torch.bfloat16, ): + from .modules.causal_model import CausalWanModel + model_dir = getattr(config, "model_dir", None) generator_path = getattr(config, "generator_path", None) text_encoder_path = getattr(config, "text_encoder_path", None) diff --git a/src/scope/core/pipelines/longlive/pipeline.py b/src/scope/core/pipelines/longlive/pipeline.py index 67d76b9a..b2dfe99f 100644 --- a/src/scope/core/pipelines/longlive/pipeline.py +++ b/src/scope/core/pipelines/longlive/pipeline.py @@ -23,7 +23,6 @@ from ..wan2_1.vace import VACEEnabledPipeline from ..wan2_1.vae import WanVAEWrapper from .modular_blocks import LongLiveBlocks -from .modules.causal_model import CausalWanModel if TYPE_CHECKING: from ..schema import BasePipelineConfig @@ -45,6 +44,8 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype = torch.bfloat16, ): + from .modules.causal_model import CausalWanModel + model_dir = getattr(config, "model_dir", None) generator_path = getattr(config, "generator_path", None) lora_path = getattr(config, "lora_path", None) diff --git a/src/scope/core/pipelines/registry.py b/src/scope/core/pipelines/registry.py index 493804be..8c43b949 100644 --- a/src/scope/core/pipelines/registry.py +++ b/src/scope/core/pipelines/registry.py @@ -5,12 +5,18 @@ metadata retrieval. """ +import importlib +import logging from typing import TYPE_CHECKING +import torch + if TYPE_CHECKING: from .interface import Pipeline from .schema import BasePipelineConfig +logger = logging.getLogger(__name__) + class PipelineRegistry: """Registry for managing available pipelines.""" @@ -64,27 +70,122 @@ def list_pipelines(cls) -> list[str]: return list(cls._pipelines.keys()) +def _get_gpu_vram_gb() -> float | None: + """Get total GPU VRAM in GB if available. + + Returns: + Total VRAM in GB if GPU is available, None otherwise + """ + try: + if torch.cuda.is_available(): + _, total_mem = torch.cuda.mem_get_info(0) + return total_mem / (1024**3) + except Exception as e: + logger.warning(f"Failed to get GPU VRAM info: {e}") + return None + + +def _should_register_pipeline( + estimated_vram_gb: float | None, vram_gb: float | None +) -> bool: + """Determine if a pipeline should be registered based on GPU requirements. + + Args: + estimated_vram_gb: Estimated/required VRAM in GB from pipeline config, + or None if no requirement + vram_gb: Total GPU VRAM in GB, or None if no GPU + + Returns: + True if the pipeline should be registered, False otherwise + """ + return estimated_vram_gb is None or vram_gb is not None + + # Register all available pipelines def _register_pipelines(): - """Register all built-in pipelines.""" - # Import lazily to avoid circular imports and heavy dependencies - from .krea_realtime_video.pipeline import KreaRealtimeVideoPipeline - from .longlive.pipeline import LongLivePipeline - from .passthrough.pipeline import PassthroughPipeline - from .reward_forcing.pipeline import RewardForcingPipeline - from .streamdiffusionv2.pipeline import StreamDiffusionV2Pipeline - - # Register each pipeline with its ID from its config class - for pipeline_class in [ - LongLivePipeline, - KreaRealtimeVideoPipeline, - StreamDiffusionV2Pipeline, - PassthroughPipeline, - RewardForcingPipeline, - ]: - config_class = pipeline_class.get_config_class() - PipelineRegistry.register(config_class.pipeline_id, pipeline_class) + """Register pipelines based on GPU availability and requirements.""" + # Check GPU VRAM + vram_gb = _get_gpu_vram_gb() + + if vram_gb is not None: + logger.info(f"GPU detected with {vram_gb:.1f} GB VRAM") + else: + logger.info("No GPU detected") + + # Define pipeline imports with their module paths and class names + pipeline_configs = [ + ("passthrough", ".passthrough.pipeline", "PassthroughPipeline"), + ("longlive", ".longlive.pipeline", "LongLivePipeline"), + ( + "krea_realtime_video", + ".krea_realtime_video.pipeline", + "KreaRealtimeVideoPipeline", + ), + ( + "streamdiffusionv2", + ".streamdiffusionv2.pipeline", + "StreamDiffusionV2Pipeline", + ), + ( + "reward_forcing", + ".reward_forcing.pipeline", + "RewardForcingPipeline", + ), + ] + + # Try to import and register each pipeline + for pipeline_name, module_path, class_name in pipeline_configs: + # Try to import the pipeline first to get its config + try: + module = importlib.import_module(module_path, package=__package__) + pipeline_class = getattr(module, class_name) + + # Get the config class to check VRAM requirements + config_class = pipeline_class.get_config_class() + estimated_vram_gb = config_class.estimated_vram_gb + + # Check if pipeline meets GPU requirements + should_register = _should_register_pipeline(estimated_vram_gb, vram_gb) + if not should_register: + logger.debug( + f"Skipping {pipeline_name} pipeline - " + f"does not meet GPU requirements " + f"(required: {estimated_vram_gb} GB, " + f"available: {vram_gb} GB)" + ) + continue + + # Register the pipeline + PipelineRegistry.register(config_class.pipeline_id, pipeline_class) + logger.debug( + f"Registered {pipeline_name} pipeline (ID: {config_class.pipeline_id})" + ) + except ImportError as e: + logger.warning( + f"Could not import {pipeline_name} pipeline: {e}. " + f"This pipeline will not be available." + ) + except Exception as e: + logger.warning( + f"Error loading {pipeline_name} pipeline: {e}. " + f"This pipeline will not be available." + ) + + +def _initialize_registry(): + """Initialize registry with built-in pipelines and plugins.""" + # Register built-in pipelines first + _register_pipelines() + + # Load and register plugin pipelines + from scope.core.plugins import load_plugins, register_plugin_pipelines + + load_plugins() + register_plugin_pipelines(PipelineRegistry) + + pipeline_count = len(PipelineRegistry.list_pipelines()) + logger.info(f"Registry initialized with {pipeline_count} pipeline(s)") # Auto-register pipelines on module import -_register_pipelines() +_initialize_registry() diff --git a/src/scope/core/pipelines/reward_forcing/pipeline.py b/src/scope/core/pipelines/reward_forcing/pipeline.py index 094f40bc..ddb28192 100644 --- a/src/scope/core/pipelines/reward_forcing/pipeline.py +++ b/src/scope/core/pipelines/reward_forcing/pipeline.py @@ -22,7 +22,6 @@ from ..wan2_1.vace.mixin import VACEEnabledPipeline from ..wan2_1.vae import WanVAEWrapper from .modular_blocks import RewardForcingBlocks -from .modules.causal_model import CausalWanModel if TYPE_CHECKING: from ..schema import BasePipelineConfig @@ -44,6 +43,8 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype = torch.bfloat16, ): + from .modules.causal_model import CausalWanModel + model_dir = getattr(config, "model_dir", None) generator_path = getattr(config, "generator_path", None) text_encoder_path = getattr(config, "text_encoder_path", None) diff --git a/src/scope/core/pipelines/schema.py b/src/scope/core/pipelines/schema.py index 73b9e68a..9cfec81a 100644 --- a/src/scope/core/pipelines/schema.py +++ b/src/scope/core/pipelines/schema.py @@ -59,11 +59,23 @@ class BasePipelineConfig(BaseModel): pipeline_name: ClassVar[str] = "Base Pipeline" pipeline_description: ClassVar[str] = "Base pipeline configuration" pipeline_version: ClassVar[str] = "1.0.0" + docs_url: ClassVar[str | None] = None + estimated_vram_gb: ClassVar[float | None] = None + requires_models: ClassVar[bool] = False + supports_lora: ClassVar[bool] = False + supports_vace: ClassVar[bool] = False # Mode support - override in subclasses supported_modes: ClassVar[list[InputMode]] = ["text"] default_mode: ClassVar[InputMode] = "text" + # Prompt and temporal interpolation support + supports_prompts: ClassVar[bool] = True + default_temporal_interpolation_method: ClassVar[Literal["linear", "slerp"]] = ( + "slerp" + ) + default_temporal_interpolation_steps: ClassVar[int] = 0 + # Resolution settings height: int = Field(default=512, ge=1, description="Output height in pixels") width: int = Field(default=512, ge=1, description="Output width in pixels") @@ -154,16 +166,23 @@ def get_schema_with_metadata(cls) -> dict[str, Any]: This is the primary method for API/UI schema generation. Returns: - Dict containing: - - Pipeline metadata (id, name, description, version) - - supported_modes: List of supported input modes - - default_mode: Default input mode - - mode_defaults: Dict of mode-specific default overrides - - config_schema: Full JSON schema for the config model + Dict containing pipeline metadata """ metadata = cls.get_pipeline_metadata() metadata["supported_modes"] = cls.supported_modes metadata["default_mode"] = cls.default_mode + metadata["supports_prompts"] = cls.supports_prompts + metadata["default_temporal_interpolation_method"] = ( + cls.default_temporal_interpolation_method + ) + metadata["default_temporal_interpolation_steps"] = ( + cls.default_temporal_interpolation_steps + ) + metadata["docs_url"] = cls.docs_url + metadata["estimated_vram_gb"] = cls.estimated_vram_gb + metadata["requires_models"] = cls.requires_models + metadata["supports_lora"] = cls.supports_lora + metadata["supports_vace"] = cls.supports_vace metadata["config_schema"] = cls.model_json_schema() # Include mode-specific defaults if defined @@ -198,8 +217,17 @@ class LongLiveConfig(BasePipelineConfig): pipeline_id: ClassVar[str] = "longlive" pipeline_name: ClassVar[str] = "LongLive" pipeline_description: ClassVar[str] = ( - "Long-form video generation with temporal consistency" + "A streaming pipeline and autoregressive video diffusion model from Nvidia, MIT, HKUST, HKU and THU. " + "The model is trained using Self-Forcing on Wan2.1 1.3b with modifications to support smoother prompt " + "switching and improved quality over longer time periods while maintaining fast generation." + ) + docs_url: ClassVar[str | None] = ( + "https://github.com/daydreamlive/scope/blob/main/src/scope/core/pipelines/longlive/docs/usage.md" ) + estimated_vram_gb: ClassVar[float | None] = 20.0 + requires_models: ClassVar[bool] = True + supports_lora: ClassVar[bool] = True + supports_vace: ClassVar[bool] = True # Mode support supported_modes: ClassVar[list[InputMode]] = ["text", "video"] @@ -258,10 +286,19 @@ class StreamDiffusionV2Config(BasePipelineConfig): """ pipeline_id: ClassVar[str] = "streamdiffusionv2" - pipeline_name: ClassVar[str] = "StreamDiffusion V2" + pipeline_name: ClassVar[str] = "StreamDiffusionV2" pipeline_description: ClassVar[str] = ( - "Real-time video-to-video generation with temporal consistency" + "A streaming pipeline and autoregressive video diffusion model from the creators of the original " + "StreamDiffusion project. The model is trained using Self-Forcing on Wan2.1 1.3b with modifications " + "to support streaming." ) + docs_url: ClassVar[str | None] = ( + "https://github.com/daydreamlive/scope/blob/main/src/scope/core/pipelines/streamdiffusionv2/docs/usage.md" + ) + estimated_vram_gb: ClassVar[float | None] = 20.0 + requires_models: ClassVar[bool] = True + supports_lora: ClassVar[bool] = True + supports_vace: ClassVar[bool] = True # Mode support supported_modes: ClassVar[list[InputMode]] = ["text", "video"] @@ -329,8 +366,20 @@ class KreaRealtimeVideoConfig(BasePipelineConfig): pipeline_id: ClassVar[str] = "krea-realtime-video" pipeline_name: ClassVar[str] = "Krea Realtime Video" pipeline_description: ClassVar[str] = ( - "High-quality real-time video generation with 14B model" + "A streaming pipeline and autoregressive video diffusion model from Krea. " + "The model is trained using Self-Forcing on Wan2.1 14b." + ) + docs_url: ClassVar[str | None] = ( + "https://github.com/daydreamlive/scope/blob/main/src/scope/core/pipelines/krea_realtime_video/docs/usage.md" ) + estimated_vram_gb: ClassVar[float | None] = 32.0 + requires_models: ClassVar[bool] = True + supports_lora: ClassVar[bool] = True + + default_temporal_interpolation_method: ClassVar[Literal["linear", "slerp"]] = ( + "linear" + ) + default_temporal_interpolation_steps: ClassVar[int] = 4 # Mode support supported_modes: ClassVar[list[InputMode]] = ["text", "video"] @@ -379,8 +428,16 @@ class RewardForcingConfig(BasePipelineConfig): pipeline_id: ClassVar[str] = "reward-forcing" pipeline_name: ClassVar[str] = "RewardForcing" pipeline_description: ClassVar[str] = ( - "Efficient streaming video generation with rewarded distribution matching distillation" + "A streaming pipeline and autoregressive video diffusion model from ZJU, Ant Group, SIAS-ZJU, HUST and SJTU. " + "The model is trained with Rewarded Distribution Matching Distillation using Wan2.1 1.3b as the base model." + ) + docs_url: ClassVar[str | None] = ( + "https://github.com/daydreamlive/scope/blob/main/src/scope/core/pipelines/reward_forcing/docs/usage.md" ) + estimated_vram_gb: ClassVar[float | None] = 20.0 + requires_models: ClassVar[bool] = True + supports_lora: ClassVar[bool] = True + supports_vace: ClassVar[bool] = True # Mode support supported_modes: ClassVar[list[InputMode]] = ["text", "video"] @@ -439,12 +496,17 @@ class PassthroughConfig(BasePipelineConfig): pipeline_id: ClassVar[str] = "passthrough" pipeline_name: ClassVar[str] = "Passthrough" - pipeline_description: ClassVar[str] = "Passthrough pipeline for testing" + pipeline_description: ClassVar[str] = ( + "A pipeline that returns the input video without any processing that is useful for testing and debugging." + ) # Mode support - video only supported_modes: ClassVar[list[InputMode]] = ["video"] default_mode: ClassVar[InputMode] = "video" + # Does not support prompts + supports_prompts: ClassVar[bool] = False + # Passthrough defaults - requires video input (distinct from StreamDiffusionV2) height: int = Field(default=512, ge=1, description="Output height in pixels") width: int = Field(default=512, ge=1, description="Output width in pixels") diff --git a/src/scope/core/pipelines/streamdiffusionv2/pipeline.py b/src/scope/core/pipelines/streamdiffusionv2/pipeline.py index 55085eea..0055f96f 100644 --- a/src/scope/core/pipelines/streamdiffusionv2/pipeline.py +++ b/src/scope/core/pipelines/streamdiffusionv2/pipeline.py @@ -22,7 +22,6 @@ from ..wan2_1.vace import VACEEnabledPipeline from .components import StreamDiffusionV2WanVAEWrapper from .modular_blocks import StreamDiffusionV2Blocks -from .modules.causal_model import CausalWanModel if TYPE_CHECKING: from ..schema import BasePipelineConfig @@ -44,6 +43,8 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype = torch.bfloat16, ): + from .modules.causal_model import CausalWanModel + model_dir = getattr(config, "model_dir", None) generator_path = getattr(config, "generator_path", None) text_encoder_path = getattr(config, "text_encoder_path", None) diff --git a/src/scope/core/plugins/__init__.py b/src/scope/core/plugins/__init__.py new file mode 100644 index 00000000..25ae657c --- /dev/null +++ b/src/scope/core/plugins/__init__.py @@ -0,0 +1,19 @@ +"""Plugin system for Scope.""" + +from .hookspecs import hookimpl +from .manager import ( + load_plugins, + pm, + register_plugin_artifacts, + register_plugin_pipelines, + register_plugin_routes, +) + +__all__ = [ + "hookimpl", + "load_plugins", + "pm", + "register_plugin_artifacts", + "register_plugin_pipelines", + "register_plugin_routes", +] diff --git a/src/scope/core/plugins/hookspecs.py b/src/scope/core/plugins/hookspecs.py new file mode 100644 index 00000000..4cb3f2e6 --- /dev/null +++ b/src/scope/core/plugins/hookspecs.py @@ -0,0 +1,69 @@ +"""Hook specifications for the Scope plugin system.""" + +import pluggy + +hookspec = pluggy.HookspecMarker("scope") +hookimpl = pluggy.HookimplMarker("scope") + + +class ScopeHookSpec: + """Hook specifications for Scope plugins.""" + + @hookspec + def register_pipelines(self, register): + """Register custom pipeline implementations. + + Args: + register: Callback to register pipeline classes. + Usage: register(PipelineClass) + + Example: + @scope.core.hookimpl + def register_pipelines(register): + register(MyPipeline) + """ + + @hookspec + def register_artifacts(self, register): + """Register model artifacts for download. + + Allows plugins to declare which model files need to be downloaded + for their pipelines. These artifacts will be downloaded when the + user requests model download for the pipeline. + + Args: + register: Callback to register artifacts. + Usage: register(pipeline_id, [Artifact, ...]) + + Example: + from scope.server.artifacts import HuggingfaceRepoArtifact + + @scope.core.hookimpl + def register_artifacts(register): + register("my-pipeline", [ + HuggingfaceRepoArtifact( + repo_id="user/model-repo", + files=["model.safetensors"], + ), + ]) + """ + + @hookspec + def register_routes(self, app): + """Register custom API routes with the FastAPI application. + + Allows plugins to add custom HTTP endpoints for pipeline-specific + functionality (e.g., uploading reference images, custom configuration). + + Args: + app: FastAPI application instance + + Example: + from fastapi import HTTPException + + @scope.core.hookimpl + def register_routes(app): + @app.post("/api/v1/my-pipeline/custom-endpoint") + async def my_custom_endpoint(): + return {"status": "ok"} + """ diff --git a/src/scope/core/plugins/manager.py b/src/scope/core/plugins/manager.py new file mode 100644 index 00000000..c5069aae --- /dev/null +++ b/src/scope/core/plugins/manager.py @@ -0,0 +1,66 @@ +"""Plugin manager for discovering and loading Scope plugins.""" + +import logging + +import pluggy + +from .hookspecs import ScopeHookSpec + +logger = logging.getLogger(__name__) + +# Create the plugin manager singleton +pm = pluggy.PluginManager("scope") +pm.add_hookspecs(ScopeHookSpec) + + +def load_plugins(): + """Discover and load all plugins via entry points.""" + pm.load_setuptools_entrypoints("scope") + logger.info(f"Loaded {len(pm.get_plugins())} plugin(s)") + + +def register_plugin_pipelines(registry): + """Call register_pipelines hook for all plugins. + + Args: + registry: PipelineRegistry to register pipelines with + """ + + def register_callback(pipeline_class): + """Callback function passed to plugins.""" + config_class = pipeline_class.get_config_class() + pipeline_id = config_class.pipeline_id + registry.register(pipeline_id, pipeline_class) + logger.info(f"Registered plugin pipeline: {pipeline_id}") + + pm.hook.register_pipelines(register=register_callback) + + +def register_plugin_artifacts(artifacts_dict: dict): + """Call register_artifacts hook for all plugins. + + Args: + artifacts_dict: Dictionary to populate with pipeline_id -> [Artifact] mappings + """ + + def register_callback(pipeline_id: str, artifacts: list): + """Callback function passed to plugins.""" + if pipeline_id in artifacts_dict: + logger.warning( + f"Overwriting existing artifacts for pipeline: {pipeline_id}" + ) + artifacts_dict[pipeline_id] = artifacts + logger.info( + f"Registered {len(artifacts)} artifact(s) for plugin pipeline: {pipeline_id}" + ) + + pm.hook.register_artifacts(register=register_callback) + + +def register_plugin_routes(app): + """Call register_routes hook for all plugins. + + Args: + app: FastAPI application instance + """ + pm.hook.register_routes(app=app) diff --git a/src/scope/server/app.py b/src/scope/server/app.py index 16f0e453..a30cbc01 100644 --- a/src/scope/server/app.py +++ b/src/scope/server/app.py @@ -1,18 +1,22 @@ -import argparse import asyncio +import contextlib +import io import logging import os import subprocess import sys import threading import time +import warnings import webbrowser from contextlib import asynccontextmanager from datetime import datetime +from functools import wraps from importlib.metadata import version from logging.handlers import RotatingFileHandler from pathlib import Path +import click import torch import uvicorn from fastapi import Depends, FastAPI, HTTPException, Query, Request @@ -117,6 +121,28 @@ def filter(self, record): logger = logging.getLogger(__name__) +def suppress_init_output(func): + """Decorator to suppress all initialization output (logging, warnings, stdout/stderr).""" + + @wraps(func) + def wrapper(*args, **kwargs): + with ( + contextlib.redirect_stdout(io.StringIO()), + contextlib.redirect_stderr(io.StringIO()), + warnings.catch_warnings(), + ): + warnings.simplefilter("ignore") + # Temporarily disable all logging + logging.disable(logging.CRITICAL) + try: + return func(*args, **kwargs) + finally: + # Re-enable logging + logging.disable(logging.NOTSET) + + return wrapper + + def get_git_commit_hash() -> str: """ Get the current git commit hash. @@ -218,6 +244,12 @@ async def lifespan(app: FastAPI): pipeline_manager = PipelineManager() logger.info("Pipeline manager initialized") + # Register plugin routes (must be done before server starts accepting requests) + from scope.core.plugins import register_plugin_routes + + register_plugin_routes(app) + logger.info("Plugin routes registered") + # Pre-warm the default pipeline if PIPELINE is not None: asyncio.create_task(prewarm_pipeline(PIPELINE)) @@ -346,15 +378,17 @@ async def get_pipeline_schemas(): The frontend should use this as the source of truth for parameter defaults. """ - from scope.core.pipelines.schema import PIPELINE_CONFIGS + from scope.core.pipelines.registry import PipelineRegistry pipelines: dict = {} - for pipeline_id, config_class in PIPELINE_CONFIGS.items(): - # get_schema_with_metadata() now includes supported_modes, default_mode, - # and mode_defaults directly from the config class - schema_data = config_class.get_schema_with_metadata() - pipelines[pipeline_id] = schema_data + for pipeline_id in PipelineRegistry.list_pipelines(): + config_class = PipelineRegistry.get_config_class(pipeline_id) + if config_class: + # get_schema_with_metadata() includes supported_modes, default_mode, + # and mode_defaults directly from the config class + schema_data = config_class.get_schema_with_metadata() + pipelines[pipeline_id] = schema_data return PipelineSchemasResponse(pipelines=pipelines) @@ -845,40 +879,12 @@ def open_browser_when_ready(host: str, port: int, server): logger.info(f"🌐 UI is available at: {url}") -def main(): - """Main entry point for the daydream-scope command.""" - parser = argparse.ArgumentParser( - description="A tool for running and customizing real-time, interactive generative AI pipelines and models" - ) - parser.add_argument( - "--version", - action="store_true", - help="Show version information and exit", - ) - parser.add_argument( - "--reload", - action="store_true", - help="Enable auto-reload for development (default: False)", - ) - parser.add_argument( - "--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=8000, help="Port to bind to (default: 8000)" - ) - parser.add_argument( - "-N", - "--no-browser", - action="store_true", - help="Do not automatically open a browser window after the server starts", - ) - - args = parser.parse_args() +def run_server(reload: bool, host: str, port: int, no_browser: bool): + """Run the Daydream Scope server.""" - # Handle version flag - if args.version: - print_version_info() - sys.exit(0) + from scope.core.pipelines.registry import ( + PipelineRegistry, # noqa: F401 - imported for side effects (registry initialization) + ) # Configure static file serving configure_static_files() @@ -891,18 +897,18 @@ def main(): # Create server instance for production mode config = uvicorn.Config( "scope.server.app:app", - host=args.host, - port=args.port, - reload=args.reload, + host=host, + port=port, + reload=reload, log_config=None, # Use our logging config, don't override it ) server = uvicorn.Server(config) # Start browser opening thread (unless disabled) - if not args.no_browser: + if not no_browser: browser_thread = threading.Thread( target=open_browser_when_ready, - args=(args.host, args.port, server), + args=(host, port, server), daemon=True, ) browser_thread.start() @@ -918,12 +924,133 @@ def main(): # Development mode - just run normally uvicorn.run( "scope.server.app:app", - host=args.host, - port=args.port, - reload=args.reload, + host=host, + port=port, + reload=reload, log_config=None, # Use our logging config, don't override it ) +@click.group(invoke_without_command=True) +@click.option("--version", is_flag=True, help="Show version information and exit") +@click.option( + "--reload", is_flag=True, help="Enable auto-reload for development (default: False)" +) +@click.option("--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)") +@click.option("--port", default=8000, help="Port to bind to (default: 8000)") +@click.option( + "-N", + "--no-browser", + is_flag=True, + help="Do not automatically open a browser window after the server starts", +) +@click.pass_context +def main(ctx, version: bool, reload: bool, host: str, port: int, no_browser: bool): + # Handle version flag + if version: + print_version_info() + sys.exit(0) + + # If no subcommand was invoked, run the server + if ctx.invoked_subcommand is None: + run_server(reload, host, port, no_browser) + + +@main.command() +def plugins(): + """List all installed plugins.""" + + @suppress_init_output + def _load_plugins(): + from scope.core.plugins import load_plugins, pm + + load_plugins() + return pm.get_plugins() + + plugin_list = _load_plugins() + + if not plugin_list: + click.echo("No plugins installed.") + return + + click.echo(f"{len(plugin_list)} plugin(s) installed:\n") + + # List each plugin + for plugin in plugin_list: + plugin_name = plugin.__name__ if hasattr(plugin, "__name__") else str(plugin) + click.echo(f" • {plugin_name}") + + +@main.command() +def pipelines(): + """List all available pipelines.""" + + @suppress_init_output + def _load_pipelines(): + from scope.core.pipelines.registry import PipelineRegistry + + return PipelineRegistry.list_pipelines() + + all_pipelines = _load_pipelines() + + if not all_pipelines: + click.echo("No pipelines available.") + return + + click.echo(f"{len(all_pipelines)} pipeline(s) available:\n") + + # List all pipelines + for pipeline_id in all_pipelines: + click.echo(f" • {pipeline_id}") + + +@main.command() +@click.argument("packages", nargs=-1, required=False) +@click.option("--upgrade", is_flag=True, help="Upgrade packages to the latest version") +@click.option( + "-e", "--editable", help="Install a project in editable mode from this path" +) +@click.option("--force-reinstall", is_flag=True, help="Force reinstall packages") +@click.option("--no-cache-dir", is_flag=True, help="Disable the cache") +@click.option( + "--pre", is_flag=True, help="Include pre-release and development versions" +) +def install(packages, upgrade, editable, force_reinstall, no_cache_dir, pre): + """Install a plugin.""" + args = ["uv", "pip", "install"] + if upgrade: + args.append("--upgrade") + if editable: + args += ["--editable", editable] + if force_reinstall: + args.append("--force-reinstall") + if no_cache_dir: + args.append("--no-cache-dir") + if pre: + args.append("--pre") + args += list(packages) + + result = subprocess.run(args, capture_output=False) + + if result.returncode != 0: + sys.exit(result.returncode) + + +@main.command() +@click.argument("packages", nargs=-1, required=True) +@click.option("-y", "--yes", is_flag=True, help="Don't ask for confirmation") +def uninstall(packages, yes): + """Uninstall a plugin.""" + args = ["uv", "pip", "uninstall"] + args += list(packages) + if yes: + args.append("-y") + + result = subprocess.run(args, capture_output=False) + + if result.returncode != 0: + sys.exit(result.returncode) + + if __name__ == "__main__": main() diff --git a/src/scope/server/artifacts.py b/src/scope/server/artifacts.py index 8ad1390e..28a67804 100644 --- a/src/scope/server/artifacts.py +++ b/src/scope/server/artifacts.py @@ -19,7 +19,10 @@ class HuggingfaceRepoArtifact(Artifact): repo_id: HuggingFace repository ID (e.g., "Wan-AI/Wan2.1-T2V-1.3B") files: List of files or directories to download Directories should be specified by their name (e.g., "google", "models") + local_dir: Optional custom local directory path relative to models_root. + If not specified, defaults to the last part of repo_id. """ repo_id: str files: list[str] + local_dir: str | None = None diff --git a/src/scope/server/download_models.py b/src/scope/server/download_models.py index 0ea4915b..f1dc9c2b 100644 --- a/src/scope/server/download_models.py +++ b/src/scope/server/download_models.py @@ -338,7 +338,11 @@ def download_hf_artifact( models_root: Root directory where models are stored pipeline_id: Pipeline ID to download models for """ - local_dir = models_root / artifact.repo_id.split("/")[-1] + # Use custom local_dir if specified, otherwise default to repo name + if artifact.local_dir: + local_dir = models_root / artifact.local_dir + else: + local_dir = models_root / artifact.repo_id.split("/")[-1] # Convert file/directory specifications to glob patterns allow_patterns = [] @@ -367,13 +371,26 @@ def download_models(pipeline_id: str) -> None: Args: pipeline_id: Pipeline ID to download models for. + + Raises: + KeyError: If pipeline_id has no registered artifacts """ - from .pipeline_artifacts import PIPELINE_ARTIFACTS + from .pipeline_artifacts import get_pipeline_artifacts models_root = ensure_models_dir() logger.info(f"Downloading models for pipeline: {pipeline_id}") - artifacts = PIPELINE_ARTIFACTS[pipeline_id] + + # Get all artifacts including plugin-registered ones + all_artifacts = get_pipeline_artifacts() + + if pipeline_id not in all_artifacts: + raise KeyError( + f"No artifacts registered for pipeline '{pipeline_id}'. " + f"Available pipelines: {list(all_artifacts.keys())}" + ) + + artifacts = all_artifacts[pipeline_id] # Download each artifact (progress tracking starts in set_download_context) for artifact in artifacts: @@ -399,6 +416,8 @@ def main(): python download_models.py --pipeline krea-realtime-video python download_models.py --pipeline reward-forcing python download_models.py -p streamdiffusionv2 + + # Plugin pipelines can also register artifacts for download """, ) parser.add_argument( @@ -407,7 +426,7 @@ def main(): type=str, default=None, required=True, - help="Pipeline ID (e.g., 'streamdiffusionv2', 'longlive', 'krea-realtime-video', 'reward-forcing').", + help="Pipeline ID to download models for. Use 'daydream-scope pipelines' to list available pipelines.", ) args = parser.parse_args() diff --git a/src/scope/server/models_config.py b/src/scope/server/models_config.py index ace85ffc..c33d6d3d 100644 --- a/src/scope/server/models_config.py +++ b/src/scope/server/models_config.py @@ -110,11 +110,15 @@ def get_required_model_files(pipeline_id: str | None = None) -> list[Path]: required_files = [] for artifact in artifacts: - local_dir_name = artifact.repo_id.split("/")[-1] + # Use custom local_dir if specified, otherwise default to repo name + if artifact.local_dir: + local_dir = artifact.local_dir + else: + local_dir = artifact.repo_id.split("/")[-1] # Add each file from the artifact's files list for file in artifact.files: - required_files.append(models_dir / local_dir_name / file) + required_files.append(models_dir / local_dir / file) return required_files diff --git a/src/scope/server/pipeline_artifacts.py b/src/scope/server/pipeline_artifacts.py index 9a0f750d..a40ba247 100644 --- a/src/scope/server/pipeline_artifacts.py +++ b/src/scope/server/pipeline_artifacts.py @@ -1,8 +1,11 @@ """ Defines which artifacts each pipeline requires. + +Built-in pipelines define their artifacts here. Plugin pipelines can register +their artifacts via the `register_artifacts` hook. """ -from .artifacts import HuggingfaceRepoArtifact +from .artifacts import Artifact, HuggingfaceRepoArtifact # Common artifacts shared across pipelines WAN_1_3B_ARTIFACT = HuggingfaceRepoArtifact( @@ -20,8 +23,8 @@ files=["Wan2_1-VACE_module_1_3B_bf16.safetensors"], ) -# Pipeline-specific artifacts -PIPELINE_ARTIFACTS = { +# Built-in pipeline artifacts +_BUILTIN_PIPELINE_ARTIFACTS: dict[str, list[Artifact]] = { "streamdiffusionv2": [ WAN_1_3B_ARTIFACT, UMT5_ENCODER_ARTIFACT, @@ -62,3 +65,27 @@ ), ], } + + +def get_pipeline_artifacts() -> dict[str, list[Artifact]]: + """Get all pipeline artifacts including those registered by plugins. + + Returns: + Dictionary mapping pipeline_id to list of artifacts + """ + from scope.core.plugins import load_plugins, register_plugin_artifacts + + # Ensure plugins are loaded (idempotent if already loaded) + load_plugins() + + # Start with built-in artifacts + all_artifacts = dict(_BUILTIN_PIPELINE_ARTIFACTS) + + # Register plugin artifacts + register_plugin_artifacts(all_artifacts) + + return all_artifacts + + +# Legacy alias for backward compatibility +PIPELINE_ARTIFACTS = _BUILTIN_PIPELINE_ARTIFACTS diff --git a/src/scope/server/pipeline_manager.py b/src/scope/server/pipeline_manager.py index 383eb995..9f5329b3 100644 --- a/src/scope/server/pipeline_manager.py +++ b/src/scope/server/pipeline_manager.py @@ -318,6 +318,65 @@ def _load_pipeline_implementation( self, pipeline_id: str, load_params: dict | None = None ): """Synchronous pipeline loading (runs in thread executor).""" + from scope.core.pipelines.registry import PipelineRegistry + + # Check if pipeline is in registry + pipeline_class = PipelineRegistry.get(pipeline_id) + + # List of built-in pipelines with custom initialization + BUILTIN_PIPELINES = { + "streamdiffusionv2", + "passthrough", + "longlive", + "krea-realtime-video", + "reward-forcing", + } + + if pipeline_class is not None and pipeline_id not in BUILTIN_PIPELINES: + # Plugin pipeline - instantiate with config object like built-in pipelines + logger.info(f"Loading plugin pipeline: {pipeline_id}") + from .models_config import get_models_dir + + models_dir = get_models_dir() + config = OmegaConf.create( + { + "model_dir": str(models_dir), + } + ) + + # Apply load parameters (resolution, seed, LoRAs) to config + # Get defaults from the pipeline's config class if available + config_class = None + if hasattr(pipeline_class, "get_config_class"): + config_class = pipeline_class.get_config_class() + + default_height = 512 + default_width = 512 + if config_class is not None: + default_height = getattr( + config_class.model_fields.get("height"), "default", 512 + ) or 512 + default_width = getattr( + config_class.model_fields.get("width"), "default", 512 + ) or 512 + + self._apply_load_params( + config, + load_params, + default_height=default_height, + default_width=default_width, + default_seed=42, + ) + + pipeline = pipeline_class( + config, + device=torch.device("cuda"), + dtype=torch.float16, + ) + logger.info(f"Plugin pipeline {pipeline_id} initialized") + return pipeline + + # Fall through to built-in pipeline initialization if pipeline_id == "streamdiffusionv2": from scope.core.pipelines import ( StreamDiffusionV2Pipeline, diff --git a/src/scope/server/schema.py b/src/scope/server/schema.py index ca0bc1f9..5640d76d 100644 --- a/src/scope/server/schema.py +++ b/src/scope/server/schema.py @@ -379,7 +379,12 @@ class KreaRealtimeVideoLoadParams(LoRAEnabledLoadParams): class PipelineLoadRequest(BaseModel): - """Pipeline load request schema.""" + """Pipeline load request schema. + + Note: Plugin pipelines can define their own load parameters. The load_params + field accepts any dict-like structure that will be passed to the pipeline. + Built-in pipelines have typed load params for validation. + """ pipeline_id: str = Field( default="streamdiffusionv2", description="ID of pipeline to load" @@ -389,6 +394,7 @@ class PipelineLoadRequest(BaseModel): | PassthroughLoadParams | LongLiveLoadParams | KreaRealtimeVideoLoadParams + | PipelineLoadParams # Generic fallback for plugin pipelines | None ) = Field(default=None, description="Pipeline-specific load parameters") diff --git a/uv.lock b/uv.lock index 400751ba..a2952b5b 100644 --- a/uv.lock +++ b/uv.lock @@ -532,14 +532,14 @@ wheels = [ [[package]] name = "click" -version = "8.3.0" +version = "8.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/46/61/de6cd827efad202d7057d93e0fed9294b96952e188f7384832791c7b2254/click-8.3.0.tar.gz", hash = "sha256:e7b8232224eba16f4ebe410c25ced9f7875cb5f3263ffc93cc3e8da705e229c4", size = 276943, upload-time = "2025-09-18T17:32:23.696Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/db/d3/9dcc0f5797f070ec8edf30fbadfb200e71d9db6b84d211e3b2085a7589a0/click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc", size = 107295, upload-time = "2025-09-18T17:32:22.42Z" }, + { url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274, upload-time = "2025-11-15T20:45:41.139Z" }, ] [[package]] @@ -623,6 +623,7 @@ source = { editable = "." } dependencies = [ { name = "accelerate" }, { name = "aiortc" }, + { name = "click" }, { name = "diffusers" }, { name = "easydict" }, { name = "einops" }, @@ -636,6 +637,7 @@ dependencies = [ { name = "lmdb" }, { name = "omegaconf" }, { name = "peft" }, + { name = "pluggy" }, { name = "pyopengl", marker = "sys_platform == 'win32'" }, { name = "safetensors" }, { name = "sageattention", version = "2.2.0", source = { url = "https://github.com/daydreamlive/SageAttention/releases/download/v2.2.0-linux/sageattention-2.2.0-cp310-cp310-linux_x86_64.whl" }, marker = "sys_platform == 'linux'" }, @@ -668,6 +670,7 @@ dev = [ requires-dist = [ { name = "accelerate", specifier = ">=1.1.1" }, { name = "aiortc", specifier = ">=1.13.0" }, + { name = "click", specifier = ">=8.3.1" }, { name = "diffusers", specifier = ">=0.31.0" }, { name = "easydict", specifier = ">=1.13" }, { name = "einops", specifier = ">=0.8.1" }, @@ -681,6 +684,7 @@ requires-dist = [ { name = "lmdb", specifier = ">=1.7.3" }, { name = "omegaconf", specifier = ">=2.3.0" }, { name = "peft", specifier = ">=0.17.1" }, + { name = "pluggy", specifier = ">=1.5.0" }, { name = "pyopengl", marker = "sys_platform == 'win32'", specifier = ">=3.1.10" }, { name = "safetensors", specifier = ">=0.6.2" }, { name = "sageattention", marker = "sys_platform == 'linux'", url = "https://github.com/daydreamlive/SageAttention/releases/download/v2.2.0-linux/sageattention-2.2.0-cp310-cp310-linux_x86_64.whl" },