add_confidence_score (#56)

- add confidence score parsing (2ce7cd7837dfc8d93bf1d77dae95669ef1bcf0b3)
This commit is contained in:
ai-modelscope
2024-11-15 20:09:08 +08:00
parent 18c21287ed
commit 41c72b788b

View File

@ -20,6 +20,7 @@ import re
import logging
from typing import List, Optional, Union
import numpy as np
import math
import torch
@ -32,6 +33,7 @@ from transformers.tokenization_utils_base import (
TextInput,
TruncationStrategy,
)
from transformers import BartTokenizer, BartTokenizerFast
from transformers.utils import TensorType
@ -304,7 +306,7 @@ class Florence2Processor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def post_process_generation(self, text, task, image_size):
def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None):
"""
Post-process the output of the model to each of the task outputs.
@ -317,6 +319,8 @@ class Florence2Processor(ProcessorMixin):
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
task_answer = self.post_processor(
text=text,
sequence=sequence,
transition_beam_score=transition_beam_score,
image_size=image_size,
parse_tasks=task_answer_post_processing_type,
)[task_answer_post_processing_type]
@ -330,6 +334,9 @@ class Florence2Processor(ProcessorMixin):
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
if len(od_instances) and 'score' in od_instances[0]:
scores_od = [_od_instance['score'] for _od_instance in od_instances]
final_answer['scores'] = scores_od
elif task_answer_post_processing_type in ['ocr']:
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
labels = [str(_od_instance['text']) for _od_instance in task_answer]
@ -591,7 +598,8 @@ class Florence2PostProcesser(object):
'PARSE_TASKS': [
{
'TASK_NAME': 'od',
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>',
'SCORE_MODE': 'avg_loc_scores'
},
{
'TASK_NAME': 'ocr',
@ -607,6 +615,7 @@ class Florence2PostProcesser(object):
},
{
'TASK_NAME': 'description_with_bboxes',
'SCORE_MODE': 'avg_loc_scores'
},
{
'TASK_NAME': 'description_with_polygons',
@ -647,10 +656,6 @@ class Florence2PostProcesser(object):
filtered_tokens = tokenizer.convert_ids_to_tokens(
token_ids, skip_special_tokens=False)
assert len(filtered_tokens) == len(token_ids)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
for token in filtered_tokens:
if token in self.all_special_tokens:
@ -658,10 +663,6 @@ class Florence2PostProcesser(object):
else:
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
sub_text = tokenizer.convert_tokens_to_string([token])
elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
# Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
# Note: Do not strip sub_text as it may have functional whitespace
sub_text = token.replace('', ' ')
else:
raise ValueError(f'type {type(tokenizer)} not supported')
sub_texts.append(sub_text)
@ -672,14 +673,6 @@ class Florence2PostProcesser(object):
span = (len(text), len(text) + len(sub_text)) # [start index, end index).
text += sub_text
spans.append(span)
# Text format:
# 1. T5Tokenizer/T5TokenizerFast:
# "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
# Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
# 2. BartTokenizer (need to double check):
# "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
# Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
return text, spans
def parse_od_from_text_and_spans(
@ -714,7 +707,7 @@ class Florence2PostProcesser(object):
return instances
def parse_ocr_from_text_and_spans(self,
text,
text,
pattern,
image_size,
area_threshold=-1.0,
@ -818,9 +811,26 @@ class Florence2PostProcesser(object):
return instances
def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
# temporary parse solution, split by '.'
# ignore <s> </s> and <pad>
def parse_description_with_bboxes_from_text_and_spans(
self,
text,
spans=None,
scores=None,
score_mode=None,
pattern=None,
image_size=None,
allow_empty_phrase=False
):
def find_matched_token_indices(cur_span, token_spans):
inds = []
for i, token_span in enumerate(token_spans):
if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
inds.append(i)
return inds
cur_span = 0
if text.startswith('<s>'):
cur_span += 3
text = text.replace('<s>', '')
text = text.replace('</s>', '')
@ -842,13 +852,16 @@ class Florence2PostProcesser(object):
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
if phrase_text_strip == '' and not allow_empty_phrase:
cur_span += len(pharse_text)
continue
# parse phrase, get string
phrase = re.search(pattern, phrase_text_strip)
if phrase is None:
cur_span += len(pharse_text)
continue
phrase_span = phrase.span()
phrase = phrase.group()
# remove leading and trailing spaces
phrase = phrase.strip()
@ -856,6 +869,7 @@ class Florence2PostProcesser(object):
# parse bboxes by box_pattern
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
if len(bboxes_parsed) == 0:
cur_span += len(pharse_text)
continue
# a list of list
@ -866,14 +880,42 @@ class Florence2PostProcesser(object):
size=image_size
).tolist()
if score_mode == 'avg_loc_scores':
if spans is None or scores is None:
all_scores = None
else:
bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
all_scores = []
for _spans in bbox_end_spans:
token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
loc_scores = [scores[token_i] for token_i in token_inds]
score = sum(loc_scores) / len(loc_scores)
all_scores.append(score)
elif score_mode == 'avg_cat_name_scores':
if spans is None or scores is None:
all_scores = None
else:
cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
score = sum(cat_name_scores) / len(cat_name_scores)
all_scores = [score] * len(bboxes)
elif score_mode is None:
all_scores = None
else:
raise ValueError('Unknown score mode: {}'.format(score_mode))
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
for _bboxes in bboxes:
for _idx, _bboxes in enumerate(bboxes):
# Prepare instance.
instance = {}
instance['bbox'] = _bboxes
# exclude non-ascii characters
instance['cat_name'] = phrase
if all_scores is not None:
instance['score'] = math.exp(all_scores[_idx])
instances.append(instance)
cur_span += len(pharse_text)
return instances
@ -991,6 +1033,8 @@ class Florence2PostProcesser(object):
def __call__(
self,
text=None,
sequence=None,
transition_beam_score=None,
image_size=None,
parse_tasks=None,
):
@ -1008,7 +1052,18 @@ class Florence2PostProcesser(object):
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
# sequence or text should be provided
assert text is not None, 'text should be provided'
assert sequence is not None or text is not None, 'sequence or text should be provided'
assert sequence is None or text is None, 'only one of sequence and text should be provided'
if sequence is not None:
sequence = sequence.tolist()[1:]
text, spans = self.decode_with_spans(self.tokenizer, sequence)
if transition_beam_score is not None:
transition_beam_score = transition_beam_score.tolist()
assert len(sequence) == len(transition_beam_score)
else:
spans = None
transition_beam_score = None
parsed_dict = {
'text': text
@ -1019,6 +1074,7 @@ class Florence2PostProcesser(object):
continue
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
if task == 'ocr':
instances = self.parse_ocr_from_text_and_spans(
@ -1040,6 +1096,9 @@ class Florence2PostProcesser(object):
elif task == 'description_with_bboxes':
instances = self.parse_description_with_bboxes_from_text_and_spans(
text,
spans=spans,
scores=transition_beam_score,
score_mode=score_mode,
pattern=pattern,
image_size=image_size,
)