poem_analysis / graph_guided_learning.py
esocoder's picture
first commit
996aa19
raw
history blame
No virus
3.03 kB
import os
import streamlit as st
import networkx as nx
import torch
import matplotlib.pyplot as plt
from torch_geometric.utils import from_networkx
from vae_model import create_vae
from utils import read_poems_from_directory
from sklearn.preprocessing import StandardScaler
import numpy as np
from individual_analyzes import analyze_sentiment # Assuming the sentiment analysis function is defined here
def build_poem_graph(poems, sentiment_labels):
poem_graph = nx.Graph()
poem_graph.add_nodes_from(range(len(poems)))
# Add edges based on similarity between poems (example: based on shared words)
for i in range(len(poems)):
for j in range(i+1, len(poems)):
if sentiment_labels[i] == sentiment_labels[j]:
poem_graph.add_edge(i, j)
return poem_graph
def visualize_poem_graph(poem_graph, sentiment_labels):
pos = nx.spring_layout(poem_graph)
colors = ['skyblue' if label == 'positive' else 'lightcoral' for label in sentiment_labels]
nx.draw_networkx_nodes(poem_graph, pos, node_size=200, node_color=colors)
nx.draw_networkx_edges(poem_graph, pos, edge_color='gray')
nx.draw_networkx_labels(poem_graph, pos, font_size=10)
plt.axis('off')
st.pyplot(plt)
def graph_guided_learning_page():
st.header("Graph Guided Learning")
# Load and process poems
poems_directory = "./poems"
if os.path.isdir(poems_directory):
poems = read_poems_from_directory(poems_directory)
if poems:
# Perform sentiment analysis on the poems
sentiment_labels = analyze_sentiment(poems)
# Example feature extraction from poems
def extract_features(poems):
# Placeholder example: each poem is represented by the length of its text
return np.array([[len(poem)] for poem in poems])
features = extract_features(poems)
scaler = StandardScaler()
scaled_features = scaler.fit_transform(features)
# Create VAE model and encode poems
input_dim = scaled_features.shape[1]
latent_dim = 16
vae, encoder = create_vae(input_dim, latent_dim)
vae.fit(scaled_features, scaled_features, epochs=50, batch_size=256, validation_split=0.2)
latent_features = encoder.predict(scaled_features)
# Build a graph based on sentiment similarity
poem_graph = build_poem_graph(poems, sentiment_labels)
# Visualize the poem graph with sentiment labels
visualize_poem_graph(poem_graph, sentiment_labels)
# Convert poem graph to PyTorch Geometric data format
data = from_networkx(poem_graph)
data.x = torch.tensor(latent_features, dtype=torch.float32)
st.write("Latent Features:")
st.write(latent_features)
else:
st.warning("No poems found in the specified directory.")
else:
st.error("The specified path is not a valid directory.")