poem_analysis / graph_guided_learning.py
esocoder's picture
first commit
996aa19
raw
history blame contribute delete
No virus
3.03 kB
import os
import streamlit as st
import networkx as nx
import torch
import matplotlib.pyplot as plt
from torch_geometric.utils import from_networkx
from vae_model import create_vae
from utils import read_poems_from_directory
from sklearn.preprocessing import StandardScaler
import numpy as np
from individual_analyzes import analyze_sentiment # Assuming the sentiment analysis function is defined here
def build_poem_graph(poems, sentiment_labels):
poem_graph = nx.Graph()
poem_graph.add_nodes_from(range(len(poems)))
# Add edges based on similarity between poems (example: based on shared words)
for i in range(len(poems)):
for j in range(i+1, len(poems)):
if sentiment_labels[i] == sentiment_labels[j]:
poem_graph.add_edge(i, j)
return poem_graph
def visualize_poem_graph(poem_graph, sentiment_labels):
pos = nx.spring_layout(poem_graph)
colors = ['skyblue' if label == 'positive' else 'lightcoral' for label in sentiment_labels]
nx.draw_networkx_nodes(poem_graph, pos, node_size=200, node_color=colors)
nx.draw_networkx_edges(poem_graph, pos, edge_color='gray')
nx.draw_networkx_labels(poem_graph, pos, font_size=10)
plt.axis('off')
st.pyplot(plt)
def graph_guided_learning_page():
st.header("Graph Guided Learning")
# Load and process poems
poems_directory = "./poems"
if os.path.isdir(poems_directory):
poems = read_poems_from_directory(poems_directory)
if poems:
# Perform sentiment analysis on the poems
sentiment_labels = analyze_sentiment(poems)
# Example feature extraction from poems
def extract_features(poems):
# Placeholder example: each poem is represented by the length of its text
return np.array([[len(poem)] for poem in poems])
features = extract_features(poems)
scaler = StandardScaler()
scaled_features = scaler.fit_transform(features)
# Create VAE model and encode poems
input_dim = scaled_features.shape[1]
latent_dim = 16
vae, encoder = create_vae(input_dim, latent_dim)
vae.fit(scaled_features, scaled_features, epochs=50, batch_size=256, validation_split=0.2)
latent_features = encoder.predict(scaled_features)
# Build a graph based on sentiment similarity
poem_graph = build_poem_graph(poems, sentiment_labels)
# Visualize the poem graph with sentiment labels
visualize_poem_graph(poem_graph, sentiment_labels)
# Convert poem graph to PyTorch Geometric data format
data = from_networkx(poem_graph)
data.x = torch.tensor(latent_features, dtype=torch.float32)
st.write("Latent Features:")
st.write(latent_features)
else:
st.warning("No poems found in the specified directory.")
else:
st.error("The specified path is not a valid directory.")