SARIMA模型

使用Python实现SARIMA模型,根据已有季节性数据预测未来数据


日期

基本介绍

SARIMA(Seasonal Autoregressive Integrated Moving Average)模型是一种强大的时间序列分析工具,主要用于处理和预测具有季节性特征的时间序列数据。SARIMA 模型的灵活性体现在其参数设置上。用户可以根据数据的特性选择合适的自回归(AR)、差分(I)和移动平均(MA)项的数量,以及季节性成分的参数。这种灵活性使得 SARIMA 能够适应多种不同类型的时间序列数据。

SARIMA模型参数

组成部分:

Seasonal(S):季节性成分,表示数据中存在周期性波动。

Autoregressive (AR):自回归成分,表示当前值与其过去值之间的关系。

Integrated (I):差分成分,用于使非平稳时间序列变为平稳序列。

Moving Average (MA):移动平均成分,表示当前值与过去误差之间的关系。

SARIMA 模型通常用以下符号表示:SARIMA(p,d,q)(P,D,Q,s)

p:自回归项的数量

d:差分次数

q:移动平均项的数量

P:季节性自回归项的数量

D:季节性差分次数

Q:季节性移动平均项的数量

s:季节性周期的长度

SARIMA模型预测

在以下的代码中,我们使用了pmdarima包中的auto_arima函数进行自动寻找参数,它通过对给定时间序列数据进行分析,使用信息准则(如 AIC、BIC)来评估不同模型的拟合优度,并选择具有最低信息准则值的模型,确定模型的参数(p, d, q, P, D, Q, s),根据选出的参数预测未来的数据趋势,最后进行实际值与预测值的绘图。

以下代码所采用的示例文件是某地某种疾病的16年的月度发病数据,来预测未来4年的发病数据。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.stattools import adfuller
from statsmodels.tsa.seasonal import seasonal_decompose
from sklearn.model_selection import train_test_split
from pmdarima import auto_arima
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
 
# 读取CSV文件
df = pd.read_csv('practice.csv')
# 确保 'time' 列转换为日期类型
df['time'] = pd.to_datetime(df['time'], format='%Y/%m/%d')
df['time'] = df['time'].dt.to_period('M')

# 切分训练集和测试集
train_data = df['cases'].iloc[:144]
test_data = df['cases'].iloc[144:191]

# ADF 测试
adf_result = adfuller(train_data)
print('ADF Statistic:', adf_result[0])
print('p-value:', adf_result[1])
print('Critical Values:')
for key, value in adf_result[4].items():
    print(f'   {key}: {value}')

# 检查 p-value 是否小于显著性水平
if adf_result[1] < 0.05:
    print("数据是平稳的")
else:
    print("数据是非平稳的,但不依赖模型进行差分")

    
# 使用auto_arima自动寻找最佳参数
model = auto_arima(train_data, seasonal=True, m=12, stepwise=True,
                   suppress_warnings=True, error_action="ignore", trace=True,
                   start_p=1, start_q=1, start_P=1, start_Q=1)

# 使用最佳参数拟合模型
sarima_result = model.fit(train_data)

# 使用拟合的模型进行预测
forecast = sarima_result.predict(n_periods=len(test_data))
train_forecast = sarima_result.predict_in_sample() 
future_forecast = sarima_result.predict(n_periods=48)

# 计算R方值
train_r2 = r2_score(train_data, train_forecast)
test_r2 = r2_score(test_data, forecast)
train_MASE = np.sqrt(mean_squared_error(train_data, train_forecast))
test_MASE = np.sqrt(mean_squared_error(test_data, forecast))

# 输出R方值
print(f"训练集 R2 Score: {train_r2:.4f}")
print(f"测试集 R2 Score: {test_r2:.4f}")
print(f"训练集 MASE: {train_MASE:.4f}")
print(f"测试集 MASE: {test_MASE:.4f}")

# 创建未来日期的 DataFrame
last_period = df['time'].iloc[-1].to_timestamp()
future_dates = pd.date_range(start=last_period + pd.Timedelta(days=1), periods=48, freq='ME')
future_dates_df = pd.DataFrame(future_dates, columns=['time'])

ADF Statistic: 0.5540951020672727
p-value: 0.9864277674889156
Critical Values:
   1%: -3.4816817173418295
   5%: -2.8840418343195267
   10%: -2.578770059171598
数据是非平稳的,但不依赖模型进行差分
Performing stepwise search to minimize aic
 ARIMA(1,1,1)(1,0,1)[12] intercept   : AIC=inf, Time=1.74 sec
 ARIMA(0,1,0)(0,0,0)[12] intercept   : AIC=1398.881, Time=0.02 sec
 ARIMA(1,1,0)(1,0,0)[12] intercept   : AIC=1374.684, Time=0.29 sec
 ARIMA(0,1,1)(0,0,1)[12] intercept   : AIC=1386.400, Time=0.31 sec
 ARIMA(0,1,0)(0,0,0)[12]             : AIC=1397.141, Time=0.02 sec
 ARIMA(1,1,0)(0,0,0)[12] intercept   : AIC=1398.306, Time=0.09 sec
 ARIMA(1,1,0)(2,0,0)[12] intercept   : AIC=1367.389, Time=0.88 sec
 ARIMA(1,1,0)(2,0,1)[12] intercept   : AIC=inf, Time=3.19 sec
 ARIMA(1,1,0)(1,0,1)[12] intercept   : AIC=inf, Time=0.94 sec
 ARIMA(0,1,0)(2,0,0)[12] intercept   : AIC=1391.574, Time=0.61 sec
 ARIMA(2,1,0)(2,0,0)[12] intercept   : AIC=1362.015, Time=1.17 sec
 ARIMA(2,1,0)(1,0,0)[12] intercept   : AIC=1374.310, Time=0.48 sec
 ARIMA(2,1,0)(2,0,1)[12] intercept   : AIC=inf, Time=2.58 sec
 ARIMA(2,1,0)(1,0,1)[12] intercept   : AIC=inf, Time=1.62 sec
 ARIMA(3,1,0)(2,0,0)[12] intercept   : AIC=1363.447, Time=1.49 sec
 ARIMA(2,1,1)(2,0,0)[12] intercept   : AIC=inf, Time=2.64 sec
 ARIMA(1,1,1)(2,0,0)[12] intercept   : AIC=inf, Time=2.26 sec
 ARIMA(3,1,1)(2,0,0)[12] intercept   : AIC=inf, Time=2.83 sec
 ARIMA(2,1,0)(2,0,0)[12]             : AIC=1360.099, Time=0.44 sec
 ARIMA(2,1,0)(1,0,0)[12]             : AIC=1372.452, Time=0.16 sec
 ARIMA(2,1,0)(2,0,1)[12]             : AIC=inf, Time=2.12 sec
 ARIMA(2,1,0)(1,0,1)[12]             : AIC=inf, Time=1.00 sec
 ARIMA(1,1,0)(2,0,0)[12]             : AIC=1365.481, Time=0.32 sec
 ARIMA(3,1,0)(2,0,0)[12]             : AIC=1361.530, Time=0.59 sec
 ARIMA(2,1,1)(2,0,0)[12]             : AIC=1338.023, Time=1.20 sec
 ARIMA(2,1,1)(1,0,0)[12]             : AIC=1354.069, Time=0.36 sec
 ARIMA(2,1,1)(2,0,1)[12]             : AIC=inf, Time=2.55 sec
 ARIMA(2,1,1)(1,0,1)[12]             : AIC=inf, Time=1.05 sec
 ARIMA(1,1,1)(2,0,0)[12]             : AIC=1336.395, Time=0.82 sec
 ARIMA(1,1,1)(1,0,0)[12]             : AIC=1353.806, Time=0.31 sec
 ARIMA(1,1,1)(2,0,1)[12]             : AIC=inf, Time=2.31 sec
 ARIMA(1,1,1)(1,0,1)[12]             : AIC=inf, Time=0.78 sec
 ARIMA(0,1,1)(2,0,0)[12]             : AIC=1334.406, Time=0.48 sec
 ARIMA(0,1,1)(1,0,0)[12]             : AIC=1352.555, Time=0.15 sec
 ARIMA(0,1,1)(2,0,1)[12]             : AIC=inf, Time=1.45 sec
 ARIMA(0,1,1)(1,0,1)[12]             : AIC=inf, Time=0.49 sec
 ARIMA(0,1,0)(2,0,0)[12]             : AIC=1389.700, Time=0.20 sec
 ARIMA(0,1,2)(2,0,0)[12]             : AIC=1336.396, Time=0.63 sec
 ARIMA(1,1,2)(2,0,0)[12]             : AIC=1331.492, Time=1.61 sec
 ARIMA(1,1,2)(1,0,0)[12]             : AIC=1350.035, Time=0.44 sec
 ARIMA(1,1,2)(2,0,1)[12]             : AIC=1316.748, Time=2.65 sec
 ARIMA(1,1,2)(1,0,1)[12]             : AIC=inf, Time=0.83 sec
 ARIMA(1,1,2)(2,0,2)[12]             : AIC=inf, Time=4.20 sec
 ARIMA(1,1,2)(1,0,2)[12]             : AIC=1316.606, Time=3.24 sec
 ARIMA(1,1,2)(0,0,2)[12]             : AIC=1383.563, Time=1.10 sec
 ARIMA(1,1,2)(0,0,1)[12]             : AIC=1385.024, Time=0.33 sec
 ARIMA(0,1,2)(1,0,2)[12]             : AIC=inf, Time=2.46 sec
 ARIMA(1,1,1)(1,0,2)[12]             : AIC=inf, Time=2.46 sec
 ARIMA(2,1,2)(1,0,2)[12]             : AIC=1317.407, Time=6.18 sec
 ARIMA(1,1,3)(1,0,2)[12]             : AIC=1318.564, Time=2.14 sec
 ARIMA(0,1,1)(1,0,2)[12]             : AIC=inf, Time=1.72 sec
 ARIMA(0,1,3)(1,0,2)[12]             : AIC=inf, Time=2.91 sec
 ARIMA(2,1,1)(1,0,2)[12]             : AIC=inf, Time=3.07 sec
 ARIMA(2,1,3)(1,0,2)[12]             : AIC=inf, Time=2.81 sec
 ARIMA(1,1,2)(1,0,2)[12] intercept   : AIC=inf, Time=3.25 sec

Best model:  ARIMA(1,1,2)(1,0,2)[12]          
Total fit time: 82.083 seconds
训练集 R2 Score: 0.8863
测试集 R2 Score: 0.7486
训练集 MASE: 24.4406
测试集 MASE: 19.1417
# 绘制实际数据与拟合数据的对比图
plt.figure(figsize=(10, 6))

# 设置全局字体为新罗马
plt.rcParams['font.family'] = 'Times New Roman'

# 绘制原始数据
plt.plot(df['cases'].values, label='Original Data', color='grey', alpha=0.7)

# 绘制训练集预测数据
plt.plot(range(len(train_data)), train_forecast, label='Train Forecast', color='blue', alpha=0.5)

# 绘制测试集预测数据
plt.plot(range(len(train_data), len(train_data) + len(test_data)), forecast, label='Test Forecast', color='red', linestyle='--')

# 绘制未来预测数据
plt.plot(range(len(train_data) + len(test_data), len(train_data) + len(test_data) + 48), future_forecast, label='Future Forecast', color='purple', linestyle=':')

# 设置 x 轴范围
plt.xlim(0, len(df) + 48) 
# 设置 x 轴的日期标签
date_range = pd.date_range(start='2008-01-01', periods=len(df)+48, freq='ME')  
ticks = range(0, len(df) + 48, 12)  
labels = date_range.strftime('%Y-%m')[ticks]  

plt.xticks(ticks=ticks, labels=labels, rotation=45) 

# 添加图例和标签
plt.legend()
plt.xlabel('Time')
plt.ylabel('Cases')
plt.title('SARIMA Model Result')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig("SARIMA.png")  
plt.show()

png

图1 SARIMA模型预测结果

图1展示了SARIMA模型的预测结果,训练集和测试集的拟合优度R方值都达到了0.70以上,可以认为是模型预测效果较好,对未来四年的预测也较为可信。

若R方值较低,可能是数据不太适合用于SARIMA模型进行预测,可以尝试其他数据驱动模型如ARIMA、季节性分解模型(STL)、GARCH模型或机器学习模型(如随机森林、XGBoost等),以提高预测能力。另外,SARIMA模型在最前期的数据预测上效果较差,计算拟合优度会较低,若数据后期预测效果较好,也可以继续使用该模型。

# 计算残差
residuals = train_data - train_forecast

# 绘制残差图
plt.figure(figsize=(10, 6))
plt.subplot(211)
plt.plot(residuals, label='Residuals', color='blue')
plt.axhline(0, color='red', linestyle='--')
plt.title('Residuals of the Model')
plt.legend()

# 绘制残差的自相关图
plt.subplot(212)
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
plot_acf(residuals, lags=20, ax=plt.gca())
plt.title('ACF of Residuals')

plt.tight_layout()
plt.show()

# 正态性检验
from scipy import stats
shapiro_test = stats.shapiro(residuals)
print(f'Shapiro-Wilk Test Statistic: {shapiro_test.statistic}, p-value: {shapiro_test.pvalue}')

print(f"残差正态性检验值: {shapiro_test.pvalue.4f}")
if shapiro_test.pvalue > 0.05:
    print("残差符合正态分布")
else:
    print("残差不符合正态分布")

png

图2 SARIMA模型预测的残差图与残差自相关图

Shapiro-Wilk Test Statistic: 0.9771958732822187, p-value: 0.016648007947939247

残差不符合正态分布

从残差图和残差自相关图上可以看到,第一年(前12个数据)的残差较大,但整体模型的预测能力和其他残差的分布良好,可以考虑继续使用该模型。

结果分析

通过仿真结果,可以观察到随着时间的推移,该病的病例数呈现明显的季节性特征,训练集的原始数据和预测数据较为符合,预测效果较为可信。

样本量较少、数据噪声较大、数据非线性等原因均可能影响模型的预测结果,可以从以下方面改进:

  1. 对数据进行适当的转换(如对数变换、差分等),以消除趋势和季节性,提高数据的平稳性;
  2. 识别并处理数据中的异常值,以减少对模型拟合的影响;
  3. 考虑使用模型集成方法(如Bagging、Boosting等),结合多个模型的预测结果,以提高整体预测性能。

和越
犯困嫌疑人