【深度学习】用LSTM写诗,生成式的方式写诗系列之一
Epoch 4: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=18.5, loss=5.8]
[5] loss: 5.828, accuracy: 18.389 , lr:0.001000
Epoch 5: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=19.2, loss=5.68]
[6] loss: 5.739, accuracy: 18.732 , lr:0.001000
Epoch 6: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=19.6, loss=5.57]
[7] loss: 5.629, accuracy: 19.197 , lr:0.001000
Epoch 7: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=19.9, loss=5.45]
[8] loss: 5.517, accuracy: 19.745 , lr:0.001000
Epoch 8: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=20.4, loss=5.38]
[9] loss: 5.402, accuracy: 20.316 , lr:0.001000
Epoch 9: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=21, loss=5.27]
[10] loss: 5.299, accuracy: 20.887 , lr:0.001000
Epoch 10: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=21.8, loss=5.16]
[11] loss: 5.210, accuracy: 21.427 , lr:0.001000
Epoch 11: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=22.2, loss=5.1]
[12] loss: 5.136, accuracy: 21.923 , lr:0.001000
Epoch 12: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=22.5, loss=5.06]
[13] loss: 5.071, accuracy: 22.379 , lr:0.001000
Epoch 13: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=22.9, loss=5.02]
[14] loss: 5.011, accuracy: 22.819 , lr:0.001000
Epoch 14: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=23.3, loss=4.94]
[15] loss: 4.959, accuracy: 23.212 , lr:0.001000
Epoch 15: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=23.7, loss=4.91]
[16] loss: 4.910, accuracy: 23.564 , lr:0.001000
Epoch 16: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=24.3, loss=4.82]
[17] loss: 4.862, accuracy: 23.914 , lr:0.001000
Epoch 17: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=24, loss=4.83]
[18] loss: 4.818, accuracy: 24.228 , lr:0.001000
Epoch 18: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=24.7, loss=4.77]
[19] loss: 4.775, accuracy: 24.523 , lr:0.001000
Epoch 19: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=24.6, loss=4.73]
[20] loss: 4.734, accuracy: 24.808 , lr:0.001000
Epoch 20: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=25, loss=4.69]
[21] loss: 4.694, accuracy: 25.090 , lr:0.001000
Epoch 21: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=25, loss=4.71]
[22] loss: 4.657, accuracy: 25.346 , lr:0.001000
Epoch 22: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=25.8, loss=4.62]
[23] loss: 4.619, accuracy: 25.587 , lr:0.001000
Epoch 23: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=25.9, loss=4.59]
[24] loss: 4.584, accuracy: 25.825 , lr:0.001000
Epoch 24: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.3, loss=4.52]
[25] loss: 4.549, accuracy: 26.078 , lr:0.001000
Epoch 25: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=26.3, loss=4.53]
[26] loss: 4.516, accuracy: 26.280 , lr:0.001000
Epoch 26: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.6, loss=4.49]
[27] loss: 4.483, accuracy: 26.517 , lr:0.001000
Epoch 27: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.8, loss=4.46]
[28] loss: 4.451, accuracy: 26.746 , lr:0.001000
Epoch 28: 100%|███████████████████████████████████████████████████████| 63/63 [1:00:07<00:00, 57.26s/batch, acc=27.1, loss=4.41]
[29] loss: 4.422, accuracy: 26.937 , lr:0.001000
Epoch 29: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=27.2, loss=4.38]
[30] loss: 4.389, accuracy: 27.182 , lr:0.001000
Epoch 30: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=27, loss=4.4]
[31] loss: 4.361, accuracy: 27.371 , lr:0.001000
Epoch 31: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=27.5, loss=4.34]
[32] loss: 4.332, accuracy: 27.589 , lr:0.001000
Epoch 32: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=27.6, loss=4.31]
[33] loss: 4.304, accuracy: 27.791 , lr:0.001000
Epoch 33: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.89batch/s, acc=27.9, loss=4.28]
[34] loss: 4.277, accuracy: 28.014 , lr:0.001000
Epoch 34: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=28, loss=4.26]
[35] loss: 4.248, accuracy: 28.200 , lr:0.001000
Epoch 35: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=28.6, loss=4.22]
[36] loss: 4.222, accuracy: 28.433 , lr:0.001000
Epoch 36: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=28.3, loss=4.21]
[37] loss: 4.196, accuracy: 28.625 , lr:0.001000
Epoch 37: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=29.1, loss=4.16]
[38] loss: 4.169, accuracy: 28.858 , lr:0.001000
Epoch 38: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29.2, loss=4.13]
[39] loss: 4.142, accuracy: 29.056 , lr:0.001000
Epoch 39: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29, loss=4.13]
[40] loss: 4.116, accuracy: 29.282 , lr:0.001000
Epoch 40: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=29.5, loss=4.12]
[41] loss: 4.092, accuracy: 29.477 , lr:0.001000
Epoch 41: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=29.7, loss=4.08]
[42] loss: 4.066, accuracy: 29.716 , lr:0.001000
Epoch 42: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29.8, loss=4.06]
[43] loss: 4.042, accuracy: 29.918 , lr:0.001000
Epoch 43: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.5, loss=3.99]
[44] loss: 4.016, accuracy: 30.146 , lr:0.001000
Epoch 44: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.2, loss=4.01]
[45] loss: 3.990, accuracy: 30.398 , lr:0.001000
Epoch 45: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.6, loss=3.96]
[46] loss: 3.968, accuracy: 30.607 , lr:0.001000
Epoch 46: 100%|█████████████████████████████████████████████████████████| 63/63 [40:05<00:00, 38.19s/batch, acc=30.6, loss=3.96]
[47] loss: 3.945, accuracy: 30.814 , lr:0.001000
Epoch 47: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=30.9, loss=3.94]
[48] loss: 3.918, accuracy: 31.073 , lr:0.001000
Epoch 48: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=31.1, loss=3.91]
[49] loss: 3.893, accuracy: 31.322 , lr:0.001000
Epoch 49: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=32, loss=3.86]
[50] loss: 3.869, accuracy: 31.574 , lr:0.001000
Epoch 50: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=31.2, loss=3.9]
[51] loss: 3.846, accuracy: 31.811 , lr:0.001000
Epoch 51: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=31.7, loss=3.85]
[52] loss: 3.823, accuracy: 32.042 , lr:0.001000
Epoch 52: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=32.4, loss=3.8]
[53] loss: 3.798, accuracy: 32.325 , lr:0.001000
Epoch 53: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.2, loss=3.8]
[54] loss: 3.776, accuracy: 32.552 , lr:0.001000
Epoch 54: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=32.4, loss=3.79]
[55] loss: 3.755, accuracy: 32.794 , lr:0.001000
Epoch 55: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.8, loss=3.75]
[56] loss: 3.729, accuracy: 33.081 , lr:0.001000
Epoch 56: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.8, loss=3.74]
[57] loss: 3.708, accuracy: 33.301 , lr:0.001000
Epoch 57: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=33.8, loss=3.68]
[58] loss: 3.683, accuracy: 33.597 , lr:0.001000
Epoch 58: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=33.5, loss=3.67]
[59] loss: 3.661, accuracy: 33.838 , lr:0.001000
Epoch 59: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=34, loss=3.65]
[60] loss: 3.639, accuracy: 34.106 , lr:0.001000
Epoch 60: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=34, loss=3.65]
[61] loss: 3.619, accuracy: 34.350 , lr:0.001000
Epoch 61: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=34.1, loss=3.64]
[62] loss: 3.595, accuracy: 34.632 , lr:0.001000
Epoch 62: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=34.6, loss=3.57]
[63] loss: 3.573, accuracy: 34.872 , lr:0.001000
Epoch 63: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=34.8, loss=3.58]
[64] loss: 3.553, accuracy: 35.140 , lr:0.001000
Epoch 64: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=35.1, loss=3.53]
[65] loss: 3.531, accuracy: 35.394 , lr:0.001000
Epoch 65: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=34.8, loss=3.56]
[66] loss: 3.512, accuracy: 35.636 , lr:0.001000
Epoch 66: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=35.1, loss=3.55]
[67] loss: 3.490, accuracy: 35.896 , lr:0.001000
Epoch 67: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=36.1, loss=3.49]
[68] loss: 3.471, accuracy: 36.147 , lr:0.001000
Epoch 68: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36, loss=3.48]
[69] loss: 3.451, accuracy: 36.413 , lr:0.001000
Epoch 69: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=36.5, loss=3.44]
[70] loss: 3.436, accuracy: 36.595 , lr:0.001000
Epoch 70: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36.5, loss=3.45]
[71] loss: 3.412, accuracy: 36.873 , lr:0.001000
Epoch 71: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36.2, loss=3.44]
[72] loss: 3.393, accuracy: 37.130 , lr:0.001000
Epoch 72: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=36.4, loss=3.44]
[73] loss: 3.375, accuracy: 37.342 , lr:0.001000
Epoch 73: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=37.1, loss=3.4]
[74] loss: 3.355, accuracy: 37.608 , lr:0.001000
Epoch 74: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=37.2, loss=3.37]
[75] loss: 3.337, accuracy: 37.853 , lr:0.001000
Epoch 75: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=37.9, loss=3.35]
[76] loss: 3.318, accuracy: 38.105 , lr:0.001000
Epoch 76: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=37.5, loss=3.35]
[77] loss: 3.303, accuracy: 38.282 , lr:0.001000
Epoch 77: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=37.9, loss=3.31]
[78] loss: 3.285, accuracy: 38.523 , lr:0.001000
Epoch 78: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.1, loss=3.3]
[79] loss: 3.267, accuracy: 38.738 , lr:0.001000
Epoch 79: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=38.9, loss=3.28]
[80] loss: 3.250, accuracy: 38.972 , lr:0.001000
Epoch 80: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.6, loss=3.27]
[81] loss: 3.230, accuracy: 39.248 , lr:0.001000
Epoch 81: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=39.1, loss=3.22]
[82] loss: 3.216, accuracy: 39.435 , lr:0.001000
Epoch 82: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.8, loss=3.25]
[83] loss: 3.197, accuracy: 39.675 , lr:0.001000
Epoch 83: 100%|██████████████████████████████████████████████████████████| 63/63 [00:38<00:00, 1.62batch/s, acc=39.7, loss=3.2]
[84] loss: 3.180, accuracy: 39.914 , lr:0.001000
Epoch 84: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=39.4, loss=3.2]
[85] loss: 3.165, accuracy: 40.108 , lr:0.001000
Epoch 85: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=40.1, loss=3.17]
[86] loss: 3.152, accuracy: 40.277 , lr:0.001000
Epoch 86: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=39.9, loss=3.18]
[87] loss: 3.135, accuracy: 40.508 , lr:0.001000
Epoch 87: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=40.4, loss=3.14]
[88] loss: 3.118, accuracy: 40.736 , lr:0.001000
Epoch 88: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=40.5, loss=3.14]
[89] loss: 3.104, accuracy: 40.918 , lr:0.001000
Epoch 89: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=40.7, loss=3.11]
[90] loss: 3.093, accuracy: 41.061 , lr:0.001000
Epoch 90: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=40.8, loss=3.1]
[91] loss: 3.074, accuracy: 41.315 , lr:0.001000
Epoch 91: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.5, loss=3.06]
[92] loss: 3.057, accuracy: 41.559 , lr:0.001000
Epoch 92: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.1, loss=3.09]
[93] loss: 3.043, accuracy: 41.745 , lr:0.001000
Epoch 93: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.5, loss=3.06]
[94] loss: 3.029, accuracy: 41.924 , lr:0.001000
Epoch 94: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=42, loss=3.03]
[95] loss: 3.015, accuracy: 42.133 , lr:0.001000
Epoch 95: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=41.6, loss=3.04]
[96] loss: 3.001, accuracy: 42.302 , lr:0.001000
Epoch 96: 100%|██████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=42, loss=3]
[97] loss: 2.988, accuracy: 42.483 , lr:0.001000
Epoch 97: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=42.9, loss=2.96]
[98] loss: 2.972, accuracy: 42.694 , lr:0.001000
Epoch 98: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=42.1, loss=3.01]
[99] loss: 2.964, accuracy: 42.804 , lr:0.001000
Epoch 99: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=42.2, loss=3.01]
[100] loss: 2.953, accuracy: 42.973 , lr:0.001000
Finished Training using %.3f seconds 6896.93013882637
先训练了100轮次,后面应该还能增长,但是不等了
数据初探:
class DictObj(object):def __init__(self, map):self.map = mapdef __getattr__(self, attr):if attr in self.map:return self.map[attr]else:raise AttributeError("No such attribute: " + attr)Config = DictObj({'poem_path':os.path.join(base_dir, "tang.npz"),"tensorboard_path":os.path.join(base_dir, "tensorboard"),"model_save_path":os.path.join(base_dir,"modelDict"),"embedding_dim":100,"hidden_dim":1024,"lr":0.001,"LSTM_layers":2,'batch_size':512,'epochs':500,'dropout':0.2,'ealier_stop':10,'device':torch.device('cuda' if torch.cuda.is_available() else 'cpu')
})
def view_data(poem_path):datas = np.load(poem_path, allow_pickle=True)data = datas['data'] #(57580,125)ix2word = datas['ix2word'].item() # datas['word2ix'].item() 8293word2ix = datas['word2ix'].item() # datas['word2ix'].item() 8293word_data = np.zeros((1,data.shape[1]), dtype = str) # 将所有的0 转化成 ''# 看一下其中一行的数据是什么?row = np.random.randint(0, data.shape[0]) # 随机选一行,左闭右开没问题print(data[row])for i in range(data.shape[1]):word_data[0][i] = ix2word[data[row][i]]print(word_data)view_data(Config.poem_path)
数据处理:
class PoemDataset(Dataset):def __init__(self, poem_path, seq_len):super().__init__()# np 文件的地址self.poem_path = poem_path# 序列长度,48 是认为规定的,也可以是其它值,因为大部分是5言或者7言,加上表达就是 6,或8, 取48确保是整句话self.seq_len = seq_lenself.poem_data, self.ix2word, self.word2ix = self.get_raw_data()self.no_space_data = self.filter_space()print("no_space_data len:", self.no_space_data[0:200])def __len__(self):return len(self.no_space_data)//(self.seq_len)def __getitem__(self, idx):txt = self.no_space_data[idx*self.seq_len:(idx+1)*self.seq_len]label = self.no_space_data[idx*self.seq_len+1:(idx+1)*self.seq_len+1]return torch.LongTensor(txt), torch.LongTensor(label)def filter_space(self):# 7197500 个文本tensor_data = torch.from_numpy(self.poem_data).view(-1)no_space_data = [] for i in range(tensor_data.shape[0]):word_idx = tensor_data[i].item()if word_idx!= 8292:no_space_data.append(word_idx)return no_space_datadef get_raw_data(self):datas = np.load(self.poem_path, allow_pickle=True)data = datas['data']ix2word = datas['ix2word'].item()word2ix = datas['word2ix'].item()return data, ix2word, word2ix
poem_dataset = PoemDataset(Config.poem_path, 96)
[8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8291 5428 6933 3469 7066 3465 64078248 7009 82 7435 925 3469 3576 232 786 5272 2296 7066 4807 61036663 2958 2003 2173 28 7066 1987 8061 4299 848 4874 7435 8290]
[['<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '冬' '月' '内' ',' '无' '叶' '艾' '枝' '枯' '。' '草' '内' '急' '寻' '蛇' '床''子' ',' '烧' '烟' '入' '中' '自' '消' '除' ',' '速' '救' '免' '灾' '虞' '。' '<']]
no_space_data len: [8291, 6731, 4770, 1787, 8118, 7577, 7066, 4817, 648, 7121, 1542, 6483, 7435, 7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629, 1379, 2703, 7435, 6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534, 70, 3788, 3823, 7435, 4907, 5567, 201, 2834, 1519, 7066, 782, 782, 2063, 2031, 846, 7435, 8290, 8291, 2309, 2596, 6483, 2260, 7316, 7066, 6332, 5274, 2125, 5029, 7792, 7435, 4186, 8087, 7047, 6622, 6933, 7066, 6134, 3564, 3766, 6920, 6157, 7435, 7086, 4770, 5849, 4776, 4981, 7066, 4857, 2649, 3020, 332, 1727, 7435, 7458, 7294, 3465, 5149, 1671, 7066, 2834, 6000, 3942, 3534, 1534, 7435, 4102, 7460, 758, 3961, 3374, 7066, 7904, 6811, 4449, 2121, 6802, 7435, 6182, 27, 7912, 1756, 7440, 7066, 201, 7909, 8118, 201, 4662, 7435, 7824, 1508, 3154, 152, 5862, 7066, 7976, 6043, 258, 47, 7878, 7435, 8290, 8291, 3495, 70, 7113, 4839, 5237, 7066, 65, 3941, 2031, 2260, 5418, 7435, 411, 6773, 2878, 4686, 482, 7066, 1989, 5617, 4992, 8245, 676, 7435, 4236, 1418, 4915, 7686, 7363, 7066, 5708, 7541, 7440, 5237, 2192, 7435, 3114, 5913, 7989, 3069, 1845, 7066, 7047, 3534, 4921, 6622, 6933, 7435, 1664, 2260, 2003, 4816, 7151, 7066, 5036, 2219, 5849, 4898, 174, 7435, 201, 7228, 222]
因为有空格,啥的,要先吧空格之类的去掉。
def show_dataset():idx,label = poem_dataset[0]for id in idx:print(poem_dataset.ix2word[id.item()], end=' ')print("\n")for la in label:print(poem_dataset.ix2word[la.item()], end=' ')'''<START> 度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 云'''
# 构建模型
class PoemModel(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)self.dropout = nn.Dropout(dropout)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, input, hidden=None):embeds = self.embedding(input)batch_size,seq_len,embedding_dim = embeds.shapeif hidden is None:h0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)c0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)else:h0,c0 = hiddenoutput, hidden = self.lstm(embeds, (h0, c0))# output = torch.tanh(self.dropout(self.fc1(output)))output = self.fc(output)return output, hiddenvocab_size = len(poem_dataset.word2ix)
model = PoemModel(vocab_size, Config.embedding_dim, Config.hidden_dim, Config.LSTM_layers, Config.dropout).to(Config.device)
input_data, label_data = next(iter(dataloader))
print(input_data.shape, label_data.shape)
output, hidden = model(input_data.to(Config.device))
# output.shape torch.Size([1024, 96, 8293]) hidden[0].shape torch.Size([3, 1024, 1024]) hidden[1].shape torch.Size([3, 1024, 1024]) label_data.shape torch.Size([1024, 96])
a = 1def accuracy(output, label_data):pred = output.argmax(dim=2)correct = (pred == label_data).sum().item()total = label_data.numel()return correct / total * 100
# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=Config.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=1)
def train(model, dataloader, criterion, optimizer, scheduler, epochs):if not os.path.exists(Config.model_save_path):os.makedirs(Config.model_save_path)best_acc = 0.0early_stop = 0start_time = time.time()for epoch in range(epochs):model.train()running_loss = 0.0running_acc = 0.0last_acc = 0.0with tqdm(dataloader, unit="batch") as tepoch:for input_data, label_data in tepoch:tepoch.set_description(f"Epoch {epoch}")input_data, label_data = input_data.to(Config.device), label_data.to(Config.device)optimizer.zero_grad()output, hidden = model(input_data)current_acc = accuracy(output, label_data)running_acc += current_accloss = criterion(output.view(-1, vocab_size), label_data.view(-1))loss.backward()optimizer.step()running_loss += loss.item()tepoch.set_postfix(loss=loss.item(), acc=current_acc)scheduler.step()last_acc = running_acc / len(dataloader)if last_acc > best_acc:best_acc = last_acctorch.save(model.state_dict(), os.path.join(Config.model_save_path, "best_model.pth"))else:early_stop += 1torch.save(model.state_dict(), os.path.join(Config.model_save_path, "last_model.pth"))print('[%d] loss: %.3f, accuracy: %.3f , lr:%.6f' % (epoch + 1, running_loss / len(dataloader), last_acc,scheduler.get_last_lr()[0]))if early_stop >= Config.ealier_stop:print("Early Stop")print("Best Accuracy: %.3f" % best_acc)breakprint('Finished Training using %.3f seconds', time.time() - start_time)
train(model, dataloader, criterion, optimizer, scheduler, Config.epochs)
模型构建以及训练如上,
现在看 500轮次,50个忍耐度的效果比较好
Epoch 145: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=61.6, loss=1.62]
[146] loss: 1.580, accuracy: 62.374 , lr:0.001000
Epoch 146: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.80batch/s, acc=62.5, loss=1.57]
[147] loss: 1.581, accuracy: 62.360 , lr:0.001000
Epoch 147: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=61.9, loss=1.61]
[148] loss: 1.582, accuracy: 62.324 , lr:0.001000
Early Stop
Best Accuracy: 62.397
Finished Training using %.3f seconds 1205.000694513321
使用效果
不满意的地方,写死96=seq_len 是不对的。
应该是 配合 padding 使用,并mask padding来指导损失 @todo, 下一篇文章我会搞定!
import torch
from train03 import Config
from train03 import PoemModelfrom train03 import PoemDataset
import ospoem_dataset = PoemDataset(Config.poem_path, 96)vocab_size = len(poem_dataset.word2ix)
model = PoemModel(vocab_size, Config.embedding_dim, Config.hidden_dim, Config.LSTM_layers, Config.dropout).to(Config.device)
model.load_state_dict(torch.load(os.path.join(Config.model_save_path, "best_model.pth")))def generate(model, start_words, ix2word, word2ix, device):results = list(start_words)start_words_len = len(start_words)# 第一个词语是<START>input = torch.Tensor([word2ix['<START>']]).view(1, 1).long()# 最开始的隐状态初始为0矩阵# torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim)hidden = torch.zeros((2,Config.LSTM_layers * 1, 1, Config.hidden_dim), dtype=torch.float32).to(Config.device)input = input.to(Config.device)hidden = hidden.to(Config.device)model.eval()with torch.no_grad():for i in range(48):output, hidden = model(input, hidden)# 如果在给定的句首中,input为句首中的下一个字if i < start_words_len:w = results[i]input = input.data.new([word2ix[w]]).view(1, 1)else:top_index = output.data[0].topk(1)[1][0].item()w = ix2word[top_index]results.append(w)input = input.data.new([top_index]).view(1, 1)if w == '<EOP>':del results[-1]breakreturn results
雨 余 虚 馆 竹 阴 清 , 独 坐 寒 窗 昼 未 醒 。 云 布 远 村 红 叶 返 , 水 深 秋 竹 翠 梢 寒 。 泉 声 入 阁 慙 嘉 石 , 山 色 题 诗 好 赋 诗 。
但是我有一点不太理解。
他是输入一个字,输出一个字,这一点好像不妥。不应该是 输入一个 生成1个, 然后输入两个,生成1个,然后输入3个生成1个么。。。 大神请指教一下吧。
相关文章:
【深度学习】用LSTM写诗,生成式的方式写诗系列之一
Epoch 4: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc18.5, loss5.8] [5] loss: 5.828, accuracy: 18.389 , lr:0.001000 Epoch 5: 100%|███…...

HomeAssistant自定义组件学习-【二】
#要说的话# 前面把中盛科技的控制器组件写完了。稍稍熟悉了一些HA,现在准备写窗帘控制组件,构想的东西会比较多,估计有些难度,过程会比较长,边写边记录吧! #设备和场景环境# 使用的是Novo的电机…...

如何看待AI技术的应用前景?
文章目录 如何看待AI技术的应用前景引言AI技术的现状1. AI的定义与分类2. 当前AI技术的应用领域 AI技术的应用前景1. 经济效益2. 社会影响3. 技术进步 AI技术应用面临的挑战1. 数据隐私与安全2. 可解释性与信任3. 技能短缺与就业影响 AI技术的未来发展方向1. 人工智能的伦理与法…...

Unity中的屏幕坐标系
获得视口宽高 拖动视口会改变屏幕宽高数值 MousePosition 屏幕坐标系的原点在左下角,MousePosition返回Z为0也就是纵深为0的Vector3 但是如果鼠标超出屏幕范围不会做限制,所以可能出现负数或者大于屏幕宽高的情况,做鼠标拖拽物体时需要注…...
标题点击可跳转网页
要实现点击标题跳转到网页的功能,你可以在Vue组件中使用<a>标签(锚点标签)并设置href属性为网页的URL。如果你希望使用uni-app的特性来控制页面跳转,可以使用uni.navigateTo方法(这适用于uni-app环境,…...

易语言模拟真人动态生成鼠标滑动路径
一.简介 鼠标轨迹算法是一种模拟人类鼠标操作的程序,它能够模拟出自然而真实的鼠标移动路径。 鼠标轨迹算法的底层实现采用C/C语言,原因在于C/C提供了高性能的执行能力和直接访问操作系统底层资源的能力。 鼠标轨迹算法具有以下优势: 模拟…...

Linux:生态与软件安装
文章目录 前言一、Linux下安装软件的方案二、包管理器是什么?三、生态问题相关的理解1. 什么操作系统是好的操作系统?2. 什么是生态?3. 软件包是谁写的?这些工程师为什么要写?钱的问题怎么解决? 四、我的服务器怎么知…...
R 语言与其他编程语言的区别
R 语言与其他编程语言的区别 R 语言作为一种专门用于统计计算和图形的编程语言,与其他编程语言相比有一些独特的特点和区别。本文将详细介绍这些区别,帮助你更好地理解 R 语言的优势和适用场景。 1. 专为统计和数据分析设计 统计功能 内置统计函数&…...

RC低通滤波器Bode图分析(传递函数零极点)
RC低通滤波器 我们使得R1K,C1uF;电容C的阻抗为Xc; 传递函数 H ( s ) u o u i X C X C R 1 s C 1 s C R 1 1 s R C (其中 s j ω ) H(s)\frac{u_{o} }{u_{i} } \frac{X_{C} }{X_{C}R} \frac{\frac{1}{sC} }{\…...
基于深度学习的网络入侵检测
基于深度学习的网络入侵检测是一种利用深度学习技术对网络流量进行实时监测与分析的方法,旨在识别并防范网络攻击和恶意活动。随着网络环境日益复杂,传统的入侵检测系统(IDS)在面对不断变化的攻击模式时,往往难以保持高…...
《构建一个具备从后端数据库获取数据并再前端显示的内容页面:前后端实现解析》
一、前端页面:布局与功能 1. 页面结构 我们先来看前端页面的 HTML 结构,它主要由以下几个部分组成: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewp…...

Rust 力扣 - 59. 螺旋矩阵 II
文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 使用一个全局变量current记录当前遍历到的元素的值 我们只需要一圈一圈的从外向内遍历矩阵,每一圈遍历顺序为上边、右边、下边、左边,每遍历完一个元素后current 我们需要注意的是如果上…...

2.4w字 —TS入门教程
目录 1. 什么是TS 2. TS基本使用 3 TS基础语法 3.1 基础类型约束 3.11 string,number,boolean, null和undefined 3.12 any 3.13 unknown 3.14 void 3.15 数组 3.16 对象 3.2 函数的约束 3.21 普通写法 3.22 函数表达式 3.22 可选…...

java: 未结束的字符文字 报错及解决:将编码全部改为UTF-8或者GBK
报错: 解决: 将编码都改成UTF-8或者GBK:...

Android平台RTSP转RTMP推送之采集麦克风音频转发
技术背景 RTSP转RTMP推送,好多开发者第一想到的是采用ffmpeg命令行的形式,如果对ffmpeg比较熟,而且产品不要额外的定制和更高阶的要求,未尝不可,如果对产品稳定性、时延、断网重连等有更高的技术诉求,比较…...

认证鉴权框架之—sa-token
一、概述 Satoken 是一个 Java 实现的权限认证框架,它主要用于 Web 应用程序的权限控制。Satoken 提供了丰富的功能来简化权限管理的过程,使得开发者可以更加专注于业务逻辑的开发。 二、逻辑流程 1、登录认证 (1)、创建token …...
Spring源码(十一):Spring MVC之DispatchServlet
本篇重点在于分析Spring MVC与Servlet标准的整合,下节将详细讨论Spring MVC的启动/加载流程、处理请求的具体流程。 一、介绍 Spring框架提供了构建Web应用程序的全功能MVC模块。通过策略接口 ,Spring框架是高度可配置的,而且支持多种视图技…...
gitbash简单操作
https://blog.csdn.net/qq_42363495/article/details/104878170 工作区(空间)--暂存区--本地仓库--远程仓库 方法一:创建一个新的分支master,且远程库里没有该分支 只要将.gitignore文件放在文件夹下就可以,.gitignore是文本文档形式的文件…...
pnpm install安装element-plus的版本跟package.json指定的版本不一样
pnpm安装的版本不同于package.json中指定的版本可能是由于以下几种情况导致的: 依赖项冲突:当项目依赖的不同模块或库之间存在版本冲突时,pnpm可能会安装与package.json中指定的版本不同的版本。这可能是因为其他依赖项指定了不同的版本&…...

Java线程池的核心内容详解
文章内容已经收录在《面试进阶之路》,从原理出发,直击面试难点,实现更高维度的降维打击! 目录 文章目录 目录Java线程池的核心内容详解线程池的优势什么场景下要用到线程池呢?线程池中重要的参数【掌握】新加入一个任…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?
编辑:陈萍萍的公主一点人工一点智能 未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战,在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...
vscode里如何用git
打开vs终端执行如下: 1 初始化 Git 仓库(如果尚未初始化) git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...
linux 下常用变更-8
1、删除普通用户 查询用户初始UID和GIDls -l /home/ ###家目录中查看UID cat /etc/group ###此文件查看GID删除用户1.编辑文件 /etc/passwd 找到对应的行,YW343:x:0:0::/home/YW343:/bin/bash 2.将标红的位置修改为用户对应初始UID和GID: YW3…...

C# 类和继承(抽象类)
抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...

自然语言处理——Transformer
自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效,它能挖掘数据中的时序信息以及语义信息,但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN,但是…...

【JavaWeb】Docker项目部署
引言 之前学习了Linux操作系统的常见命令,在Linux上安装软件,以及如何在Linux上部署一个单体项目,大多数同学都会有相同的感受,那就是麻烦。 核心体现在三点: 命令太多了,记不住 软件安装包名字复杂&…...
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

OPENCV形态学基础之二腐蚀
一.腐蚀的原理 (图1) 数学表达式:dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一,腐蚀跟膨胀属于反向操作,膨胀是把图像图像变大,而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...
LeetCode - 199. 二叉树的右视图
题目 199. 二叉树的右视图 - 力扣(LeetCode) 思路 右视图是指从树的右侧看,对于每一层,只能看到该层最右边的节点。实现思路是: 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...