Spaces:
Configuration error
Configuration error
from sagemaker.huggingface import HuggingFace | |
import logging | |
import sys | |
from contextlib import contextmanager | |
from io import StringIO | |
from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME | |
from threading import current_thread | |
import streamlit as st | |
import sys | |
import sagemaker | |
import boto3 | |
def st_redirect(src, dst): | |
placeholder = st.empty() | |
output_func = getattr(placeholder, dst) | |
with StringIO() as buffer: | |
old_write = src.write | |
def new_write(b): | |
if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None): | |
buffer.write(b) | |
output_func(buffer.getvalue()) | |
else: | |
old_write(b) | |
try: | |
src.write = new_write | |
yield | |
finally: | |
src.write = old_write | |
def st_stdout(dst): | |
with st_redirect(sys.stdout, dst): | |
yield | |
def st_stderr(dst): | |
with st_redirect(sys.stderr, dst): | |
yield | |
task2script = { | |
"text-classification": { | |
"entry_point": "run_glue.py", | |
"source_dir": "examples/text-classification", | |
}, | |
"token-classification": { | |
"entry_point": "run_ner.py", | |
"source_dir": "examples/token-classification", | |
}, | |
"question-answering": { | |
"entry_point": "run_qa.py", | |
"source_dir": "examples/question-answering", | |
}, | |
"summarization": { | |
"entry_point": "run_summarization.py", | |
"source_dir": "examples/seq2seq", | |
}, | |
"translation": { | |
"entry_point": "run_translation.py", | |
"source_dir": "examples/seq2seq", | |
}, | |
"causal-language-modeling": { | |
"entry_point": "run_clm.py", | |
"source_dir": "examples/language-modeling", | |
}, | |
"masked-language-modeling": { | |
"entry_point": "run_mlm.py", | |
"source_dir": "examples/language-modeling", | |
}, | |
} | |
def train_estimtator(parameter, config): | |
with st_stdout("code"): | |
logger = logging.getLogger(__name__) | |
logging.basicConfig( | |
level=logging.getLevelName("INFO"), | |
handlers=[logging.StreamHandler(sys.stdout)], | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
logger.info = print | |
# git configuration to download our fine-tuning script | |
git_config = {"repo": "https://github.com/huggingface/transformers.git", "branch": "v4.4.2"} | |
# creating fine-tuning script | |
entry_point = task2script[parameter["task"]]["entry_point"] | |
source_dir = task2script[parameter["task"]]["source_dir"] | |
# create train file | |
# iam configuration | |
session = boto3.session.Session( | |
aws_access_key_id=config["aws_access_key_id"], | |
aws_secret_access_key=config["aws_secret_accesskey"], | |
region_name=config["region"], | |
) | |
sess = sagemaker.Session(boto_session=session) | |
iam = session.client( | |
"iam", aws_access_key_id=config["aws_access_key_id"], aws_secret_access_key=config["aws_secret_accesskey"] | |
) | |
role = iam.get_role(RoleName=config["aws_sagemaker_role"])["Role"]["Arn"] | |
logger.info(f"role: {role}") | |
instance_type = config["instance_type"].split("|")[1].split("|")[0].strip() | |
logger.info(f"instance_type: {instance_type}") | |
hyperparameters = { | |
"output_dir": "/opt/ml/model", | |
"do_train": True, | |
"do_eval": True, | |
"do_predict": True, | |
**parameter, | |
} | |
del hyperparameters["task"] | |
# create estimator | |
huggingface_estimator = HuggingFace( | |
entry_point=entry_point, | |
source_dir=source_dir, | |
git_config=git_config, | |
base_job_name=config["job_name"], | |
instance_type=instance_type, | |
sagemaker_session=sess, | |
instance_count=config["instance_count"], | |
role=role, | |
transformers_version="4.4", | |
pytorch_version="1.6", | |
py_version="py36", | |
hyperparameters=hyperparameters, | |
) | |
# train | |
huggingface_estimator.fit() | |