mmedit.models.editors.stable_diffusion.clip_wrapper 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import os
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmedit.utils import try_import
if transformers is not None:
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from transformers.models.clip.feature_extraction_clip import \
CLIPFeatureExtractor # noqa
from transformers.models.clip.modeling_clip import CLIPTextModel
from transformers.models.clip.tokenization_clip import CLIPTokenizer
def cosine_distance(image_embeds, text_embeds):
"""compute the cosine distance of image embeddings and text
embeddings."""
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ['CLIPEncoderLayer']
def __init__(self, config: CLIPConfig):
"""check result image for stable diffsuion to prevent NSFW content
generated.
Args:
config(CLIPConfig): config for transformers clip.
"""
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(
config.vision_config.hidden_size,
config.projection_dim,
bias=False)
self.concept_embeds = nn.Parameter(
torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(
torch.ones(3, config.projection_dim), requires_grad=False)
self.concept_embeds_weights = nn.Parameter(
torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(
torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):
"""return black image if input image has nsfw content.
Args:
clip_input(torch.Tensor):
image feature extracted by clip feature extractor.
images(torch.Tensor):
image generated by stable diffusion.
Returns:
images(torch.Tensor):
black images if input images have nsfw content,
otherwise return input images.
has_nsfw_concepts(list[bool]):
flag list to indicate whether input images have
nsfw content.
"""
pooled_output = self.vision_model(clip_input)[1]
image_embeds = self.visual_projection(pooled_output)
# we always cast to float32 as this does not cause
# significant overhead and is compatible with bfloa16
special_cos_dist = cosine_distance(
image_embeds, self.special_care_embeds).cpu().float().numpy()
cos_dist = cosine_distance(
image_embeds, self.concept_embeds).cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {
'special_scores': {},
'special_care': [],
'concept_scores': {},
'bad_concepts': []
}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of
# filtering benign images
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[
concept_idx].item()
result_img['special_scores'][concept_idx] = round(
concept_cos - concept_threshold + adjustment, 3)
if result_img['special_scores'][concept_idx] > 0:
result_img['special_care'].append({
concept_idx,
result_img['special_scores'][concept_idx]
})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[
concept_idx].item()
result_img['concept_scores'][concept_idx] = round(
concept_cos - concept_threshold + adjustment, 3)
if result_img['concept_scores'][concept_idx] > 0:
result_img['bad_concepts'].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [
len(res['bad_concepts']) > 0 for res in result
]
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
images[idx] = torch.zeros(images[idx].shape) # black image
if any(has_nsfw_concepts):
logger.warning(
'NSFW content was detected in one or more images.'
' A black image will be returned instead.'
' Try again with a different prompt and/or seed.')
return images, has_nsfw_concepts
def load_clip_submodels(init_cfg, submodels, requires_safety_checker):
"""
Args:
init_cfg (dict):
ckpt path of clip models.
submodels (List):
list of stable diffusion submodels.
requires_safety_checker (bool):
whether to load safety checker
Returns:
tokenizer(CLIPTokenizer):
tokenizer with ckpt loaded.
feature_extractor(CLIPFeatureExtractor):
feature_extractor with ckpt loaded.
text_encoder(CLIPTextModel):
text_encoder with ckpt loaded.
safety_checker(StableDiffusionSafetyChecker):
safety_checker with ckpt loaded.
"""
pretrained_model_path = init_cfg.get('pretrained_model_path', None)
tokenizer, feature_extractor, text_encoder, safety_checker = \
None, None, None, None
if pretrained_model_path:
tokenizer = CLIPTokenizer.from_pretrained(
os.path.join(pretrained_model_path, 'tokenizer'))
feature_extractor = CLIPFeatureExtractor.from_pretrained(
os.path.join(pretrained_model_path, 'feature_extractor'))
text_encoder = CLIPTextModel.from_pretrained(
os.path.join(pretrained_model_path, 'text_encoder'))
if requires_safety_checker:
submodels.append('safety_checker')
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
os.path.join(pretrained_model_path, 'safety_checker'))
return tokenizer, feature_extractor, text_encoder, safety_checker
else:
def load_clip_submodels(init_cfg, submodels, requires_safety_checker):
raise ImportError('Please install tranformers via '
'\'pip install transformers\'')