Impossible_llm / utils.py
Yaning1001's picture
Add files using upload-large-folder tool
81dc001 verified
# utils_qwen.py
# Author: Yaning
from collections import deque
from string import punctuation
from transformers import AutoTokenizer, AddedToken
from functools import partial
from numpy.random import default_rng
from nltk.tree import ParentedTree
import torch
##############################################################################
# CONSTANTS
##############################################################################
BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
# Yj: 用于在参数解析和数据加载时指定数据集
# 影响数据集的预处理过程,如生成训练、开发、测试和单元测试集。
SEEDS = [21, 57, 84]
CHECKPOINTS = list(range(50, 501, 50))
GENRES = {
"aochildes": "CHILDES",
"bnc_spoken": "British National Corpus (BNC)",
"cbt": "Children’s Book Test",
"children_stories": "Children’s Stories Text Corpus",
"gutenberg": "Standardized Project Gutenberg Corpus",
"open_subtitles": "OpenSubtitles",
"qed": "QCRI Educational Domain Corpus",
"simple_wikipedia": "Simple Wikipedia",
"switchboard": "Switchboard Dialog Act Corpus",
"wikipedia": "Wikipedia"
}
CHECKPOINT_WRITE_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_models"
CHECKPOINT_READ_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_models"
# BABYLM_DATA_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_data"
BABYLM_DATA_PATH = "."
MARKER_HOP_SING = "🅂"
MARKER_HOP_PLUR = "🄿"
MARKER_REV = "🅁"
BOS_TOKEN = "<BOS_TOKEN>"
PART_TOKENS = set(["n't", "'ll", "'s", "'re", "'ve", "'m"])
PUNCT_TOKENS = set(punctuation)
MODEL_NAME = "gpt2"
##############################################################################
# PARENS MODELS (Structurally-pretrained)
##############################################################################
PAREN_MODEL_PATH = "/u/scr/isabelvp//tilt-stuff/tilt-finetuning/pretrained_checkpoints/"
PAREN_MODELS = {
"CROSS": "flat-parens_vocab500-uniform_deplength-nesting-nolimit",
"NEST": "nested-parens0.49_vocab500-uniform",
"RAND": "random_vocab500-uniform",
}
##############################################################################
# HELPER FUNCTIONS
##############################################################################
def write_file(directory, filename, lines):
f = open(directory + filename, "w")
f.writelines(lines)
f.close()
def get_qwen_tokenizer_with_markers(marker_list):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# If no new markers to add, return normal tokenizer
if len(marker_list) == 0:
return tokenizer
# Create tokens and return modified tokenizer
new_tokens = []
for marker in marker_list:
new_tokens.append(AddedToken(marker, lstrip=True, rstrip=False))
tokenizer.add_tokens(new_tokens)
return tokenizer
qwen_original_tokenizer = get_qwen_tokenizer_with_markers([])
# GPT-2 hop tokenization
qwen_hop_tokenizer = get_qwen_tokenizer_with_markers(
[MARKER_HOP_SING, MARKER_HOP_PLUR])
# Get ids of marker tokens
marker_sg_token = qwen_hop_tokenizer.get_added_vocab()[
MARKER_HOP_SING]
# Yj:获取分词器中所有自定义添加的标记及其对应的 token ID
marker_pl_token = qwen_hop_tokenizer.get_added_vocab()[
MARKER_HOP_PLUR]
# Qwen reverse tokenization
qwen_rev_tokenizer = get_qwen_tokenizer_with_markers(
[MARKER_REV])
# Get ids of marker tokens
marker_rev_token = qwen_rev_tokenizer.get_added_vocab()[
MARKER_REV]
# Qwen determiner tokenization
qwen_det_tokenizer = get_qwen_tokenizer_with_markers(
[BOS_TOKEN])
# Get id of BOS token
bos_token_id = qwen_det_tokenizer.get_added_vocab()[BOS_TOKEN]
MARKER_TOKEN_IDS = [marker_sg_token, marker_pl_token, marker_rev_token]
def compute_surprisals(model, input_ids):
# Get the log probabilities from the model
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[:, :-1]
shifted_input_ids = input_ids[:, 1:]
# Get the log probabilities for the actual next tokens
log_probs = torch.log2(torch.nn.functional.softmax(logits, dim=-1))
true_log_probs = log_probs.gather(
2, shifted_input_ids.unsqueeze(-1)).squeeze(-1)
# Get the negative log probabilities
neg_log_probs = (-true_log_probs).tolist()
surprisals = [[None] + probs for probs in neg_log_probs]
return surprisals
def compute_token_probabilities(model, input_ids, token_id, pad_token_id):
# Get the log probabilities from the model
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[:, :-1]
probs = torch.nn.functional.softmax(logits, dim=-1)
# Get the probabilities for the specified token at each position
token_probs = probs[:, :, token_id]
# Convert to list and add None at the beginning to align with input tokens
# Put null probability for instances of pad token
token_probs_list = []
for batch_i, probs in enumerate(token_probs):
input_ids_seq = input_ids[batch_i].tolist() + [pad_token_id]
filtered = [p if input_ids_seq[pos_i+1] !=
pad_token_id else None for pos_i, p in enumerate(probs.tolist())]
token_probs_list.append([None] + filtered)
return token_probs_list
def merge_part_tokens(words):
result = []
for s in words:
if result and s in PART_TOKENS and len(result) > 0:
result[-1] += s
else:
result.append(s)
return result
def __affect_hop_word(word):
return word["feats"] and "Person=3" in word["feats"] \
and "Tense=Pres" in word["feats"] \
and "VerbForm=Fin" in word["feats"] \
and "Number" in word["feats"]
def __perturb_hop_words(sent, num_hops, marker_sg, marker_pl):
perturbed_tokens, _ = __perturb_hop_words_complete_hops(
sent, num_hops, marker_sg, marker_pl)
return perturbed_tokens
def check_word_hops_completed(sent, num_hops=4, marker=MARKER_HOP_SING):
_, hops_completed = __perturb_hop_words_complete_hops(
sent, num_hops, marker, marker)
return hops_completed
def __perturb_hop_words_complete_hops(sent, num_hops, marker_sg, marker_pl):
word_annotations = sent["word_annotations"].copy()
word_annotations.reverse()
hop_completed = []
new_sent = []
for word in word_annotations:
# Identify 3.pres verbs
if __affect_hop_word(word):
# Lemmatize verb if possible
new_sent.append(
word["lemma"] if word["lemma"] is not None else word["text"])
# Marker hopping logic
insert_index = len(new_sent)-1
skipped_words = 0
while skipped_words < num_hops and insert_index > 0:
# Handle edge case when punctuation (or sequence of
# punctuation) begin the sentence
if (not any([c.isalnum() for c in
"".join(new_sent[:insert_index])])):
break
# Yj: 如果字符串中不存在任何字母或数字字符(即都是标点、空格等)
# Count word as skipped if it is not a special token
if (new_sent[insert_index] not in PART_TOKENS) and \
(not set(new_sent[insert_index]).issubset(PUNCT_TOKENS)):
skipped_words += 1
insert_index -= 1
# Handle edge case when insert index is punctuation (and this is not
# sentence-initial punctuation)
if any([c.isalnum() for c in
"".join(new_sent[:insert_index])]):
while insert_index != 0 and (new_sent[insert_index] in PART_TOKENS
or set(new_sent[insert_index]).issubset(PUNCT_TOKENS)):
insert_index -= 1
# Handle edge case when token before insert index is part/aux token
if insert_index != 0 and new_sent[insert_index-1] in PART_TOKENS:
insert_index -= 1
# Log if this sentence had all full hops
hop_completed.append(skipped_words == num_hops)
# Use correct marker for singular vs. plural
if "Number=Sing" in word["feats"]:
new_sent.insert(insert_index, marker_sg)
elif "Number=Plur" in word["feats"]:
new_sent.insert(insert_index, marker_pl)
else:
raise Exception(
"Number not in verb features\n" + sent["sent_text"])
else:
new_sent.append(word["text"])
new_sent.reverse()
sent_string = " ".join(merge_part_tokens(new_sent))
tokens = qwen_hop_tokenizer.encode(sent_string)
return tokens, all(hop_completed) and len(hop_completed) > 0
def __perturb_hop_tokens(sent, num_hops):
word_annotations = sent["word_annotations"].copy()
word_annotations.reverse()
new_sent = deque()
tokens = []
for word in word_annotations:
# Identify 3.pres verbs
if __affect_hop_word(word):
# Lemmatize verb if possible
lemma = word["lemma"] if word["lemma"] is not None else word["text"]
if len(new_sent) > 0 and new_sent[0] in PART_TOKENS:
lemma = lemma + new_sent[0]
new_sent.popleft()
if len(new_sent) > 0:
sent_string = " ".join(merge_part_tokens(new_sent))
tokens = qwen_hop_tokenizer.encode(
" " + sent_string) + tokens
# Use correct marker for singular vs. plural
if "Number=Sing" in word["feats"]:
tokens.insert(num_hops, marker_sg_token)
elif "Number=Plur" in word["feats"]:
tokens.insert(num_hops, marker_pl_token)
else:
raise Exception(
"Number not in verb features\n" + sent["sent_text"])
new_sent = deque()
new_sent.append(lemma)
else:
new_sent.appendleft(word["text"])
if len(new_sent) > 0:
sent_string = " ".join(merge_part_tokens(new_sent))
tokens = qwen_hop_tokenizer.encode(sent_string) + tokens
return tokens
def __perturb_reverse(sent, rng, reverse, full):
# Get sentence text and GPT-2 tokens
tokens = qwen_rev_tokenizer.encode(sent["sent_text"])
# Pick random index to insert REV token
i = rng.choice(len(tokens)+1)
tokens.insert(i, marker_rev_token)
# Extract tokens before/after the marker, and reverse tokens after
tokens_before = tokens[:i+1]
tokens_after = tokens[i+1:]
if reverse:
tokens_after.reverse()
new_tokens = tokens_before + tokens_after
if full:
assert not reverse
new_tokens.reverse()
return new_tokens
def __perturb_shuffle_deterministic(sent, seed, shuffle):
# Get sentence text and GPT-2 tokens
tokens = qwen_original_tokenizer.encode(sent["sent_text"])
if shuffle:
default_rng(seed).shuffle(tokens)
return tokens
def __perturb_shuffle_nondeterministic(sent, rng):
# Get sentence text and GPT-2 tokens
tokens = qwen_original_tokenizer.encode(sent["sent_text"])
rng.shuffle(tokens)
return tokens
def __perturb_shuffle_local(sent, seed, window=5):
# Get sentence text and GPT-2 tokens
tokens = qwen_original_tokenizer.encode(sent["sent_text"])
# Shuffle tokens in batches of size window
shuffled_tokens = []
for i in range(0, len(tokens), window):
batch = tokens[i:i+window].copy()
default_rng(seed).shuffle(batch)
shuffled_tokens += batch
return shuffled_tokens
def __perturb_shuffle_even_odd(sent):
# Get sentence text and GPT-2 tokens
tokens = qwen_original_tokenizer.encode(sent["sent_text"])
even = [tok for i, tok in enumerate(tokens) if i % 2 == 0]
odd = [tok for i, tok in enumerate(tokens) if i % 2 != 0]
return even + odd
##############################################################################
# AFFECT FUNCTIONS
# These functions define when a perturbation has been applied to a sentence
# not. This is used for identifying which test sentences have been
# altered to separate affected vs. unaffected senences. Affect functions are
# functions of the input sentence object and return a boolean.
##############################################################################
def affect_hop(sent):
return any([__affect_hop_word(word) for word in sent['word_annotations']]) \
and sent["constituency_parse"] is not None
def affect_reverse(sent):
return True
def affect_shuffle(sent):
return True
def affect_none(sent):
return False
##############################################################################
# FILTER FUNCTIONS
# These functions define when an affected sentence should be included in the
# final dataset. For instance, hop perturbations where the marker is placed
# at the end of the sentence should be excluded. A filter function returns
# True if an affected sentence should be included in the dataset.
##############################################################################
def filter_hop(sent):
# Assertion needed since filter function is only defined for affected
# sentences
assert (affect_hop(sent))
return check_word_hops_completed(sent, 4)
def filter_reverse(sent):
return True
def filter_shuffle(sent):
tokens = qwen_original_tokenizer.encode(sent["sent_text"])
return len(tokens) > 1 and len(tokens) <= 350
def filter_none(sent):
return False
##############################################################################
# PERTURBATION FUNCTIONS
# These functions define how a perturbation will affect a sentence. They
# take in a sentence object and an optional marker
# for verb transformations. They return a string representing the transformed
# sentence.
##############################################################################
def perturb_hop_words4(sent):
return __perturb_hop_words(sent, 4, MARKER_HOP_SING, MARKER_HOP_PLUR)
def perturb_hop_tokens4(sent):
return __perturb_hop_tokens(sent, 4)
def perturb_hop_control(sent):
return __perturb_hop_tokens(sent, 0)
def perturb_reverse(sent, rng, reverse=True, full=False):
return __perturb_reverse(sent, rng, reverse, full)
def perturb_shuffle_deterministic(sent, seed=None, shuffle=True):
return __perturb_shuffle_deterministic(sent, seed, shuffle)
def perturb_shuffle_nondeterministic(sent, rng):
return __perturb_shuffle_nondeterministic(sent, rng)
def perturb_shuffle_local(sent, seed, window):
return __perturb_shuffle_local(sent, seed, window)
def perturb_shuffle_even_odd(sent):
return __perturb_shuffle_even_odd(sent)
##############################################################################
# PERTURBATIONS
# This dict maps the name of a perturbation to its perturbation and filter
# functions. The names and functions in this dict are used throughout the
# repo.
##############################################################################
PERTURBATIONS = {
"shuffle_control": {
"perturbation_function": partial(perturb_shuffle_deterministic, seed=None, shuffle=False),
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#606060",
},
"shuffle_nondeterministic": {
"perturbation_function": partial(perturb_shuffle_nondeterministic, rng=default_rng(0)),
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#E8384F",
},
"shuffle_deterministic21": {
"perturbation_function": partial(perturb_shuffle_deterministic, seed=21, shuffle=True),
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#FFB000",
},
"shuffle_deterministic57": {
"perturbation_function": partial(perturb_shuffle_deterministic, seed=57, shuffle=True),
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#8db000",
},
"shuffle_deterministic84": {
"perturbation_function": partial(perturb_shuffle_deterministic, seed=84, shuffle=True),
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#62BB35",
},
"shuffle_local3": {
"perturbation_function": partial(perturb_shuffle_local, seed=0, window=3),
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#208EA3",
},
"shuffle_local5": {
"perturbation_function": partial(perturb_shuffle_local, seed=0, window=5),
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#4178BC",
},
"shuffle_local10": {
"perturbation_function": partial(perturb_shuffle_local, seed=0, window=10),
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#AA71FF",
},
"shuffle_even_odd": {
"perturbation_function": perturb_shuffle_even_odd,
"affect_function": affect_shuffle,
"filter_function": filter_shuffle,
"qwen_tokenizer": qwen_original_tokenizer,
"color": "#E37CFF",
},
"reverse_control": {
"perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=False, full=False),
"affect_function": affect_reverse,
"filter_function": filter_reverse,
"qwen_tokenizer": qwen_rev_tokenizer,
"color": "#606060",
},
"reverse_partial": {
"perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=True, full=False),
"affect_function": affect_reverse,
"filter_function": filter_reverse,
"qwen_tokenizer": qwen_rev_tokenizer,
"color": "#E5A836",
},
"reverse_full": {
"perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=False, full=True),
"affect_function": affect_reverse,
"filter_function": filter_reverse,
"qwen_tokenizer": qwen_rev_tokenizer,
"color": "#A348A6",
},
"hop_control": {
"perturbation_function": perturb_hop_control,
"affect_function": affect_hop,
"filter_function": filter_hop,
"qwen_tokenizer": qwen_hop_tokenizer,
"color": "#606060",
},
"hop_tokens4": {
"perturbation_function": perturb_hop_tokens4,
"affect_function": affect_hop,
"filter_function": filter_hop,
"qwen_tokenizer": qwen_hop_tokenizer,
"color": "#fa8128",
},
"hop_words4": {
"perturbation_function": perturb_hop_words4,
"affect_function": affect_hop,
"filter_function": filter_hop,
"qwen_tokenizer": qwen_hop_tokenizer,
"color": "#03a0ff",
},
}
# # utils.py
# # Author: Julie Kallini
# from collections import deque
# from string import punctuation
# from transformers import AutoTokenizer, AddedToken
# from functools import partial
# from numpy.random import default_rng
# from nltk.tree import ParentedTree
# import torch
# ##############################################################################
# # CONSTANTS
# ##############################################################################
# BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
# # Yj: 用于在参数解析和数据加载时指定数据集
# # 影响数据集的预处理过程,如生成训练、开发、测试和单元测试集。
# SEEDS = [21, 57, 84]
# CHECKPOINTS = list(range(50, 501, 50))
# GENRES = {
# "aochildes": "CHILDES",
# "bnc_spoken": "British National Corpus (BNC)",
# "cbt": "Children’s Book Test",
# "children_stories": "Children’s Stories Text Corpus",
# "gutenberg": "Standardized Project Gutenberg Corpus",
# "open_subtitles": "OpenSubtitles",
# "qed": "QCRI Educational Domain Corpus",
# "simple_wikipedia": "Simple Wikipedia",
# "switchboard": "Switchboard Dialog Act Corpus",
# "wikipedia": "Wikipedia"
# }
# CHECKPOINT_WRITE_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_models"
# CHECKPOINT_READ_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_models"
# # BABYLM_DATA_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_data"
# BABYLM_DATA_PATH = "."
# MARKER_HOP_SING = "🅂"
# MARKER_HOP_PLUR = "🄿"
# MARKER_REV = "🅁"
# BOS_TOKEN = "<BOS_TOKEN>"
# PART_TOKENS = set(["n't", "'ll", "'s", "'re", "'ve", "'m"])
# PUNCT_TOKENS = set(punctuation)
# ##############################################################################
# # PARENS MODELS (Structurally-pretrained)
# ##############################################################################
# PAREN_MODEL_PATH = "/u/scr/isabelvp//tilt-stuff/tilt-finetuning/pretrained_checkpoints/"
# PAREN_MODELS = {
# "CROSS": "flat-parens_vocab500-uniform_deplength-nesting-nolimit",
# "NEST": "nested-parens0.49_vocab500-uniform",
# "RAND": "random_vocab500-uniform",
# }
# ##############################################################################
# # HELPER FUNCTIONS
# ##############################################################################
# def write_file(directory, filename, lines):
# f = open(directory + filename, "w")
# f.writelines(lines)
# f.close()
# def get_gpt2_tokenizer_with_markers(marker_list):
# tokenizer = AutoTokenizer.from_pretrained("gpt2")
# # If no new markers to add, return normal tokenizer
# if len(marker_list) == 0:
# return tokenizer
# # Create tokens and return modified tokenizer
# new_tokens = []
# for marker in marker_list:
# new_tokens.append(AddedToken(marker, lstrip=True, rstrip=False))
# tokenizer.add_tokens(new_tokens)
# return tokenizer
# gpt2_original_tokenizer = get_gpt2_tokenizer_with_markers([])
# # GPT-2 hop tokenization
# gpt2_hop_tokenizer = get_gpt2_tokenizer_with_markers(
# [MARKER_HOP_SING, MARKER_HOP_PLUR])
# # Get ids of marker tokens
# marker_sg_token = gpt2_hop_tokenizer.get_added_vocab()[
# MARKER_HOP_SING]
# # Yj:获取分词器中所有自定义添加的标记及其对应的 token ID
# marker_pl_token = gpt2_hop_tokenizer.get_added_vocab()[
# MARKER_HOP_PLUR]
# # GPT-2 reverse tokenization
# gpt2_rev_tokenizer = get_gpt2_tokenizer_with_markers(
# [MARKER_REV])
# # Get ids of marker tokens
# marker_rev_token = gpt2_rev_tokenizer.get_added_vocab()[
# MARKER_REV]
# # GPT-2 determiner tokenization
# gpt2_det_tokenizer = get_gpt2_tokenizer_with_markers(
# [BOS_TOKEN])
# # Get id of BOS token
# bos_token_id = gpt2_det_tokenizer.get_added_vocab()[BOS_TOKEN]
# MARKER_TOKEN_IDS = [marker_sg_token, marker_pl_token, marker_rev_token]
# def compute_surprisals(model, input_ids):
# # Get the log probabilities from the model
# with torch.no_grad():
# outputs = model(input_ids)
# logits = outputs.logits[:, :-1]
# shifted_input_ids = input_ids[:, 1:]
# # Get the log probabilities for the actual next tokens
# log_probs = torch.log2(torch.nn.functional.softmax(logits, dim=-1))
# true_log_probs = log_probs.gather(
# 2, shifted_input_ids.unsqueeze(-1)).squeeze(-1)
# # Get the negative log probabilities
# neg_log_probs = (-true_log_probs).tolist()
# surprisals = [[None] + probs for probs in neg_log_probs]
# return surprisals
# def compute_token_probabilities(model, input_ids, token_id, pad_token_id):
# # Get the log probabilities from the model
# with torch.no_grad():
# outputs = model(input_ids)
# logits = outputs.logits[:, :-1]
# probs = torch.nn.functional.softmax(logits, dim=-1)
# # Get the probabilities for the specified token at each position
# token_probs = probs[:, :, token_id]
# # Convert to list and add None at the beginning to align with input tokens
# # Put null probability for instances of pad token
# token_probs_list = []
# for batch_i, probs in enumerate(token_probs):
# input_ids_seq = input_ids[batch_i].tolist() + [pad_token_id]
# filtered = [p if input_ids_seq[pos_i+1] !=
# pad_token_id else None for pos_i, p in enumerate(probs.tolist())]
# token_probs_list.append([None] + filtered)
# return token_probs_list
# def merge_part_tokens(words):
# result = []
# for s in words:
# if result and s in PART_TOKENS and len(result) > 0:
# result[-1] += s
# else:
# result.append(s)
# return result
# def __affect_hop_word(word):
# return word["feats"] and "Person=3" in word["feats"] \
# and "Tense=Pres" in word["feats"] \
# and "VerbForm=Fin" in word["feats"] \
# and "Number" in word["feats"]
# def __perturb_hop_words(sent, num_hops, marker_sg, marker_pl):
# perturbed_tokens, _ = __perturb_hop_words_complete_hops(
# sent, num_hops, marker_sg, marker_pl)
# return perturbed_tokens
# def check_word_hops_completed(sent, num_hops=4, marker=MARKER_HOP_SING):
# _, hops_completed = __perturb_hop_words_complete_hops(
# sent, num_hops, marker, marker)
# return hops_completed
# def __perturb_hop_words_complete_hops(sent, num_hops, marker_sg, marker_pl):
# word_annotations = sent["word_annotations"].copy()
# word_annotations.reverse()
# hop_completed = []
# new_sent = []
# for word in word_annotations:
# # Identify 3.pres verbs
# if __affect_hop_word(word):
# # Lemmatize verb if possible
# new_sent.append(
# word["lemma"] if word["lemma"] is not None else word["text"])
# # Marker hopping logic
# insert_index = len(new_sent)-1
# skipped_words = 0
# while skipped_words < num_hops and insert_index > 0:
# # Handle edge case when punctuation (or sequence of
# # punctuation) begin the sentence
# if (not any([c.isalnum() for c in
# "".join(new_sent[:insert_index])])):
# break
# # Yj: 如果字符串中不存在任何字母或数字字符(即都是标点、空格等)
# # Count word as skipped if it is not a special token
# if (new_sent[insert_index] not in PART_TOKENS) and \
# (not set(new_sent[insert_index]).issubset(PUNCT_TOKENS)):
# skipped_words += 1
# insert_index -= 1
# # Handle edge case when insert index is punctuation (and this is not
# # sentence-initial punctuation)
# if any([c.isalnum() for c in
# "".join(new_sent[:insert_index])]):
# while insert_index != 0 and (new_sent[insert_index] in PART_TOKENS
# or set(new_sent[insert_index]).issubset(PUNCT_TOKENS)):
# insert_index -= 1
# # Handle edge case when token before insert index is part/aux token
# if insert_index != 0 and new_sent[insert_index-1] in PART_TOKENS:
# insert_index -= 1
# # Log if this sentence had all full hops
# hop_completed.append(skipped_words == num_hops)
# # Use correct marker for singular vs. plural
# if "Number=Sing" in word["feats"]:
# new_sent.insert(insert_index, marker_sg)
# elif "Number=Plur" in word["feats"]:
# new_sent.insert(insert_index, marker_pl)
# else:
# raise Exception(
# "Number not in verb features\n" + sent["sent_text"])
# else:
# new_sent.append(word["text"])
# new_sent.reverse()
# sent_string = " ".join(merge_part_tokens(new_sent))
# tokens = gpt2_hop_tokenizer.encode(sent_string)
# return tokens, all(hop_completed) and len(hop_completed) > 0
# def __perturb_hop_tokens(sent, num_hops):
# word_annotations = sent["word_annotations"].copy()
# word_annotations.reverse()
# new_sent = deque()
# tokens = []
# for word in word_annotations:
# # Identify 3.pres verbs
# if __affect_hop_word(word):
# # Lemmatize verb if possible
# lemma = word["lemma"] if word["lemma"] is not None else word["text"]
# if len(new_sent) > 0 and new_sent[0] in PART_TOKENS:
# lemma = lemma + new_sent[0]
# new_sent.popleft()
# if len(new_sent) > 0:
# sent_string = " ".join(merge_part_tokens(new_sent))
# tokens = gpt2_hop_tokenizer.encode(
# " " + sent_string) + tokens
# # Use correct marker for singular vs. plural
# if "Number=Sing" in word["feats"]:
# tokens.insert(num_hops, marker_sg_token)
# elif "Number=Plur" in word["feats"]:
# tokens.insert(num_hops, marker_pl_token)
# else:
# raise Exception(
# "Number not in verb features\n" + sent["sent_text"])
# new_sent = deque()
# new_sent.append(lemma)
# else:
# new_sent.appendleft(word["text"])
# if len(new_sent) > 0:
# sent_string = " ".join(merge_part_tokens(new_sent))
# tokens = gpt2_hop_tokenizer.encode(sent_string) + tokens
# return tokens
# def __perturb_reverse(sent, rng, reverse, full):
# # Get sentence text and GPT-2 tokens
# tokens = gpt2_rev_tokenizer.encode(sent["sent_text"])
# # Pick random index to insert REV token
# i = rng.choice(len(tokens)+1)
# tokens.insert(i, marker_rev_token)
# # Extract tokens before/after the marker, and reverse tokens after
# tokens_before = tokens[:i+1]
# tokens_after = tokens[i+1:]
# if reverse:
# tokens_after.reverse()
# new_tokens = tokens_before + tokens_after
# if full:
# assert not reverse
# new_tokens.reverse()
# return new_tokens
# def __perturb_shuffle_deterministic(sent, seed, shuffle):
# # Get sentence text and GPT-2 tokens
# tokens = gpt2_original_tokenizer.encode(sent["sent_text"])
# if shuffle:
# default_rng(seed).shuffle(tokens)
# return tokens
# def __perturb_shuffle_nondeterministic(sent, rng):
# # Get sentence text and GPT-2 tokens
# tokens = gpt2_original_tokenizer.encode(sent["sent_text"])
# rng.shuffle(tokens)
# return tokens
# def __perturb_shuffle_local(sent, seed, window=5):
# # Get sentence text and GPT-2 tokens
# tokens = gpt2_original_tokenizer.encode(sent["sent_text"])
# # Shuffle tokens in batches of size window
# shuffled_tokens = []
# for i in range(0, len(tokens), window):
# batch = tokens[i:i+window].copy()
# default_rng(seed).shuffle(batch)
# shuffled_tokens += batch
# return shuffled_tokens
# def __perturb_shuffle_even_odd(sent):
# # Get sentence text and GPT-2 tokens
# tokens = gpt2_original_tokenizer.encode(sent["sent_text"])
# even = [tok for i, tok in enumerate(tokens) if i % 2 == 0]
# odd = [tok for i, tok in enumerate(tokens) if i % 2 != 0]
# return even + odd
# ##############################################################################
# # AFFECT FUNCTIONS
# # These functions define when a perturbation has been applied to a sentence
# # not. This is used for identifying which test sentences have been
# # altered to separate affected vs. unaffected senences. Affect functions are
# # functions of the input sentence object and return a boolean.
# ##############################################################################
# def affect_hop(sent):
# return any([__affect_hop_word(word) for word in sent['word_annotations']]) \
# and sent["constituency_parse"] is not None
# def affect_reverse(sent):
# return True
# def affect_shuffle(sent):
# return True
# def affect_none(sent):
# return False
# ##############################################################################
# # FILTER FUNCTIONS
# # These functions define when an affected sentence should be included in the
# # final dataset. For instance, hop perturbations where the marker is placed
# # at the end of the sentence should be excluded. A filter function returns
# # True if an affected sentence should be included in the dataset.
# ##############################################################################
# def filter_hop(sent):
# # Assertion needed since filter function is only defined for affected
# # sentences
# assert (affect_hop(sent))
# return check_word_hops_completed(sent, 4)
# def filter_reverse(sent):
# return True
# def filter_shuffle(sent):
# tokens = gpt2_original_tokenizer.encode(sent["sent_text"])
# return len(tokens) > 1 and len(tokens) <= 350
# def filter_none(sent):
# return False
# ##############################################################################
# # PERTURBATION FUNCTIONS
# # These functions define how a perturbation will affect a sentence. They
# # take in a sentence object and an optional marker
# # for verb transformations. They return a string representing the transformed
# # sentence.
# ##############################################################################
# def perturb_hop_words4(sent):
# return __perturb_hop_words(sent, 4, MARKER_HOP_SING, MARKER_HOP_PLUR)
# def perturb_hop_tokens4(sent):
# return __perturb_hop_tokens(sent, 4)
# def perturb_hop_control(sent):
# return __perturb_hop_tokens(sent, 0)
# def perturb_reverse(sent, rng, reverse=True, full=False):
# return __perturb_reverse(sent, rng, reverse, full)
# def perturb_shuffle_deterministic(sent, seed=None, shuffle=True):
# return __perturb_shuffle_deterministic(sent, seed, shuffle)
# def perturb_shuffle_nondeterministic(sent, rng):
# return __perturb_shuffle_nondeterministic(sent, rng)
# def perturb_shuffle_local(sent, seed, window):
# return __perturb_shuffle_local(sent, seed, window)
# def perturb_shuffle_even_odd(sent):
# return __perturb_shuffle_even_odd(sent)
# ##############################################################################
# # PERTURBATIONS
# # This dict maps the name of a perturbation to its perturbation and filter
# # functions. The names and functions in this dict are used throughout the
# # repo.
# ##############################################################################
# PERTURBATIONS = {
# "shuffle_control": {
# "perturbation_function": partial(perturb_shuffle_deterministic, seed=None, shuffle=False),
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#606060",
# },
# "shuffle_nondeterministic": {
# "perturbation_function": partial(perturb_shuffle_nondeterministic, rng=default_rng(0)),
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#E8384F",
# },
# "shuffle_deterministic21": {
# "perturbation_function": partial(perturb_shuffle_deterministic, seed=21, shuffle=True),
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#FFB000",
# },
# "shuffle_deterministic57": {
# "perturbation_function": partial(perturb_shuffle_deterministic, seed=57, shuffle=True),
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#8db000",
# },
# "shuffle_deterministic84": {
# "perturbation_function": partial(perturb_shuffle_deterministic, seed=84, shuffle=True),
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#62BB35",
# },
# "shuffle_local3": {
# "perturbation_function": partial(perturb_shuffle_local, seed=0, window=3),
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#208EA3",
# },
# "shuffle_local5": {
# "perturbation_function": partial(perturb_shuffle_local, seed=0, window=5),
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#4178BC",
# },
# "shuffle_local10": {
# "perturbation_function": partial(perturb_shuffle_local, seed=0, window=10),
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#AA71FF",
# },
# "shuffle_even_odd": {
# "perturbation_function": perturb_shuffle_even_odd,
# "affect_function": affect_shuffle,
# "filter_function": filter_shuffle,
# "gpt2_tokenizer": gpt2_original_tokenizer,
# "color": "#E37CFF",
# },
# "reverse_control": {
# "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=False, full=False),
# "affect_function": affect_reverse,
# "filter_function": filter_reverse,
# "gpt2_tokenizer": gpt2_rev_tokenizer,
# "color": "#606060",
# },
# "reverse_partial": {
# "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=True, full=False),
# "affect_function": affect_reverse,
# "filter_function": filter_reverse,
# "gpt2_tokenizer": gpt2_rev_tokenizer,
# "color": "#E5A836",
# },
# "reverse_full": {
# "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=False, full=True),
# "affect_function": affect_reverse,
# "filter_function": filter_reverse,
# "gpt2_tokenizer": gpt2_rev_tokenizer,
# "color": "#A348A6",
# },
# "hop_control": {
# "perturbation_function": perturb_hop_control,
# "affect_function": affect_hop,
# "filter_function": filter_hop,
# "gpt2_tokenizer": gpt2_hop_tokenizer,
# "color": "#606060",
# },
# "hop_tokens4": {
# "perturbation_function": perturb_hop_tokens4,
# "affect_function": affect_hop,
# "filter_function": filter_hop,
# "gpt2_tokenizer": gpt2_hop_tokenizer,
# "color": "#fa8128",
# },
# "hop_words4": {
# "perturbation_function": perturb_hop_words4,
# "affect_function": affect_hop,
# "filter_function": filter_hop,
# "gpt2_tokenizer": gpt2_hop_tokenizer,
# "color": "#03a0ff",
# },
# }