friedrichor commited on
Commit
b03a999
1 Parent(s): a6d3762
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import argparse
2
  import gradio as gr
3
 
@@ -90,15 +92,19 @@ def main(args):
90
 
91
 
92
  if __name__ == "__main__":
 
 
 
 
93
  parser = argparse.ArgumentParser()
94
 
95
  parser.add_argument('--intent_predict_model_name', type=str, default="t5-base")
96
- parser.add_argument('--intent_predict_model_weights_path', type=str, default="model_weights/Tiger_t5_base_encoder.pth")
97
 
98
  parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium")
99
- parser.add_argument('--text_dialog_model_weights_path', type=str, default="model_weights/Tiger_DialoGPT_medium.pth")
100
 
101
- parser.add_argument('--text2image_model_weights_path', type=str, default="model_weights/stable-diffusion-2-1-realistic")
102
 
103
  parser.add_argument('--device', default="cuda:6")
104
 
 
1
+ import os
2
+ import sys
3
  import argparse
4
  import gradio as gr
5
 
 
92
 
93
 
94
  if __name__ == "__main__":
95
+ intent_predict_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_t5_base_encoder.pth")
96
+ text_dialog_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_DialoGPT_medium.pth")
97
+ text2image_model_weights_path = os.path.join(sys.path[0], "model_weights/stable-diffusion-2-1-realistic")
98
+
99
  parser = argparse.ArgumentParser()
100
 
101
  parser.add_argument('--intent_predict_model_name', type=str, default="t5-base")
102
+ parser.add_argument('--intent_predict_model_weights_path', type=str, default=intent_predict_model_weights_path)
103
 
104
  parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium")
105
+ parser.add_argument('--text_dialog_model_weights_path', type=str, default=text_dialog_model_weights_path)
106
 
107
+ parser.add_argument('--text2image_model_weights_path', type=str, default=text2image_model_weights_path)
108
 
109
  parser.add_argument('--device', default="cuda:6")
110