mirror of
https://www.modelscope.cn/AI-ModelScope/Florence-2-large.git
synced 2026-04-02 21:52:53 +08:00
add_confidence_score (#56)
- add confidence score parsing (2ce7cd7837dfc8d93bf1d77dae95669ef1bcf0b3)
This commit is contained in:
@ -20,6 +20,7 @@ import re
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
@ -32,6 +33,7 @@ from transformers.tokenization_utils_base import (
|
||||
TextInput,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from transformers import BartTokenizer, BartTokenizerFast
|
||||
from transformers.utils import TensorType
|
||||
|
||||
|
||||
@ -304,7 +306,7 @@ class Florence2Processor(ProcessorMixin):
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
def post_process_generation(self, text, task, image_size):
|
||||
def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None):
|
||||
"""
|
||||
Post-process the output of the model to each of the task outputs.
|
||||
|
||||
@ -317,6 +319,8 @@ class Florence2Processor(ProcessorMixin):
|
||||
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
||||
task_answer = self.post_processor(
|
||||
text=text,
|
||||
sequence=sequence,
|
||||
transition_beam_score=transition_beam_score,
|
||||
image_size=image_size,
|
||||
parse_tasks=task_answer_post_processing_type,
|
||||
)[task_answer_post_processing_type]
|
||||
@ -330,6 +334,9 @@ class Florence2Processor(ProcessorMixin):
|
||||
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
||||
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
||||
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
||||
if len(od_instances) and 'score' in od_instances[0]:
|
||||
scores_od = [_od_instance['score'] for _od_instance in od_instances]
|
||||
final_answer['scores'] = scores_od
|
||||
elif task_answer_post_processing_type in ['ocr']:
|
||||
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
||||
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
||||
@ -591,7 +598,8 @@ class Florence2PostProcesser(object):
|
||||
'PARSE_TASKS': [
|
||||
{
|
||||
'TASK_NAME': 'od',
|
||||
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
|
||||
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>',
|
||||
'SCORE_MODE': 'avg_loc_scores'
|
||||
},
|
||||
{
|
||||
'TASK_NAME': 'ocr',
|
||||
@ -607,6 +615,7 @@ class Florence2PostProcesser(object):
|
||||
},
|
||||
{
|
||||
'TASK_NAME': 'description_with_bboxes',
|
||||
'SCORE_MODE': 'avg_loc_scores'
|
||||
},
|
||||
{
|
||||
'TASK_NAME': 'description_with_polygons',
|
||||
@ -647,10 +656,6 @@ class Florence2PostProcesser(object):
|
||||
filtered_tokens = tokenizer.convert_ids_to_tokens(
|
||||
token_ids, skip_special_tokens=False)
|
||||
assert len(filtered_tokens) == len(token_ids)
|
||||
|
||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||
# we need to build string separately for added tokens and byte-level tokens
|
||||
# cf. https://github.com/huggingface/transformers/issues/1133
|
||||
sub_texts = []
|
||||
for token in filtered_tokens:
|
||||
if token in self.all_special_tokens:
|
||||
@ -658,10 +663,6 @@ class Florence2PostProcesser(object):
|
||||
else:
|
||||
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
||||
sub_text = tokenizer.convert_tokens_to_string([token])
|
||||
elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
|
||||
# Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
|
||||
# Note: Do not strip sub_text as it may have functional whitespace
|
||||
sub_text = token.replace('▁', ' ')
|
||||
else:
|
||||
raise ValueError(f'type {type(tokenizer)} not supported')
|
||||
sub_texts.append(sub_text)
|
||||
@ -672,14 +673,6 @@ class Florence2PostProcesser(object):
|
||||
span = (len(text), len(text) + len(sub_text)) # [start index, end index).
|
||||
text += sub_text
|
||||
spans.append(span)
|
||||
|
||||
# Text format:
|
||||
# 1. T5Tokenizer/T5TokenizerFast:
|
||||
# "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
|
||||
# Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
||||
# 2. BartTokenizer (need to double check):
|
||||
# "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
|
||||
# Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
||||
return text, spans
|
||||
|
||||
def parse_od_from_text_and_spans(
|
||||
@ -818,9 +811,26 @@ class Florence2PostProcesser(object):
|
||||
|
||||
return instances
|
||||
|
||||
def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
|
||||
# temporary parse solution, split by '.'
|
||||
# ignore <s> </s> and <pad>
|
||||
def parse_description_with_bboxes_from_text_and_spans(
|
||||
self,
|
||||
text,
|
||||
spans=None,
|
||||
scores=None,
|
||||
score_mode=None,
|
||||
pattern=None,
|
||||
image_size=None,
|
||||
allow_empty_phrase=False
|
||||
):
|
||||
def find_matched_token_indices(cur_span, token_spans):
|
||||
inds = []
|
||||
for i, token_span in enumerate(token_spans):
|
||||
if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
|
||||
inds.append(i)
|
||||
return inds
|
||||
|
||||
cur_span = 0
|
||||
if text.startswith('<s>'):
|
||||
cur_span += 3
|
||||
|
||||
text = text.replace('<s>', '')
|
||||
text = text.replace('</s>', '')
|
||||
@ -842,13 +852,16 @@ class Florence2PostProcesser(object):
|
||||
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
||||
|
||||
if phrase_text_strip == '' and not allow_empty_phrase:
|
||||
cur_span += len(pharse_text)
|
||||
continue
|
||||
|
||||
# parse phrase, get string
|
||||
phrase = re.search(pattern, phrase_text_strip)
|
||||
if phrase is None:
|
||||
cur_span += len(pharse_text)
|
||||
continue
|
||||
|
||||
phrase_span = phrase.span()
|
||||
phrase = phrase.group()
|
||||
# remove leading and trailing spaces
|
||||
phrase = phrase.strip()
|
||||
@ -856,6 +869,7 @@ class Florence2PostProcesser(object):
|
||||
# parse bboxes by box_pattern
|
||||
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
||||
if len(bboxes_parsed) == 0:
|
||||
cur_span += len(pharse_text)
|
||||
continue
|
||||
|
||||
# a list of list
|
||||
@ -866,15 +880,43 @@ class Florence2PostProcesser(object):
|
||||
size=image_size
|
||||
).tolist()
|
||||
|
||||
if score_mode == 'avg_loc_scores':
|
||||
if spans is None or scores is None:
|
||||
all_scores = None
|
||||
else:
|
||||
bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
|
||||
all_scores = []
|
||||
for _spans in bbox_end_spans:
|
||||
token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
|
||||
loc_scores = [scores[token_i] for token_i in token_inds]
|
||||
score = sum(loc_scores) / len(loc_scores)
|
||||
all_scores.append(score)
|
||||
elif score_mode == 'avg_cat_name_scores':
|
||||
if spans is None or scores is None:
|
||||
all_scores = None
|
||||
else:
|
||||
cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
|
||||
cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
|
||||
score = sum(cat_name_scores) / len(cat_name_scores)
|
||||
all_scores = [score] * len(bboxes)
|
||||
elif score_mode is None:
|
||||
all_scores = None
|
||||
else:
|
||||
raise ValueError('Unknown score mode: {}'.format(score_mode))
|
||||
|
||||
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
||||
for _bboxes in bboxes:
|
||||
for _idx, _bboxes in enumerate(bboxes):
|
||||
# Prepare instance.
|
||||
instance = {}
|
||||
instance['bbox'] = _bboxes
|
||||
# exclude non-ascii characters
|
||||
instance['cat_name'] = phrase
|
||||
if all_scores is not None:
|
||||
instance['score'] = math.exp(all_scores[_idx])
|
||||
instances.append(instance)
|
||||
|
||||
cur_span += len(pharse_text)
|
||||
|
||||
return instances
|
||||
|
||||
def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
|
||||
@ -991,6 +1033,8 @@ class Florence2PostProcesser(object):
|
||||
def __call__(
|
||||
self,
|
||||
text=None,
|
||||
sequence=None,
|
||||
transition_beam_score=None,
|
||||
image_size=None,
|
||||
parse_tasks=None,
|
||||
):
|
||||
@ -1008,7 +1052,18 @@ class Florence2PostProcesser(object):
|
||||
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
||||
|
||||
# sequence or text should be provided
|
||||
assert text is not None, 'text should be provided'
|
||||
assert sequence is not None or text is not None, 'sequence or text should be provided'
|
||||
assert sequence is None or text is None, 'only one of sequence and text should be provided'
|
||||
|
||||
if sequence is not None:
|
||||
sequence = sequence.tolist()[1:]
|
||||
text, spans = self.decode_with_spans(self.tokenizer, sequence)
|
||||
if transition_beam_score is not None:
|
||||
transition_beam_score = transition_beam_score.tolist()
|
||||
assert len(sequence) == len(transition_beam_score)
|
||||
else:
|
||||
spans = None
|
||||
transition_beam_score = None
|
||||
|
||||
parsed_dict = {
|
||||
'text': text
|
||||
@ -1019,6 +1074,7 @@ class Florence2PostProcesser(object):
|
||||
continue
|
||||
|
||||
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
||||
score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
|
||||
|
||||
if task == 'ocr':
|
||||
instances = self.parse_ocr_from_text_and_spans(
|
||||
@ -1040,6 +1096,9 @@ class Florence2PostProcesser(object):
|
||||
elif task == 'description_with_bboxes':
|
||||
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
||||
text,
|
||||
spans=spans,
|
||||
scores=transition_beam_score,
|
||||
score_mode=score_mode,
|
||||
pattern=pattern,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user