Shortcuts

mmedit.models.editors.ddpm.unet_blocks

Module Contents

Classes

UNetMidBlock2DCrossAttn

unet mid block built by cross attention.

CrossAttnDownBlock2D

Down block built by cross attention.

DownBlock2D

Down block built by resnet.

CrossAttnUpBlock2D

Up block built by cross attention.

UpBlock2D

Up block built by resnet.

Functions

get_down_block(down_block_type, num_layers, ...[, ...])

get unet down path block.

get_up_block(up_block_type, num_layers, in_channels, ...)

get unet up path block.

mmedit.models.editors.ddpm.unet_blocks.get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_act_fn, attn_num_head_channels, resnet_eps=1e-05, resnet_groups=32, cross_attention_dim=1280, downsample_padding=1, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False)[源代码]

get unet down path block.

mmedit.models.editors.ddpm.unet_blocks.get_up_block(up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_act_fn, attn_num_head_channels, resnet_eps=1e-05, resnet_groups=32, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False)[源代码]

get unet up path block.

class mmedit.models.editors.ddpm.unet_blocks.UNetMidBlock2DCrossAttn(in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, attention_type='default', output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False)[源代码]

Bases: torch.nn.Module

unet mid block built by cross attention.

set_attention_slice(slice_size)[源代码]

set attention slice.

forward(hidden_states, temb=None, encoder_hidden_states=None)[源代码]

forward with hidden states.

class mmedit.models.editors.ddpm.unet_blocks.CrossAttnDownBlock2D(in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, attention_type='default', output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False)[源代码]

Bases: torch.nn.Module

Down block built by cross attention.

set_attention_slice(slice_size)[源代码]

set attention slice.

forward(hidden_states, temb=None, encoder_hidden_states=None)[源代码]

forward with hidden states.

class mmedit.models.editors.ddpm.unet_blocks.DownBlock2D(in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1)[源代码]

Bases: torch.nn.Module

Down block built by resnet.

forward(hidden_states, temb=None)[源代码]

forward with hidden states.

class mmedit.models.editors.ddpm.unet_blocks.CrossAttnUpBlock2D(in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, attention_type='default', output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False)[源代码]

Bases: torch.nn.Module

Up block built by cross attention.

set_attention_slice(slice_size)[源代码]

set attention slice.

forward(hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None)[源代码]

forward with hidden states and res hidden states.

class mmedit.models.editors.ddpm.unet_blocks.UpBlock2D(in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True)[源代码]

Bases: torch.nn.Module

Up block built by resnet.

forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None)[源代码]

forward with hidden states and res hidden states.