diff --git a/processing_florence2.py b/processing_florence2.py index 42d24ca..dcb7451 100644 --- a/processing_florence2.py +++ b/processing_florence2.py @@ -20,6 +20,7 @@ import re import logging from typing import List, Optional, Union import numpy as np +import math import torch @@ -32,6 +33,7 @@ from transformers.tokenization_utils_base import ( TextInput, TruncationStrategy, ) +from transformers import BartTokenizer, BartTokenizerFast from transformers.utils import TensorType @@ -304,7 +306,7 @@ class Florence2Processor(ProcessorMixin): image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) - def post_process_generation(self, text, task, image_size): + def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None): """ Post-process the output of the model to each of the task outputs. @@ -317,6 +319,8 @@ class Florence2Processor(ProcessorMixin): task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text') task_answer = self.post_processor( text=text, + sequence=sequence, + transition_beam_score=transition_beam_score, image_size=image_size, parse_tasks=task_answer_post_processing_type, )[task_answer_post_processing_type] @@ -330,6 +334,9 @@ class Florence2Processor(ProcessorMixin): bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances] labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances] final_answer = {'bboxes': bboxes_od, 'labels': labels_od} + if len(od_instances) and 'score' in od_instances[0]: + scores_od = [_od_instance['score'] for _od_instance in od_instances] + final_answer['scores'] = scores_od elif task_answer_post_processing_type in ['ocr']: bboxes = [_od_instance['quad_box'] for _od_instance in task_answer] labels = [str(_od_instance['text']) for _od_instance in task_answer] @@ -591,7 +598,8 @@ class Florence2PostProcesser(object): 'PARSE_TASKS': [ { 'TASK_NAME': 'od', - 'PATTERN': r'([a-zA-Z0-9 ]+)' + 'PATTERN': r'([a-zA-Z0-9 ]+)', + 'SCORE_MODE': 'avg_loc_scores' }, { 'TASK_NAME': 'ocr', @@ -607,6 +615,7 @@ class Florence2PostProcesser(object): }, { 'TASK_NAME': 'description_with_bboxes', + 'SCORE_MODE': 'avg_loc_scores' }, { 'TASK_NAME': 'description_with_polygons', @@ -647,10 +656,6 @@ class Florence2PostProcesser(object): filtered_tokens = tokenizer.convert_ids_to_tokens( token_ids, skip_special_tokens=False) assert len(filtered_tokens) == len(token_ids) - - # To avoid mixing byte-level and unicode for byte-level BPT - # we need to build string separately for added tokens and byte-level tokens - # cf. https://github.com/huggingface/transformers/issues/1133 sub_texts = [] for token in filtered_tokens: if token in self.all_special_tokens: @@ -658,10 +663,6 @@ class Florence2PostProcesser(object): else: if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)): sub_text = tokenizer.convert_tokens_to_string([token]) - elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)): - # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol - # Note: Do not strip sub_text as it may have functional whitespace - sub_text = token.replace('▁', ' ') else: raise ValueError(f'type {type(tokenizer)} not supported') sub_texts.append(sub_text) @@ -672,14 +673,6 @@ class Florence2PostProcesser(object): span = (len(text), len(text) + len(sub_text)) # [start index, end index). text += sub_text spans.append(span) - - # Text format: - # 1. T5Tokenizer/T5TokenizerFast: - # " transplanting dog cat" - # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False) - # 2. BartTokenizer (need to double check): - # "transplanting dogcat" - # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False) return text, spans def parse_od_from_text_and_spans( @@ -714,7 +707,7 @@ class Florence2PostProcesser(object): return instances def parse_ocr_from_text_and_spans(self, - text, + text, pattern, image_size, area_threshold=-1.0, @@ -818,9 +811,26 @@ class Florence2PostProcesser(object): return instances - def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False): - # temporary parse solution, split by '.' - # ignore and + def parse_description_with_bboxes_from_text_and_spans( + self, + text, + spans=None, + scores=None, + score_mode=None, + pattern=None, + image_size=None, + allow_empty_phrase=False + ): + def find_matched_token_indices(cur_span, token_spans): + inds = [] + for i, token_span in enumerate(token_spans): + if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]): + inds.append(i) + return inds + + cur_span = 0 + if text.startswith(''): + cur_span += 3 text = text.replace('', '') text = text.replace('', '') @@ -842,13 +852,16 @@ class Florence2PostProcesser(object): phrase_text_strip = pharse_text.replace('', '', 1) if phrase_text_strip == '' and not allow_empty_phrase: + cur_span += len(pharse_text) continue # parse phrase, get string phrase = re.search(pattern, phrase_text_strip) if phrase is None: + cur_span += len(pharse_text) continue + phrase_span = phrase.span() phrase = phrase.group() # remove leading and trailing spaces phrase = phrase.strip() @@ -856,6 +869,7 @@ class Florence2PostProcesser(object): # parse bboxes by box_pattern bboxes_parsed = list(re.finditer(box_pattern, pharse_text)) if len(bboxes_parsed) == 0: + cur_span += len(pharse_text) continue # a list of list @@ -866,14 +880,42 @@ class Florence2PostProcesser(object): size=image_size ).tolist() + if score_mode == 'avg_loc_scores': + if spans is None or scores is None: + all_scores = None + else: + bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed] + all_scores = [] + for _spans in bbox_end_spans: + token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans) + loc_scores = [scores[token_i] for token_i in token_inds] + score = sum(loc_scores) / len(loc_scores) + all_scores.append(score) + elif score_mode == 'avg_cat_name_scores': + if spans is None or scores is None: + all_scores = None + else: + cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans) + cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds] + score = sum(cat_name_scores) / len(cat_name_scores) + all_scores = [score] * len(bboxes) + elif score_mode is None: + all_scores = None + else: + raise ValueError('Unknown score mode: {}'.format(score_mode)) + phrase = phrase.encode('ascii',errors='ignore').decode('ascii') - for _bboxes in bboxes: + for _idx, _bboxes in enumerate(bboxes): # Prepare instance. instance = {} instance['bbox'] = _bboxes # exclude non-ascii characters instance['cat_name'] = phrase + if all_scores is not None: + instance['score'] = math.exp(all_scores[_idx]) instances.append(instance) + + cur_span += len(pharse_text) return instances @@ -991,6 +1033,8 @@ class Florence2PostProcesser(object): def __call__( self, text=None, + sequence=None, + transition_beam_score=None, image_size=None, parse_tasks=None, ): @@ -1008,7 +1052,18 @@ class Florence2PostProcesser(object): assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported' # sequence or text should be provided - assert text is not None, 'text should be provided' + assert sequence is not None or text is not None, 'sequence or text should be provided' + assert sequence is None or text is None, 'only one of sequence and text should be provided' + + if sequence is not None: + sequence = sequence.tolist()[1:] + text, spans = self.decode_with_spans(self.tokenizer, sequence) + if transition_beam_score is not None: + transition_beam_score = transition_beam_score.tolist() + assert len(sequence) == len(transition_beam_score) + else: + spans = None + transition_beam_score = None parsed_dict = { 'text': text @@ -1019,6 +1074,7 @@ class Florence2PostProcesser(object): continue pattern = self.parse_tasks_configs[task].get('PATTERN', None) + score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None) if task == 'ocr': instances = self.parse_ocr_from_text_and_spans( @@ -1040,6 +1096,9 @@ class Florence2PostProcesser(object): elif task == 'description_with_bboxes': instances = self.parse_description_with_bboxes_from_text_and_spans( text, + spans=spans, + scores=transition_beam_score, + score_mode=score_mode, pattern=pattern, image_size=image_size, )