import multiprocessing import gradio as gr import torch from omnigenome import OmniGenomeModelForRNADesign # Assuming this is where the model class is defined import RNA # ViennaRNA library for RNA structure plotting import tempfile # For handling temporary files import os # For file operations # Initialize the model for RNA design model = OmniGenomeModelForRNADesign(model_path="anonymous8/OmniGenome-186M") model.to("cuda" if torch.cuda.is_available() else "cpu") # RNA Design function with structure plotting def design_rna(target_structure): if not 0 < len(target_structure) <= 50: return "The online demo only supports RNA structures with 1 to 100 characters.", None # Run the genetic algorithm to design RNA sequences best_sequences = model.run_rna_design( structure=target_structure.strip(), mutation_ratio=0.5, num_population=50, num_generation=100 ) # Select the best sequence (assuming it's the first one) best_sequence = best_sequences[0] # Generate the RNA secondary structure plot plot_path = plot_rna_structure(best_sequence, target_structure) return best_sequence, plot_path # Function to plot RNA structure and return the path to the SVG image def plot_rna_structure(sequence, structure): # Create a temporary file to save the SVG plot with tempfile.NamedTemporaryFile(delete=False, suffix=".svg") as tmpfile: plot_path = tmpfile.name # Plot RNA structure using ViennaRNA RNA.svg_rna_plot(sequence, structure, plot_path) return plot_path # Launch the app if __name__ == "__main__": multiprocessing.set_start_method('spawn', force=True) # Gradio Interface with vertical layout with gr.Blocks() as iface: gr.Markdown("# RNA Design with OmniGenome") gr.Markdown( "Enter a target RNA secondary structure to generate a designed RNA sequence and visualize its structure. " "Please note that the online demo only supports RNA structures with 1 to 50 bases due to computational resource shortage." "For larger structures, please run the model locally." ) gr.Markdown(""" ### Example RNA Structures: - `(((((......)))))` - `((((((.((((....))))))).)))..........` - `((....)).((....))` - `.(((((((((((...)))))....)))))).` - `..((((((((.....))))((((.....))))))))..` """) with gr.Column(): target_structure_input = gr.Textbox( label="Target RNA Secondary Structure", placeholder="Enter RNA structure here, e.g., (((((......)))))" ) output_sequence = gr.Textbox(label="Designed RNA Sequence") output_plot = gr.Image(type="filepath", label="RNA Structure Plot") # Defining the function call on input submit_button = gr.Button("Submit") submit_button.click( design_rna, inputs=target_structure_input, outputs=[output_sequence, output_plot] ) iface.launch()