From 594ad89c942c7a3b2a64b843cb5c611f5f1da58a Mon Sep 17 00:00:00 2001 From: huiwenshi Date: Sat, 14 Jun 2025 06:29:55 +0000 Subject: [PATCH] Delete hunyuan3d-paintpbr-v2-1/attn_processor.py --- hunyuan3d-paintpbr-v2-1/attn_processor.py | 839 ---------------------- 1 file changed, 839 deletions(-) delete mode 100644 hunyuan3d-paintpbr-v2-1/attn_processor.py diff --git a/hunyuan3d-paintpbr-v2-1/attn_processor.py b/hunyuan3d-paintpbr-v2-1/attn_processor.py deleted file mode 100644 index 37ebde0..0000000 --- a/hunyuan3d-paintpbr-v2-1/attn_processor.py +++ /dev/null @@ -1,839 +0,0 @@ -# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT -# except for the third-party components listed below. -# Hunyuan 3D does not impose any additional limitations beyond what is outlined -# in the repsective licenses of these third-party components. -# Users must comply with all terms and conditions of original licenses of these third-party -# components and must ensure that the usage of the third party components adheres to -# all relevant laws and regulations. - -# For avoidance of doubts, Hunyuan 3D means the large language models and -# their software and algorithms, including trained model weights, parameters (including -# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, -# fine-tuning enabling code and other elements of the foregoing made publicly available -# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Optional, Dict, Tuple, Union, Literal, List, Callable -from einops import rearrange -from diffusers.utils import deprecate -from diffusers.models.attention_processor import Attention, AttnProcessor - - -class AttnUtils: - """ - Shared utility functions for attention processing. - - This class provides common operations used across different attention processors - to eliminate code duplication and improve maintainability. - """ - - @staticmethod - def check_pytorch_compatibility(): - """ - Check PyTorch compatibility for scaled_dot_product_attention. - - Raises: - ImportError: If PyTorch version doesn't support scaled_dot_product_attention - """ - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - @staticmethod - def handle_deprecation_warning(args, kwargs): - """ - Handle deprecation warning for the 'scale' argument. - - Args: - args: Positional arguments passed to attention processor - kwargs: Keyword arguments passed to attention processor - """ - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = ( - "The `scale` argument is deprecated and will be ignored." - "Please remove it, as passing it will raise an error in the future." - "`scale` should directly be passed while calling the underlying pipeline component" - "i.e., via `cross_attention_kwargs`." - ) - deprecate("scale", "1.0.0", deprecation_message) - - @staticmethod - def prepare_hidden_states( - hidden_states, attn, temb, spatial_norm_attr="spatial_norm", group_norm_attr="group_norm" - ): - """ - Common preprocessing of hidden states for attention computation. - - Args: - hidden_states: Input hidden states tensor - attn: Attention module instance - temb: Optional temporal embedding tensor - spatial_norm_attr: Attribute name for spatial normalization - group_norm_attr: Attribute name for group normalization - - Returns: - Tuple of (processed_hidden_states, residual, input_ndim, shape_info) - """ - residual = hidden_states - - spatial_norm = getattr(attn, spatial_norm_attr, None) - if spatial_norm is not None: - hidden_states = spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - else: - batch_size, channel, height, width = None, None, None, None - - group_norm = getattr(attn, group_norm_attr, None) - if group_norm is not None: - hidden_states = group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - return hidden_states, residual, input_ndim, (batch_size, channel, height, width) - - @staticmethod - def prepare_attention_mask(attention_mask, attn, sequence_length, batch_size): - """ - Prepare attention mask for scaled_dot_product_attention. - - Args: - attention_mask: Input attention mask tensor or None - attn: Attention module instance - sequence_length: Length of the sequence - batch_size: Batch size - - Returns: - Prepared attention mask tensor reshaped for multi-head attention - """ - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - return attention_mask - - @staticmethod - def reshape_qkv_for_attention(tensor, batch_size, attn_heads, head_dim): - """ - Reshape Q/K/V tensors for multi-head attention computation. - - Args: - tensor: Input tensor to reshape - batch_size: Batch size - attn_heads: Number of attention heads - head_dim: Dimension per attention head - - Returns: - Reshaped tensor with shape [batch_size, attn_heads, seq_len, head_dim] - """ - return tensor.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - - @staticmethod - def apply_norms(query, key, norm_q, norm_k): - """ - Apply Q/K normalization layers if available. - - Args: - query: Query tensor - key: Key tensor - norm_q: Query normalization layer (optional) - norm_k: Key normalization layer (optional) - - Returns: - Tuple of (normalized_query, normalized_key) - """ - if norm_q is not None: - query = norm_q(query) - if norm_k is not None: - key = norm_k(key) - return query, key - - @staticmethod - def finalize_output(hidden_states, input_ndim, shape_info, attn, residual, to_out): - """ - Common output processing including projection, dropout, reshaping, and residual connection. - - Args: - hidden_states: Processed hidden states from attention - input_ndim: Original input tensor dimensions - shape_info: Tuple containing original shape information - attn: Attention module instance - residual: Residual connection tensor - to_out: Output projection layers [linear, dropout] - - Returns: - Final output tensor after all processing steps - """ - batch_size, channel, height, width = shape_info - - # Apply output projection and dropout - hidden_states = to_out[0](hidden_states) - hidden_states = to_out[1](hidden_states) - - # Reshape back if needed - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # Apply residual connection - if attn.residual_connection: - hidden_states = hidden_states + residual - - # Apply rescaling - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states - - -# Base class for attention processors (eliminating initialization duplication) -class BaseAttnProcessor(nn.Module): - """ - Base class for attention processors with common initialization. - - This base class provides shared parameter initialization and module registration - functionality to reduce code duplication across different attention processor types. - """ - - def __init__( - self, - query_dim: int, - pbr_setting: List[str] = ["albedo", "mr"], - cross_attention_dim: Optional[int] = None, - heads: int = 8, - kv_heads: Optional[int] = None, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - qk_norm: Optional[str] = None, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, - out_dim: int = None, - out_context_dim: int = None, - context_pre_only=None, - pre_only=False, - elementwise_affine: bool = True, - is_causal: bool = False, - **kwargs, - ): - """ - Initialize base attention processor with common parameters. - - Args: - query_dim: Dimension of query features - pbr_setting: List of PBR material types to process (e.g., ["albedo", "mr"]) - cross_attention_dim: Dimension of cross-attention features (optional) - heads: Number of attention heads - kv_heads: Number of key-value heads for grouped query attention (optional) - dim_head: Dimension per attention head - dropout: Dropout rate - bias: Whether to use bias in linear projections - upcast_attention: Whether to upcast attention computation to float32 - upcast_softmax: Whether to upcast softmax computation to float32 - cross_attention_norm: Type of cross-attention normalization (optional) - cross_attention_norm_num_groups: Number of groups for cross-attention norm - qk_norm: Type of query-key normalization (optional) - added_kv_proj_dim: Dimension for additional key-value projections (optional) - added_proj_bias: Whether to use bias in additional projections - norm_num_groups: Number of groups for normalization (optional) - spatial_norm_dim: Dimension for spatial normalization (optional) - out_bias: Whether to use bias in output projection - scale_qk: Whether to scale query-key products - only_cross_attention: Whether to only perform cross-attention - eps: Small epsilon value for numerical stability - rescale_output_factor: Factor to rescale output values - residual_connection: Whether to use residual connections - _from_deprecated_attn_block: Flag for deprecated attention blocks - processor: Optional attention processor instance - out_dim: Output dimension (optional) - out_context_dim: Output context dimension (optional) - context_pre_only: Whether to only process context in pre-processing - pre_only: Whether to only perform pre-processing - elementwise_affine: Whether to use element-wise affine transformations - is_causal: Whether to use causal attention masking - **kwargs: Additional keyword arguments - """ - super().__init__() - AttnUtils.check_pytorch_compatibility() - - # Store common attributes - self.pbr_setting = pbr_setting - self.n_pbr_tokens = len(self.pbr_setting) - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim - self.context_pre_only = context_pre_only - self.pre_only = pre_only - self.is_causal = is_causal - self._from_deprecated_attn_block = _from_deprecated_attn_block - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - self.heads = out_dim // dim_head if out_dim is not None else heads - self.sliceable_head_dim = heads - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - self.added_proj_bias = added_proj_bias - - # Validation - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None." - "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - def register_pbr_modules(self, module_types: List[str], **kwargs): - """ - Generic PBR module registration to eliminate code repetition. - - Dynamically registers PyTorch modules for different PBR material types - based on the specified module types and PBR settings. - - Args: - module_types: List of module types to register ("qkv", "v_only", "out", "add_kv") - **kwargs: Additional arguments for module configuration - """ - for pbr_token in self.pbr_setting: - if pbr_token == "albedo": - continue - - for module_type in module_types: - if module_type == "qkv": - self.register_module( - f"to_q_{pbr_token}", nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias) - ) - self.register_module( - f"to_k_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) - ) - self.register_module( - f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) - ) - elif module_type == "v_only": - self.register_module( - f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) - ) - elif module_type == "out": - if not self.pre_only: - self.register_module( - f"to_out_{pbr_token}", - nn.ModuleList( - [ - nn.Linear(self.inner_dim, self.out_dim, bias=kwargs.get("out_bias", True)), - nn.Dropout(self.dropout), - ] - ), - ) - else: - self.register_module(f"to_out_{pbr_token}", None) - elif module_type == "add_kv": - if self.added_kv_proj_dim is not None: - self.register_module( - f"add_k_proj_{pbr_token}", - nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias), - ) - self.register_module( - f"add_v_proj_{pbr_token}", - nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias), - ) - else: - self.register_module(f"add_k_proj_{pbr_token}", None) - self.register_module(f"add_v_proj_{pbr_token}", None) - - -# Rotary Position Embedding utilities (specialized for PoseRoPE) -class RotaryEmbedding: - """ - Rotary position embedding utilities for 3D spatial attention. - - Provides functions to compute and apply rotary position embeddings (RoPE) - for 1D, 3D spatial coordinates used in 3D-aware attention mechanisms. - """ - - @staticmethod - def get_1d_rotary_pos_embed(dim: int, pos: torch.Tensor, theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0): - """ - Compute 1D rotary position embeddings. - - Args: - dim: Embedding dimension (must be even) - pos: Position tensor - theta: Base frequency for rotary embeddings - linear_factor: Linear scaling factor - ntk_factor: NTK (Neural Tangent Kernel) scaling factor - - Returns: - Tuple of (cos_embeddings, sin_embeddings) - """ - assert dim % 2 == 0 - theta = theta * ntk_factor - freqs = ( - 1.0 - / (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim)) - / linear_factor - ) - freqs = torch.outer(pos, freqs) - freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() - freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() - return freqs_cos, freqs_sin - - @staticmethod - def get_3d_rotary_pos_embed(position, embed_dim, voxel_resolution, theta: int = 10000): - """ - Compute 3D rotary position embeddings for spatial coordinates. - - Args: - position: 3D position tensor with shape [..., 3] - embed_dim: Embedding dimension - voxel_resolution: Resolution of the voxel grid - theta: Base frequency for rotary embeddings - - Returns: - Tuple of (cos_embeddings, sin_embeddings) for 3D positions - """ - assert position.shape[-1] == 3 - dim_xy = embed_dim // 8 * 3 - dim_z = embed_dim // 8 * 2 - - grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device) - freqs_xy = RotaryEmbedding.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta) - freqs_z = RotaryEmbedding.get_1d_rotary_pos_embed(dim_z, grid, theta=theta) - - xy_cos, xy_sin = freqs_xy - z_cos, z_sin = freqs_z - - embed_flattn = position.view(-1, position.shape[-1]) - x_cos = xy_cos[embed_flattn[:, 0], :] - x_sin = xy_sin[embed_flattn[:, 0], :] - y_cos = xy_cos[embed_flattn[:, 1], :] - y_sin = xy_sin[embed_flattn[:, 1], :] - z_cos = z_cos[embed_flattn[:, 2], :] - z_sin = z_sin[embed_flattn[:, 2], :] - - cos = torch.cat((x_cos, y_cos, z_cos), dim=-1) - sin = torch.cat((x_sin, y_sin, z_sin), dim=-1) - - cos = cos.view(*position.shape[:-1], embed_dim) - sin = sin.view(*position.shape[:-1], embed_dim) - return cos, sin - - @staticmethod - def apply_rotary_emb(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]): - """ - Apply rotary position embeddings to input tensor. - - Args: - x: Input tensor to apply rotary embeddings to - freqs_cis: Tuple of (cos_embeddings, sin_embeddings) or single tensor - - Returns: - Tensor with rotary position embeddings applied - """ - cos, sin = freqs_cis - cos, sin = cos.to(x.device), sin.to(x.device) - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - return out - - -# Core attention processing logic (eliminating major duplication) -class AttnCore: - """ - Core attention processing logic shared across processors. - - This class provides the fundamental attention computation pipeline - that can be reused across different attention processor implementations. - """ - - @staticmethod - def process_attention_base( - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - get_qkv_fn: Callable = None, - apply_rope_fn: Optional[Callable] = None, - **kwargs, - ): - """ - Generic attention processing core shared across different processors. - - This function implements the common attention computation pipeline including: - 1. Hidden state preprocessing - 2. Attention mask preparation - 3. Q/K/V computation via provided function - 4. Tensor reshaping for multi-head attention - 5. Optional normalization and RoPE application - 6. Scaled dot-product attention computation - - Args: - attn: Attention module instance - hidden_states: Input hidden states tensor - encoder_hidden_states: Optional encoder hidden states for cross-attention - attention_mask: Optional attention mask tensor - temb: Optional temporal embedding tensor - get_qkv_fn: Function to compute Q, K, V tensors - apply_rope_fn: Optional function to apply rotary position embeddings - **kwargs: Additional keyword arguments passed to subfunctions - - Returns: - Tuple containing (attention_output, residual, input_ndim, shape_info, - batch_size, num_heads, head_dim) - """ - # Prepare hidden states - hidden_states, residual, input_ndim, shape_info = AttnUtils.prepare_hidden_states(hidden_states, attn, temb) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - # Prepare attention mask - attention_mask = AttnUtils.prepare_attention_mask(attention_mask, attn, sequence_length, batch_size) - - # Get Q, K, V - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - query, key, value = get_qkv_fn(attn, hidden_states, encoder_hidden_states, **kwargs) - - # Reshape for attention - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = AttnUtils.reshape_qkv_for_attention(query, batch_size, attn.heads, head_dim) - key = AttnUtils.reshape_qkv_for_attention(key, batch_size, attn.heads, head_dim) - value = AttnUtils.reshape_qkv_for_attention(value, batch_size, attn.heads, value.shape[-1] // attn.heads) - - # Apply normalization - query, key = AttnUtils.apply_norms(query, key, getattr(attn, "norm_q", None), getattr(attn, "norm_k", None)) - - # Apply RoPE if provided - if apply_rope_fn is not None: - query, key = apply_rope_fn(query, key, head_dim, **kwargs) - - # Compute attention - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - return hidden_states, residual, input_ndim, shape_info, batch_size, attn.heads, head_dim - - -# Specific processor implementations (minimal unique code) -class PoseRoPEAttnProcessor2_0: - """ - Attention processor with Rotary Position Encoding (RoPE) for 3D spatial awareness. - - This processor extends standard attention with 3D rotary position embeddings - to provide spatial awareness for 3D scene understanding tasks. - """ - - def __init__(self): - """Initialize the RoPE attention processor.""" - AttnUtils.check_pytorch_compatibility() - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_indices: Dict = None, - temb: Optional[torch.Tensor] = None, - n_pbrs=1, - *args, - **kwargs, - ) -> torch.Tensor: - """ - Apply RoPE-enhanced attention computation. - - Args: - attn: Attention module instance - hidden_states: Input hidden states tensor - encoder_hidden_states: Optional encoder hidden states for cross-attention - attention_mask: Optional attention mask tensor - position_indices: Dictionary containing 3D position information for RoPE - temb: Optional temporal embedding tensor - n_pbrs: Number of PBR material types - *args: Additional positional arguments - **kwargs: Additional keyword arguments - - Returns: - Attention output tensor with applied rotary position encodings - """ - AttnUtils.handle_deprecation_warning(args, kwargs) - - def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): - return attn.to_q(hidden_states), attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states) - - def apply_rope(query, key, head_dim, **kwargs): - if position_indices is not None: - if head_dim in position_indices: - image_rotary_emb = position_indices[head_dim] - else: - image_rotary_emb = RotaryEmbedding.get_3d_rotary_pos_embed( - rearrange( - position_indices["voxel_indices"].unsqueeze(1).repeat(1, n_pbrs, 1, 1), - "b n_pbrs l c -> (b n_pbrs) l c", - ), - head_dim, - voxel_resolution=position_indices["voxel_resolution"], - ) - position_indices[head_dim] = image_rotary_emb - - query = RotaryEmbedding.apply_rotary_emb(query, image_rotary_emb) - key = RotaryEmbedding.apply_rotary_emb(key, image_rotary_emb) - return query, key - - # Core attention processing - hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( - attn, - hidden_states, - encoder_hidden_states, - attention_mask, - temb, - get_qkv_fn=get_qkv, - apply_rope_fn=apply_rope, - position_indices=position_indices, - n_pbrs=n_pbrs, - ) - - # Finalize output - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim) - hidden_states = hidden_states.to(hidden_states.dtype) - - return AttnUtils.finalize_output(hidden_states, input_ndim, shape_info, attn, residual, attn.to_out) - - -class SelfAttnProcessor2_0(BaseAttnProcessor): - """ - Self-attention processor with PBR (Physically Based Rendering) material support. - - This processor handles multiple PBR material types (e.g., albedo, metallic-roughness) - with separate attention computation paths for each material type. - """ - - def __init__(self, **kwargs): - """ - Initialize self-attention processor with PBR support. - - Args: - **kwargs: Arguments passed to BaseAttnProcessor initialization - """ - super().__init__(**kwargs) - self.register_pbr_modules(["qkv", "out", "add_kv"], **kwargs) - - def process_single( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - token: Literal["albedo", "mr"] = "albedo", - multiple_devices=False, - *args, - **kwargs, - ): - """ - Process attention for a single PBR material type. - - Args: - attn: Attention module instance - hidden_states: Input hidden states tensor - encoder_hidden_states: Optional encoder hidden states for cross-attention - attention_mask: Optional attention mask tensor - temb: Optional temporal embedding tensor - token: PBR material type to process ("albedo", "mr", etc.) - multiple_devices: Whether to use multiple GPU devices - *args: Additional positional arguments - **kwargs: Additional keyword arguments - - Returns: - Processed attention output for the specified PBR material type - """ - target = attn if token == "albedo" else attn.processor - token_suffix = "" if token == "albedo" else "_" + token - - # Device management (if needed) - if multiple_devices: - device = torch.device("cuda:0") if token == "albedo" else torch.device("cuda:1") - for attr in [f"to_q{token_suffix}", f"to_k{token_suffix}", f"to_v{token_suffix}", f"to_out{token_suffix}"]: - getattr(target, attr).to(device) - - def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): - return ( - getattr(target, f"to_q{token_suffix}")(hidden_states), - getattr(target, f"to_k{token_suffix}")(encoder_hidden_states), - getattr(target, f"to_v{token_suffix}")(encoder_hidden_states), - ) - - # Core processing using shared logic - hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( - attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv - ) - - # Finalize - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim) - hidden_states = hidden_states.to(hidden_states.dtype) - - return AttnUtils.finalize_output( - hidden_states, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}") - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - """ - Apply self-attention with PBR material processing. - - Processes multiple PBR material types sequentially, applying attention - computation for each material type separately and combining results. - - Args: - attn: Attention module instance - hidden_states: Input hidden states tensor with PBR dimension - encoder_hidden_states: Optional encoder hidden states for cross-attention - attention_mask: Optional attention mask tensor - temb: Optional temporal embedding tensor - *args: Additional positional arguments - **kwargs: Additional keyword arguments - - Returns: - Combined attention output for all PBR material types - """ - AttnUtils.handle_deprecation_warning(args, kwargs) - - B = hidden_states.size(0) - pbr_hidden_states = torch.split(hidden_states, 1, dim=1) - - # Process each PBR setting - results = [] - for token, pbr_hs in zip(self.pbr_setting, pbr_hidden_states): - processed_hs = rearrange(pbr_hs, "b n_pbrs n l c -> (b n_pbrs n) l c").to("cuda:0") - result = self.process_single(attn, processed_hs, None, attention_mask, temb, token, False) - results.append(result) - - outputs = [rearrange(result, "(b n_pbrs n) l c -> b n_pbrs n l c", b=B, n_pbrs=1) for result in results] - return torch.cat(outputs, dim=1) - - -class RefAttnProcessor2_0(BaseAttnProcessor): - """ - Reference attention processor with shared value computation across PBR materials. - - This processor computes query and key once, but uses separate value projections - for different PBR material types, enabling efficient multi-material processing. - """ - - def __init__(self, **kwargs): - """ - Initialize reference attention processor. - - Args: - **kwargs: Arguments passed to BaseAttnProcessor initialization - """ - super().__init__(**kwargs) - self.pbr_settings = self.pbr_setting # Alias for compatibility - self.register_pbr_modules(["v_only", "out"], **kwargs) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - """ - Apply reference attention with shared Q/K and separate V projections. - - This method computes query and key tensors once and reuses them across - all PBR material types, while using separate value projections for each - material type to maintain material-specific information. - - Args: - attn: Attention module instance - hidden_states: Input hidden states tensor - encoder_hidden_states: Optional encoder hidden states for cross-attention - attention_mask: Optional attention mask tensor - temb: Optional temporal embedding tensor - *args: Additional positional arguments - **kwargs: Additional keyword arguments - - Returns: - Stacked attention output for all PBR material types - """ - AttnUtils.handle_deprecation_warning(args, kwargs) - - def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - - # Concatenate values from all PBR settings - value_list = [attn.to_v(encoder_hidden_states)] - for token in ["_" + token for token in self.pbr_settings if token != "albedo"]: - value_list.append(getattr(attn.processor, f"to_v{token}")(encoder_hidden_states)) - value = torch.cat(value_list, dim=-1) - - return query, key, value - - # Core processing - hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( - attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv - ) - - # Split and process each PBR setting output - hidden_states_list = torch.split(hidden_states, head_dim, dim=-1) - output_hidden_states_list = [] - - for i, hs in enumerate(hidden_states_list): - hs = hs.transpose(1, 2).reshape(batch_size, -1, heads * head_dim).to(hs.dtype) - token_suffix = "_" + self.pbr_settings[i] if self.pbr_settings[i] != "albedo" else "" - target = attn if self.pbr_settings[i] == "albedo" else attn.processor - - hs = AttnUtils.finalize_output( - hs, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}") - ) - output_hidden_states_list.append(hs) - - return torch.stack(output_hidden_states_list, dim=1)