From c2ad6ef1a2c10f6225152f205ddb91fa93a352cb Mon Sep 17 00:00:00 2001 From: lxlxlxlxlxlx Date: Wed, 25 Jun 2025 12:44:20 +0000 Subject: [PATCH] upload model folder to repo --- .gitattributes | 4 +- transformer/config.json | 4 +- transformer/transformer_omnigen2.py | 1215 ++++++++++++++++++++++----- 3 files changed, 1009 insertions(+), 214 deletions(-) diff --git a/.gitattributes b/.gitattributes index 23a9f37..43598ff 100644 --- a/.gitattributes +++ b/.gitattributes @@ -50,4 +50,6 @@ processor/tokenizer.json filter=lfs diff=lfs merge=lfs -text tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text examples_edit.png filter=lfs diff=lfs merge=lfs -text -examples_subject.png filter=lfs diff=lfs merge=lfs -text \ No newline at end of file +examples_subject.png filter=lfs diff=lfs merge=lfs -text + +processor/tokenizer.json filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/transformer/config.json b/transformer/config.json index af61c77..d56d1bd 100644 --- a/transformer/config.json +++ b/transformer/config.json @@ -23,7 +23,5 @@ "out_channels": null, "patch_size": 2, "text_feat_dim": 2048, - "timestep_scale": 1000.0, - "use_fused_rms_norm": true, - "use_fused_swiglu": true + "timestep_scale": 1000.0 } diff --git a/transformer/transformer_omnigen2.py b/transformer/transformer_omnigen2.py index 8fd7926..a5ad125 100644 --- a/transformer/transformer_omnigen2.py +++ b/transformer/transformer_omnigen2.py @@ -20,30 +20,994 @@ from diffusers.models.embeddings import get_1d_rotary_pos_embed from diffusers.models.activations import get_activation from diffusers.models.embeddings import Timesteps -from flash_attn import flash_attn_varlen_func -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +import importlib.util +import sys -# try: -# from .triton.layer_norm import RMSNorm as FusedRMSNorm -# FUSEDRMSNORM_AVALIBLE = True -# except ImportError: -# FUSEDRMSNORM_AVALIBLE = False -# warnings.warn("Cannot import FusedRMSNorm, falling back to vanilla implementation") +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata -FUSEDRMSNORM_AVALIBLE = False +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" -try: - from flash_attn.ops.activations import swiglu as fused_swiglu - FUSEDSWIGLU_AVALIBLE = True -except ImportError: + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + +_triton_available, _triton_version = _is_package_available("triton") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") + +def is_triton_available(): + return _triton_available + +def is_flash_attn_available(): + return _flash_attn_available + +if is_triton_available(): + # from ...ops.triton.layer_norm import RMSNorm + import triton + import triton.language as tl + + + from typing import Callable + + + def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + + if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] + else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + + custom_fwd = custom_amp_decorator(custom_fwd, deprecated) + custom_bwd = custom_amp_decorator(custom_bwd, deprecated) + + + def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs=[] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block=1024 + # Default to warp size 32 if not defined by device + warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + warp_count=1 + while warp_count*warp_size <= max_threads_per_block: + configs.append(triton.Config({}, num_warps=warp_count)) + warp_count*=2 + return configs - FUSEDSWIGLU_AVALIBLE = False - warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + @triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], + ) + # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) + # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) + @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) + @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) + @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) + @triton.jit + def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + + def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None + ): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + else: + assert out.shape == x.shape + assert out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + if residual_out is None: + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask = None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + out, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + out, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) + + @triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], + ) + # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) + # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) + # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) + @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) + @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) + @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) + @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) + @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) + @triton.jit + def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + zero_centered_weight, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + ): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + + def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, + has_residual=False, + has_x1=False, + zero_centered_weight=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, + ): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = ( + torch.empty_like(x) + if has_residual + and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) + + class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None + ): + x_shape_og = x.shape + # Check for zero sequence length + if x.numel() == 0: + ctx.zero_seq_length = True + # Only save minimal required tensors for backward + # ctx.save_for_backward(weight, bias, weight1, bias1) + ctx.x_shape_og = x_shape_og + ctx.weight_shape = weight.shape + ctx.weight_dtype = weight.dtype + ctx.weight_device = weight.device + + ctx.has_bias = bias is not None + ctx.bias_shape = bias.shape if bias is not None else None + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.bias_device = bias.device if bias is not None else None + + ctx.has_weight1 = weight1 is not None + ctx.weight1_shape = weight1.shape if weight1 is not None else None + ctx.weight1_dtype = weight1.dtype if weight1 is not None else None + ctx.weight1_device = weight1.device if weight1 is not None else None + + ctx.has_bias1 = bias1 is not None + ctx.bias1_shape = bias1.shape if bias1 is not None else None + ctx.bias1_dtype = bias1.dtype if bias1 is not None else None + ctx.bias1_device = bias1.device if bias1 is not None else None + + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.dropout_p = dropout_p + + # Handle output tensors with correct dtype + y = x # Preserve input tensor properties + y1 = torch.empty_like(x) if x1 is not None else None + + # Only create residual_out if prenorm is True + residual_out = torch.empty(x.shape, + dtype=torch.float32 if residual_in_fp32 else x.dtype, + device=x.device) if prenorm else None + + # Handle dropout masks + dropout_mask = None + dropout_mask1 = None + if return_dropout_mask: + dropout_mask = torch.empty_like(x, dtype=torch.uint8) + if x1 is not None: + dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) + + # Return based on configuration + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ((y, dropout_mask, dropout_mask1) if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1)) + else: + return ((y, y1, dropout_mask, dropout_mask1) if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1)) + + ctx.zero_seq_length = False + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd + ) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + if ctx.zero_seq_length: + return ( + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), + torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device), + torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, + torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None, + torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.zero_centered_weight, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out=None, + residual_out=None + ): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out, + residual_out + ) + + class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, + device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None + self.zero_centered_weight = zero_centered_weight + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + if not self.zero_centered_weight: + torch.nn.init.ones_(self.weight) + else: + torch.nn.init.zeros_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + zero_centered_weight=self.zero_centered_weight, + ) +else: + from torch.nn import RMSNorm + warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance") + +def swiglu(x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y logger = logging.get_logger(__name__) -def swiglu(x, y): - return F.silu(x.float(), inplace=False).to(x.dtype) * y class TimestepEmbedding(nn.Module): def __init__( @@ -285,7 +1249,6 @@ class LuminaRMSNormZero(nn.Module): embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool, - use_fused_rms_norm: bool = False, ): super().__init__() self.silu = nn.SiLU() @@ -294,14 +1257,7 @@ class LuminaRMSNormZero(nn.Module): 4 * embedding_dim, bias=True, ) - if use_fused_rms_norm: - if FUSEDRMSNORM_AVALIBLE: - self.norm = FusedRMSNorm(embedding_dim, eps=norm_eps) - else: - warnings.warn("Cannot import FusedRMSNorm, falling back to vanilla implementation") - self.norm = nn.RMSNorm(embedding_dim, eps=norm_eps) - else: - self.norm = nn.RMSNorm(embedding_dim, eps=norm_eps) + self.norm = RMSNorm(embedding_dim, eps=norm_eps) def forward( self, @@ -311,12 +1267,6 @@ class LuminaRMSNormZero(nn.Module): emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) - # x_norm = self.norm(x) - # print(f"{x.shape=} {x.dtype=} {x_norm.shape=} {x_norm.dtype=}") - # print(f"{scale_msa.shape=} {scale_msa.dtype=}") - # print(f"{scale_msa[:, None].shape=} {scale_msa[:, None].dtype=}") - # x = x_norm * (1 + scale_msa[:, None]) - return x, gate_msa, scale_mlp, gate_mlp @@ -335,7 +1285,6 @@ class LuminaLayerNormContinuous(nn.Module): bias=True, norm_type="layer_norm", out_dim: Optional[int] = None, - use_fused_rms_norm: bool = False ): super().__init__() @@ -346,14 +1295,7 @@ class LuminaLayerNormContinuous(nn.Module): if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) elif norm_type == "rms_norm": - if use_fused_rms_norm: - if FUSEDRMSNORM_AVALIBLE: - self.norm = FusedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) - else: - warnings.warn("Cannot import FusedRMSNorm, falling back to vanilla implementation") - self.norm = nn.RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) - else: - self.norm = nn.RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") @@ -398,16 +1340,10 @@ class LuminaFeedForward(nn.Module): inner_dim: int, multiple_of: Optional[int] = 256, ffn_dim_multiplier: Optional[float] = None, - use_fused_swiglu: bool = False ): super().__init__() - self.use_fused_swiglu = use_fused_swiglu - if use_fused_swiglu: - assert FUSEDSWIGLU_AVALIBLE - self.swiglu = fused_swiglu - else: - self.swiglu = swiglu + self.swiglu = swiglu # custom hidden_size factor multiplier if ffn_dim_multiplier is not None: @@ -443,7 +1379,6 @@ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, - use_fused_rms_norm: bool = False ) -> None: super().__init__() @@ -455,15 +1390,6 @@ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) ) - if use_fused_rms_norm: - if FUSEDRMSNORM_AVALIBLE: - RMSNorm = FusedRMSNorm - else: - warnings.warn("Cannot import FusedRMSNorm, falling back to vanilla implementation") - RMSNorm = nn.RMSNorm - else: - RMSNorm = nn.RMSNorm - self.caption_embedder = nn.Sequential( RMSNorm(text_feat_dim, eps=norm_eps), nn.Linear(text_feat_dim, hidden_size, bias=True), @@ -484,9 +1410,9 @@ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): return time_embed, caption_embed -class OmniGen2AttnProcessorFlash2Varlen: +class OmniGen2AttnProcessor: """ - Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + Processor for implementing scaled dot-product attention. This processor is optimized for PyTorch 2.0 and implements: - Flash attention with variable length sequences @@ -509,85 +1435,6 @@ class OmniGen2AttnProcessorFlash2Varlen: "Please upgrade PyTorch to version 2.0 or later." ) - def _upad_input( - self, - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - num_heads: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: - """ - Unpad the input tensors for flash attention. - - Args: - query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) - key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) - value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) - attention_mask: Attention mask tensor of shape (batch_size, seq_len) - query_length: Length of the query sequence - num_heads: Number of attention heads - - Returns: - Tuple containing: - - Unpadded query tensor - - Unpadded key tensor - - Unpadded value tensor - - Query indices - - Tuple of cumulative sequence lengths for query and key - - Tuple of maximum sequence lengths for query and key - """ - def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: - """Helper function to get unpadding data from attention mask.""" - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return indices, cu_seqlens, max_seqlen_in_batch - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - # Unpad key and value layers - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - - # Handle different query length cases - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), - indices_k, - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - def __call__( self, attn: Attention, @@ -650,41 +1497,23 @@ class OmniGen2AttnProcessorFlash2Varlen: else: softmax_scale = attn.scale - # Unpad input for flash attention - ( - query_states, - key_states, - value_states, - indices_q, - cu_seq_lens, - max_seq_lens, - ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + if attention_mask is not None: + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) - # Handle different number of heads - if kv_heads < attn.heads: - key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) - value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) - - # Apply flash attention - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=0.0, - causal=False, - softmax_scale=softmax_scale, + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale ) - - # Pad output and apply final transformations - hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) - hidden_states = hidden_states.flatten(-2) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.type_as(query) # Apply output projection @@ -724,8 +1553,6 @@ class OmniGen2TransformerBlock(nn.Module): ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, - use_fused_rms_norm: bool = True, - use_fused_swiglu: bool = True, ) -> None: """Initialize the transformer block.""" super().__init__() @@ -743,7 +1570,7 @@ class OmniGen2TransformerBlock(nn.Module): eps=1e-5, bias=False, out_bias=False, - processor=OmniGen2AttnProcessorFlash2Varlen(), + processor=OmniGen2AttnProcessor(), ) # Initialize feed-forward network @@ -752,7 +1579,6 @@ class OmniGen2TransformerBlock(nn.Module): inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, - use_fused_swiglu=use_fused_swiglu, ) # Initialize normalization layers @@ -761,32 +1587,13 @@ class OmniGen2TransformerBlock(nn.Module): embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True, - use_fused_rms_norm=use_fused_rms_norm, ) else: - if use_fused_rms_norm: - if FUSEDRMSNORM_AVALIBLE: - self.norm1 = FusedRMSNorm(dim, eps=norm_eps) - else: - warnings.warn("Cannot import FusedRMSNorm, falling back to vanilla implementation") - self.norm1 = nn.RMSNorm(dim, eps=norm_eps) - else: - self.norm1 = nn.RMSNorm(dim, eps=norm_eps) + self.norm1 = RMSNorm(dim, eps=norm_eps) - if use_fused_rms_norm: - if FUSEDRMSNORM_AVALIBLE: - self.ffn_norm1 = FusedRMSNorm(dim, eps=norm_eps) - self.norm2 = FusedRMSNorm(dim, eps=norm_eps) - self.ffn_norm2 = FusedRMSNorm(dim, eps=norm_eps) - else: - warnings.warn("Cannot import FusedRMSNorm, falling back to vanilla implementation") - self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps) - self.norm2 = nn.RMSNorm(dim, eps=norm_eps) - self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps) - else: - self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps) - self.norm2 = nn.RMSNorm(dim, eps=norm_eps) - self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) self.initialize_weights() @@ -909,8 +1716,6 @@ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From axes_lens: Tuple[int, int, int] = (300, 512, 512), text_feat_dim: int = 1024, timestep_scale: float = 1.0, - use_fused_rms_norm: bool = True, - use_fused_swiglu: bool = True, ) -> None: """Initialize the OmniGen2 transformer model.""" super().__init__() @@ -947,7 +1752,6 @@ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From text_feat_dim=text_feat_dim, norm_eps=norm_eps, timestep_scale=timestep_scale, - use_fused_rms_norm=use_fused_rms_norm, ) # Initialize transformer blocks @@ -960,8 +1764,6 @@ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From ffn_dim_multiplier, norm_eps, modulation=True, - use_fused_rms_norm=use_fused_rms_norm, - use_fused_swiglu=use_fused_swiglu, ) for _ in range(num_refiner_layers) ]) @@ -975,8 +1777,6 @@ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From ffn_dim_multiplier, norm_eps, modulation=True, - use_fused_rms_norm=use_fused_rms_norm, - use_fused_swiglu=use_fused_swiglu, ) for _ in range(num_refiner_layers) ]) @@ -991,8 +1791,6 @@ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From ffn_dim_multiplier, norm_eps, modulation=False, - use_fused_rms_norm=use_fused_rms_norm, - use_fused_swiglu=use_fused_swiglu ) for _ in range(num_refiner_layers) ] @@ -1009,8 +1807,6 @@ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From ffn_dim_multiplier, norm_eps, modulation=True, - use_fused_rms_norm=use_fused_rms_norm, - use_fused_swiglu=use_fused_swiglu ) for _ in range(num_layers) ] @@ -1024,7 +1820,6 @@ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From eps=1e-6, bias=True, out_dim=patch_size * patch_size * self.out_channels, - use_fused_rms_norm=use_fused_rms_norm, ) # Add learnable embeddings to distinguish different images