{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/Users/mattrosinski/git/transformers-demos\n" ] } ], "source": [ "!pwd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/Users/mattrosinski/mambaforge/bin/python\n" ] } ], "source": [ "!which python" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "from transformers import pipeline" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "pipe = pipeline(\"text-classification\", model=\"mrosinski/autotrain-distilbert-risk-ranker-1593356256\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "text = ['''\n", "A truck narrowly missed a person on a bicycle when they were reversing out of the depot on Friday. \\\n", " It was early morning before the sun was up and the cyclist did not have a light. Fortunately the \\\n", " driver spotted the rider and braked heavily to avoid a collision.\n", "''']" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "60" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len('A truck narrowly missed a person on a bicycle when they were')" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "277" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(text[0])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'label': 'high risk', 'score': 0.7180770635604858}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = pipe(text)[0]\n", "preds" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7180770635604858" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "score = preds['score']\n", "score" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "float" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(score)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('High Risk', 'Confidence score: 71.8%')" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds[\"label\"].title(), f'Confidence Score: {round(preds[\"score\"], 3)*100}%'" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "\n", "def predict(text):\n", " if len(text) < 60:\n", " return 'Invalid entry', 'Try adding more information to describe the incident'\n", " preds = pipe(text)[0]\n", " return preds[\"label\"].title(), f'Confidence Score: {round(preds[\"score\"], 3)*100}%'" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tuple" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(text)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "string = 'some text'" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tuple" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(string)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.10", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "b29df53a373a75f04ac216b720f486bfd73e41a5a0018838dedd490de94cf09c" } } }, "nbformat": 4, "nbformat_minor": 2 }