rag_highlights / app.py
m-ric's picture
m-ric HF staff
Fix color theme
38fce31
import torch
from transformers import AutoTokenizer
from lxt.models.llama import LlamaForCausalLM, attnlrp
from lxt.utils import clean_tokens
import gradio as gr
import numpy as np
import spaces
from scipy.signal import convolve2d
from huggingface_hub import login
import os
from dotenv import load_dotenv
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.bfloat16,
)
load_dotenv()
login(os.getenv("HF_TOKEN"))
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
print(f"Loading model {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda", use_safetensors=True)
# model.gradient_checkpointing_enable()
attnlrp.register(model)
print(f"Loaded model.")
def really_clean_tokens(tokens):
tokens = clean_tokens(tokens)
cleaned_tokens = []
for token in tokens:
token = token.replace("_", " ").replace("▁", " ").replace("<s>", " ").replace("Ċ", " ").replace("Ġ", " ")
if token.startswith("<0x") and token.endswith(">"):
# Convert hex to character
char_code = int(token[3:-1], 16)
token = chr(char_code)
cleaned_tokens.append(token)
return cleaned_tokens
@spaces.GPU
def generate_and_visualize(prompt, num_tokens=10):
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)
input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))
generated_tokens_ids = []
all_relevances = []
for _ in range(num_tokens):
output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
max_logits.backward(max_logits)
relevance = input_embeds.grad.float().sum(-1).cpu()[0]
all_relevances.append(relevance)
next_token = max_indices.unsqueeze(0)
generated_tokens_ids.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
input_embeds = model.get_input_embeddings()(input_ids)
if next_token.item() == tokenizer.eos_token_id:
print("EOS token generated, stopping generation.")
break
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
return input_tokens, all_relevances, generated_tokens
def process_relevances(input_tokens, all_relevances, generated_tokens):
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
### FIND ZONES OF INTEREST
threshold_per_token = 0.2
kernel_width = 6
context_width = 20 # Number of tokens to include as context on each side
kernel = np.ones((kernel_width, kernel_width))
if len(generated_tokens) < kernel_width:
return [(token, None, None) for token in generated_tokens]
# Compute the rolling sum using 2D convolution
rolled_sum = convolve2d(attention_matrix, kernel, mode='valid') / kernel_width**2
# Find where the rolled sum is greater than the threshold
significant_areas = rolled_sum > threshold_per_token
print(f"Found {significant_areas.sum()} relevant tokens: lower threshold to find more. Max was {rolled_sum.max()}")
print("LENGTHS:", len(input_tokens), significant_areas.shape, len(generated_tokens))
def find_largest_contiguous_patch(array):
current_patch_start = None
best_width, best_patch_start = None, None
current_width = 0
for i in range(len(array)):
if array[i]:
if current_patch_start is not None and current_patch_start + current_width == i:
current_width += 1
else:
current_patch_start = i
current_width = 1
if current_patch_start and (best_width is None or current_width > best_width):
best_patch_start = current_patch_start
best_width = current_width
else:
current_width = 0
return best_width, best_patch_start
output_with_notes = []
for row in range(len(generated_tokens)-kernel_width+1):
best_width, best_patch_start = find_largest_contiguous_patch(significant_areas[row])
if best_width is not None:
output_with_notes.append((generated_tokens[row], (best_width, best_patch_start)))
else:
output_with_notes.append((generated_tokens[row], None))
output_with_notes += [(el, None) for el in generated_tokens[-kernel_width+1:]]
# Fuse the notes for consecutive output tokens if necessary
for i in range(len(output_with_notes)):
token, coords = output_with_notes[i]
if coords is not None:
best_width, best_patch_start = coords
note_width_generated = kernel_width
for next_id in range(i+1, min(i+2*kernel_width, len(output_with_notes))):
next_token, next_coords = output_with_notes[next_id]
if next_coords is not None:
next_width, next_patch_start = next_coords
if best_patch_start + best_width >= next_patch_start:
# then notes are overlapping: thus we delete the last one and make the first wider if needed
output_with_notes[next_id] = (next_token, None)
larger_end = max(best_patch_start + best_width, next_patch_start + next_width)
best_width = larger_end - best_patch_start
note_width_generated = kernel_width + (next_id-i)
output_with_notes[i] = (token, (best_width, best_patch_start), note_width_generated)
else:
output_with_notes[i] = (token, None, None)
# Convert to text slices
for i, (token, coords, width) in enumerate(output_with_notes):
if coords is not None:
best_width, best_patch_start = coords
significant_start = max(0, best_patch_start)
significant_end = best_patch_start + kernel_width + best_width
context_start = max(0, significant_start - context_width)
context_end = min(len(input_tokens), significant_end + context_width)
first_part = "".join(input_tokens[context_start:significant_start])
significant_part = "".join(input_tokens[significant_start:significant_end])
final_part = "".join(input_tokens[significant_end:context_end])
output_with_notes[i] = (token, (first_part, significant_part, final_part), width)
return output_with_notes
def create_html_with_hover(output_with_notes):
html = "<div id='output-container'>"
note_number = 0
i = 0
while i < len(output_with_notes):
(token, notes, width) = output_with_notes[i]
if notes is None:
html += f'{token}'
i += 1
else:
text = "".join(really_clean_tokens([element[0] for element in output_with_notes[i:i+width]]))
print(text)
first_part, significant_part, final_part = notes
formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
html += f'<span class="hover-note">{formatted_note}</span></span>'
note_number += 1
i += width
html += "</div>"
return html
@spaces.GPU
def on_generate(prompt, num_tokens):
input_tokens, all_relevances, generated_tokens = generate_and_visualize(prompt, num_tokens)
output_with_notes = process_relevances(input_tokens, all_relevances, generated_tokens)
html_output = create_html_with_hover(output_with_notes)
return html_output
css = """
#output-container {
font-size: 18px;
line-height: 1.5;
position: relative;
}
.hoverable {
color: var(--primary-500);
position: relative;
display: inline-block;
}
.hover-note {
display: none;
position: absolute;
padding: 5px;
border-radius: 5px;
bottom: 100%;
left: 0;
white-space: normal;
background-color: var(--input-background-fill);
max-width: 600px;
width: 500px;
word-wrap: break-word;
z-index: 100;
}
.hoverable:hover .hover-note {
display: block;
}
"""
examples = [
"""Context:
The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960.
Question: How many meters above 8000 did the 1922 expedition go?
Answer:""",
"""Context:
Hurricane Katrina killed hundreds of people as it made landfall on New Orleans in 2005 - many of these deaths could have been avoided if alerts had been given one day earlier. Accurate weather forecasts are really life-saving.
🔥 Now, NASA and IBM just dropped a game-changing new model: the first ever foundation model for weather! This means, it's the first time we have a generalist model not restricted to one task, but able to predict 160 weather variables!
Prithvi WxC (Prithvi, "पृथ्वी", is the Sanskrit name for Earth) - is a 2.3 billion parameter model, with an architecture close to previous vision transformers like Hiera.
💡 But it comes with some important tweaks: under the hood, Prithvi WxC uses a clever transformer-based architecture with 25 encoder and 5 decoder blocks. It alternates between "local" and "global" attention to capture both regional and global weather patterns.
Question: How many weather variables can Prithvi predict?
Answer:""",
"""Context:
Transformers v4.45.0 released: includes a lightning-fast method to build tools! ⚡️
During user research with colleagues @MoritzLaurer and @Jofthomas , we discovered that the class definition currently in used to define a Tool in transformers.agents is a bit tedious to use, because it goes in great detail.
➡️ So I've made an easier way to build tools: just make a function with type hints + a docstring, and add a @tool decorator in front.
✅ Voilà, you're good to go!
Question: How can you build tools simply in transformers?
Answer:""",
]
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# RAG with source linking using Source attribution with [LXT](https://lxt.readthedocs.io/en/latest/quickstart.html#tinyllama)")
input_text = gr.Textbox(label="Enter your prompt:", lines=10, value=examples[0])
num_tokens = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Number of tokens to generate (while no EOS token)")
generate_button = gr.Button("Generate")
output_html = gr.HTML(label="Generated Output")
generate_button.click(
on_generate,
inputs=[input_text, num_tokens],
outputs=[output_html]
)
gr.Markdown("Hover over the blue text with superscript numbers to see the important input tokens for that group.")
# Add clickable examples
gr.Examples(
examples=examples,
inputs=[input_text],
)
if __name__ == "__main__":
demo.launch()