dpfried's picture
Duplicate from facebook/incoder-demo
78c9772
import sys
from typing import List
import traceback
import os
import base64
import logging
logging.basicConfig(level=logging.INFO)
import modules.cloud_logging
import tokenizers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import pprint
# needs to be imported *before* transformers
if os.path.exists('debug'):
BIG_MODEL = False
CUDA = False
else:
BIG_MODEL = True
CUDA = True
# from flask import Flask, request, render_template
# from flask_cors import CORS
# app = Flask(__name__, static_folder='static')
# app.config['TEMPLATES_AUTO_RELOAD'] = Tru
# CORS(app, resources= {
# r"/generate": {"origins": origins},
# r"/infill": {"origins": origins},
# })
# origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"]
PORT = 7860
VERBOSE = False
if os.path.exists('unlock'):
MAX_LENGTH = 2048
else:
MAX_LENGTH = 256+64
TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'
if BIG_MODEL:
model_name = "facebook/incoder-6B"
kwargs = dict(
revision="float16",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
else:
model_name = "facebook/incoder-1B"
kwargs = dict()
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
logging.info("loading model")
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
logging.info("loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.info("loading complete")
if CUDA:
model = model.half().cuda()
BOS = "<|endoftext|>"
EOM = "<|endofmask|>"
def make_sentinel(i):
return f"<|mask:{i}|>"
SPECIAL_TOKENS = [make_sentinel(i) for i in range(256)] + [EOM]
def generate(input, length_limit=None, temperature=None):
input_ids = tokenizer(input, return_tensors="pt").input_ids
if CUDA:
input_ids = input_ids.cuda()
current_length = input_ids.flatten().size(0)
max_length = length_limit + current_length
truncated = False
if max_length > MAX_LENGTH:
max_length = MAX_LENGTH
truncated = True
if max_length == current_length:
return input, True
output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=max_length)
detok_hypo_str = tokenizer.decode(output.flatten())
if detok_hypo_str.startswith(BOS):
detok_hypo_str = detok_hypo_str[len(BOS):]
return detok_hypo_str, truncated
def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1):
assert isinstance(parts, list)
retries_attempted = 0
done = False
while (not done) and (retries_attempted < max_retries):
any_truncated = False
retries_attempted += 1
if VERBOSE:
logging.info(f"retry {retries_attempted}")
if len(parts) == 1:
prompt = parts[0]
else:
prompt = ""
# encode parts separated by sentinel
for sentinel_ix, part in enumerate(parts):
prompt += part
if extra_sentinel or (sentinel_ix < len(parts) - 1):
prompt += make_sentinel(sentinel_ix)
# prompt += TokenizerWrapper.make_sentinel(0)
infills = []
complete = []
done = True
for sentinel_ix, part in enumerate(parts[:-1]):
complete.append(part)
prompt += make_sentinel(sentinel_ix)
completion, this_truncated = generate(prompt, length_limit, temperature)
any_truncated |= this_truncated
completion = completion[len(prompt):]
if EOM not in completion:
if VERBOSE:
logging.info(f"warning: {EOM} not found")
completion += EOM
# TODO: break inner loop here
done = False
completion = completion[:completion.index(EOM) + len(EOM)]
infilled = completion[:-len(EOM)]
infills.append(infilled)
complete.append(infilled)
prompt += completion
complete.append(parts[-1])
text = ''.join(complete)
if VERBOSE:
logging.info("generated text:")
logging.info(prompt)
logging.info()
logging.info("parts:")
logging.info(parts)
logging.info()
logging.info("infills:")
logging.info(infills)
logging.info()
logging.info("restitched text:")
logging.info(text)
logging.info()
return {
'text': text,
'parts': parts,
'infills': infills,
'retries_attempted': retries_attempted,
'truncated': any_truncated,
}
@app.head("/")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="static/index.html", media_type="text/html")
@app.get('/generate')
# async def generate_maybe(request: Request):
async def generate_maybe(info: str):
# form = await info.json()
# form = await request.json()
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
# fix padding, following https://stackoverflow.com/a/9956217/1319683
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
form = json.loads(info)
# print(form)
prompt = form['prompt']
length_limit = int(form['length'])
temperature = float(form['temperature'])
logging.info(json.dumps({
'length': length_limit,
'temperature': temperature,
'prompt': prompt,
}))
try:
generation, truncated = generate(prompt, length_limit, temperature)
if truncated:
message = TRUNCATION_MESSAGE
else:
message = ''
return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation, 'message': message}
except Exception as e:
traceback.print_exception(*sys.exc_info())
logging.error(e)
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
@app.get('/infill')
# async def infill_maybe(request: Request):
async def infill_maybe(info: str):
# form = await info.json()
# form = await request.json()
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
# fix padding, following https://stackoverflow.com/a/9956217/1319683
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
form = json.loads(info)
length_limit = int(form['length'])
temperature = float(form['temperature'])
max_retries = 1
extra_sentinel = True
logging.info(json.dumps({
'length': length_limit,
'temperature': temperature,
'parts_joined': '<infill>'.join(form['parts']),
}))
try:
if len(form['parts']) > 4:
return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Can't use more than 3 <infill> tokens in this demo (for efficiency)."}
generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries)
generation['result'] = 'success'
generation['type'] = 'infill'
if generation['truncated']:
generation['message'] = TRUNCATION_MESSAGE
else:
generation['message'] = ''
return generation
# return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']}
except Exception as e:
traceback.print_exception(*sys.exc_info())
logging.error(e)
return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}
if __name__ == "__main__":
app.run(host='0.0.0.0', port=PORT, threaded=False)