System Info
Info:
Framework: TensorFlow 2 (Keras)
Version: 2.6
OS: Kaggle
Who can help?
Swin Model Card @amyeroberts
TensorFlow: @Rocketknight1
Vision: @NielsRogge, @sgugger
Information
Tasks
Reproduction
A recent kaggle competition (hosted by Google), I tried to use pretrained tf
swin transformer model from hugging face but even with the base model, I consistently received out of memory error. Below is the submission status with a base_tf_swin
model.
Some note:
- Other framework like pytorch works fine here.
- Other than this model, much larger model like
tf_convnext_xlarge
is able to run without OOM.
So, I’m assuming there might be some potential memory leakage in tf_swin
implementation. Below is the code I use to build the complete model.
id = "microsoft/swin-base-patch4-window7-224-in22k"
from transformers import AutoFeatureExtractor, TFSwinModel
feature_extractor = AutoFeatureExtractor.from_pretrained(id)
inputs = keras.Input(shape=(None, None, 3), dtype='uint8')
mode_inputs = tf.cast(inputs, tf.float32)
mode_inputs = keras.layers.Resizing(*INPUT_SHAPE)(mode_inputs)
mode_inputs = keras.layers.Rescaling(scale=1.0 / 255)(mode_inputs)
mode_inputs = keras.layers.Normalization(
mean=feature_extractor.image_mean,
variance=[x ** 2 for x in feature_extractor.image_std ],
axis=3
)(mode_inputs)
mode_inputs = keras.layers.Permute(dims=(3, 1, 2))(mode_inputs)
tf_huggingface_module = TFSwinModel.from_pretrained(id)
tf_huggingface_module.trainable = False
logits = tf_huggingface_module(mode_inputs)
adv_logits = keras.Dense(64)(logits.pooler_output)
outputs = keras.layers.Lambda(
lambda x: tf.math.l2_normalize(x, axis=-1), name='embedding_norm'
)(adv_logits)
tf_huggingface_classifier = keras.Model(inputs, outputs)
Expected behavior
It should work like other model. To reproduce the issue exactly, (in the worst case), you may need to run it on kaggle platform. Kaggle submission status (as shown in the above diagram) is not very descriptive other than just showing submission status :(. Mainly, I like to know what could be the cause of it and any possible solution.
Read more here: Source link