【深度学习】用LSTM写诗,生成式的方式写诗系列之一
Epoch 4: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=18.5, loss=5.8]
[5] loss: 5.828, accuracy: 18.389 , lr:0.001000
Epoch 5: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=19.2, loss=5.68]
[6] loss: 5.739, accuracy: 18.732 , lr:0.001000
Epoch 6: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=19.6, loss=5.57]
[7] loss: 5.629, accuracy: 19.197 , lr:0.001000
Epoch 7: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=19.9, loss=5.45]
[8] loss: 5.517, accuracy: 19.745 , lr:0.001000
Epoch 8: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=20.4, loss=5.38]
[9] loss: 5.402, accuracy: 20.316 , lr:0.001000
Epoch 9: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=21, loss=5.27]
[10] loss: 5.299, accuracy: 20.887 , lr:0.001000
Epoch 10: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=21.8, loss=5.16]
[11] loss: 5.210, accuracy: 21.427 , lr:0.001000
Epoch 11: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=22.2, loss=5.1]
[12] loss: 5.136, accuracy: 21.923 , lr:0.001000
Epoch 12: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=22.5, loss=5.06]
[13] loss: 5.071, accuracy: 22.379 , lr:0.001000
Epoch 13: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=22.9, loss=5.02]
[14] loss: 5.011, accuracy: 22.819 , lr:0.001000
Epoch 14: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=23.3, loss=4.94]
[15] loss: 4.959, accuracy: 23.212 , lr:0.001000
Epoch 15: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=23.7, loss=4.91]
[16] loss: 4.910, accuracy: 23.564 , lr:0.001000
Epoch 16: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=24.3, loss=4.82]
[17] loss: 4.862, accuracy: 23.914 , lr:0.001000
Epoch 17: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=24, loss=4.83]
[18] loss: 4.818, accuracy: 24.228 , lr:0.001000
Epoch 18: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=24.7, loss=4.77]
[19] loss: 4.775, accuracy: 24.523 , lr:0.001000
Epoch 19: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=24.6, loss=4.73]
[20] loss: 4.734, accuracy: 24.808 , lr:0.001000
Epoch 20: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=25, loss=4.69]
[21] loss: 4.694, accuracy: 25.090 , lr:0.001000
Epoch 21: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=25, loss=4.71]
[22] loss: 4.657, accuracy: 25.346 , lr:0.001000
Epoch 22: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=25.8, loss=4.62]
[23] loss: 4.619, accuracy: 25.587 , lr:0.001000
Epoch 23: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=25.9, loss=4.59]
[24] loss: 4.584, accuracy: 25.825 , lr:0.001000
Epoch 24: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.3, loss=4.52]
[25] loss: 4.549, accuracy: 26.078 , lr:0.001000
Epoch 25: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=26.3, loss=4.53]
[26] loss: 4.516, accuracy: 26.280 , lr:0.001000
Epoch 26: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.6, loss=4.49]
[27] loss: 4.483, accuracy: 26.517 , lr:0.001000
Epoch 27: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.8, loss=4.46]
[28] loss: 4.451, accuracy: 26.746 , lr:0.001000
Epoch 28: 100%|███████████████████████████████████████████████████████| 63/63 [1:00:07<00:00, 57.26s/batch, acc=27.1, loss=4.41]
[29] loss: 4.422, accuracy: 26.937 , lr:0.001000
Epoch 29: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=27.2, loss=4.38]
[30] loss: 4.389, accuracy: 27.182 , lr:0.001000
Epoch 30: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=27, loss=4.4]
[31] loss: 4.361, accuracy: 27.371 , lr:0.001000
Epoch 31: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=27.5, loss=4.34]
[32] loss: 4.332, accuracy: 27.589 , lr:0.001000
Epoch 32: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=27.6, loss=4.31]
[33] loss: 4.304, accuracy: 27.791 , lr:0.001000
Epoch 33: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.89batch/s, acc=27.9, loss=4.28]
[34] loss: 4.277, accuracy: 28.014 , lr:0.001000
Epoch 34: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=28, loss=4.26]
[35] loss: 4.248, accuracy: 28.200 , lr:0.001000
Epoch 35: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=28.6, loss=4.22]
[36] loss: 4.222, accuracy: 28.433 , lr:0.001000
Epoch 36: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=28.3, loss=4.21]
[37] loss: 4.196, accuracy: 28.625 , lr:0.001000
Epoch 37: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=29.1, loss=4.16]
[38] loss: 4.169, accuracy: 28.858 , lr:0.001000
Epoch 38: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29.2, loss=4.13]
[39] loss: 4.142, accuracy: 29.056 , lr:0.001000
Epoch 39: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29, loss=4.13]
[40] loss: 4.116, accuracy: 29.282 , lr:0.001000
Epoch 40: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=29.5, loss=4.12]
[41] loss: 4.092, accuracy: 29.477 , lr:0.001000
Epoch 41: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=29.7, loss=4.08]
[42] loss: 4.066, accuracy: 29.716 , lr:0.001000
Epoch 42: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29.8, loss=4.06]
[43] loss: 4.042, accuracy: 29.918 , lr:0.001000
Epoch 43: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.5, loss=3.99]
[44] loss: 4.016, accuracy: 30.146 , lr:0.001000
Epoch 44: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.2, loss=4.01]
[45] loss: 3.990, accuracy: 30.398 , lr:0.001000
Epoch 45: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.6, loss=3.96]
[46] loss: 3.968, accuracy: 30.607 , lr:0.001000
Epoch 46: 100%|█████████████████████████████████████████████████████████| 63/63 [40:05<00:00, 38.19s/batch, acc=30.6, loss=3.96]
[47] loss: 3.945, accuracy: 30.814 , lr:0.001000
Epoch 47: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=30.9, loss=3.94]
[48] loss: 3.918, accuracy: 31.073 , lr:0.001000
Epoch 48: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=31.1, loss=3.91]
[49] loss: 3.893, accuracy: 31.322 , lr:0.001000
Epoch 49: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=32, loss=3.86]
[50] loss: 3.869, accuracy: 31.574 , lr:0.001000
Epoch 50: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=31.2, loss=3.9]
[51] loss: 3.846, accuracy: 31.811 , lr:0.001000
Epoch 51: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=31.7, loss=3.85]
[52] loss: 3.823, accuracy: 32.042 , lr:0.001000
Epoch 52: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=32.4, loss=3.8]
[53] loss: 3.798, accuracy: 32.325 , lr:0.001000
Epoch 53: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.2, loss=3.8]
[54] loss: 3.776, accuracy: 32.552 , lr:0.001000
Epoch 54: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=32.4, loss=3.79]
[55] loss: 3.755, accuracy: 32.794 , lr:0.001000
Epoch 55: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.8, loss=3.75]
[56] loss: 3.729, accuracy: 33.081 , lr:0.001000
Epoch 56: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.8, loss=3.74]
[57] loss: 3.708, accuracy: 33.301 , lr:0.001000
Epoch 57: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=33.8, loss=3.68]
[58] loss: 3.683, accuracy: 33.597 , lr:0.001000
Epoch 58: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=33.5, loss=3.67]
[59] loss: 3.661, accuracy: 33.838 , lr:0.001000
Epoch 59: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=34, loss=3.65]
[60] loss: 3.639, accuracy: 34.106 , lr:0.001000
Epoch 60: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=34, loss=3.65]
[61] loss: 3.619, accuracy: 34.350 , lr:0.001000
Epoch 61: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=34.1, loss=3.64]
[62] loss: 3.595, accuracy: 34.632 , lr:0.001000
Epoch 62: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=34.6, loss=3.57]
[63] loss: 3.573, accuracy: 34.872 , lr:0.001000
Epoch 63: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=34.8, loss=3.58]
[64] loss: 3.553, accuracy: 35.140 , lr:0.001000
Epoch 64: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=35.1, loss=3.53]
[65] loss: 3.531, accuracy: 35.394 , lr:0.001000
Epoch 65: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=34.8, loss=3.56]
[66] loss: 3.512, accuracy: 35.636 , lr:0.001000
Epoch 66: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=35.1, loss=3.55]
[67] loss: 3.490, accuracy: 35.896 , lr:0.001000
Epoch 67: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=36.1, loss=3.49]
[68] loss: 3.471, accuracy: 36.147 , lr:0.001000
Epoch 68: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36, loss=3.48]
[69] loss: 3.451, accuracy: 36.413 , lr:0.001000
Epoch 69: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=36.5, loss=3.44]
[70] loss: 3.436, accuracy: 36.595 , lr:0.001000
Epoch 70: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36.5, loss=3.45]
[71] loss: 3.412, accuracy: 36.873 , lr:0.001000
Epoch 71: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36.2, loss=3.44]
[72] loss: 3.393, accuracy: 37.130 , lr:0.001000
Epoch 72: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=36.4, loss=3.44]
[73] loss: 3.375, accuracy: 37.342 , lr:0.001000
Epoch 73: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=37.1, loss=3.4]
[74] loss: 3.355, accuracy: 37.608 , lr:0.001000
Epoch 74: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=37.2, loss=3.37]
[75] loss: 3.337, accuracy: 37.853 , lr:0.001000
Epoch 75: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=37.9, loss=3.35]
[76] loss: 3.318, accuracy: 38.105 , lr:0.001000
Epoch 76: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=37.5, loss=3.35]
[77] loss: 3.303, accuracy: 38.282 , lr:0.001000
Epoch 77: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=37.9, loss=3.31]
[78] loss: 3.285, accuracy: 38.523 , lr:0.001000
Epoch 78: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.1, loss=3.3]
[79] loss: 3.267, accuracy: 38.738 , lr:0.001000
Epoch 79: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=38.9, loss=3.28]
[80] loss: 3.250, accuracy: 38.972 , lr:0.001000
Epoch 80: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.6, loss=3.27]
[81] loss: 3.230, accuracy: 39.248 , lr:0.001000
Epoch 81: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=39.1, loss=3.22]
[82] loss: 3.216, accuracy: 39.435 , lr:0.001000
Epoch 82: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.8, loss=3.25]
[83] loss: 3.197, accuracy: 39.675 , lr:0.001000
Epoch 83: 100%|██████████████████████████████████████████████████████████| 63/63 [00:38<00:00, 1.62batch/s, acc=39.7, loss=3.2]
[84] loss: 3.180, accuracy: 39.914 , lr:0.001000
Epoch 84: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=39.4, loss=3.2]
[85] loss: 3.165, accuracy: 40.108 , lr:0.001000
Epoch 85: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=40.1, loss=3.17]
[86] loss: 3.152, accuracy: 40.277 , lr:0.001000
Epoch 86: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=39.9, loss=3.18]
[87] loss: 3.135, accuracy: 40.508 , lr:0.001000
Epoch 87: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=40.4, loss=3.14]
[88] loss: 3.118, accuracy: 40.736 , lr:0.001000
Epoch 88: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=40.5, loss=3.14]
[89] loss: 3.104, accuracy: 40.918 , lr:0.001000
Epoch 89: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=40.7, loss=3.11]
[90] loss: 3.093, accuracy: 41.061 , lr:0.001000
Epoch 90: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=40.8, loss=3.1]
[91] loss: 3.074, accuracy: 41.315 , lr:0.001000
Epoch 91: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.5, loss=3.06]
[92] loss: 3.057, accuracy: 41.559 , lr:0.001000
Epoch 92: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.1, loss=3.09]
[93] loss: 3.043, accuracy: 41.745 , lr:0.001000
Epoch 93: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.5, loss=3.06]
[94] loss: 3.029, accuracy: 41.924 , lr:0.001000
Epoch 94: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=42, loss=3.03]
[95] loss: 3.015, accuracy: 42.133 , lr:0.001000
Epoch 95: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=41.6, loss=3.04]
[96] loss: 3.001, accuracy: 42.302 , lr:0.001000
Epoch 96: 100%|██████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=42, loss=3]
[97] loss: 2.988, accuracy: 42.483 , lr:0.001000
Epoch 97: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=42.9, loss=2.96]
[98] loss: 2.972, accuracy: 42.694 , lr:0.001000
Epoch 98: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=42.1, loss=3.01]
[99] loss: 2.964, accuracy: 42.804 , lr:0.001000
Epoch 99: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=42.2, loss=3.01]
[100] loss: 2.953, accuracy: 42.973 , lr:0.001000
Finished Training using %.3f seconds 6896.93013882637
先训练了100轮次,后面应该还能增长,但是不等了
数据初探:
class DictObj(object):def __init__(self, map):self.map = mapdef __getattr__(self, attr):if attr in self.map:return self.map[attr]else:raise AttributeError("No such attribute: " + attr)Config = DictObj({'poem_path':os.path.join(base_dir, "tang.npz"),"tensorboard_path":os.path.join(base_dir, "tensorboard"),"model_save_path":os.path.join(base_dir,"modelDict"),"embedding_dim":100,"hidden_dim":1024,"lr":0.001,"LSTM_layers":2,'batch_size':512,'epochs':500,'dropout':0.2,'ealier_stop':10,'device':torch.device('cuda' if torch.cuda.is_available() else 'cpu')
})
def view_data(poem_path):datas = np.load(poem_path, allow_pickle=True)data = datas['data'] #(57580,125)ix2word = datas['ix2word'].item() # datas['word2ix'].item() 8293word2ix = datas['word2ix'].item() # datas['word2ix'].item() 8293word_data = np.zeros((1,data.shape[1]), dtype = str) # 将所有的0 转化成 ''# 看一下其中一行的数据是什么?row = np.random.randint(0, data.shape[0]) # 随机选一行,左闭右开没问题print(data[row])for i in range(data.shape[1]):word_data[0][i] = ix2word[data[row][i]]print(word_data)view_data(Config.poem_path)
数据处理:
class PoemDataset(Dataset):def __init__(self, poem_path, seq_len):super().__init__()# np 文件的地址self.poem_path = poem_path# 序列长度,48 是认为规定的,也可以是其它值,因为大部分是5言或者7言,加上表达就是 6,或8, 取48确保是整句话self.seq_len = seq_lenself.poem_data, self.ix2word, self.word2ix = self.get_raw_data()self.no_space_data = self.filter_space()print("no_space_data len:", self.no_space_data[0:200])def __len__(self):return len(self.no_space_data)//(self.seq_len)def __getitem__(self, idx):txt = self.no_space_data[idx*self.seq_len:(idx+1)*self.seq_len]label = self.no_space_data[idx*self.seq_len+1:(idx+1)*self.seq_len+1]return torch.LongTensor(txt), torch.LongTensor(label)def filter_space(self):# 7197500 个文本tensor_data = torch.from_numpy(self.poem_data).view(-1)no_space_data = [] for i in range(tensor_data.shape[0]):word_idx = tensor_data[i].item()if word_idx!= 8292:no_space_data.append(word_idx)return no_space_datadef get_raw_data(self):datas = np.load(self.poem_path, allow_pickle=True)data = datas['data']ix2word = datas['ix2word'].item()word2ix = datas['word2ix'].item()return data, ix2word, word2ix
poem_dataset = PoemDataset(Config.poem_path, 96)
[8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8291 5428 6933 3469 7066 3465 64078248 7009 82 7435 925 3469 3576 232 786 5272 2296 7066 4807 61036663 2958 2003 2173 28 7066 1987 8061 4299 848 4874 7435 8290]
[['<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '冬' '月' '内' ',' '无' '叶' '艾' '枝' '枯' '。' '草' '内' '急' '寻' '蛇' '床''子' ',' '烧' '烟' '入' '中' '自' '消' '除' ',' '速' '救' '免' '灾' '虞' '。' '<']]
no_space_data len: [8291, 6731, 4770, 1787, 8118, 7577, 7066, 4817, 648, 7121, 1542, 6483, 7435, 7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629, 1379, 2703, 7435, 6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534, 70, 3788, 3823, 7435, 4907, 5567, 201, 2834, 1519, 7066, 782, 782, 2063, 2031, 846, 7435, 8290, 8291, 2309, 2596, 6483, 2260, 7316, 7066, 6332, 5274, 2125, 5029, 7792, 7435, 4186, 8087, 7047, 6622, 6933, 7066, 6134, 3564, 3766, 6920, 6157, 7435, 7086, 4770, 5849, 4776, 4981, 7066, 4857, 2649, 3020, 332, 1727, 7435, 7458, 7294, 3465, 5149, 1671, 7066, 2834, 6000, 3942, 3534, 1534, 7435, 4102, 7460, 758, 3961, 3374, 7066, 7904, 6811, 4449, 2121, 6802, 7435, 6182, 27, 7912, 1756, 7440, 7066, 201, 7909, 8118, 201, 4662, 7435, 7824, 1508, 3154, 152, 5862, 7066, 7976, 6043, 258, 47, 7878, 7435, 8290, 8291, 3495, 70, 7113, 4839, 5237, 7066, 65, 3941, 2031, 2260, 5418, 7435, 411, 6773, 2878, 4686, 482, 7066, 1989, 5617, 4992, 8245, 676, 7435, 4236, 1418, 4915, 7686, 7363, 7066, 5708, 7541, 7440, 5237, 2192, 7435, 3114, 5913, 7989, 3069, 1845, 7066, 7047, 3534, 4921, 6622, 6933, 7435, 1664, 2260, 2003, 4816, 7151, 7066, 5036, 2219, 5849, 4898, 174, 7435, 201, 7228, 222]
因为有空格,啥的,要先吧空格之类的去掉。
def show_dataset():idx,label = poem_dataset[0]for id in idx:print(poem_dataset.ix2word[id.item()], end=' ')print("\n")for la in label:print(poem_dataset.ix2word[la.item()], end=' ')'''<START> 度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 云'''
# 构建模型
class PoemModel(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)self.dropout = nn.Dropout(dropout)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, input, hidden=None):embeds = self.embedding(input)batch_size,seq_len,embedding_dim = embeds.shapeif hidden is None:h0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)c0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)else:h0,c0 = hiddenoutput, hidden = self.lstm(embeds, (h0, c0))# output = torch.tanh(self.dropout(self.fc1(output)))output = self.fc(output)return output, hiddenvocab_size = len(poem_dataset.word2ix)
model = PoemModel(vocab_size, Config.embedding_dim, Config.hidden_dim, Config.LSTM_layers, Config.dropout).to(Config.device)
input_data, label_data = next(iter(dataloader))
print(input_data.shape, label_data.shape)
output, hidden = model(input_data.to(Config.device))
# output.shape torch.Size([1024, 96, 8293]) hidden[0].shape torch.Size([3, 1024, 1024]) hidden[1].shape torch.Size([3, 1024, 1024]) label_data.shape torch.Size([1024, 96])
a = 1def accuracy(output, label_data):pred = output.argmax(dim=2)correct = (pred == label_data).sum().item()total = label_data.numel()return correct / total * 100
# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=Config.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=1)
def train(model, dataloader, criterion, optimizer, scheduler, epochs):if not os.path.exists(Config.model_save_path):os.makedirs(Config.model_save_path)best_acc = 0.0early_stop = 0start_time = time.time()for epoch in range(epochs):model.train()running_loss = 0.0running_acc = 0.0last_acc = 0.0with tqdm(dataloader, unit="batch") as tepoch:for input_data, label_data in tepoch:tepoch.set_description(f"Epoch {epoch}")input_data, label_data = input_data.to(Config.device), label_data.to(Config.device)optimizer.zero_grad()output, hidden = model(input_data)current_acc = accuracy(output, label_data)running_acc += current_accloss = criterion(output.view(-1, vocab_size), label_data.view(-1))loss.backward()optimizer.step()running_loss += loss.item()tepoch.set_postfix(loss=loss.item(), acc=current_acc)scheduler.step()last_acc = running_acc / len(dataloader)if last_acc > best_acc:best_acc = last_acctorch.save(model.state_dict(), os.path.join(Config.model_save_path, "best_model.pth"))else:early_stop += 1torch.save(model.state_dict(), os.path.join(Config.model_save_path, "last_model.pth"))print('[%d] loss: %.3f, accuracy: %.3f , lr:%.6f' % (epoch + 1, running_loss / len(dataloader), last_acc,scheduler.get_last_lr()[0]))if early_stop >= Config.ealier_stop:print("Early Stop")print("Best Accuracy: %.3f" % best_acc)breakprint('Finished Training using %.3f seconds', time.time() - start_time)
train(model, dataloader, criterion, optimizer, scheduler, Config.epochs)
模型构建以及训练如上,
现在看 500轮次,50个忍耐度的效果比较好
Epoch 145: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=61.6, loss=1.62]
[146] loss: 1.580, accuracy: 62.374 , lr:0.001000
Epoch 146: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.80batch/s, acc=62.5, loss=1.57]
[147] loss: 1.581, accuracy: 62.360 , lr:0.001000
Epoch 147: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=61.9, loss=1.61]
[148] loss: 1.582, accuracy: 62.324 , lr:0.001000
Early Stop
Best Accuracy: 62.397
Finished Training using %.3f seconds 1205.000694513321
使用效果
不满意的地方,写死96=seq_len 是不对的。
应该是 配合 padding 使用,并mask padding来指导损失 @todo, 下一篇文章我会搞定!
import torch
from train03 import Config
from train03 import PoemModelfrom train03 import PoemDataset
import ospoem_dataset = PoemDataset(Config.poem_path, 96)vocab_size = len(poem_dataset.word2ix)
model = PoemModel(vocab_size, Config.embedding_dim, Config.hidden_dim, Config.LSTM_layers, Config.dropout).to(Config.device)
model.load_state_dict(torch.load(os.path.join(Config.model_save_path, "best_model.pth")))def generate(model, start_words, ix2word, word2ix, device):results = list(start_words)start_words_len = len(start_words)# 第一个词语是<START>input = torch.Tensor([word2ix['<START>']]).view(1, 1).long()# 最开始的隐状态初始为0矩阵# torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim)hidden = torch.zeros((2,Config.LSTM_layers * 1, 1, Config.hidden_dim), dtype=torch.float32).to(Config.device)input = input.to(Config.device)hidden = hidden.to(Config.device)model.eval()with torch.no_grad():for i in range(48):output, hidden = model(input, hidden)# 如果在给定的句首中,input为句首中的下一个字if i < start_words_len:w = results[i]input = input.data.new([word2ix[w]]).view(1, 1)else:top_index = output.data[0].topk(1)[1][0].item()w = ix2word[top_index]results.append(w)input = input.data.new([top_index]).view(1, 1)if w == '<EOP>':del results[-1]breakreturn results
雨 余 虚 馆 竹 阴 清 , 独 坐 寒 窗 昼 未 醒 。 云 布 远 村 红 叶 返 , 水 深 秋 竹 翠 梢 寒 。 泉 声 入 阁 慙 嘉 石 , 山 色 题 诗 好 赋 诗 。
但是我有一点不太理解。
他是输入一个字,输出一个字,这一点好像不妥。不应该是 输入一个 生成1个, 然后输入两个,生成1个,然后输入3个生成1个么。。。 大神请指教一下吧。
相关文章:
【深度学习】用LSTM写诗,生成式的方式写诗系列之一
Epoch 4: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc18.5, loss5.8] [5] loss: 5.828, accuracy: 18.389 , lr:0.001000 Epoch 5: 100%|███…...

HomeAssistant自定义组件学习-【二】
#要说的话# 前面把中盛科技的控制器组件写完了。稍稍熟悉了一些HA,现在准备写窗帘控制组件,构想的东西会比较多,估计有些难度,过程会比较长,边写边记录吧! #设备和场景环境# 使用的是Novo的电机…...

如何看待AI技术的应用前景?
文章目录 如何看待AI技术的应用前景引言AI技术的现状1. AI的定义与分类2. 当前AI技术的应用领域 AI技术的应用前景1. 经济效益2. 社会影响3. 技术进步 AI技术应用面临的挑战1. 数据隐私与安全2. 可解释性与信任3. 技能短缺与就业影响 AI技术的未来发展方向1. 人工智能的伦理与法…...

Unity中的屏幕坐标系
获得视口宽高 拖动视口会改变屏幕宽高数值 MousePosition 屏幕坐标系的原点在左下角,MousePosition返回Z为0也就是纵深为0的Vector3 但是如果鼠标超出屏幕范围不会做限制,所以可能出现负数或者大于屏幕宽高的情况,做鼠标拖拽物体时需要注…...
标题点击可跳转网页
要实现点击标题跳转到网页的功能,你可以在Vue组件中使用<a>标签(锚点标签)并设置href属性为网页的URL。如果你希望使用uni-app的特性来控制页面跳转,可以使用uni.navigateTo方法(这适用于uni-app环境,…...

易语言模拟真人动态生成鼠标滑动路径
一.简介 鼠标轨迹算法是一种模拟人类鼠标操作的程序,它能够模拟出自然而真实的鼠标移动路径。 鼠标轨迹算法的底层实现采用C/C语言,原因在于C/C提供了高性能的执行能力和直接访问操作系统底层资源的能力。 鼠标轨迹算法具有以下优势: 模拟…...

Linux:生态与软件安装
文章目录 前言一、Linux下安装软件的方案二、包管理器是什么?三、生态问题相关的理解1. 什么操作系统是好的操作系统?2. 什么是生态?3. 软件包是谁写的?这些工程师为什么要写?钱的问题怎么解决? 四、我的服务器怎么知…...
R 语言与其他编程语言的区别
R 语言与其他编程语言的区别 R 语言作为一种专门用于统计计算和图形的编程语言,与其他编程语言相比有一些独特的特点和区别。本文将详细介绍这些区别,帮助你更好地理解 R 语言的优势和适用场景。 1. 专为统计和数据分析设计 统计功能 内置统计函数&…...

RC低通滤波器Bode图分析(传递函数零极点)
RC低通滤波器 我们使得R1K,C1uF;电容C的阻抗为Xc; 传递函数 H ( s ) u o u i X C X C R 1 s C 1 s C R 1 1 s R C (其中 s j ω ) H(s)\frac{u_{o} }{u_{i} } \frac{X_{C} }{X_{C}R} \frac{\frac{1}{sC} }{\…...
基于深度学习的网络入侵检测
基于深度学习的网络入侵检测是一种利用深度学习技术对网络流量进行实时监测与分析的方法,旨在识别并防范网络攻击和恶意活动。随着网络环境日益复杂,传统的入侵检测系统(IDS)在面对不断变化的攻击模式时,往往难以保持高…...
《构建一个具备从后端数据库获取数据并再前端显示的内容页面:前后端实现解析》
一、前端页面:布局与功能 1. 页面结构 我们先来看前端页面的 HTML 结构,它主要由以下几个部分组成: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewp…...

Rust 力扣 - 59. 螺旋矩阵 II
文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 使用一个全局变量current记录当前遍历到的元素的值 我们只需要一圈一圈的从外向内遍历矩阵,每一圈遍历顺序为上边、右边、下边、左边,每遍历完一个元素后current 我们需要注意的是如果上…...

2.4w字 —TS入门教程
目录 1. 什么是TS 2. TS基本使用 3 TS基础语法 3.1 基础类型约束 3.11 string,number,boolean, null和undefined 3.12 any 3.13 unknown 3.14 void 3.15 数组 3.16 对象 3.2 函数的约束 3.21 普通写法 3.22 函数表达式 3.22 可选…...

java: 未结束的字符文字 报错及解决:将编码全部改为UTF-8或者GBK
报错: 解决: 将编码都改成UTF-8或者GBK:...

Android平台RTSP转RTMP推送之采集麦克风音频转发
技术背景 RTSP转RTMP推送,好多开发者第一想到的是采用ffmpeg命令行的形式,如果对ffmpeg比较熟,而且产品不要额外的定制和更高阶的要求,未尝不可,如果对产品稳定性、时延、断网重连等有更高的技术诉求,比较…...

认证鉴权框架之—sa-token
一、概述 Satoken 是一个 Java 实现的权限认证框架,它主要用于 Web 应用程序的权限控制。Satoken 提供了丰富的功能来简化权限管理的过程,使得开发者可以更加专注于业务逻辑的开发。 二、逻辑流程 1、登录认证 (1)、创建token …...
Spring源码(十一):Spring MVC之DispatchServlet
本篇重点在于分析Spring MVC与Servlet标准的整合,下节将详细讨论Spring MVC的启动/加载流程、处理请求的具体流程。 一、介绍 Spring框架提供了构建Web应用程序的全功能MVC模块。通过策略接口 ,Spring框架是高度可配置的,而且支持多种视图技…...
gitbash简单操作
https://blog.csdn.net/qq_42363495/article/details/104878170 工作区(空间)--暂存区--本地仓库--远程仓库 方法一:创建一个新的分支master,且远程库里没有该分支 只要将.gitignore文件放在文件夹下就可以,.gitignore是文本文档形式的文件…...
pnpm install安装element-plus的版本跟package.json指定的版本不一样
pnpm安装的版本不同于package.json中指定的版本可能是由于以下几种情况导致的: 依赖项冲突:当项目依赖的不同模块或库之间存在版本冲突时,pnpm可能会安装与package.json中指定的版本不同的版本。这可能是因为其他依赖项指定了不同的版本&…...

Java线程池的核心内容详解
文章内容已经收录在《面试进阶之路》,从原理出发,直击面试难点,实现更高维度的降维打击! 目录 文章目录 目录Java线程池的核心内容详解线程池的优势什么场景下要用到线程池呢?线程池中重要的参数【掌握】新加入一个任…...

通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
线程与协程
1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指:像函数调用/返回一样轻量地完成任务切换。 举例说明: 当你在程序中写一个函数调用: funcA() 然后 funcA 执行完后返回&…...
django filter 统计数量 按属性去重
在Django中,如果你想要根据某个属性对查询集进行去重并统计数量,你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求: 方法1:使用annotate()和Count 假设你有一个模型Item,并且你想…...

有限自动机到正规文法转换器v1.0
1 项目简介 这是一个功能强大的有限自动机(Finite Automaton, FA)到正规文法(Regular Grammar)转换器,它配备了一个直观且完整的图形用户界面,使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...
Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?
Redis 的发布订阅(Pub/Sub)模式与专业的 MQ(Message Queue)如 Kafka、RabbitMQ 进行比较,核心的权衡点在于:简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...

HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...

sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!
简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求,并检查收到的响应。它以以下模式之一…...
管理学院权限管理系统开发总结
文章目录 🎓 管理学院权限管理系统开发总结 - 现代化Web应用实践之路📝 项目概述🏗️ 技术架构设计后端技术栈前端技术栈 💡 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 🗄️ 数据库设…...
掌握 HTTP 请求:理解 cURL GET 语法
cURL 是一个强大的命令行工具,用于发送 HTTP 请求和与 Web 服务器交互。在 Web 开发和测试中,cURL 经常用于发送 GET 请求来获取服务器资源。本文将详细介绍 cURL GET 请求的语法和使用方法。 一、cURL 基本概念 cURL 是 "Client URL" 的缩写…...

华为OD机试-最短木板长度-二分法(A卷,100分)
此题是一个最大化最小值的典型例题, 因为搜索范围是有界的,上界最大木板长度补充的全部木料长度,下界最小木板长度; 即left0,right10^6; 我们可以设置一个候选值x(mid),将木板的长度全部都补充到x,如果成功…...