File size: 3,869 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import sys
import time
import torch
import argparse
import utils
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor

from ortools.constraint_solver import pywrapcp
from ortools.constraint_solver import routing_enums_pb2


def solve(problem, i, max_time):
    scale = 100000
    size = problem.task_demand.size(1)
    demand = [0] + problem.task_demand[i].tolist()
    capacity = problem.worker_weight_limit[i].tolist()
    distance = (problem.distance_matrix[i] * scale + 0.5).to(torch.int32).tolist()

    queue = mp.Queue()
    p = mp.Process(target=do_solve, args=(size, demand, capacity, distance, max_time, queue))
    p.start()
    p.join()

    return queue.get() / scale, queue.get()


def do_solve(size, demand, capacity, distance, max_time, queue):
    capacity = capacity * size

    manager = pywrapcp.RoutingIndexManager(size + 1, size, 0)
    routing = pywrapcp.RoutingModel(manager)

    def distance_callback(from_index, to_index):
        from_node = manager.IndexToNode(from_index)
        to_node = manager.IndexToNode(to_index)
        return distance[from_node][to_node]

    distance_callback_index = routing.RegisterTransitCallback(distance_callback)
    routing.SetArcCostEvaluatorOfAllVehicles(distance_callback_index)

    def demand_callback(from_index):
        from_node = manager.IndexToNode(from_index)
        return demand[from_node]

    demand_callback_index = routing.RegisterUnaryTransitCallback(demand_callback)
    routing.AddDimensionWithVehicleCapacity(demand_callback_index, 0, capacity, True, 'capacity')

    params = pywrapcp.DefaultRoutingSearchParameters()
    params.first_solution_strategy = (routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC)
    params.local_search_metaheuristic = (routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH)
    params.time_limit.seconds = max_time

    start_time = time.time()
    solution = routing.SolveWithParameters(params)
    spent_time = time.time() - start_time

    queue.put(solution.ObjectiveValue())
    queue.put(spent_time)


def run_orts(task, max_time):
    problem, i = task
    return solve(problem, i, max_time)


def main(args):
    print("args: {}".format(vars(args)))
    problem_size = args.problem_size
    problem_count = args.problem_count
    batch_size = args.batch_size

    assert problem_count % batch_size == 0
    batch_count = problem_count // batch_size
    problem_list = utils.make_problem(batch_count, batch_size, problem_size)

    executor = ThreadPoolExecutor(max_workers=args.threads)
    task_list = [(p, i) for p in problem_list for i in range(batch_size)]

    total_cost = 0
    total_time = 0
    for cost, elapse in executor.map(run_orts, task_list, [args.max_time] * problem_count):
        total_cost += cost
        total_time += elapse

    avg_cost = total_cost / problem_count
    avg_time = total_time / problem_count
    print()
    print("-----------------------------------------------------")
    print("avg_cost: {:.4f}".format(avg_cost))
    print("avg_time: {:.6f}s".format(avg_time))
    print("total_count: {}".format(problem_count))
    print("-----------------------------------------------------\n")
    sys.stdout.flush()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--threads', default=20, type=int, help='number of threads')
    parser.add_argument('--max_time', default=60, type=int, help='the time limit for the search in seconds')

    parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000],  help='problem size')
    parser.add_argument('--problem_count', default=128, type=int,  help='total number of generated problem instances')
    parser.add_argument('--batch_size', default=128, type=int,  help='batch size for feedforwarding')

    args = parser.parse_args()
    main(args)