HUANG-Stephanie commited on
Commit
23477cc
1 Parent(s): aef2bf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -40
app.py CHANGED
@@ -41,47 +41,13 @@ mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
41
  ds = []
42
  images = []
43
 
44
- @app.post("/index")
45
- def index(files: List[UploadFile] = File(...)):
46
- global ds, images
47
- images = []
48
- ds = []
49
- for file in files:
50
- content = file.read()
51
- pdf_image_list = convert_from_path(io.BytesIO(content))
52
- images.extend(pdf_image_list)
53
-
54
- dataloader = DataLoader(
55
- images,
56
- batch_size=4,
57
- shuffle=False,
58
- collate_fn=lambda x: process_images(processor, x),
59
- )
60
- for batch_doc in dataloader:
61
- with torch.no_grad():
62
- batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
63
- embeddings_doc = model(**batch_doc)
64
- ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
65
-
66
- return {"message": f"Uploaded and converted {len(images)} pages"}
67
 
68
- @app.post("/search")
69
- def search(query: str, k: int):
70
- qs = []
71
- with torch.no_grad():
72
- batch_query = process_queries(processor, [query], mock_image)
73
- batch_query = {k: v.to(device) for k, v in batch_query.items()}
74
- embeddings_query = model(**batch_query)
75
- qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
76
-
77
- retriever_evaluator = CustomEvaluator(is_multi_vector=True)
78
- scores = retriever_evaluator.evaluate(qs, ds)
79
-
80
- top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
81
-
82
- results = [{"page": idx, "image": "image_placeholder"} for idx in top_k_indices]
83
-
84
- return {"results": results}
85
 
86
  # Rediriger la racine vers /docs
87
  @app.get("/")
 
41
  ds = []
42
  images = []
43
 
44
+ # Initialiser le pipeline de génération de texte
45
+ pipe = pipeline("text2text-generation", model="google/flan-t5-small")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ @app.get("/generate")
48
+ def generate(text: str):
49
+ output = pipe(text)
50
+ return {"output": output[0]["generated_text"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Rediriger la racine vers /docs
53
  @app.get("/")