Spaces:
Runtime error
Runtime error
friedrichor
commited on
Commit
•
b03a999
1
Parent(s):
a6d3762
update
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import argparse
|
2 |
import gradio as gr
|
3 |
|
@@ -90,15 +92,19 @@ def main(args):
|
|
90 |
|
91 |
|
92 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
93 |
parser = argparse.ArgumentParser()
|
94 |
|
95 |
parser.add_argument('--intent_predict_model_name', type=str, default="t5-base")
|
96 |
-
parser.add_argument('--intent_predict_model_weights_path', type=str, default=
|
97 |
|
98 |
parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium")
|
99 |
-
parser.add_argument('--text_dialog_model_weights_path', type=str, default=
|
100 |
|
101 |
-
parser.add_argument('--text2image_model_weights_path', type=str, default=
|
102 |
|
103 |
parser.add_argument('--device', default="cuda:6")
|
104 |
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
import argparse
|
4 |
import gradio as gr
|
5 |
|
|
|
92 |
|
93 |
|
94 |
if __name__ == "__main__":
|
95 |
+
intent_predict_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_t5_base_encoder.pth")
|
96 |
+
text_dialog_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_DialoGPT_medium.pth")
|
97 |
+
text2image_model_weights_path = os.path.join(sys.path[0], "model_weights/stable-diffusion-2-1-realistic")
|
98 |
+
|
99 |
parser = argparse.ArgumentParser()
|
100 |
|
101 |
parser.add_argument('--intent_predict_model_name', type=str, default="t5-base")
|
102 |
+
parser.add_argument('--intent_predict_model_weights_path', type=str, default=intent_predict_model_weights_path)
|
103 |
|
104 |
parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium")
|
105 |
+
parser.add_argument('--text_dialog_model_weights_path', type=str, default=text_dialog_model_weights_path)
|
106 |
|
107 |
+
parser.add_argument('--text2image_model_weights_path', type=str, default=text2image_model_weights_path)
|
108 |
|
109 |
parser.add_argument('--device', default="cuda:6")
|
110 |
|