devStorm commited on
Commit
25e32a2
1 Parent(s): d123e86

feat: 🎨 color metadata

Browse files
src/benchmarks/get_semistruct.py CHANGED
@@ -8,7 +8,6 @@ def get_semistructured_data(name, root='data/', download_processed=True, **kwarg
8
  categories = ['Sports_and_Outdoors']
9
  kb = AmazonSemiStruct(root=data_root,
10
  categories=categories,
11
- meta_link_types=['brand', 'category'],
12
  download_processed=download_processed,
13
  **kwargs
14
  )
 
8
  categories = ['Sports_and_Outdoors']
9
  kb = AmazonSemiStruct(root=data_root,
10
  categories=categories,
 
11
  download_processed=download_processed,
12
  **kwargs
13
  )
src/benchmarks/semistruct/amazon.py CHANGED
@@ -6,6 +6,7 @@ import json
6
  import torch
7
  import pandas as pd
8
  import numpy as np
 
9
  from tqdm import tqdm
10
  from huggingface_hub import hf_hub_download
11
  import zipfile
@@ -51,19 +52,20 @@ class AmazonSemiStruct(SemiStructureKB):
51
  sub_category = 'data/amazon/stats/category_list.json'
52
  SUB_CATEGORIES = set(json.load(open(sub_category, 'r')))
53
  link_columns = ['also_buy', 'also_view']
54
- review_columns = ['reviewerID', 'summary', 'reviewText', 'vote', 'overall', 'verified', 'reviewTime']
55
  qa_columns = ['questionType', 'answerType', 'question', 'answer', 'answerTime']
56
  meta_columns = ['asin', 'title', 'global_category', 'category', 'price', 'brand', 'feature',
57
  'rank', 'details', 'description']
58
  candidate_types = ['product']
59
  node_attr_dict = {'product': ['title', 'dimensions', 'weight', 'description', 'features', 'reviews', 'Q&A'],
60
  'brand': ['brand_name'],
61
- 'category': ['category_name']}
 
62
 
63
  def __init__(self,
64
  root,
65
  categories: list,
66
- meta_link_types=['category'],
67
  max_entries=25,
68
  download_processed=True,
69
  **kwargs):
@@ -117,10 +119,6 @@ class AmazonSemiStruct(SemiStructureKB):
117
  def __getitem__(self, idx):
118
  idx = int(idx)
119
  node_info = self.node_info[idx]
120
- # try:
121
- # dimensions, weight = node.details.dictionary.product_dimensions.split(' ; ')
122
- # node_info['dimensions'], node_info['weight'] = dimensions, weight
123
- # except: pass
124
  node = Node()
125
  register_node(node, node_info)
126
  return node
@@ -173,6 +171,8 @@ class AmazonSemiStruct(SemiStructureKB):
173
  return f'brand name: {self[idx].brand_name}'
174
  if self.node_type_dict[int(self.node_types[idx])] == 'category':
175
  return f'category name: {self[idx].category_name}'
 
 
176
 
177
  node = self[idx]
178
  doc = f'- product: {node.title}\n'
@@ -370,9 +370,9 @@ class AmazonSemiStruct(SemiStructureKB):
370
  n_e_types, n_n_types = len(edge_type_dict), len(node_type_dict)
371
  for i, link_type in enumerate(meta_link_types):
372
  if link_type == 'brand':
373
- values = np.array([self._process_brand(node_info_i[link_type]) for node_info_i in node_info.values() if link_type in node_info_i.keys()])
374
  indices = np.array([idx for idx, node_info_i in enumerate(node_info.values()) if link_type in node_info_i.keys()])
375
- elif link_type == 'category':
376
  value_list = []
377
  indice_list = []
378
  for idx, node_info_i in enumerate(node_info.values()):
@@ -381,9 +381,6 @@ class AmazonSemiStruct(SemiStructureKB):
381
  indice_list.extend([idx for _ in range(len(node_info_i[link_type]))])
382
  values = np.array(value_list)
383
  indices = np.array(indice_list)
384
- print(f'{link_type=}, {len(values)=}, {len(indices)=}')
385
- # print(values[:50])
386
- print(indices[:50])
387
  else:
388
  raise Exception(f'Invalid meta link type {link_type}')
389
 
@@ -391,13 +388,15 @@ class AmazonSemiStruct(SemiStructureKB):
391
  node_type_dict[n_n_types + i] = link_type
392
  edge_type_dict[n_e_types + i] = "has_" + link_type
393
  unique = np.unique(values)
394
- for j, unique_j in enumerate(unique):
395
  node_info[cur_n_nodes + j] = {link_type + '_name': unique_j}
396
  ids = indices[np.array(values == unique_j)]
397
  edge_index[0].extend(list(ids))
398
  edge_index[1].extend([cur_n_nodes + j for _ in range(len(ids))])
399
  edge_types.extend([i + n_e_types for _ in range(len(ids))])
400
  node_types.extend([n_n_types + i for _ in range(len(unique))])
 
 
401
  edge_index = torch.LongTensor(edge_index)
402
  edge_types = torch.LongTensor(edge_types)
403
  node_types = torch.LongTensor(node_types)
@@ -431,6 +430,72 @@ class AmazonSemiStruct(SemiStructureKB):
431
  node_info[idx]['review'] = []
432
  node_info[idx]['qa'] = []
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  for i in tqdm(range(len(df_meta))):
435
  df_meta_i = df_meta.iloc[i]
436
  asin = df_meta_i['asin']
@@ -450,7 +515,7 @@ class AmazonSemiStruct(SemiStructureKB):
450
  node_info[idx]['category'] = category_list
451
  else:
452
  node_info[idx][column] = clean_data(df_meta_i[column])
453
-
454
  for name, df in zip(['review', 'qa'], [df_review, df_qa]):
455
  for i in tqdm(range(len(df))):
456
  df_i = df.iloc[i]
@@ -459,6 +524,7 @@ class AmazonSemiStruct(SemiStructureKB):
459
  node_info[idx][name].append(
460
  df_row_to_dict(df_i, colunm_names=self.review_columns \
461
  if name == 'review' else self.qa_columns))
 
462
  return node_info
463
 
464
  def create_raw_product_graph(self, df, columns):
 
6
  import torch
7
  import pandas as pd
8
  import numpy as np
9
+ from collections import Counter
10
  from tqdm import tqdm
11
  from huggingface_hub import hf_hub_download
12
  import zipfile
 
52
  sub_category = 'data/amazon/stats/category_list.json'
53
  SUB_CATEGORIES = set(json.load(open(sub_category, 'r')))
54
  link_columns = ['also_buy', 'also_view']
55
+ review_columns = ['reviewerID', 'summary', 'style', 'reviewText', 'vote', 'overall', 'verified', 'reviewTime']
56
  qa_columns = ['questionType', 'answerType', 'question', 'answer', 'answerTime']
57
  meta_columns = ['asin', 'title', 'global_category', 'category', 'price', 'brand', 'feature',
58
  'rank', 'details', 'description']
59
  candidate_types = ['product']
60
  node_attr_dict = {'product': ['title', 'dimensions', 'weight', 'description', 'features', 'reviews', 'Q&A'],
61
  'brand': ['brand_name'],
62
+ 'category': ['category_name'],
63
+ 'color': ['color_name']}
64
 
65
  def __init__(self,
66
  root,
67
  categories: list,
68
+ meta_link_types=['brand', 'category', 'color'],
69
  max_entries=25,
70
  download_processed=True,
71
  **kwargs):
 
119
  def __getitem__(self, idx):
120
  idx = int(idx)
121
  node_info = self.node_info[idx]
 
 
 
 
122
  node = Node()
123
  register_node(node, node_info)
124
  return node
 
171
  return f'brand name: {self[idx].brand_name}'
172
  if self.node_type_dict[int(self.node_types[idx])] == 'category':
173
  return f'category name: {self[idx].category_name}'
174
+ if self.node_type_dict[int(self.node_types[idx])] == 'color':
175
+ return f'color name: {self[idx].color_name}'
176
 
177
  node = self[idx]
178
  doc = f'- product: {node.title}\n'
 
370
  n_e_types, n_n_types = len(edge_type_dict), len(node_type_dict)
371
  for i, link_type in enumerate(meta_link_types):
372
  if link_type == 'brand':
373
+ values = np.array([node_info_i[link_type] for node_info_i in node_info.values() if link_type in node_info_i.keys()])
374
  indices = np.array([idx for idx, node_info_i in enumerate(node_info.values()) if link_type in node_info_i.keys()])
375
+ elif link_type in ['category', 'color']:
376
  value_list = []
377
  indice_list = []
378
  for idx, node_info_i in enumerate(node_info.values()):
 
381
  indice_list.extend([idx for _ in range(len(node_info_i[link_type]))])
382
  values = np.array(value_list)
383
  indices = np.array(indice_list)
 
 
 
384
  else:
385
  raise Exception(f'Invalid meta link type {link_type}')
386
 
 
388
  node_type_dict[n_n_types + i] = link_type
389
  edge_type_dict[n_e_types + i] = "has_" + link_type
390
  unique = np.unique(values)
391
+ for j, unique_j in tqdm(enumerate(unique)):
392
  node_info[cur_n_nodes + j] = {link_type + '_name': unique_j}
393
  ids = indices[np.array(values == unique_j)]
394
  edge_index[0].extend(list(ids))
395
  edge_index[1].extend([cur_n_nodes + j for _ in range(len(ids))])
396
  edge_types.extend([i + n_e_types for _ in range(len(ids))])
397
  node_types.extend([n_n_types + i for _ in range(len(unique))])
398
+ print(f'finished adding {link_type}')
399
+
400
  edge_index = torch.LongTensor(edge_index)
401
  edge_types = torch.LongTensor(edge_types)
402
  node_types = torch.LongTensor(node_types)
 
430
  node_info[idx]['review'] = []
431
  node_info[idx]['qa'] = []
432
 
433
+ ###################### Assign color ########################
434
+ def assign_colors(df_review, lower_limit=20):
435
+ # asign to color
436
+ df_review = df_review[['asin', 'style']]
437
+ df_review = df_review.dropna(subset=['style'])
438
+ raw_color_dict = {}
439
+ for idx, row in tqdm(df_review.iterrows()):
440
+ asin, style = row['asin'], row['style']
441
+ for key in style.keys():
442
+ if 'color' in key.lower():
443
+ try:
444
+ raw_color_dict[asin]
445
+ except:
446
+ raw_color_dict[asin] = []
447
+ raw_color_dict[asin].append(
448
+ style[key].strip().lower() if isinstance(style[key], str) else style[key][0].strip())
449
+
450
+ all_color_values = []
451
+ for asin in raw_color_dict.keys():
452
+ raw_color_dict[asin] = list(set(raw_color_dict[asin]))
453
+ all_color_values.extend(raw_color_dict[asin])
454
+
455
+ print('number of all colors', len(all_color_values))
456
+ color_counter = Counter(all_color_values)
457
+ print('number of unique colors', len(color_counter))
458
+ color_counter = {k: v for k, v in sorted(color_counter.items(), key=lambda item: item[1], reverse=True)}
459
+ selected_colors = []
460
+ for color, number in color_counter.items():
461
+ if number > lower_limit and len(color) > 2 and len(color.split(' ')) < 5 and color.isnumeric() is False:
462
+ selected_colors.append(color)
463
+ print('number of selected colors', len(selected_colors))
464
+
465
+ filtered_color_dict = {}
466
+ total_color_connections = 0
467
+ for asin in raw_color_dict.keys():
468
+ filtered_color_dict[asin] = []
469
+ for value in raw_color_dict[asin]:
470
+ if value in selected_colors:
471
+ filtered_color_dict[asin].append(value)
472
+ total_color_connections += len(filtered_color_dict[asin])
473
+ print('number of linked products', len(filtered_color_dict))
474
+ print('number of total connections', total_color_connections)
475
+ return filtered_color_dict
476
+
477
+ filtered_color_dict_path = os.path.join('data/amazon/intermediate',
478
+ 'filtered_color_dict.pkl')
479
+ if os.path.exists(filtered_color_dict_path):
480
+ with open(filtered_color_dict_path, 'rb') as f:
481
+ filtered_color_dict = pickle.load(f)
482
+ else:
483
+ filtered_color_dict = assign_colors(df_review)
484
+ with open(filtered_color_dict_path, 'wb') as f:
485
+ pickle.dump(filtered_color_dict, f)
486
+
487
+ for i in tqdm(range(len(df_meta))):
488
+ df_meta_i = df_meta.iloc[i]
489
+ asin = df_meta_i['asin']
490
+ idx = self.asin2id[asin]
491
+ try:
492
+ color = filtered_color_dict[asin]
493
+ if len(color):
494
+ node_info[idx]['color'] = color
495
+ except: pass
496
+ print('loaded color')
497
+ ####################################################################
498
+
499
  for i in tqdm(range(len(df_meta))):
500
  df_meta_i = df_meta.iloc[i]
501
  asin = df_meta_i['asin']
 
515
  node_info[idx]['category'] = category_list
516
  else:
517
  node_info[idx][column] = clean_data(df_meta_i[column])
518
+
519
  for name, df in zip(['review', 'qa'], [df_review, df_qa]):
520
  for i in tqdm(range(len(df))):
521
  df_i = df.iloc[i]
 
524
  node_info[idx][name].append(
525
  df_row_to_dict(df_i, colunm_names=self.review_columns \
526
  if name == 'review' else self.qa_columns))
527
+ import pdb; pdb.set_trace()
528
  return node_info
529
 
530
  def create_raw_product_graph(self, df, columns):