File size: 23,797 Bytes
505eceb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
START_POSITION = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
VOCAB = ["-", ".", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "B", "K", "N", "P", "Q", "R", "a", "b", "c", "d", "e", "f", "g", "h", "k", "n", "p", "q", "r", "w"]
ACTION_SPACE = ["a1a2", "a1a3", "a1a4", "a1a5", "a1a6", "a1a7", "a1a8", "a1b1", "a1b2", "a1b3", "a1c1", "a1c2", "a1c3", "a1d1", "a1d4", "a1e1", "a1e5", "a1f1", "a1f6", "a1g1", "a1g7", "a1h1", "a1h8", "a2a1", "a2a1b", "a2a1n", "a2a1q", "a2a1r", "a2a3", "a2a4", "a2a5", "a2a6", "a2a7", "a2a8", "a2b1", "a2b1b", "a2b1n", "a2b1q", "a2b1r", "a2b2", "a2b3", "a2b4", "a2c1", "a2c2", "a2c3", "a2c4", "a2d2", "a2d5", "a2e2", "a2e6", "a2f2", "a2f7", "a2g2", "a2g8", "a2h2", "a3a1", "a3a2", "a3a4", "a3a5", "a3a6", "a3a7", "a3a8", "a3b1", "a3b2", "a3b3", "a3b4", "a3b5", "a3c1", "a3c2", "a3c3", "a3c4", "a3c5", "a3d3", "a3d6", "a3e3", "a3e7", "a3f3", "a3f8", "a3g3", "a3h3", "a4a1", "a4a2", "a4a3", "a4a5", "a4a6", "a4a7", "a4a8", "a4b2", "a4b3", "a4b4", "a4b5", "a4b6", "a4c2", "a4c3", "a4c4", "a4c5", "a4c6", "a4d1", "a4d4", "a4d7", "a4e4", "a4e8", "a4f4", "a4g4", "a4h4", "a5a1", "a5a2", "a5a3", "a5a4", "a5a6", "a5a7", "a5a8", "a5b3", "a5b4", "a5b5", "a5b6", "a5b7", "a5c3", "a5c4", "a5c5", "a5c6", "a5c7", "a5d2", "a5d5", "a5d8", "a5e1", "a5e5", "a5f5", "a5g5", "a5h5", "a6a1", "a6a2", "a6a3", "a6a4", "a6a5", "a6a7", "a6a8", "a6b4", "a6b5", "a6b6", "a6b7", "a6b8", "a6c4", "a6c5", "a6c6", "a6c7", "a6c8", "a6d3", "a6d6", "a6e2", "a6e6", "a6f1", "a6f6", "a6g6", "a6h6", "a7a1", "a7a2", "a7a3", "a7a4", "a7a5", "a7a6", "a7a8", "a7a8b", "a7a8n", "a7a8q", "a7a8r", "a7b5", "a7b6", "a7b7", "a7b8", "a7b8b", "a7b8n", "a7b8q", "a7b8r", "a7c5", "a7c6", "a7c7", "a7c8", "a7d4", "a7d7", "a7e3", "a7e7", "a7f2", "a7f7", "a7g1", "a7g7", "a7h7", "a8a1", "a8a2", "a8a3", "a8a4", "a8a5", "a8a6", "a8a7", "a8b6", "a8b7", "a8b8", "a8c6", "a8c7", "a8c8", "a8d5", "a8d8", "a8e4", "a8e8", "a8f3", "a8f8", "a8g2", "a8g8", "a8h1", "a8h8", "b1a1", "b1a2", "b1a3", "b1b2", "b1b3", "b1b4", "b1b5", "b1b6", "b1b7", "b1b8", "b1c1", "b1c2", "b1c3", "b1d1", "b1d2", "b1d3", "b1e1", "b1e4", "b1f1", "b1f5", "b1g1", "b1g6", "b1h1", "b1h7", "b2a1", "b2a1b", "b2a1n", "b2a1q", "b2a1r", "b2a2", "b2a3", "b2a4", "b2b1", "b2b1b", "b2b1n", "b2b1q", "b2b1r", "b2b3", "b2b4", "b2b5", "b2b6", "b2b7", "b2b8", "b2c1", "b2c1b", "b2c1n", "b2c1q", "b2c1r", "b2c2", "b2c3", "b2c4", "b2d1", "b2d2", "b2d3", "b2d4", "b2e2", "b2e5", "b2f2", "b2f6", "b2g2", "b2g7", "b2h2", "b2h8", "b3a1", "b3a2", "b3a3", "b3a4", "b3a5", "b3b1", "b3b2", "b3b4", "b3b5", "b3b6", "b3b7", "b3b8", "b3c1", "b3c2", "b3c3", "b3c4", "b3c5", "b3d1", "b3d2", "b3d3", "b3d4", "b3d5", "b3e3", "b3e6", "b3f3", "b3f7", "b3g3", "b3g8", "b3h3", "b4a2", "b4a3", "b4a4", "b4a5", "b4a6", "b4b1", "b4b2", "b4b3", "b4b5", "b4b6", "b4b7", "b4b8", "b4c2", "b4c3", "b4c4", "b4c5", "b4c6", "b4d2", "b4d3", "b4d4", "b4d5", "b4d6", "b4e1", "b4e4", "b4e7", "b4f4", "b4f8", "b4g4", "b4h4", "b5a3", "b5a4", "b5a5", "b5a6", "b5a7", "b5b1", "b5b2", "b5b3", "b5b4", "b5b6", "b5b7", "b5b8", "b5c3", "b5c4", "b5c5", "b5c6", "b5c7", "b5d3", "b5d4", "b5d5", "b5d6", "b5d7", "b5e2", "b5e5", "b5e8", "b5f1", "b5f5", "b5g5", "b5h5", "b6a4", "b6a5", "b6a6", "b6a7", "b6a8", "b6b1", "b6b2", "b6b3", "b6b4", "b6b5", "b6b7", "b6b8", "b6c4", "b6c5", "b6c6", "b6c7", "b6c8", "b6d4", "b6d5", "b6d6", "b6d7", "b6d8", "b6e3", "b6e6", "b6f2", "b6f6", "b6g1", "b6g6", "b6h6", "b7a5", "b7a6", "b7a7", "b7a8", "b7a8b", "b7a8n", "b7a8q", "b7a8r", "b7b1", "b7b2", "b7b3", "b7b4", "b7b5", "b7b6", "b7b8", "b7b8b", "b7b8n", "b7b8q", "b7b8r", "b7c5", "b7c6", "b7c7", "b7c8", "b7c8b", "b7c8n", "b7c8q", "b7c8r", "b7d5", "b7d6", "b7d7", "b7d8", "b7e4", "b7e7", "b7f3", "b7f7", "b7g2", "b7g7", "b7h1", "b7h7", "b8a6", "b8a7", "b8a8", "b8b1", "b8b2", "b8b3", "b8b4", "b8b5", "b8b6", "b8b7", "b8c6", "b8c7", "b8c8", "b8d6", "b8d7", "b8d8", "b8e5", "b8e8", "b8f4", "b8f8", "b8g3", "b8g8", "b8h2", "b8h8", "c1a1", "c1a2", "c1a3", "c1b1", "c1b2", "c1b3", "c1c2", "c1c3", "c1c4", "c1c5", "c1c6", "c1c7", "c1c8", "c1d1", "c1d2", "c1d3", "c1e1", "c1e2", "c1e3", "c1f1", "c1f4", "c1g1", "c1g5", "c1h1", "c1h6", "c2a1", "c2a2", "c2a3", "c2a4", "c2b1", "c2b1b", "c2b1n", "c2b1q", "c2b1r", "c2b2", "c2b3", "c2b4", "c2c1", "c2c1b", "c2c1n", "c2c1q", "c2c1r", "c2c3", "c2c4", "c2c5", "c2c6", "c2c7", "c2c8", "c2d1", "c2d1b", "c2d1n", "c2d1q", "c2d1r", "c2d2", "c2d3", "c2d4", "c2e1", "c2e2", "c2e3", "c2e4", "c2f2", "c2f5", "c2g2", "c2g6", "c2h2", "c2h7", "c3a1", "c3a2", "c3a3", "c3a4", "c3a5", "c3b1", "c3b2", "c3b3", "c3b4", "c3b5", "c3c1", "c3c2", "c3c4", "c3c5", "c3c6", "c3c7", "c3c8", "c3d1", "c3d2", "c3d3", "c3d4", "c3d5", "c3e1", "c3e2", "c3e3", "c3e4", "c3e5", "c3f3", "c3f6", "c3g3", "c3g7", "c3h3", "c3h8", "c4a2", "c4a3", "c4a4", "c4a5", "c4a6", "c4b2", "c4b3", "c4b4", "c4b5", "c4b6", "c4c1", "c4c2", "c4c3", "c4c5", "c4c6", "c4c7", "c4c8", "c4d2", "c4d3", "c4d4", "c4d5", "c4d6", "c4e2", "c4e3", "c4e4", "c4e5", "c4e6", "c4f1", "c4f4", "c4f7", "c4g4", "c4g8", "c4h4", "c5a3", "c5a4", "c5a5", "c5a6", "c5a7", "c5b3", "c5b4", "c5b5", "c5b6", "c5b7", "c5c1", "c5c2", "c5c3", "c5c4", "c5c6", "c5c7", "c5c8", "c5d3", "c5d4", "c5d5", "c5d6", "c5d7", "c5e3", "c5e4", "c5e5", "c5e6", "c5e7", "c5f2", "c5f5", "c5f8", "c5g1", "c5g5", "c5h5", "c6a4", "c6a5", "c6a6", "c6a7", "c6a8", "c6b4", "c6b5", "c6b6", "c6b7", "c6b8", "c6c1", "c6c2", "c6c3", "c6c4", "c6c5", "c6c7", "c6c8", "c6d4", "c6d5", "c6d6", "c6d7", "c6d8", "c6e4", "c6e5", "c6e6", "c6e7", "c6e8", "c6f3", "c6f6", "c6g2", "c6g6", "c6h1", "c6h6", "c7a5", "c7a6", "c7a7", "c7a8", "c7b5", "c7b6", "c7b7", "c7b8", "c7b8b", "c7b8n", "c7b8q", "c7b8r", "c7c1", "c7c2", "c7c3", "c7c4", "c7c5", "c7c6", "c7c8", "c7c8b", "c7c8n", "c7c8q", "c7c8r", "c7d5", "c7d6", "c7d7", "c7d8", "c7d8b", "c7d8n", "c7d8q", "c7d8r", "c7e5", "c7e6", "c7e7", "c7e8", "c7f4", "c7f7", "c7g3", "c7g7", "c7h2", "c7h7", "c8a6", "c8a7", "c8a8", "c8b6", "c8b7", "c8b8", "c8c1", "c8c2", "c8c3", "c8c4", "c8c5", "c8c6", "c8c7", "c8d6", "c8d7", "c8d8", "c8e6", "c8e7", "c8e8", "c8f5", "c8f8", "c8g4", "c8g8", "c8h3", "c8h8", "d1a1", "d1a4", "d1b1", "d1b2", "d1b3", "d1c1", "d1c2", "d1c3", "d1d2", "d1d3", "d1d4", "d1d5", "d1d6", "d1d7", "d1d8", "d1e1", "d1e2", "d1e3", "d1f1", "d1f2", "d1f3", "d1g1", "d1g4", "d1h1", "d1h5", "d2a2", "d2a5", "d2b1", "d2b2", "d2b3", "d2b4", "d2c1", "d2c1b", "d2c1n", "d2c1q", "d2c1r", "d2c2", "d2c3", "d2c4", "d2d1", "d2d1b", "d2d1n", "d2d1q", "d2d1r", "d2d3", "d2d4", "d2d5", "d2d6", "d2d7", "d2d8", "d2e1", "d2e1b", "d2e1n", "d2e1q", "d2e1r", "d2e2", "d2e3", "d2e4", "d2f1", "d2f2", "d2f3", "d2f4", "d2g2", "d2g5", "d2h2", "d2h6", "d3a3", "d3a6", "d3b1", "d3b2", "d3b3", "d3b4", "d3b5", "d3c1", "d3c2", "d3c3", "d3c4", "d3c5", "d3d1", "d3d2", "d3d4", "d3d5", "d3d6", "d3d7", "d3d8", "d3e1", "d3e2", "d3e3", "d3e4", "d3e5", "d3f1", "d3f2", "d3f3", "d3f4", "d3f5", "d3g3", "d3g6", "d3h3", "d3h7", "d4a1", "d4a4", "d4a7", "d4b2", "d4b3", "d4b4", "d4b5", "d4b6", "d4c2", "d4c3", "d4c4", "d4c5", "d4c6", "d4d1", "d4d2", "d4d3", "d4d5", "d4d6", "d4d7", "d4d8", "d4e2", "d4e3", "d4e4", "d4e5", "d4e6", "d4f2", "d4f3", "d4f4", "d4f5", "d4f6", "d4g1", "d4g4", "d4g7", "d4h4", "d4h8", "d5a2", "d5a5", "d5a8", "d5b3", "d5b4", "d5b5", "d5b6", "d5b7", "d5c3", "d5c4", "d5c5", "d5c6", "d5c7", "d5d1", "d5d2", "d5d3", "d5d4", "d5d6", "d5d7", "d5d8", "d5e3", "d5e4", "d5e5", "d5e6", "d5e7", "d5f3", "d5f4", "d5f5", "d5f6", "d5f7", "d5g2", "d5g5", "d5g8", "d5h1", "d5h5", "d6a3", "d6a6", "d6b4", "d6b5", "d6b6", "d6b7", "d6b8", "d6c4", "d6c5", "d6c6", "d6c7", "d6c8", "d6d1", "d6d2", "d6d3", "d6d4", "d6d5", "d6d7", "d6d8", "d6e4", "d6e5", "d6e6", "d6e7", "d6e8", "d6f4", "d6f5", "d6f6", "d6f7", "d6f8", "d6g3", "d6g6", "d6h2", "d6h6", "d7a4", "d7a7", "d7b5", "d7b6", "d7b7", "d7b8", "d7c5", "d7c6", "d7c7", "d7c8", "d7c8b", "d7c8n", "d7c8q", "d7c8r", "d7d1", "d7d2", "d7d3", "d7d4", "d7d5", "d7d6", "d7d8", "d7d8b", "d7d8n", "d7d8q", "d7d8r", "d7e5", "d7e6", "d7e7", "d7e8", "d7e8b", "d7e8n", "d7e8q", "d7e8r", "d7f5", "d7f6", "d7f7", "d7f8", "d7g4", "d7g7", "d7h3", "d7h7", "d8a5", "d8a8", "d8b6", "d8b7", "d8b8", "d8c6", "d8c7", "d8c8", "d8d1", "d8d2", "d8d3", "d8d4", "d8d5", "d8d6", "d8d7", "d8e6", "d8e7", "d8e8", "d8f6", "d8f7", "d8f8", "d8g5", "d8g8", "d8h4", "d8h8", "e1a1", "e1a5", "e1b1", "e1b4", "e1c1", "e1c2", "e1c3", "e1d1", "e1d2", "e1d3", "e1e2", "e1e3", "e1e4", "e1e5", "e1e6", "e1e7", "e1e8", "e1f1", "e1f2", "e1f3", "e1g1", "e1g2", "e1g3", "e1h1", "e1h4", "e2a2", "e2a6", "e2b2", "e2b5", "e2c1", "e2c2", "e2c3", "e2c4", "e2d1", "e2d1b", "e2d1n", "e2d1q", "e2d1r", "e2d2", "e2d3", "e2d4", "e2e1", "e2e1b", "e2e1n", "e2e1q", "e2e1r", "e2e3", "e2e4", "e2e5", "e2e6", "e2e7", "e2e8", "e2f1", "e2f1b", "e2f1n", "e2f1q", "e2f1r", "e2f2", "e2f3", "e2f4", "e2g1", "e2g2", "e2g3", "e2g4", "e2h2", "e2h5", "e3a3", "e3a7", "e3b3", "e3b6", "e3c1", "e3c2", "e3c3", "e3c4", "e3c5", "e3d1", "e3d2", "e3d3", "e3d4", "e3d5", "e3e1", "e3e2", "e3e4", "e3e5", "e3e6", "e3e7", "e3e8", "e3f1", "e3f2", "e3f3", "e3f4", "e3f5", "e3g1", "e3g2", "e3g3", "e3g4", "e3g5", "e3h3", "e3h6", "e4a4", "e4a8", "e4b1", "e4b4", "e4b7", "e4c2", "e4c3", "e4c4", "e4c5", "e4c6", "e4d2", "e4d3", "e4d4", "e4d5", "e4d6", "e4e1", "e4e2", "e4e3", "e4e5", "e4e6", "e4e7", "e4e8", "e4f2", "e4f3", "e4f4", "e4f5", "e4f6", "e4g2", "e4g3", "e4g4", "e4g5", "e4g6", "e4h1", "e4h4", "e4h7", "e5a1", "e5a5", "e5b2", "e5b5", "e5b8", "e5c3", "e5c4", "e5c5", "e5c6", "e5c7", "e5d3", "e5d4", "e5d5", "e5d6", "e5d7", "e5e1", "e5e2", "e5e3", "e5e4", "e5e6", "e5e7", "e5e8", "e5f3", "e5f4", "e5f5", "e5f6", "e5f7", "e5g3", "e5g4", "e5g5", "e5g6", "e5g7", "e5h2", "e5h5", "e5h8", "e6a2", "e6a6", "e6b3", "e6b6", "e6c4", "e6c5", "e6c6", "e6c7", "e6c8", "e6d4", "e6d5", "e6d6", "e6d7", "e6d8", "e6e1", "e6e2", "e6e3", "e6e4", "e6e5", "e6e7", "e6e8", "e6f4", "e6f5", "e6f6", "e6f7", "e6f8", "e6g4", "e6g5", "e6g6", "e6g7", "e6g8", "e6h3", "e6h6", "e7a3", "e7a7", "e7b4", "e7b7", "e7c5", "e7c6", "e7c7", "e7c8", "e7d5", "e7d6", "e7d7", "e7d8", "e7d8b", "e7d8n", "e7d8q", "e7d8r", "e7e1", "e7e2", "e7e3", "e7e4", "e7e5", "e7e6", "e7e8", "e7e8b", "e7e8n", "e7e8q", "e7e8r", "e7f5", "e7f6", "e7f7", "e7f8", "e7f8b", "e7f8n", "e7f8q", "e7f8r", "e7g5", "e7g6", "e7g7", "e7g8", "e7h4", "e7h7", "e8a4", "e8a8", "e8b5", "e8b8", "e8c6", "e8c7", "e8c8", "e8d6", "e8d7", "e8d8", "e8e1", "e8e2", "e8e3", "e8e4", "e8e5", "e8e6", "e8e7", "e8f6", "e8f7", "e8f8", "e8g6", "e8g7", "e8g8", "e8h5", "e8h8", "f1a1", "f1a6", "f1b1", "f1b5", "f1c1", "f1c4", "f1d1", "f1d2", "f1d3", "f1e1", "f1e2", "f1e3", "f1f2", "f1f3", "f1f4", "f1f5", "f1f6", "f1f7", "f1f8", "f1g1", "f1g2", "f1g3", "f1h1", "f1h2", "f1h3", "f2a2", "f2a7", "f2b2", "f2b6", "f2c2", "f2c5", "f2d1", "f2d2", "f2d3", "f2d4", "f2e1", "f2e1b", "f2e1n", "f2e1q", "f2e1r", "f2e2", "f2e3", "f2e4", "f2f1", "f2f1b", "f2f1n", "f2f1q", "f2f1r", "f2f3", "f2f4", "f2f5", "f2f6", "f2f7", "f2f8", "f2g1", "f2g1b", "f2g1n", "f2g1q", "f2g1r", "f2g2", "f2g3", "f2g4", "f2h1", "f2h2", "f2h3", "f2h4", "f3a3", "f3a8", "f3b3", "f3b7", "f3c3", "f3c6", "f3d1", "f3d2", "f3d3", "f3d4", "f3d5", "f3e1", "f3e2", "f3e3", "f3e4", "f3e5", "f3f1", "f3f2", "f3f4", "f3f5", "f3f6", "f3f7", "f3f8", "f3g1", "f3g2", "f3g3", "f3g4", "f3g5", "f3h1", "f3h2", "f3h3", "f3h4", "f3h5", "f4a4", "f4b4", "f4b8", "f4c1", "f4c4", "f4c7", "f4d2", "f4d3", "f4d4", "f4d5", "f4d6", "f4e2", "f4e3", "f4e4", "f4e5", "f4e6", "f4f1", "f4f2", "f4f3", "f4f5", "f4f6", "f4f7", "f4f8", "f4g2", "f4g3", "f4g4", "f4g5", "f4g6", "f4h2", "f4h3", "f4h4", "f4h5", "f4h6", "f5a5", "f5b1", "f5b5", "f5c2", "f5c5", "f5c8", "f5d3", "f5d4", "f5d5", "f5d6", "f5d7", "f5e3", "f5e4", "f5e5", "f5e6", "f5e7", "f5f1", "f5f2", "f5f3", "f5f4", "f5f6", "f5f7", "f5f8", "f5g3", "f5g4", "f5g5", "f5g6", "f5g7", "f5h3", "f5h4", "f5h5", "f5h6", "f5h7", "f6a1", "f6a6", "f6b2", "f6b6", "f6c3", "f6c6", "f6d4", "f6d5", "f6d6", "f6d7", "f6d8", "f6e4", "f6e5", "f6e6", "f6e7", "f6e8", "f6f1", "f6f2", "f6f3", "f6f4", "f6f5", "f6f7", "f6f8", "f6g4", "f6g5", "f6g6", "f6g7", "f6g8", "f6h4", "f6h5", "f6h6", "f6h7", "f6h8", "f7a2", "f7a7", "f7b3", "f7b7", "f7c4", "f7c7", "f7d5", "f7d6", "f7d7", "f7d8", "f7e5", "f7e6", "f7e7", "f7e8", "f7e8b", "f7e8n", "f7e8q", "f7e8r", "f7f1", "f7f2", "f7f3", "f7f4", "f7f5", "f7f6", "f7f8", "f7f8b", "f7f8n", "f7f8q", "f7f8r", "f7g5", "f7g6", "f7g7", "f7g8", "f7g8b", "f7g8n", "f7g8q", "f7g8r", "f7h5", "f7h6", "f7h7", "f7h8", "f8a3", "f8a8", "f8b4", "f8b8", "f8c5", "f8c8", "f8d6", "f8d7", "f8d8", "f8e6", "f8e7", "f8e8", "f8f1", "f8f2", "f8f3", "f8f4", "f8f5", "f8f6", "f8f7", "f8g6", "f8g7", "f8g8", "f8h6", "f8h7", "f8h8", "g1a1", "g1a7", "g1b1", "g1b6", "g1c1", "g1c5", "g1d1", "g1d4", "g1e1", "g1e2", "g1e3", "g1f1", "g1f2", "g1f3", "g1g2", "g1g3", "g1g4", "g1g5", "g1g6", "g1g7", "g1g8", "g1h1", "g1h2", "g1h3", "g2a2", "g2a8", "g2b2", "g2b7", "g2c2", "g2c6", "g2d2", "g2d5", "g2e1", "g2e2", "g2e3", "g2e4", "g2f1", "g2f1b", "g2f1n", "g2f1q", "g2f1r", "g2f2", "g2f3", "g2f4", "g2g1", "g2g1b", "g2g1n", "g2g1q", "g2g1r", "g2g3", "g2g4", "g2g5", "g2g6", "g2g7", "g2g8", "g2h1", "g2h1b", "g2h1n", "g2h1q", "g2h1r", "g2h2", "g2h3", "g2h4", "g3a3", "g3b3", "g3b8", "g3c3", "g3c7", "g3d3", "g3d6", "g3e1", "g3e2", "g3e3", "g3e4", "g3e5", "g3f1", "g3f2", "g3f3", "g3f4", "g3f5", "g3g1", "g3g2", "g3g4", "g3g5", "g3g6", "g3g7", "g3g8", "g3h1", "g3h2", "g3h3", "g3h4", "g3h5", "g4a4", "g4b4", "g4c4", "g4c8", "g4d1", "g4d4", "g4d7", "g4e2", "g4e3", "g4e4", "g4e5", "g4e6", "g4f2", "g4f3", "g4f4", "g4f5", "g4f6", "g4g1", "g4g2", "g4g3", "g4g5", "g4g6", "g4g7", "g4g8", "g4h2", "g4h3", "g4h4", "g4h5", "g4h6", "g5a5", "g5b5", "g5c1", "g5c5", "g5d2", "g5d5", "g5d8", "g5e3", "g5e4", "g5e5", "g5e6", "g5e7", "g5f3", "g5f4", "g5f5", "g5f6", "g5f7", "g5g1", "g5g2", "g5g3", "g5g4", "g5g6", "g5g7", "g5g8", "g5h3", "g5h4", "g5h5", "g5h6", "g5h7", "g6a6", "g6b1", "g6b6", "g6c2", "g6c6", "g6d3", "g6d6", "g6e4", "g6e5", "g6e6", "g6e7", "g6e8", "g6f4", "g6f5", "g6f6", "g6f7", "g6f8", "g6g1", "g6g2", "g6g3", "g6g4", "g6g5", "g6g7", "g6g8", "g6h4", "g6h5", "g6h6", "g6h7", "g6h8", "g7a1", "g7a7", "g7b2", "g7b7", "g7c3", "g7c7", "g7d4", "g7d7", "g7e5", "g7e6", "g7e7", "g7e8", "g7f5", "g7f6", "g7f7", "g7f8", "g7f8b", "g7f8n", "g7f8q", "g7f8r", "g7g1", "g7g2", "g7g3", "g7g4", "g7g5", "g7g6", "g7g8", "g7g8b", "g7g8n", "g7g8q", "g7g8r", "g7h5", "g7h6", "g7h7", "g7h8", "g7h8b", "g7h8n", "g7h8q", "g7h8r", "g8a2", "g8a8", "g8b3", "g8b8", "g8c4", "g8c8", "g8d5", "g8d8", "g8e6", "g8e7", "g8e8", "g8f6", "g8f7", "g8f8", "g8g1", "g8g2", "g8g3", "g8g4", "g8g5", "g8g6", "g8g7", "g8h6", "g8h7", "g8h8", "h1a1", "h1a8", "h1b1", "h1b7", "h1c1", "h1c6", "h1d1", "h1d5", "h1e1", "h1e4", "h1f1", "h1f2", "h1f3", "h1g1", "h1g2", "h1g3", "h1h2", "h1h3", "h1h4", "h1h5", "h1h6", "h1h7", "h1h8", "h2a2", "h2b2", "h2b8", "h2c2", "h2c7", "h2d2", "h2d6", "h2e2", "h2e5", "h2f1", "h2f2", "h2f3", "h2f4", "h2g1", "h2g1b", "h2g1n", "h2g1q", "h2g1r", "h2g2", "h2g3", "h2g4", "h2h1", "h2h1b", "h2h1n", "h2h1q", "h2h1r", "h2h3", "h2h4", "h2h5", "h2h6", "h2h7", "h2h8", "h3a3", "h3b3", "h3c3", "h3c8", "h3d3", "h3d7", "h3e3", "h3e6", "h3f1", "h3f2", "h3f3", "h3f4", "h3f5", "h3g1", "h3g2", "h3g3", "h3g4", "h3g5", "h3h1", "h3h2", "h3h4", "h3h5", "h3h6", "h3h7", "h3h8", "h4a4", "h4b4", "h4c4", "h4d4", "h4d8", "h4e1", "h4e4", "h4e7", "h4f2", "h4f3", "h4f4", "h4f5", "h4f6", "h4g2", "h4g3", "h4g4", "h4g5", "h4g6", "h4h1", "h4h2", "h4h3", "h4h5", "h4h6", "h4h7", "h4h8", "h5a5", "h5b5", "h5c5", "h5d1", "h5d5", "h5e2", "h5e5", "h5e8", "h5f3", "h5f4", "h5f5", "h5f6", "h5f7", "h5g3", "h5g4", "h5g5", "h5g6", "h5g7", "h5h1", "h5h2", "h5h3", "h5h4", "h5h6", "h5h7", "h5h8", "h6a6", "h6b6", "h6c1", "h6c6", "h6d2", "h6d6", "h6e3", "h6e6", "h6f4", "h6f5", "h6f6", "h6f7", "h6f8", "h6g4", "h6g5", "h6g6", "h6g7", "h6g8", "h6h1", "h6h2", "h6h3", "h6h4", "h6h5", "h6h7", "h6h8", "h7a7", "h7b1", "h7b7", "h7c2", "h7c7", "h7d3", "h7d7", "h7e4", "h7e7", "h7f5", "h7f6", "h7f7", "h7f8", "h7g5", "h7g6", "h7g7", "h7g8", "h7g8b", "h7g8n", "h7g8q", "h7g8r", "h7h1", "h7h2", "h7h3", "h7h4", "h7h5", "h7h6", "h7h8", "h7h8b", "h7h8n", "h7h8q", "h7h8r", "h8a1", "h8a8", "h8b2", "h8b8", "h8c3", "h8c8", "h8d4", "h8d8", "h8e5", "h8e8", "h8f6", "h8f7", "h8f8", "h8g6", "h8g7", "h8g8", "h8h1", "h8h2", "h8h3", "h8h4", "h8h5", "h8h6", "h8h7"]
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from transformers import (
LlamaConfig, LlamaForSequenceClassification, LlamaForCausalLM,
GPT2Config, GPT2ForSequenceClassification, GPT2LMHeadModel,
PreTrainedTokenizerFast
)
from tokenizers import Tokenizer
from tokenizers.models import BPE
class RookTokenizer(PreTrainedTokenizerFast):
# TODO: make it easier to use checkpoints from the hub
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
def __call__(self, *args, **kwargs):
kwargs["return_token_type_ids"] = False
return super().__call__(*args, **kwargs)
def make_model(config_dict, arch="llama"):
if config_dict["finetuning_task"] == "text-classification":
return make_model_clf(config_dict, arch=arch)
elif config_dict["finetuning_task"] == "text-generation":
return make_model_lm(config_dict, arch=arch)
else:
raise ValueError(f"Unknown config finetuning_task: {config_dict['finetuning_task']}")
def make_model_clf(config_dict, arch):
if arch == "llama":
Config = LlamaConfig
Model = LlamaForSequenceClassification
if arch == "gpt2":
Config = GPT2Config
Model = GPT2ForSequenceClassification
# pad to multiple of 128
config_dict["vocab_size"] = ((len(VOCAB) + 127) // 128) * 128
config = Config(**config_dict)
label_to_id = {v: i for i, v in enumerate(ACTION_SPACE)}
config.num_labels = len(ACTION_SPACE)
config.label2id = label_to_id
config.id2label = {id: label for label, id in label_to_id.items()}
model = Model(config=config)
return model
def make_model_lm(config_dict, arch):
if arch == "llama":
Config = LlamaConfig
Model = LlamaForCausalLM
if arch == "gpt2":
Config = GPT2Config
Model = GPT2LMHeadModel
# pad to multiple of 128
config_dict["vocab_size"] = ((len(VOCAB) + len(ACTION_SPACE) + 4 + 127) // 128) * 128
config = Config(**config_dict)
model = Model(config=config)
return model
def make_tokenizer(task="clf"):
if task == "clf":
return make_tokenizer_clf(model_max_length=78)
elif task == "lm":
return make_tokenizer_lm(model_max_length=79)
elif task == "lm-cot":
return make_tokenizer_lm(model_max_length=116)
else:
raise ValueError(f"Unknown task: {task}")
def make_tokenizer_clf(model_max_length):
single_char_vocab = [e for e in VOCAB if len(e) == 1]
multi_char_vocab = [e for e in VOCAB if len(e) > 1]
merges = [tuple(e) for e in multi_char_vocab]
print(merges[:5])
tokenizer = Tokenizer(BPE(
vocab=dict(zip(single_char_vocab, range(len(single_char_vocab)))),
merges=merges)
)
fast_tokenizer = RookTokenizer(
tokenizer_object=tokenizer,
model_max_length=model_max_length,
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
clean_up_tokenization_spaces=False
)
return fast_tokenizer
def make_tokenizer_lm(model_max_length):
vocab = VOCAB + ACTION_SPACE
vocab += ["[OPTIONS]", "[VALUES]", "[ACTION]", "0000"]
single_char_vocab = [e for e in vocab if len(e) == 1]
multi_char_vocab = [e for e in vocab if len(e) > 1]
merges = []
tokenizer = Tokenizer(BPE(
vocab=dict(zip(single_char_vocab, range(len(single_char_vocab)))),
merges=merges)
)
tokenizer.add_special_tokens(multi_char_vocab)
fast_tokenizer = RookTokenizer(
tokenizer_object=tokenizer,
model_max_length=model_max_length,
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
clean_up_tokenization_spaces=False
)
return fast_tokenizer
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from transformers import pipeline
import torch
import chess
def process_fen(**kwargs):
pass
def unprocess_fen(**kwargs):
pass
class BCChessPolicy():
def __init__(self, model, tokenizer, batch_size=1, train_task="clf", filter_illegal=False):
print(f"Loading ROOK Chess Policy (Behavior Cloning)")
pipeline_type = "text-classification" if train_task == "clf" else "text-generation"
self.pipeline_type = pipeline_type
if isinstance(tokenizer, str) and os.path.exists(tokenizer):
tokenizer = RookTokenizer.from_pretrained(tokenizer)
if train_task == "clf":
top_k = None if filter_illegal else 1
max_new_tokens = None
else:
top_k = None
max_new_tokens = 2 # increase for lm-cot
self._pipeline = pipeline(
pipeline_type,
model=model,
tokenizer=tokenizer,
device="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16,
batch_size=batch_size,
top_k=top_k,
)
self._filter_illegal = filter_illegal
print(f"Total parameters: {self._pipeline.model.num_parameters():,}")
def _convert_fen(self, fen):
fen = process_fen(fen) if "/" in fen else fen
end_token = "[CLS]" if self.pipeline_type == "text-classification" else "[ACTION]"
fen += end_token if not fen.endswith(end_token) else ""
return fen
def _normalize_fen(self, fen):
try:
board = chess.Board(fen)
return fen
except ValueError:
if fen[-5:] == "[CLS]":
fen = fen[:-5]
return unprocess_fen(fen)
def play(self, fen):
if isinstance(fen, str):
fen = [fen]
inputs = [self._convert_fen(f) for f in fen]
if not self._filter_illegal:
if self.pipeline_type == "text-classification":
predictions = self._pipeline(inputs)
return [p[0]["label"] for p in predictions]
else:
predictions = self._pipeline(inputs, max_new_tokens=2)
actions = []
for p in predictions:
try:
actions.append(p[0]["generated_text"].split("[ACTION]")[-1])
except:
print("failed extracting [ACTION] from", p[0]["generated_text"])
actions.append("0000")
return actions
else:
# TODO vectorize
# TODO add pipeline_type text-generation
if self.pipeline_type == "text-generation":
raise NotImplementedError("text-generation pipeline is not yet implemented for filter_illegal")
predictions = self._pipeline(inputs)
boards = [chess.Board(self._normalize_fen(f)) for f in fen]
legal_moves = [[m.uci() for m in board.legal_moves] for board in boards]
scores = [[p["score"] for p in pred] for pred in predictions]
labels = [[p["label"] for p in pred] for pred in predictions]
dropped_labels = []
for i, label in enumerate(labels):
for j, l in enumerate(label):
if l not in legal_moves[i]:
dropped_labels.append((i, j))
labels[i] = [l for j, l in enumerate(label) if (i, j) not in dropped_labels]
scores[i] = [s for j, s in enumerate(scores[i]) if (i, j) not in dropped_labels]
best_moves = [max(zip(label, score), key=lambda x: x[1])[0] for label, score in zip(labels, scores)]
return best_moves
|