【深度学习】用LSTM写诗,生成式的方式写诗系列之一
Epoch 4: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=18.5, loss=5.8]
[5] loss: 5.828, accuracy: 18.389 , lr:0.001000
Epoch 5: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=19.2, loss=5.68]
[6] loss: 5.739, accuracy: 18.732 , lr:0.001000
Epoch 6: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=19.6, loss=5.57]
[7] loss: 5.629, accuracy: 19.197 , lr:0.001000
Epoch 7: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=19.9, loss=5.45]
[8] loss: 5.517, accuracy: 19.745 , lr:0.001000
Epoch 8: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=20.4, loss=5.38]
[9] loss: 5.402, accuracy: 20.316 , lr:0.001000
Epoch 9: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=21, loss=5.27]
[10] loss: 5.299, accuracy: 20.887 , lr:0.001000
Epoch 10: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=21.8, loss=5.16]
[11] loss: 5.210, accuracy: 21.427 , lr:0.001000
Epoch 11: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=22.2, loss=5.1]
[12] loss: 5.136, accuracy: 21.923 , lr:0.001000
Epoch 12: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=22.5, loss=5.06]
[13] loss: 5.071, accuracy: 22.379 , lr:0.001000
Epoch 13: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=22.9, loss=5.02]
[14] loss: 5.011, accuracy: 22.819 , lr:0.001000
Epoch 14: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=23.3, loss=4.94]
[15] loss: 4.959, accuracy: 23.212 , lr:0.001000
Epoch 15: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=23.7, loss=4.91]
[16] loss: 4.910, accuracy: 23.564 , lr:0.001000
Epoch 16: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=24.3, loss=4.82]
[17] loss: 4.862, accuracy: 23.914 , lr:0.001000
Epoch 17: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=24, loss=4.83]
[18] loss: 4.818, accuracy: 24.228 , lr:0.001000
Epoch 18: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=24.7, loss=4.77]
[19] loss: 4.775, accuracy: 24.523 , lr:0.001000
Epoch 19: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=24.6, loss=4.73]
[20] loss: 4.734, accuracy: 24.808 , lr:0.001000
Epoch 20: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=25, loss=4.69]
[21] loss: 4.694, accuracy: 25.090 , lr:0.001000
Epoch 21: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=25, loss=4.71]
[22] loss: 4.657, accuracy: 25.346 , lr:0.001000
Epoch 22: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=25.8, loss=4.62]
[23] loss: 4.619, accuracy: 25.587 , lr:0.001000
Epoch 23: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=25.9, loss=4.59]
[24] loss: 4.584, accuracy: 25.825 , lr:0.001000
Epoch 24: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.3, loss=4.52]
[25] loss: 4.549, accuracy: 26.078 , lr:0.001000
Epoch 25: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=26.3, loss=4.53]
[26] loss: 4.516, accuracy: 26.280 , lr:0.001000
Epoch 26: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.6, loss=4.49]
[27] loss: 4.483, accuracy: 26.517 , lr:0.001000
Epoch 27: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=26.8, loss=4.46]
[28] loss: 4.451, accuracy: 26.746 , lr:0.001000
Epoch 28: 100%|███████████████████████████████████████████████████████| 63/63 [1:00:07<00:00, 57.26s/batch, acc=27.1, loss=4.41]
[29] loss: 4.422, accuracy: 26.937 , lr:0.001000
Epoch 29: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=27.2, loss=4.38]
[30] loss: 4.389, accuracy: 27.182 , lr:0.001000
Epoch 30: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=27, loss=4.4]
[31] loss: 4.361, accuracy: 27.371 , lr:0.001000
Epoch 31: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=27.5, loss=4.34]
[32] loss: 4.332, accuracy: 27.589 , lr:0.001000
Epoch 32: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=27.6, loss=4.31]
[33] loss: 4.304, accuracy: 27.791 , lr:0.001000
Epoch 33: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.89batch/s, acc=27.9, loss=4.28]
[34] loss: 4.277, accuracy: 28.014 , lr:0.001000
Epoch 34: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=28, loss=4.26]
[35] loss: 4.248, accuracy: 28.200 , lr:0.001000
Epoch 35: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=28.6, loss=4.22]
[36] loss: 4.222, accuracy: 28.433 , lr:0.001000
Epoch 36: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=28.3, loss=4.21]
[37] loss: 4.196, accuracy: 28.625 , lr:0.001000
Epoch 37: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=29.1, loss=4.16]
[38] loss: 4.169, accuracy: 28.858 , lr:0.001000
Epoch 38: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29.2, loss=4.13]
[39] loss: 4.142, accuracy: 29.056 , lr:0.001000
Epoch 39: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29, loss=4.13]
[40] loss: 4.116, accuracy: 29.282 , lr:0.001000
Epoch 40: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=29.5, loss=4.12]
[41] loss: 4.092, accuracy: 29.477 , lr:0.001000
Epoch 41: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=29.7, loss=4.08]
[42] loss: 4.066, accuracy: 29.716 , lr:0.001000
Epoch 42: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=29.8, loss=4.06]
[43] loss: 4.042, accuracy: 29.918 , lr:0.001000
Epoch 43: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.5, loss=3.99]
[44] loss: 4.016, accuracy: 30.146 , lr:0.001000
Epoch 44: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.2, loss=4.01]
[45] loss: 3.990, accuracy: 30.398 , lr:0.001000
Epoch 45: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=30.6, loss=3.96]
[46] loss: 3.968, accuracy: 30.607 , lr:0.001000
Epoch 46: 100%|█████████████████████████████████████████████████████████| 63/63 [40:05<00:00, 38.19s/batch, acc=30.6, loss=3.96]
[47] loss: 3.945, accuracy: 30.814 , lr:0.001000
Epoch 47: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=30.9, loss=3.94]
[48] loss: 3.918, accuracy: 31.073 , lr:0.001000
Epoch 48: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=31.1, loss=3.91]
[49] loss: 3.893, accuracy: 31.322 , lr:0.001000
Epoch 49: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=32, loss=3.86]
[50] loss: 3.869, accuracy: 31.574 , lr:0.001000
Epoch 50: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=31.2, loss=3.9]
[51] loss: 3.846, accuracy: 31.811 , lr:0.001000
Epoch 51: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=31.7, loss=3.85]
[52] loss: 3.823, accuracy: 32.042 , lr:0.001000
Epoch 52: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=32.4, loss=3.8]
[53] loss: 3.798, accuracy: 32.325 , lr:0.001000
Epoch 53: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.2, loss=3.8]
[54] loss: 3.776, accuracy: 32.552 , lr:0.001000
Epoch 54: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.88batch/s, acc=32.4, loss=3.79]
[55] loss: 3.755, accuracy: 32.794 , lr:0.001000
Epoch 55: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.8, loss=3.75]
[56] loss: 3.729, accuracy: 33.081 , lr:0.001000
Epoch 56: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=32.8, loss=3.74]
[57] loss: 3.708, accuracy: 33.301 , lr:0.001000
Epoch 57: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=33.8, loss=3.68]
[58] loss: 3.683, accuracy: 33.597 , lr:0.001000
Epoch 58: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=33.5, loss=3.67]
[59] loss: 3.661, accuracy: 33.838 , lr:0.001000
Epoch 59: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=34, loss=3.65]
[60] loss: 3.639, accuracy: 34.106 , lr:0.001000
Epoch 60: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=34, loss=3.65]
[61] loss: 3.619, accuracy: 34.350 , lr:0.001000
Epoch 61: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=34.1, loss=3.64]
[62] loss: 3.595, accuracy: 34.632 , lr:0.001000
Epoch 62: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=34.6, loss=3.57]
[63] loss: 3.573, accuracy: 34.872 , lr:0.001000
Epoch 63: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=34.8, loss=3.58]
[64] loss: 3.553, accuracy: 35.140 , lr:0.001000
Epoch 64: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=35.1, loss=3.53]
[65] loss: 3.531, accuracy: 35.394 , lr:0.001000
Epoch 65: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=34.8, loss=3.56]
[66] loss: 3.512, accuracy: 35.636 , lr:0.001000
Epoch 66: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=35.1, loss=3.55]
[67] loss: 3.490, accuracy: 35.896 , lr:0.001000
Epoch 67: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=36.1, loss=3.49]
[68] loss: 3.471, accuracy: 36.147 , lr:0.001000
Epoch 68: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36, loss=3.48]
[69] loss: 3.451, accuracy: 36.413 , lr:0.001000
Epoch 69: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=36.5, loss=3.44]
[70] loss: 3.436, accuracy: 36.595 , lr:0.001000
Epoch 70: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36.5, loss=3.45]
[71] loss: 3.412, accuracy: 36.873 , lr:0.001000
Epoch 71: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=36.2, loss=3.44]
[72] loss: 3.393, accuracy: 37.130 , lr:0.001000
Epoch 72: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=36.4, loss=3.44]
[73] loss: 3.375, accuracy: 37.342 , lr:0.001000
Epoch 73: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=37.1, loss=3.4]
[74] loss: 3.355, accuracy: 37.608 , lr:0.001000
Epoch 74: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=37.2, loss=3.37]
[75] loss: 3.337, accuracy: 37.853 , lr:0.001000
Epoch 75: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=37.9, loss=3.35]
[76] loss: 3.318, accuracy: 38.105 , lr:0.001000
Epoch 76: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=37.5, loss=3.35]
[77] loss: 3.303, accuracy: 38.282 , lr:0.001000
Epoch 77: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=37.9, loss=3.31]
[78] loss: 3.285, accuracy: 38.523 , lr:0.001000
Epoch 78: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.1, loss=3.3]
[79] loss: 3.267, accuracy: 38.738 , lr:0.001000
Epoch 79: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=38.9, loss=3.28]
[80] loss: 3.250, accuracy: 38.972 , lr:0.001000
Epoch 80: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.6, loss=3.27]
[81] loss: 3.230, accuracy: 39.248 , lr:0.001000
Epoch 81: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=39.1, loss=3.22]
[82] loss: 3.216, accuracy: 39.435 , lr:0.001000
Epoch 82: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=38.8, loss=3.25]
[83] loss: 3.197, accuracy: 39.675 , lr:0.001000
Epoch 83: 100%|██████████████████████████████████████████████████████████| 63/63 [00:38<00:00, 1.62batch/s, acc=39.7, loss=3.2]
[84] loss: 3.180, accuracy: 39.914 , lr:0.001000
Epoch 84: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=39.4, loss=3.2]
[85] loss: 3.165, accuracy: 40.108 , lr:0.001000
Epoch 85: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.87batch/s, acc=40.1, loss=3.17]
[86] loss: 3.152, accuracy: 40.277 , lr:0.001000
Epoch 86: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=39.9, loss=3.18]
[87] loss: 3.135, accuracy: 40.508 , lr:0.001000
Epoch 87: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=40.4, loss=3.14]
[88] loss: 3.118, accuracy: 40.736 , lr:0.001000
Epoch 88: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.84batch/s, acc=40.5, loss=3.14]
[89] loss: 3.104, accuracy: 40.918 , lr:0.001000
Epoch 89: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=40.7, loss=3.11]
[90] loss: 3.093, accuracy: 41.061 , lr:0.001000
Epoch 90: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=40.8, loss=3.1]
[91] loss: 3.074, accuracy: 41.315 , lr:0.001000
Epoch 91: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.5, loss=3.06]
[92] loss: 3.057, accuracy: 41.559 , lr:0.001000
Epoch 92: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.1, loss=3.09]
[93] loss: 3.043, accuracy: 41.745 , lr:0.001000
Epoch 93: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=41.5, loss=3.06]
[94] loss: 3.029, accuracy: 41.924 , lr:0.001000
Epoch 94: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=42, loss=3.03]
[95] loss: 3.015, accuracy: 42.133 , lr:0.001000
Epoch 95: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc=41.6, loss=3.04]
[96] loss: 3.001, accuracy: 42.302 , lr:0.001000
Epoch 96: 100%|██████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.82batch/s, acc=42, loss=3]
[97] loss: 2.988, accuracy: 42.483 , lr:0.001000
Epoch 97: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=42.9, loss=2.96]
[98] loss: 2.972, accuracy: 42.694 , lr:0.001000
Epoch 98: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.86batch/s, acc=42.1, loss=3.01]
[99] loss: 2.964, accuracy: 42.804 , lr:0.001000
Epoch 99: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=42.2, loss=3.01]
[100] loss: 2.953, accuracy: 42.973 , lr:0.001000
Finished Training using %.3f seconds 6896.93013882637
先训练了100轮次,后面应该还能增长,但是不等了
数据初探:
class DictObj(object):def __init__(self, map):self.map = mapdef __getattr__(self, attr):if attr in self.map:return self.map[attr]else:raise AttributeError("No such attribute: " + attr)Config = DictObj({'poem_path':os.path.join(base_dir, "tang.npz"),"tensorboard_path":os.path.join(base_dir, "tensorboard"),"model_save_path":os.path.join(base_dir,"modelDict"),"embedding_dim":100,"hidden_dim":1024,"lr":0.001,"LSTM_layers":2,'batch_size':512,'epochs':500,'dropout':0.2,'ealier_stop':10,'device':torch.device('cuda' if torch.cuda.is_available() else 'cpu')
})
def view_data(poem_path):datas = np.load(poem_path, allow_pickle=True)data = datas['data'] #(57580,125)ix2word = datas['ix2word'].item() # datas['word2ix'].item() 8293word2ix = datas['word2ix'].item() # datas['word2ix'].item() 8293word_data = np.zeros((1,data.shape[1]), dtype = str) # 将所有的0 转化成 ''# 看一下其中一行的数据是什么?row = np.random.randint(0, data.shape[0]) # 随机选一行,左闭右开没问题print(data[row])for i in range(data.shape[1]):word_data[0][i] = ix2word[data[row][i]]print(word_data)view_data(Config.poem_path)
数据处理:
class PoemDataset(Dataset):def __init__(self, poem_path, seq_len):super().__init__()# np 文件的地址self.poem_path = poem_path# 序列长度,48 是认为规定的,也可以是其它值,因为大部分是5言或者7言,加上表达就是 6,或8, 取48确保是整句话self.seq_len = seq_lenself.poem_data, self.ix2word, self.word2ix = self.get_raw_data()self.no_space_data = self.filter_space()print("no_space_data len:", self.no_space_data[0:200])def __len__(self):return len(self.no_space_data)//(self.seq_len)def __getitem__(self, idx):txt = self.no_space_data[idx*self.seq_len:(idx+1)*self.seq_len]label = self.no_space_data[idx*self.seq_len+1:(idx+1)*self.seq_len+1]return torch.LongTensor(txt), torch.LongTensor(label)def filter_space(self):# 7197500 个文本tensor_data = torch.from_numpy(self.poem_data).view(-1)no_space_data = [] for i in range(tensor_data.shape[0]):word_idx = tensor_data[i].item()if word_idx!= 8292:no_space_data.append(word_idx)return no_space_datadef get_raw_data(self):datas = np.load(self.poem_path, allow_pickle=True)data = datas['data']ix2word = datas['ix2word'].item()word2ix = datas['word2ix'].item()return data, ix2word, word2ix
poem_dataset = PoemDataset(Config.poem_path, 96)
[8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 82928292 8292 8292 8292 8292 8292 8292 8291 5428 6933 3469 7066 3465 64078248 7009 82 7435 925 3469 3576 232 786 5272 2296 7066 4807 61036663 2958 2003 2173 28 7066 1987 8061 4299 848 4874 7435 8290]
[['<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<''<' '<' '冬' '月' '内' ',' '无' '叶' '艾' '枝' '枯' '。' '草' '内' '急' '寻' '蛇' '床''子' ',' '烧' '烟' '入' '中' '自' '消' '除' ',' '速' '救' '免' '灾' '虞' '。' '<']]
no_space_data len: [8291, 6731, 4770, 1787, 8118, 7577, 7066, 4817, 648, 7121, 1542, 6483, 7435, 7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629, 1379, 2703, 7435, 6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534, 70, 3788, 3823, 7435, 4907, 5567, 201, 2834, 1519, 7066, 782, 782, 2063, 2031, 846, 7435, 8290, 8291, 2309, 2596, 6483, 2260, 7316, 7066, 6332, 5274, 2125, 5029, 7792, 7435, 4186, 8087, 7047, 6622, 6933, 7066, 6134, 3564, 3766, 6920, 6157, 7435, 7086, 4770, 5849, 4776, 4981, 7066, 4857, 2649, 3020, 332, 1727, 7435, 7458, 7294, 3465, 5149, 1671, 7066, 2834, 6000, 3942, 3534, 1534, 7435, 4102, 7460, 758, 3961, 3374, 7066, 7904, 6811, 4449, 2121, 6802, 7435, 6182, 27, 7912, 1756, 7440, 7066, 201, 7909, 8118, 201, 4662, 7435, 7824, 1508, 3154, 152, 5862, 7066, 7976, 6043, 258, 47, 7878, 7435, 8290, 8291, 3495, 70, 7113, 4839, 5237, 7066, 65, 3941, 2031, 2260, 5418, 7435, 411, 6773, 2878, 4686, 482, 7066, 1989, 5617, 4992, 8245, 676, 7435, 4236, 1418, 4915, 7686, 7363, 7066, 5708, 7541, 7440, 5237, 2192, 7435, 3114, 5913, 7989, 3069, 1845, 7066, 7047, 3534, 4921, 6622, 6933, 7435, 1664, 2260, 2003, 4816, 7151, 7066, 5036, 2219, 5849, 4898, 174, 7435, 201, 7228, 222]
因为有空格,啥的,要先吧空格之类的去掉。
def show_dataset():idx,label = poem_dataset[0]for id in idx:print(poem_dataset.ix2word[id.item()], end=' ')print("\n")for la in label:print(poem_dataset.ix2word[la.item()], end=' ')'''<START> 度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 云'''
# 构建模型
class PoemModel(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)self.dropout = nn.Dropout(dropout)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, input, hidden=None):embeds = self.embedding(input)batch_size,seq_len,embedding_dim = embeds.shapeif hidden is None:h0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)c0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)else:h0,c0 = hiddenoutput, hidden = self.lstm(embeds, (h0, c0))# output = torch.tanh(self.dropout(self.fc1(output)))output = self.fc(output)return output, hiddenvocab_size = len(poem_dataset.word2ix)
model = PoemModel(vocab_size, Config.embedding_dim, Config.hidden_dim, Config.LSTM_layers, Config.dropout).to(Config.device)
input_data, label_data = next(iter(dataloader))
print(input_data.shape, label_data.shape)
output, hidden = model(input_data.to(Config.device))
# output.shape torch.Size([1024, 96, 8293]) hidden[0].shape torch.Size([3, 1024, 1024]) hidden[1].shape torch.Size([3, 1024, 1024]) label_data.shape torch.Size([1024, 96])
a = 1def accuracy(output, label_data):pred = output.argmax(dim=2)correct = (pred == label_data).sum().item()total = label_data.numel()return correct / total * 100
# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=Config.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=1)
def train(model, dataloader, criterion, optimizer, scheduler, epochs):if not os.path.exists(Config.model_save_path):os.makedirs(Config.model_save_path)best_acc = 0.0early_stop = 0start_time = time.time()for epoch in range(epochs):model.train()running_loss = 0.0running_acc = 0.0last_acc = 0.0with tqdm(dataloader, unit="batch") as tepoch:for input_data, label_data in tepoch:tepoch.set_description(f"Epoch {epoch}")input_data, label_data = input_data.to(Config.device), label_data.to(Config.device)optimizer.zero_grad()output, hidden = model(input_data)current_acc = accuracy(output, label_data)running_acc += current_accloss = criterion(output.view(-1, vocab_size), label_data.view(-1))loss.backward()optimizer.step()running_loss += loss.item()tepoch.set_postfix(loss=loss.item(), acc=current_acc)scheduler.step()last_acc = running_acc / len(dataloader)if last_acc > best_acc:best_acc = last_acctorch.save(model.state_dict(), os.path.join(Config.model_save_path, "best_model.pth"))else:early_stop += 1torch.save(model.state_dict(), os.path.join(Config.model_save_path, "last_model.pth"))print('[%d] loss: %.3f, accuracy: %.3f , lr:%.6f' % (epoch + 1, running_loss / len(dataloader), last_acc,scheduler.get_last_lr()[0]))if early_stop >= Config.ealier_stop:print("Early Stop")print("Best Accuracy: %.3f" % best_acc)breakprint('Finished Training using %.3f seconds', time.time() - start_time)
train(model, dataloader, criterion, optimizer, scheduler, Config.epochs)
模型构建以及训练如上,
现在看 500轮次,50个忍耐度的效果比较好
Epoch 145: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.81batch/s, acc=61.6, loss=1.62]
[146] loss: 1.580, accuracy: 62.374 , lr:0.001000
Epoch 146: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.80batch/s, acc=62.5, loss=1.57]
[147] loss: 1.581, accuracy: 62.360 , lr:0.001000
Epoch 147: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.83batch/s, acc=61.9, loss=1.61]
[148] loss: 1.582, accuracy: 62.324 , lr:0.001000
Early Stop
Best Accuracy: 62.397
Finished Training using %.3f seconds 1205.000694513321
使用效果
不满意的地方,写死96=seq_len 是不对的。
应该是 配合 padding 使用,并mask padding来指导损失 @todo, 下一篇文章我会搞定!
import torch
from train03 import Config
from train03 import PoemModelfrom train03 import PoemDataset
import ospoem_dataset = PoemDataset(Config.poem_path, 96)vocab_size = len(poem_dataset.word2ix)
model = PoemModel(vocab_size, Config.embedding_dim, Config.hidden_dim, Config.LSTM_layers, Config.dropout).to(Config.device)
model.load_state_dict(torch.load(os.path.join(Config.model_save_path, "best_model.pth")))def generate(model, start_words, ix2word, word2ix, device):results = list(start_words)start_words_len = len(start_words)# 第一个词语是<START>input = torch.Tensor([word2ix['<START>']]).view(1, 1).long()# 最开始的隐状态初始为0矩阵# torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim)hidden = torch.zeros((2,Config.LSTM_layers * 1, 1, Config.hidden_dim), dtype=torch.float32).to(Config.device)input = input.to(Config.device)hidden = hidden.to(Config.device)model.eval()with torch.no_grad():for i in range(48):output, hidden = model(input, hidden)# 如果在给定的句首中,input为句首中的下一个字if i < start_words_len:w = results[i]input = input.data.new([word2ix[w]]).view(1, 1)else:top_index = output.data[0].topk(1)[1][0].item()w = ix2word[top_index]results.append(w)input = input.data.new([top_index]).view(1, 1)if w == '<EOP>':del results[-1]breakreturn results
雨 余 虚 馆 竹 阴 清 , 独 坐 寒 窗 昼 未 醒 。 云 布 远 村 红 叶 返 , 水 深 秋 竹 翠 梢 寒 。 泉 声 入 阁 慙 嘉 石 , 山 色 题 诗 好 赋 诗 。
但是我有一点不太理解。
他是输入一个字,输出一个字,这一点好像不妥。不应该是 输入一个 生成1个, 然后输入两个,生成1个,然后输入3个生成1个么。。。 大神请指教一下吧。
相关文章:

【深度学习】用LSTM写诗,生成式的方式写诗系列之一
Epoch 4: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00, 8.85batch/s, acc18.5, loss5.8] [5] loss: 5.828, accuracy: 18.389 , lr:0.001000 Epoch 5: 100%|███…...

HomeAssistant自定义组件学习-【二】
#要说的话# 前面把中盛科技的控制器组件写完了。稍稍熟悉了一些HA,现在准备写窗帘控制组件,构想的东西会比较多,估计有些难度,过程会比较长,边写边记录吧! #设备和场景环境# 使用的是Novo的电机…...

如何看待AI技术的应用前景?
文章目录 如何看待AI技术的应用前景引言AI技术的现状1. AI的定义与分类2. 当前AI技术的应用领域 AI技术的应用前景1. 经济效益2. 社会影响3. 技术进步 AI技术应用面临的挑战1. 数据隐私与安全2. 可解释性与信任3. 技能短缺与就业影响 AI技术的未来发展方向1. 人工智能的伦理与法…...

Unity中的屏幕坐标系
获得视口宽高 拖动视口会改变屏幕宽高数值 MousePosition 屏幕坐标系的原点在左下角,MousePosition返回Z为0也就是纵深为0的Vector3 但是如果鼠标超出屏幕范围不会做限制,所以可能出现负数或者大于屏幕宽高的情况,做鼠标拖拽物体时需要注…...

标题点击可跳转网页
要实现点击标题跳转到网页的功能,你可以在Vue组件中使用<a>标签(锚点标签)并设置href属性为网页的URL。如果你希望使用uni-app的特性来控制页面跳转,可以使用uni.navigateTo方法(这适用于uni-app环境,…...

易语言模拟真人动态生成鼠标滑动路径
一.简介 鼠标轨迹算法是一种模拟人类鼠标操作的程序,它能够模拟出自然而真实的鼠标移动路径。 鼠标轨迹算法的底层实现采用C/C语言,原因在于C/C提供了高性能的执行能力和直接访问操作系统底层资源的能力。 鼠标轨迹算法具有以下优势: 模拟…...

Linux:生态与软件安装
文章目录 前言一、Linux下安装软件的方案二、包管理器是什么?三、生态问题相关的理解1. 什么操作系统是好的操作系统?2. 什么是生态?3. 软件包是谁写的?这些工程师为什么要写?钱的问题怎么解决? 四、我的服务器怎么知…...

R 语言与其他编程语言的区别
R 语言与其他编程语言的区别 R 语言作为一种专门用于统计计算和图形的编程语言,与其他编程语言相比有一些独特的特点和区别。本文将详细介绍这些区别,帮助你更好地理解 R 语言的优势和适用场景。 1. 专为统计和数据分析设计 统计功能 内置统计函数&…...

RC低通滤波器Bode图分析(传递函数零极点)
RC低通滤波器 我们使得R1K,C1uF;电容C的阻抗为Xc; 传递函数 H ( s ) u o u i X C X C R 1 s C 1 s C R 1 1 s R C (其中 s j ω ) H(s)\frac{u_{o} }{u_{i} } \frac{X_{C} }{X_{C}R} \frac{\frac{1}{sC} }{\…...

基于深度学习的网络入侵检测
基于深度学习的网络入侵检测是一种利用深度学习技术对网络流量进行实时监测与分析的方法,旨在识别并防范网络攻击和恶意活动。随着网络环境日益复杂,传统的入侵检测系统(IDS)在面对不断变化的攻击模式时,往往难以保持高…...

《构建一个具备从后端数据库获取数据并再前端显示的内容页面:前后端实现解析》
一、前端页面:布局与功能 1. 页面结构 我们先来看前端页面的 HTML 结构,它主要由以下几个部分组成: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewp…...

Rust 力扣 - 59. 螺旋矩阵 II
文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 使用一个全局变量current记录当前遍历到的元素的值 我们只需要一圈一圈的从外向内遍历矩阵,每一圈遍历顺序为上边、右边、下边、左边,每遍历完一个元素后current 我们需要注意的是如果上…...

2.4w字 —TS入门教程
目录 1. 什么是TS 2. TS基本使用 3 TS基础语法 3.1 基础类型约束 3.11 string,number,boolean, null和undefined 3.12 any 3.13 unknown 3.14 void 3.15 数组 3.16 对象 3.2 函数的约束 3.21 普通写法 3.22 函数表达式 3.22 可选…...

java: 未结束的字符文字 报错及解决:将编码全部改为UTF-8或者GBK
报错: 解决: 将编码都改成UTF-8或者GBK:...

Android平台RTSP转RTMP推送之采集麦克风音频转发
技术背景 RTSP转RTMP推送,好多开发者第一想到的是采用ffmpeg命令行的形式,如果对ffmpeg比较熟,而且产品不要额外的定制和更高阶的要求,未尝不可,如果对产品稳定性、时延、断网重连等有更高的技术诉求,比较…...

认证鉴权框架之—sa-token
一、概述 Satoken 是一个 Java 实现的权限认证框架,它主要用于 Web 应用程序的权限控制。Satoken 提供了丰富的功能来简化权限管理的过程,使得开发者可以更加专注于业务逻辑的开发。 二、逻辑流程 1、登录认证 (1)、创建token …...

Spring源码(十一):Spring MVC之DispatchServlet
本篇重点在于分析Spring MVC与Servlet标准的整合,下节将详细讨论Spring MVC的启动/加载流程、处理请求的具体流程。 一、介绍 Spring框架提供了构建Web应用程序的全功能MVC模块。通过策略接口 ,Spring框架是高度可配置的,而且支持多种视图技…...

gitbash简单操作
https://blog.csdn.net/qq_42363495/article/details/104878170 工作区(空间)--暂存区--本地仓库--远程仓库 方法一:创建一个新的分支master,且远程库里没有该分支 只要将.gitignore文件放在文件夹下就可以,.gitignore是文本文档形式的文件…...

pnpm install安装element-plus的版本跟package.json指定的版本不一样
pnpm安装的版本不同于package.json中指定的版本可能是由于以下几种情况导致的: 依赖项冲突:当项目依赖的不同模块或库之间存在版本冲突时,pnpm可能会安装与package.json中指定的版本不同的版本。这可能是因为其他依赖项指定了不同的版本&…...

Java线程池的核心内容详解
文章内容已经收录在《面试进阶之路》,从原理出发,直击面试难点,实现更高维度的降维打击! 目录 文章目录 目录Java线程池的核心内容详解线程池的优势什么场景下要用到线程池呢?线程池中重要的参数【掌握】新加入一个任…...

学习笔记——三小时玩转JQuery
也可以使用在线版,不过在线版需要有网络,网不好的情况下加载也不好 取值的时候也是只会取到有样式的纯文本,不会取到标签,会取到标签效果 prepend和append这两个方法用的比较多,before和affter用的比较少 想要把代码写…...

word试题转excel(最简单的办法,无格式要求)
分享早下班的终极秘诀~ 今天本来是个愉快的周五,心里想着周末的聚会和各种安排,然而突然一个加急任务砸了过来——要求在下周一提交一份精细整理的Excel表格! 打开Word文件一看,成堆的试题内容需要整理到Excel里。看着满屏的题目…...

基于web的中小学成绩管理系统的设计与实现
目录 第一章 研究背景与意义 1.1 研究背景 1.2 研究意义 1.3 研究目的 第二章 关于系统的设计 2.1系统总体架构设计 2.2功能模块设计 2.3数据存储与管理 第三章 系统功能介绍 3.1成绩录入及发布 3.2班级管理和学生管理 3.3成绩分析结果展示 3.4用户反馈与改进 …...

Conmi的正确答案——在Kibana中进入Elasticsearch的索引管理页面
Elasticsearch版本:7.17.25 Kibana版本:7.17.25 注:索引即类似mysql的表。 0、进入首页 1、未创建任何“索引模式”时: 1.1、点击左边的三横菜单; 1.2、点击“Discover”,进入“发现”页面; 2…...

【JavaEE】【多线程】进阶知识
目录 一、常见的锁策略1.1 悲观锁 vs 乐观锁1.2 重量级锁 vs 轻量级锁1.3 挂起等待锁 vs 自旋锁1.4 普通互斥锁 vs 读写锁1.5 可重入锁 vs 不可重入锁1.6 不公平锁 vs 公平锁 二、synchronized特性2.1 synchronized的锁策略2.2 synchronized加锁过程2.3 其它优化措施 三、CAS3.…...

LeetCode100之三数之和(15)--Java
1.问题描述 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三元组。 注意 答案中不可以包含重复的三元组 示例1 输入&…...

并发编程三大特性--可见性和有序性
可见性: 什么是可见性: 可见性是指在数据在收到一个线程的修改时,其他的线程也可以得知并获取修改后的值的属性。这是并发编程的三大特性之一。 为了提高cpu的利用率,cpu在获取数据时,不是直接在主内存读取数据&…...

Android 使用ninja加速编译的方法
ninja的简介 随着Android版本的更迭,makefile体系逐渐增多,导致make单编模块的时间越来越长,每次都需要半个小时甚至更长时间,其原因为每次make都会重新加载所有mk文件,再生成ninja编译,此完整过程十分耗时…...

《Java 实现选择排序:原理剖析与代码详解》
目录 一、引言 二、选择排序原理 三、代码分析 1. 代码整体结构 2. main方法 3. sort方法(选择排序核心逻辑) 四、测试结果 一、引言 排序算法在计算机科学领域中是非常重要的一部分,它能够帮助我们将无序的数据按照特定的顺序进行排列…...

数据结构之双链表——考研笔记
文章目录 一.单链表VS双链表二.创建双链表(带头结点)三.双链表的插入四.双链表删除五.销毁双链表六.双链表遍历七. 循环链表八.静态链表1.用代码定义一个静态链表 一.单链表VS双链表 单链表中只包含指向它后继结点的指针,所以给定一个结点p找…...