Skip to content

Commit

Permalink
[pipeline] Add pool option to image feature extraction pipeline (hu…
Browse files Browse the repository at this point in the history
…ggingface#28985)

* Add pool option

* PR comments - error message and exact outputs check
  • Loading branch information
amyeroberts authored Feb 20, 2024
1 parent c47576c commit e770f03
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 8 deletions.
32 changes: 25 additions & 7 deletions src/transformers/pipelines/image_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
image_processor_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the image processor e.g.
{"size": {"height": 100, "width": 100}}
pool (`bool`, *optional*, defaults to `False`):
Whether or not to return the pooled output. If `False`, the model will return the raw hidden states.
""",
)
class ImageFeatureExtractionPipeline(Pipeline):
Expand Down Expand Up @@ -41,9 +43,14 @@ class ImageFeatureExtractionPipeline(Pipeline):
[huggingface.co/models](https://huggingface.co/models).
"""

def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, **kwargs):
def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, pool=None, **kwargs):
preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
postprocess_params = {"return_tensors": return_tensors} if return_tensors is not None else {}

postprocess_params = {}
if pool is not None:
postprocess_params["pool"] = pool
if return_tensors is not None:
postprocess_params["return_tensors"] = return_tensors

if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
Expand All @@ -59,14 +66,25 @@ def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs

def postprocess(self, model_outputs, return_tensors=False):
# [0] is the first available tensor, logits or last_hidden_state.
def postprocess(self, model_outputs, pool=None, return_tensors=False):
pool = pool if pool is not None else False

if pool:
if "pooler_output" not in model_outputs:
raise ValueError(
"No pooled output was returned. Make sure the model has a `pooler` layer when using the `pool` option."
)
outputs = model_outputs["pooler_output"]
else:
# [0] is the first available tensor, logits or last_hidden_state.
outputs = model_outputs[0]

if return_tensors:
return model_outputs[0]
return outputs
if self.framework == "pt":
return model_outputs[0].tolist()
return outputs.tolist()
elif self.framework == "tf":
return model_outputs[0].numpy().tolist()
return outputs.numpy().tolist()

def __call__(self, *args, **kwargs):
"""
Expand Down
32 changes: 31 additions & 1 deletion tests/pipelines/test_pipelines_image_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,39 @@ def test_small_model_pt(self):
nested_simplify(outputs[0][0]),
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip

@require_torch
def test_small_model_w_pooler_pt(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit-w-pooler", framework="pt"
)
img = prepare_img()
outputs = feature_extractor(img, pool=True)
self.assertEqual(
nested_simplify(outputs[0]),
[-0.056, 0.083, 0.021, 0.038, 0.242, -0.279, -0.033, -0.003, 0.200, -0.192, 0.045, -0.095, -0.077, 0.017, -0.058, -0.063, -0.029, -0.204, 0.014, 0.042, 0.305, -0.205, -0.099, 0.146, -0.287, 0.020, 0.168, -0.052, 0.046, 0.048, -0.156, 0.093]) # fmt: skip

@require_tf
def test_small_model_tf(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit-w-pooler", framework="tf"
)
img = prepare_img()
outputs = feature_extractor(img)
self.assertEqual(
nested_simplify(outputs[0][0]),
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip

@require_tf
def test_small_model_w_pooler_tf(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit-w-pooler", framework="tf"
)
img = prepare_img()
outputs = feature_extractor(img, pool=True)
self.assertEqual(
nested_simplify(outputs[0]),
[-0.056, 0.083, 0.021, 0.038, 0.242, -0.279, -0.033, -0.003, 0.200, -0.192, 0.045, -0.095, -0.077, 0.017, -0.058, -0.063, -0.029, -0.204, 0.014, 0.042, 0.305, -0.205, -0.099, 0.146, -0.287, 0.020, 0.168, -0.052, 0.046, 0.048, -0.156, 0.093]) # fmt: skip

@require_torch
def test_image_processing_small_model_pt(self):
feature_extractor = pipeline(
Expand All @@ -91,6 +113,10 @@ def test_image_processing_small_model_pt(self):
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
self.assertEqual(np.squeeze(outputs).shape, (226, 32))

# Test pooling option
outputs = feature_extractor(img, pool=True)
self.assertEqual(np.squeeze(outputs).shape, (32,))

@require_tf
def test_image_processing_small_model_tf(self):
feature_extractor = pipeline(
Expand All @@ -109,6 +135,10 @@ def test_image_processing_small_model_tf(self):
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
self.assertEqual(np.squeeze(outputs).shape, (226, 32))

# Test pooling option
outputs = feature_extractor(img, pool=True)
self.assertEqual(np.squeeze(outputs).shape, (32,))

@require_torch
def test_return_tensors_pt(self):
feature_extractor = pipeline(
Expand Down

0 comments on commit e770f03

Please sign in to comment.