elon_py/api/main.py
2025-02-25 17:56:23 +08:00

233 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objs as go
import pandas as pd
import pytz
from datetime import datetime
from sqlalchemy import create_engine
# 数据库连接配置
DB_CONFIG = {
'host': '8.155.23.172',
'port': 3306,
'user': 'root2',
'password': 'tG0f6PVYh18le41BCb',
'database': 'elonX'
}
TABLE_NAME = 'elon_tweets'
# 使用SQLAlchemy创建数据库连接
db_uri = f"mysql+pymysql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}"
engine = create_engine(db_uri)
# 加载数据
df = pd.read_sql(f'SELECT timestamp FROM {TABLE_NAME}', con=engine)
# 数据预处理基于EST
eastern = pytz.timezone('America/New_York') # EST
pacific = pytz.timezone('America/Los_Angeles') # PST
central = pytz.timezone('America/Chicago') # CST
df['datetime'] = pd.to_datetime(df['timestamp'], unit='s')
df['datetime_est'] = df['datetime'].dt.tz_localize('UTC').dt.tz_convert(eastern)
df['date'] = df['datetime_est'].dt.date
df['minute_of_day'] = df['datetime_est'].dt.hour * 60 + df['datetime_est'].dt.minute
agg_df = df.groupby(['date', 'minute_of_day']).size().reset_index(name='tweet_count')
# 获取所有日期用于选择器
all_dates = sorted(agg_df['date'].unique())
default_dates = all_dates[-4:] # 默认显示最近4天
# 初始化Dash应用
app = dash.Dash(__name__)
# 时间间隔选项
interval_options = [
{'label': '1分钟', 'value': 1},
{'label': '5分钟', 'value': 5},
{'label': '10分钟', 'value': 10},
{'label': '30分钟', 'value': 30},
{'label': '60分钟', 'value': 60}
]
# Dash应用布局
app.layout = html.Div([
html.H1("Elon Musk 发帖时间分析 (EST)"),
dcc.Tabs(id='tabs', value='daily-view', children=[
# 选项卡1每日视图折线图
dcc.Tab(label='Daily View', value='daily-view', children=[
dcc.DatePickerSingle(
id='date-picker',
min_date_allowed=min(all_dates),
max_date_allowed=max(all_dates),
initial_visible_month=max(all_dates),
date=max(all_dates)
),
dcc.Dropdown(
id='daily-interval-picker',
options=interval_options,
value=10, # 默认10分钟
style={'width': '50%'}
),
html.Div(id='daily-tweet-summary', style={'fontSize': 20, 'margin': '10px'}), # 单日汇总
dcc.Graph(id='daily-tweet-graph')
]),
# 选项卡2多日视图多线折线图
dcc.Tab(label='Multi-Day View', value='multi-day-view', children=[
dcc.Checklist(
id='multi-date-picker',
options=[{'label': str(date), 'value': str(date)} for date in all_dates],
value=[str(date) for date in default_dates],
style={'height': '200px', 'overflow': 'auto'}
),
dcc.Dropdown(
id='multi-interval-picker',
options=interval_options,
value=10, # 默认10分钟
style={'width': '50%'}
),
html.Div(id='multi-day-warning', style={'color': 'red'}),
html.Div(id='multi-tweet-summary', style={'fontSize': 20, 'margin': '10px'}), # 多日汇总
dcc.Graph(id='multi-tweet-graph')
])
])
])
# 数据聚合函数按指定时间间隔分组并填充0
def aggregate_data(data, interval):
all_minutes = pd.DataFrame({'interval_group': range(0, 1440, interval)})
result = []
for date in data['date'].unique():
day_data = data[data['date'] == date].copy()
day_data['interval_group'] = (day_data['minute_of_day'] // interval) * interval
agg = day_data.groupby('interval_group').size().reset_index(name='tweet_count')
complete_data = all_minutes.merge(agg, on='interval_group', how='left').fillna({'tweet_count': 0})
complete_data['date'] = date
result.append(complete_data)
return pd.concat(result, ignore_index=True)
# 生成X轴刻度EST时间
def generate_xticks(interval):
ticks = list(range(0, 1440, interval))
tick_labels = [f"{m // 60:02d}:{m % 60:02d}" for m in ticks]
return ticks, tick_labels
# 回调函数1更新Daily View图表和汇总
@app.callback(
[Output('daily-tweet-graph', 'figure'),
Output('daily-tweet-summary', 'children')],
[Input('date-picker', 'date'),
Input('daily-interval-picker', 'value'),
Input('tabs', 'value')]
)
def update_daily_graph(selected_date, interval, tab):
if tab != 'daily-view':
return go.Figure(), ""
if isinstance(selected_date, str):
selected_date = datetime.strptime(selected_date, '%Y-%m-%d').date()
day_data = agg_df[agg_df['date'] == selected_date].copy()
if day_data.empty:
day_data = pd.DataFrame({'date': [selected_date], 'minute_of_day': [0]})
tweet_count_total = 0
else:
tweet_count_total = day_data['tweet_count'].sum()
agg_data = aggregate_data(day_data, interval)
xticks, xtick_labels = generate_xticks(interval if interval >= 30 else 60)
fig = go.Figure()
fig.add_trace(go.Scatter(
x=agg_data['interval_group'],
y=agg_data['tweet_count'],
mode='lines',
name='推文数量',
line=dict(color='blue')
))
# 计算凌晨2点位置基于EST
eastern_2am = eastern.localize(datetime.combine(selected_date, datetime.time(2, 0)))
pacific_2am = pacific.localize(datetime.combine(selected_date, datetime.time(2, 0))).astimezone(eastern)
central_2am = central.localize(datetime.combine(selected_date, datetime.time(2, 0))).astimezone(eastern)
eastern_2am_minute = eastern_2am.hour * 60 + eastern_2am.minute # 120分钟
pacific_2am_minute = pacific_2am.hour * 60 + pacific_2am.minute # 300分钟 (5:00 EST)
central_2am_minute = central_2am.hour * 60 + central_2am.minute # 180分钟 (3:00 EST)
# 添加垂直线
fig.add_vline(x=eastern_2am_minute, line_dash="dash", line_color="red", annotation_text="纽约 2AM")
fig.add_vline(x=pacific_2am_minute, line_dash="dash", line_color="blue", annotation_text="加州 2AM")
fig.add_vline(x=central_2am_minute, line_dash="dash", line_color="green", annotation_text="新奥尔良 2AM")
fig.update_layout(
title=f'{selected_date} 的推文频率(间隔 {interval} 分钟EST',
xaxis_title='东部时间 (HH:MM)',
yaxis_title='推文数量',
xaxis=dict(range=[0, 1440], tickvals=xticks, ticktext=xtick_labels, tickangle=45),
height=600
)
summary = f"单日推文总数: {int(tweet_count_total)}"
return fig, summary
# 回调函数2更新Multi-Day View图表、警告信息和汇总
@app.callback(
[Output('multi-tweet-graph', 'figure'),
Output('multi-day-warning', 'children'),
Output('multi-tweet-summary', 'children')],
[Input('multi-date-picker', 'value'),
Input('multi-interval-picker', 'value'),
Input('tabs', 'value')]
)
def update_multi_graph(selected_dates, interval, tab):
if tab != 'multi-day-view':
return go.Figure(), "", ""
if len(selected_dates) > 10:
selected_dates = selected_dates[:10]
warning = "最多只能选择10天已自动截取前10天。"
else:
warning = ""
selected_dates = [datetime.strptime(date, '%Y-%m-%d').date() for date in selected_dates]
multi_data = agg_df[agg_df['date'].isin(selected_dates)].copy()
if multi_data.empty:
multi_data = pd.DataFrame({'date': selected_dates, 'minute_of_day': [0] * len(selected_dates)})
tweet_count_total = 0
else:
tweet_count_total = multi_data['tweet_count'].sum()
agg_data = aggregate_data(multi_data, interval)
xticks, xtick_labels = generate_xticks(interval if interval >= 30 else 60)
fig = go.Figure()
for i, date in enumerate(selected_dates):
day_data = agg_data[agg_data['date'] == date]
fig.add_trace(go.Scatter(
x=day_data['interval_group'],
y=day_data['tweet_count'],
mode='lines',
name=str(date),
visible=True if i < 4 else 'legendonly'
))
fig.update_layout(
title=f'多日推文频率对比(间隔 {interval} 分钟EST',
xaxis_title='东部时间 (HH:MM)',
yaxis_title='推文数量',
xaxis=dict(range=[0, 1440], tickvals=xticks, ticktext=xtick_labels, tickangle=45),
height=600,
showlegend=True
)
summary = f"所选日期推文总数: {int(tweet_count_total)}"
return fig, warning, summary
# 运行应用
if __name__ == '__main__':
app.run_server(debug=True)