Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions frontend/src/components/SettingsPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import type {
SettingsState,
InputMode,
PipelineInfo,
VaeType,
} from "../types";
import { LoRAManager } from "./LoRAManager";

Expand Down Expand Up @@ -87,6 +88,11 @@ interface SettingsPanelProps {
onVaceEnabledChange?: (enabled: boolean) => void;
vaceContextScale?: number;
onVaceContextScaleChange?: (scale: number) => void;
// VAE type selection
vaeType?: VaeType;
onVaeTypeChange?: (vaeType: VaeType) => void;
// Available VAE types from backend registry
vaeTypes?: string[];
}

export function SettingsPanel({
Expand Down Expand Up @@ -126,6 +132,9 @@ export function SettingsPanel({
onVaceEnabledChange,
vaceContextScale = 1.0,
onVaceContextScaleChange,
vaeType = "wan",
onVaeTypeChange,
vaeTypes = ["wan"],
}: SettingsPanelProps) {
// Local slider state management hooks
const noiseScaleSlider = useLocalSliderValue(noiseScale, onNoiseScaleChange);
Expand Down Expand Up @@ -389,6 +398,37 @@ export function SettingsPanel({
</div>
)}

{/* VAE Type Selection */}
{vaeTypes && vaeTypes.length > 0 && (
<div className="space-y-2">
<div className="flex items-center justify-between gap-2">
<LabelWithTooltip
label={PARAMETER_METADATA.vaeType.label}
tooltip={PARAMETER_METADATA.vaeType.tooltip}
className="text-sm text-foreground"
/>
<Select
value={vaeType}
onValueChange={value => {
onVaeTypeChange?.(value as VaeType);
}}
disabled={isStreaming}
>
<SelectTrigger className="w-[140px] h-7">
<SelectValue />
</SelectTrigger>
<SelectContent>
{vaeTypes.map(type => (
<SelectItem key={type} value={type}>
{type}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
)}

{currentPipeline?.supportsLoRA && (
<div className="space-y-4">
<LoRAManager
Expand Down
5 changes: 5 additions & 0 deletions frontend/src/data/parameterMetadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,9 @@ export const PARAMETER_METADATA: Record<string, ParameterMetadata> = {
tooltip:
"The configuration of the sender that will send video to Spout-compatible apps like TouchDesigner, Resolume, OBS.",
},
vaeType: {
label: "VAE:",
tooltip:
"VAE type to use for encoding/decoding. 'wan' is the full VAE with best quality. 'lightvae' is 75% pruned for faster performance but lower quality.",
},
};
14 changes: 14 additions & 0 deletions frontend/src/hooks/usePipelines.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ export function usePipelines() {
// Transform to camelCase for TypeScript conventions
const transformed: Record<string, PipelineInfo> = {};
for (const [id, schema] of Object.entries(schemas.pipelines)) {
// Extract VAE types from JSON schema if vae_type field exists
// Pydantic v2 represents enum fields using $ref to definitions
let vaeTypes: string[] | undefined = undefined;
const vaeTypeProperty = schema.config_schema?.properties?.vae_type;
if (vaeTypeProperty?.$ref && schema.config_schema?.$defs) {
const refPath = vaeTypeProperty.$ref;
const defName = refPath.split("/").pop();
const definition = schema.config_schema.$defs[defName || ""];
if (definition && Array.isArray(definition.enum)) {
vaeTypes = definition.enum as string[];
}
}

transformed[id] = {
name: schema.name,
about: schema.description,
Expand All @@ -45,6 +58,7 @@ export function usePipelines() {
recommendedQuantizationVramThreshold:
schema.recommended_quantization_vram_threshold ?? undefined,
modified: schema.modified,
vaeTypes,
};
}

Expand Down
3 changes: 3 additions & 0 deletions frontend/src/lib/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,16 @@ export interface PipelineSchemaProperty {
maximum?: number;
items?: unknown;
anyOf?: unknown[];
enum?: unknown[];
$ref?: string;
}

export interface PipelineConfigSchema {
type: string;
properties: Record<string, PipelineSchemaProperty>;
required?: string[];
title?: string;
$defs?: Record<string, { enum?: unknown[] }>;
}

// Mode-specific default overrides
Expand Down
10 changes: 10 additions & 0 deletions frontend/src/pages/StreamPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import type {
LoRAConfig,
LoraMergeStrategy,
DownloadProgress,
VaeType,
} from "../types";
import type { PromptItem, PromptTransition } from "../lib/api";
import { checkModelStatus, downloadPipelineModels } from "../lib/api";
Expand Down Expand Up @@ -453,6 +454,11 @@ export function StreamPage() {
// Note: This setting requires pipeline reload, so we don't send parameter update here
};

const handleVaeTypeChange = (vaeType: VaeType) => {
updateSettings({ vaeType });
// Note: This setting requires pipeline reload, so we don't send parameter update here
};

const handleKvCacheAttentionBiasChange = (bias: number) => {
updateSettings({ kvCacheAttentionBias: bias });
// Send KV cache attention bias update to backend
Expand Down Expand Up @@ -725,6 +731,7 @@ export function StreamPage() {
if (currentPipeline?.supportsQuantization) {
loadParams.seed = settings.seed ?? 42;
loadParams.quantization = settings.quantization ?? null;
loadParams.vae_type = settings.vaeType ?? "wan";
}

// Add LoRA parameters if pipeline supports LoRA
Expand Down Expand Up @@ -1117,6 +1124,9 @@ export function StreamPage() {
onVaceEnabledChange={handleVaceEnabledChange}
vaceContextScale={settings.vaceContextScale ?? 1.0}
onVaceContextScaleChange={handleVaceContextScaleChange}
vaeType={settings.vaeType ?? "wan"}
onVaeTypeChange={handleVaeTypeChange}
vaeTypes={pipelines?.[settings.pipelineId]?.vaeTypes ?? ["wan"]}
/>
</div>
</div>
Expand Down
7 changes: 7 additions & 0 deletions frontend/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ export type PipelineId = string;
// Input mode for pipeline operation
export type InputMode = "text" | "video";

// VAE type for model selection (dynamic from backend registry)
export type VaeType = string;

// WebRTC ICE server configuration
export interface IceServerConfig {
urls: string | string[];
Expand Down Expand Up @@ -73,6 +76,8 @@ export interface SettingsState {
vaceEnabled?: boolean;
refImages?: string[];
vaceContextScale?: number;
// VAE type selection
vaeType?: VaeType;
}

export interface PipelineInfo {
Expand Down Expand Up @@ -100,6 +105,8 @@ export interface PipelineInfo {
supportsQuantization?: boolean;
minDimension?: number;
recommendedQuantizationVramThreshold?: number | null;
// Available VAE types from config schema enum (derived from vae_type field presence)
vaeTypes?: string[];
}

export interface DownloadProgress {
Expand Down
14 changes: 9 additions & 5 deletions src/scope/core/pipelines/krea_realtime_video/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..utils import Quantization, load_model_config, validate_resolution
from ..wan2_1.components import WanDiffusionWrapper, WanTextEncoderWrapper
from ..wan2_1.lora.mixin import LoRAEnabledPipeline
from ..wan2_1.vae import WanVAEWrapper
from ..wan2_1.vae import create_vae
from .modular_blocks import KreaRealtimeVideoBlocks
from .schema import KreaRealtimeVideoConfig

Expand Down Expand Up @@ -128,12 +128,16 @@ def __init__(
# Move text encoder to target device but use dtype of weights
text_encoder = text_encoder.to(device=device)

# Load vae
# Load VAE using create_vae factory (supports multiple VAE types)
vae_type = getattr(config, "vae_type", "wan")
start = time.time()
vae = WanVAEWrapper(
model_name=base_model_name, model_dir=model_dir, vae_path=vae_path
vae = create_vae(
model_dir=model_dir,
model_name=base_model_name,
vae_type=vae_type,
vae_path=vae_path,
)
print(f"Loaded VAE in {time.time() - start:.3f}s")
print(f"Loaded VAE (type={vae_type}) in {time.time() - start:.3f}s")
# Move VAE to target device and use target dtype
vae = vae.to(device=device, dtype=dtype)

Expand Down
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/krea_realtime_video/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pydantic import Field

from ..base_schema import BasePipelineConfig, ModeDefaults
from ..utils import VaeType


class KreaRealtimeVideoConfig(BasePipelineConfig):
Expand Down Expand Up @@ -26,6 +29,10 @@ class KreaRealtimeVideoConfig(BasePipelineConfig):
height: int = 320
width: int = 576
denoising_steps: list[int] = [1000, 750, 500, 250]
vae_type: VaeType = Field(
default=VaeType.WAN,
description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
)

modes = {
"text": ModeDefaults(default=True),
Expand Down
11 changes: 7 additions & 4 deletions src/scope/core/pipelines/longlive/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..wan2_1.lora.mixin import LoRAEnabledPipeline
from ..wan2_1.lora.strategies.module_targeted_lora import ModuleTargetedLoRAStrategy
from ..wan2_1.vace import VACEEnabledPipeline
from ..wan2_1.vae import WanVAEWrapper
from ..wan2_1.vae import create_vae
from .modular_blocks import LongLiveBlocks
from .schema import LongLiveConfig

Expand Down Expand Up @@ -145,10 +145,13 @@ def __init__(
# Move text encoder to target device but use dtype of weights
text_encoder = text_encoder.to(device=device)

# Load VAE using unified WanVAEWrapper
# Load VAE using create_vae factory (supports multiple VAE types)
vae_type = getattr(config, "vae_type", "wan")
start = time.time()
vae = WanVAEWrapper(model_dir=model_dir, model_name=base_model_name)
print(f"Loaded VAE in {time.time() - start:.3f}s")
vae = create_vae(
model_dir=model_dir, model_name=base_model_name, vae_type=vae_type
)
print(f"Loaded VAE (type={vae_type}) in {time.time() - start:.3f}s")
# Move VAE to target device and use target dtype
vae = vae.to(device=device, dtype=dtype)

Expand Down
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/longlive/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pydantic import Field

from ..base_schema import BasePipelineConfig, ModeDefaults
from ..utils import VaeType


class LongLiveConfig(BasePipelineConfig):
Expand All @@ -23,6 +26,10 @@ class LongLiveConfig(BasePipelineConfig):
height: int = 320
width: int = 576
denoising_steps: list[int] = [1000, 750, 500, 250]
vae_type: VaeType = Field(
default=VaeType.WAN,
description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
)

modes = {
"text": ModeDefaults(default=True),
Expand Down
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/streamdiffusionv2/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pydantic import Field

from ..base_schema import BasePipelineConfig, ModeDefaults
from ..utils import VaeType


class StreamDiffusionV2Config(BasePipelineConfig):
Expand All @@ -24,6 +27,10 @@ class StreamDiffusionV2Config(BasePipelineConfig):
noise_scale: float = 0.7
noise_controller: bool = True
input_size: int = 4
vae_type: VaeType = Field(
default=VaeType.WAN,
description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
)

modes = {
"text": ModeDefaults(
Expand Down
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ class Quantization(str, Enum):
FP8_E4M3FN = "fp8_e4m3fn"


class VaeType(str, Enum):
"""VAE type enumeration."""

WAN = "wan"
LIGHTVAE = "lightvae"


def load_state_dict(weights_path: str) -> dict:
"""Load weights with automatic format detection."""
if not os.path.exists(weights_path):
Expand Down
Loading
Loading