Spaces:
Runtime error
Runtime error
File size: 5,156 Bytes
a1d409e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all.
For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded)
as a standard :class:`~transformers.BertForSequenceClassification`.
"""
import argparse
import os
import shutil
import torch
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
def main(args):
pruning_method = args.pruning_method
threshold = args.threshold
model_name_or_path = args.model_name_or_path.rstrip("/")
target_model_path = args.target_model_path
print(f"Load fine-pruned model from {model_name_or_path}")
model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin"))
pruned_model = {}
for name, tensor in model.items():
if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
pruned_model[name] = tensor
print(f"Copied layer {name}")
elif "classifier" in name or "qa_output" in name:
pruned_model[name] = tensor
print(f"Copied layer {name}")
elif "bias" in name:
pruned_model[name] = tensor
print(f"Copied layer {name}")
else:
if pruning_method == "magnitude":
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "topK":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
mask = TopKBinarizer.apply(scores, threshold)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "sigmoied_threshold":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
mask = ThresholdBinarizer.apply(scores, threshold, True)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "l0":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
l, r = -0.1, 1.1
s = torch.sigmoid(scores)
s_bar = s * (r - l) + l
mask = s_bar.clamp(min=0.0, max=1.0)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
else:
raise ValueError("Unknown pruning method")
if target_model_path is None:
target_model_path = os.path.join(
os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}"
)
if not os.path.isdir(target_model_path):
shutil.copytree(model_name_or_path, target_model_path)
print(f"\nCreated folder {target_model_path}")
torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin"))
print("\nPruned model saved! See you later!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pruning_method",
choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
type=str,
required=True,
help=(
"Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
" sigmoied_threshold = Soft movement pruning)"
),
)
parser.add_argument(
"--threshold",
type=float,
required=False,
help=(
"For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
"Not needed for `l0`"
),
)
parser.add_argument(
"--model_name_or_path",
type=str,
required=True,
help="Folder containing the model that was previously fine-pruned",
)
parser.add_argument(
"--target_model_path",
default=None,
type=str,
required=False,
help="Folder containing the model that was previously fine-pruned",
)
args = parser.parse_args()
main(args)
|