File size: 167,814 Bytes
18cb291 |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"file_path = \"words_new.txt\"\n",
"with open(file_path) as f:\n",
" lines = f.readlines()\n",
"\n",
"label_raw = lines[18:] # Skipping the initial lines that are comments\n",
"\n",
"image_texts = []\n",
"image_paths = []\n",
"default_path = \"iam_words/words/\"\n",
"for label in label_raw:\n",
" parts = label.strip().split() # Using strip() to remove any leading/trailing whitespaces\n",
" if len(parts) < 9: # Check if the line has fewer parts than expected\n",
" print(f\"Skipping line due to unexpected format: {label}\")\n",
" continue # Skip this iteration and move to the next line\n",
" if parts[1] == \"ok\":\n",
" image_texts.append(parts[-1])\n",
" image_id = parts[0]\n",
" subdir1 = image_id.split(\"-\")[0]\n",
" subdir2 = f\"{subdir1}-{image_id.split('-')[1]}\"\n",
" image_paths.append(os.path.join(default_path, subdir1, subdir2, f\"{image_id}.png\"))\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"image_texts=image_texts\n",
"image_paths=image_paths"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"word_lengths = [len(word) for word in image_texts]\n",
"plt.figure(figsize=(10, 6))\n",
"plt.hist(word_lengths, bins=range(1, max(word_lengths) + 1))\n",
"plt.title('Distribution of Word Lengths')\n",
"plt.xlabel('Word Length')\n",
"plt.ylabel('Frequency')\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from collections import Counter\n",
"word_counts = Counter(image_texts)\n",
"common_words = word_counts.most_common(10)\n",
"words, counts = zip(*common_words)\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"plt.bar(words, counts)\n",
"plt.title('Top 10 Most Common Words')\n",
"plt.xlabel('Words')\n",
"plt.ylabel('Frequency')\n",
"plt.xticks(rotation=45)\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The presence of \".\", \",\", \"of\", \"to\", \"and\", \"a\", \"in\", \"is\", \"=\", as the top words is indicative of common English language usage, particularly in written text where these words and symbols are frequently used for basic sentence structure and meaning.\n",
"\n",
"The right-skewed distribution of word lengths suggests that while most words are relatively short, there's a long tail of longer words. This is typical of many natural language datasets, where a large number of unique words are used infrequently, contributing to a long tail in the distribution. For handwriting recognition, this implies that your model needs to handle a wide range of word lengths, from short to very long words, which can be challenging in terms of both recognizing the individual characters and understanding the spatial relationships between them in longer words."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['!', '\"', '#', \"'\", '(', ')', '*', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n"
]
}
],
"source": [
"\n",
"### get vocabulary for the current dataset\n",
"vocab = set(\"\".join(map(str, image_texts)))\n",
"print(sorted(vocab))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"19"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"max_label_len = max([len(str(text)) for text in image_texts])\n",
"max_label_len\n",
"\n",
"#Output = 19"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"maximum label length among your images' text labels is 19 characters. This information is crucial for several reasons:\n",
"\n",
"Model Input and Output: Knowing the maximum label length helps in defining the architecture of your neural network, especially the output dimensions of your model. For a handwriting recognition model, especially one based on a sequence-to-sequence framework like CTC (Connectionist Temporal Classification), you need to ensure that the model can handle sequences of this length.\n",
"\n",
"Padding Sequences: When training your model, you'll likely need to pad the sequences to ensure that all of them have the same length. This is important for batching purposes, as deep learning models require inputs to be of a uniform size. Knowing the maximum label length allows you to apply the correct amount of padding to each sequence.\n",
"\n",
"Performance Considerations: The maximum label length might also have implications for your model's performance. Longer sequences can be more challenging to predict accurately due to the increased possibilities for errors and the dependence on longer-term dependencies between characters. You might need to consider model architectures that are particularly good at capturing these dependencies, such as RNNs with LSTM or GRU layers."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([24, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76,\n",
" 76, 76])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"char_list = sorted(vocab)\n",
"\n",
"def encode_to_labels(txt):\n",
" # encoding each output word into digits\n",
" dig_lst = []\n",
" \n",
" for index, char in enumerate(txt):\n",
" try:\n",
" dig_lst.append(char_list.index(char))\n",
" except:\n",
" print(char)\n",
" \n",
" return pad_sequences([dig_lst], maxlen=max_label_len, padding='post', value=len(char_list))[0]\n",
"\n",
"padded_image_texts = list(map(encode_to_labels, image_texts))\n",
"\n",
"padded_image_texts[0]\n",
"\n",
"#Output : array([24, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This step is essential for a few reasons:\n",
"\n",
"Numerical Representation: Deep learning models work with numerical data, so converting your textual data into a numerical format is a necessary preprocessing step. Your choice of encoding each character to its index in a sorted list of all characters (char_list) is a common and effective approach.\n",
"\n",
"Sequence Padding: By padding the sequences, you ensure that all your input data to the model has a uniform shape. This is crucial for batching and processing through your neural network layers, which require consistent input dimensions. You've chosen 'post' padding, which adds any necessary padding at the end of the sequences, ensuring that the beginning of each text (where the actual data is) aligns.\n",
"\n",
"Handling Unknown Characters: Your try-except block within the encode_to_labels function is a good practice for handling characters that might not be present in your char_list. If you encounter such a character, your current setup will silently ignore it. Depending on your application, you might want to log these instances or add a special token to your char_list to represent unknown characters.\n",
"\n",
"Preparation for Model Training: This encoding and padding process prepares your dataset for model training, specifically for models that rely on a fixed input size, such as Convolutional Neural Networks (CNNs) combined with Recurrent Neural Networks (RNNs) for sequence prediction tasks like handwriting recognition.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1500x700 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Word Lengths Statistics:\n",
"count 38305.000000\n",
"mean 4.302415\n",
"std 2.679270\n",
"min 1.000000\n",
"25% 2.000000\n",
"50% 4.000000\n",
"75% 6.000000\n",
"max 19.000000\n",
"dtype: float64\n"
]
},
{
"data": {
"text/plain": [
"'\\nWord Lengths Statistics:\\ncount 38304.000000\\nmean 4.302371\\nstd 2.679291\\nmin 1.000000\\n25% 2.000000\\n50% 4.000000\\n75% 6.000000\\nmax 19.000000\\ndtype: float64\\n'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"word_lengths = [len(text) for text in image_texts]\n",
"\n",
"# Plotting the distribution of word lengths\n",
"plt.figure(figsize=(15, 7))\n",
"plt.hist(word_lengths, bins=50, alpha=0.7)\n",
"plt.title('Distribution of Word Lengths')\n",
"plt.xlabel('Word Length')\n",
"plt.ylabel('Frequency')\n",
"plt.grid(True)\n",
"plt.show()\n",
"\n",
"# Display basic statistics\n",
"print(\"Word Lengths Statistics:\")\n",
"print(pd.Series(word_lengths).describe())\n",
"\n",
"\n",
"'''\n",
"Word Lengths Statistics:\n",
"count 38304.000000\n",
"mean 4.302371\n",
"std 2.679291\n",
"min 1.000000\n",
"25% 2.000000\n",
"50% 4.000000\n",
"75% 6.000000\n",
"max 19.000000\n",
"dtype: float64\n",
"'''"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 2000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from collections import Counter\n",
"\n",
"# Flatten the list of all characters in image_texts\n",
"all_chars = [char for text in image_texts for char in text]\n",
"\n",
"# Calculate the frequency of each character\n",
"char_counter = Counter(all_chars)\n",
"\n",
"# Sort characters by frequency\n",
"sorted_chars = sorted(char_counter.items(), key=lambda pair: pair[1], reverse=True)\n",
"\n",
"# Plotting\n",
"plt.figure(figsize=(20, 10))\n",
"plt.bar([pair[0] for pair in sorted_chars], [pair[1] for pair in sorted_chars])\n",
"plt.xlabel('Characters')\n",
"plt.ylabel('Frequency')\n",
"plt.title('Character Frequency Distribution')\n",
"plt.xticks(rotation=90)\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Most common characters:\n",
"[('e', 19616), ('t', 14237), ('a', 12299), ('o', 11704), ('i', 11218), ('n', 11037), ('s', 9880), ('r', 9870), ('h', 8403), ('l', 6383)]\n",
"\n",
"Least common characters:\n",
"[('#', 36), ('6', 28), ('?', 28), ('7', 24), ('!', 21), ('/', 9), ('Q', 6), ('Z', 5), ('X', 4), ('*', 3)]\n"
]
},
{
"data": {
"text/plain": [
"\"Output\\nMost common characters:\\n[('e', 19615), ('t', 14237), ('a', 12299), ('o', 11703), ('i', 11218), ('n', 11037), ('s', 9880), ('r', 9870), ('h', 8403), ('l', 6381)]\\n\\nLeast common characters:\\n[('#', 36), ('6', 28), ('?', 28), ('7', 24), ('!', 21), ('/', 9), ('Q', 6), ('Z', 5), ('X', 4), ('*', 3)]\\n\""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Displaying the most and least common characters\n",
"print(\"Most common characters:\")\n",
"print(sorted_chars[:10]) # Top 10\n",
"\n",
"print(\"\\nLeast common characters:\")\n",
"print(sorted_chars[-10:]) # Bottom 10\n",
"\n",
"'''Output\n",
"Most common characters:\n",
"[('e', 19615), ('t', 14237), ('a', 12299), ('o', 11703), ('i', 11218), ('n', 11037), ('s', 9880), ('r', 9870), ('h', 8403), ('l', 6381)]\n",
"\n",
"Least common characters:\n",
"[('#', 36), ('6', 28), ('?', 28), ('7', 24), ('!', 21), ('/', 9), ('Q', 6), ('Z', 5), ('X', 4), ('*', 3)]\n",
"'''"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 1 corrupt images during re-check.\n"
]
}
],
"source": [
"# Re-checking for any corrupt images after the initial cleanup\n",
"recheck_corrupt_images = []\n",
"\n",
"for path in image_paths:\n",
" try:\n",
" img = cv2.imread(path)\n",
" if img is None:\n",
" raise ValueError(\"Image not readable\")\n",
" except Exception as e:\n",
" recheck_corrupt_images.append(path)\n",
"\n",
"print(f\"Found {len(recheck_corrupt_images)} corrupt images during re-check.\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Observations:\n",
"\n",
"Word Lengths Statistics:\n",
"count 38304.000000\n",
"mean 4.302371\n",
"std 2.679291\n",
"min 1.000000\n",
"25% 2.000000\n",
"50% 4.000000\n",
"75% 6.000000\n",
"max 19.000000\n",
"dtype: float64\n",
"\n",
"Most common characters:\n",
"[('e', 19615), ('t', 14237), ('a', 12299), ('o', 11703), ('i', 11218), ('n', 11037), ('s', 9880), ('r', 9870), ('h', 8403), ('l', 6381)]\n",
"\n",
"Least common characters:\n",
"[('#', 36), ('6', 28), ('?', 28), ('7', 24), ('!', 21), ('/', 9), ('Q', 6), ('Z', 5), ('X', 4), ('*', 3)]\n",
"\n",
"1. Character Distribution: The most frequent characters are common English letters ('e', 't', 'a'), which is expected given the nature of English text. The least common characters include special symbols and less common letters ('Q', 'Z', 'X'), which might not be surprising but warrants attention when designing your model to ensure it can recognize these infrequent characters adequately.\n",
"2. Word Lengths: The average word length in your dataset is around 4 characters, with a standard deviation of approximately 2.68. This variability indicates that your model will need to handle a wide range of input lengths. The maximum word length is 19, which is crucial for defining the dimensions of your model's input layer or for padding sequences.\n",
"3. Class Imbalance: There's a significant imbalance between the most and least common characters. This imbalance can lead to a model that performs well on frequent characters but struggles with rare characters or symbols.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Approach to Address Class Imbalance\n",
"One effective strategy is to use a weighted loss function during training, where less frequent classes (characters) are given higher weights, encouraging the model to pay more attention to these classes. This section will outline how to calculate class weights and apply them in a training loop, assuming a hypothetical neural network model setup for context."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 103.26065162907268, 1: 4.93957559045678, 2: 60.235380116959064, 3: 8.571042230081131, 4: 30.978195488721806, 5: 33.88240131578947, 6: 722.8245614035088, 7: 1.43607528755664, 8: 3.8448114968271745, 9: 1.0425354251012147, 10: 240.94152046783626, 11: 11.473405736563631, 12: 15.489097744360903, 13: 34.4202172096909, 14: 37.38747731397459, 15: 58.607396870554766, 16: 38.043397968605724, 17: 77.44548872180451, 18: 90.3530701754386, 19: 60.235380116959064, 20: 44.254564983888294, 21: 28.91298245614035, 22: 42.51909184726522, 23: 77.44548872180451, 24: 4.683528475616687, 25: 6.213391645302368, 26: 7.451799602098029, 27: 13.46878064727035, 28: 8.997816117056125, 29: 10.180627625401533, 30: 7.228245614035088, 31: 6.950236167341431, 32: 5.5037403152551425, 33: 29.303698435277383, 34: 24.92498487598306, 35: 8.121624285432683, 36: 3.700467037901922, 37: 8.960635058721182, 38: 15.379245987308698, 39: 6.905967147167281, 40: 361.4122807017544, 41: 7.972329721362229, 42: 4.93957559045678, 43: 3.3056001283697047, 44: 23.31692133559706, 45: 43.369473684210526, 46: 9.073111649416427, 47: 542.1184210526316, 48: 52.889602053915276, 49: 433.69473684210527, 50: 0.17631300790393742, 51: 0.9149677992449479, 52: 0.4972423031897561, 53: 0.372461986295178, 54: 0.11054617068773075, 55: 0.6296381196894676, 56: 0.7722484630379367, 57: 0.25805946497804666, 58: 0.1933030561785101, 59: 14.456491228070176, 60: 2.3596013973999197, 61: 0.33972641143827764, 62: 0.5829230333899265, 63: 0.19647310720399805, 64: 0.18527628880814476, 65: 0.7657039845376152, 66: 14.751521661296097, 67: 0.21970351410440997, 68: 0.2194811421265715, 69: 0.15231254366864694, 70: 0.5459400010600519, 71: 1.383837705303463, 72: 0.7624731660374565, 73: 8.605054302422724, 74: 0.7879628212974297, 75: 29.70511896178803}\n"
]
}
],
"source": [
"from sklearn.utils.class_weight import compute_class_weight\n",
"import numpy as np\n",
"\n",
"# Assuming 'char_list' is a list of all characters, and 'image_texts' contains all the words in the dataset\n",
"all_chars = ''.join(image_texts)\n",
"char_freq = {char: all_chars.count(char) for char in char_list}\n",
"\n",
"# Create a list of all characters in the dataset in the same order as 'char_list'\n",
"y = np.array([char for word in image_texts for char in word])\n",
"\n",
"# Calculate class weights\n",
"class_weights = compute_class_weight('balanced', classes=np.unique(y), y=y)\n",
"class_weight_dict = {i: weight for i, weight in enumerate(class_weights)}\n",
"\n",
"print(class_weight_dict)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"\\n# Load class weights from JSON file\\nwith open('class_weights.json', 'r') as infile:\\n class_weight_dict_loaded = json.load(infile)\\n\\n# Convert keys back to integers if necessary\\nclass_weight_dict_loaded = {int(k): v for k, v in class_weight_dict_loaded.items()}\\n\""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import json\n",
"\n",
"# Assuming 'class_weight_dict' is your dictionary of class weights\n",
"class_weight_dict_str = {str(k): v for k, v in class_weight_dict.items()}\n",
"\n",
"# Save to JSON file\n",
"with open('class_weights.json', 'w') as outfile:\n",
" json.dump(class_weight_dict_str, outfile)\n",
"\n",
"'''\n",
"# Load class weights from JSON file\n",
"with open('class_weights.json', 'r') as infile:\n",
" class_weight_dict_loaded = json.load(infile)\n",
"\n",
"# Convert keys back to integers if necessary\n",
"class_weight_dict_loaded = {int(k): v for k, v in class_weight_dict_loaded.items()}\n",
"'''"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining the Model Architecture\n",
"\n",
"Let's define a function to create a VGG-like model. We will simplify the architecture to fit the context of character recognition from handwriting. This involves using Convolutional layers followed by MaxPooling layers, and finally, a Flatten layer to connect to the dense layers leading up to the LSTM layers for sequence prediction."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def build_vgg_model(input_shape=(224, 224, 3), num_classes=76):\n",
" model = Sequential([\n",
" # First Conv Block\n",
" Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=input_shape),\n",
" Conv2D(64, (3, 3), activation='relu', padding='same'),\n",
" MaxPooling2D((2, 2), strides=(2, 2)),\n",
" \n",
" # Second Conv Block\n",
" Conv2D(128, (3, 3), activation='relu', padding='same'),\n",
" Conv2D(128, (3, 3), activation='relu', padding='same'),\n",
" MaxPooling2D((2, 2), strides=(2, 2)),\n",
" \n",
" # Third Conv Block\n",
" Conv2D(256, (3, 3), activation='relu', padding='same'),\n",
" Conv2D(256, (3, 3), activation='relu', padding='same'),\n",
" Conv2D(256, (3, 3), activation='relu', padding='same'),\n",
" MaxPooling2D((2, 2), strides=(2, 2)),\n",
" \n",
" # Fourth Conv Block\n",
" Conv2D(512, (3, 3), activation='relu', padding='same'),\n",
" Conv2D(512, (3, 3), activation='relu', padding='same'),\n",
" Conv2D(512, (3, 3), activation='relu', padding='same'),\n",
" MaxPooling2D((2, 2), strides=(2, 2)),\n",
" \n",
" # Flatten and Fully Connected Layers\n",
" Flatten(),\n",
" Dense(4096, activation='relu'),\n",
" Dropout(0.5),\n",
" Dense(4096, activation='relu'),\n",
" Dropout(0.5),\n",
" Dense(num_classes, activation='softmax')\n",
" ])\n",
" \n",
" return model\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout\n",
"from tensorflow.keras.optimizers import Adam\n",
"\n",
"\n",
"model = build_vgg_model(input_shape=(224, 224, 3), num_classes=76)\n",
"optimizer = Adam(learning_rate=0.0001)\n",
"model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 115320 images belonging to 76 classes.\n",
"Found 2146 images belonging to 30 classes.\n"
]
}
],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
"\n",
"train_datagen = ImageDataGenerator(rescale=1./255)\n",
"val_datagen = ImageDataGenerator(rescale=1./255)\n",
"\n",
"train_generator = train_datagen.flow_from_directory(\n",
" 'iam_words/words/',\n",
" target_size=(224, 224),\n",
" batch_size=16,\n",
" class_mode='categorical')\n",
"\n",
"validation_generator = val_datagen.flow_from_directory(\n",
" 'iam_words/words/a02/',\n",
" target_size=(224, 224),\n",
" batch_size=16,\n",
" class_mode='categorical')\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\DELL\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\trainers\\data_adapters\\py_dataset_adapter.py:120: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n",
" self._warn_if_super_not_called()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"End of batch 0, loss=178.15127563476562\n",
"\u001b[1m 1/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m251:06:00\u001b[0m 125s/step - accuracy: 0.0000e+00 - loss: 178.1513End of batch 1, loss=207.32989501953125\n",
"\u001b[1m 2/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m591:50:01\u001b[0m 296s/step - accuracy: 0.0000e+00 - loss: 192.7406End of batch 2, loss=199.74803161621094\n",
"\u001b[1m 3/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m636:42:38\u001b[0m 318s/step - accuracy: 0.0000e+00 - loss: 195.0764End of batch 3, loss=174.29595947265625\n",
"\u001b[1m 4/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m622:20:36\u001b[0m 311s/step - accuracy: 0.0039 - loss: 189.8813 End of batch 4, loss=192.43576049804688\n",
"\u001b[1m 5/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m606:30:57\u001b[0m 303s/step - accuracy: 0.0081 - loss: 190.3922End of batch 5, loss=175.5897979736328\n",
"\u001b[1m 6/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m603:38:21\u001b[0m 302s/step - accuracy: 0.0102 - loss: 187.9251End of batch 6, loss=185.0625\n",
"\u001b[1m 7/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m607:14:08\u001b[0m 304s/step - accuracy: 0.0113 - loss: 187.5162End of batch 7, loss=180.98373413085938\n",
"\u001b[1m 8/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m604:02:07\u001b[0m 302s/step - accuracy: 0.0119 - loss: 186.6996End of batch 8, loss=172.88616943359375\n",
"\u001b[1m 9/7208\u001b[0m \u001b[37mββββββββββββββββββββ\u001b[0m \u001b[1m615:07:47\u001b[0m 308s/step - accuracy: 0.0121 - loss: 185.1648"
]
}
],
"source": [
"class_weights = {0: 103.25689223057644, 1: 4.939395755904568, 2: 60.23318713450293, 3: 8.570730185146662, 4: 30.977067669172932, 5: 33.8811677631579, 6: 722.7982456140351, 7: 1.4360230045311955, 8: 3.8446715192235907, 9: 1.0424974696356275, 10: 240.9327485380117, 11: 11.472988025619605, 12: 15.488533834586466, 13: 34.41896407685881, 14: 37.38611615245009, 15: 58.60526315789474, 16: 38.04201292705448, 17: 77.44266917293233, 18: 90.34978070175438, 19: 60.23318713450293, 20: 44.25295381310419, 21: 28.911929824561405, 22: 42.51754385964912, 23: 77.44266917293233, 24: 4.683357962941912, 25: 6.2131654350776655, 26: 7.45152830529933, 27: 13.468290290944752, 28: 8.997488534614545, 29: 10.180256980479367, 30: 7.227982456140351, 31: 6.949983130904184, 32: 5.503539941223617, 33: 29.30263157894737, 34: 24.924077434966726, 35: 8.121328602404889, 36: 3.700332315430214, 37: 8.960308829926054, 38: 15.378686076894363, 39: 6.927778711955608, 40: 361.39912280701753, 41: 7.972039473684211, 42: 4.939395755904568, 43: 3.305479781771502, 44: 23.316072439162422, 45: 43.3678947368421, 46: 9.072781325699186, 47: 542.0986842105264, 48: 52.88767650834403, 49: 433.67894736842106, 50: 0.17630658889682943, 51: 0.9149344881190318, 52: 0.49722420014723806, 53: 0.3724484261150988, 54: 0.11054778163864926, 55: 0.6296151965279051, 56: 0.7722203478782426, 57: 0.2580500698372135, 58: 0.19329601861669685, 59: 14.455964912280702, 60: 2.359515491667144, 61: 0.33982051979973443, 62: 0.5829018109790606, 63: 0.19646595423050695, 64: 0.185285374420414, 65: 0.7656761076419863, 66: 14.750984604368062, 67: 0.2196955153842052, 68: 0.21947315150223737, 69: 0.15230699844364018, 70: 0.5459201250861292, 71: 1.383787324085581, 72: 0.7627135901660589, 73: 8.604741019214703, 74: 0.7879341340269278, 75: 29.704037490987744}\n",
"\n",
"from tensorflow.keras.callbacks import LambdaCallback\n",
"\n",
"print_callback = LambdaCallback(\n",
" on_epoch_end=lambda epoch, logs: print(f'End of epoch {epoch+1}, val_loss={logs[\"val_loss\"]}, val_accuracy={logs[\"val_accuracy\"]}'),\n",
" on_batch_end=lambda batch, logs: print(f'End of batch {batch}, loss={logs[\"loss\"]}'),\n",
")\n",
"\n",
"model.fit(\n",
" train_generator,\n",
" epochs=1,\n",
" validation_data=validation_generator,\n",
" class_weight=class_weights,\n",
" verbose=1,\n",
" callbacks=[print_callback]\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save('vgg_model.h5')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_loss, val_accuracy = model.evaluate(validation_generator)\n",
"print(f'Validation loss: {val_loss}')\n",
"print(f'Validation accuracy: {val_accuracy}')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.models import load_model\n",
"\n",
"# Load the model\n",
"loaded_model = load_model('vgg_model.h5')\n",
"\n",
"# If you want to continue training\n",
"loaded_model.fit(train_generator, epochs=additional_epochs, validation_data=validation_generator)\n",
"\n",
"# For inference\n",
"predictions = loaded_model.predict(test_data) # Assuming test_data is prepared\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictions = model.predict(test_data)\n",
"# You may want to convert these predictions into actual class labels depending on your use case\n",
"predicted_classes = predictions.argmax(axis=-1)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|