Files
stock_system/fetch_history_v2.py
hubian 23860c3f8c feat: V2优化版 - 分文件存储+批量合并+SQLite进度库
优化点:
- 每只股票独立存小文件,避免每次读写203MB大文件
- 每50只股票批量合并一次,减少IO次数
- SQLite进度数据库,更可靠的断点续传
- 请求间隔从5秒降到0.3秒

新增文件:
- fetch_history_v2.py: V2优化版主脚本
- run_v2.sh: 启动脚本
- .gitignore: 添加config.txt和data/*.db忽略
2026-04-09 12:11:04 +08:00

394 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
A股历史数据获取系统 V2 - 性能优化版
优化点:
1. 分文件存储 - 每只股票单独存小文件,避免每次读写整个大文件
2. 批量合并 - 每100只股票合并一次减少IO次数
3. SQLite进度记录 - 更可靠的断点续传
"""
import tushare as ts
import pandas as pd
import os
import time
import sqlite3
from datetime import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import threading
# 配置
BASE_DIR = Path(__file__).parent
DATA_DIR = BASE_DIR / 'data'
TEMP_DIR = DATA_DIR / 'temp' # 临时分片目录
LOGS_DIR = BASE_DIR / 'logs'
STOCK_LIST_FILE = BASE_DIR / 'A股股票列表.csv'
DB_FILE = DATA_DIR / 'progress.db' # SQLite进度数据库
# 创建目录
DATA_DIR.mkdir(exist_ok=True)
TEMP_DIR.mkdir(exist_ok=True)
LOGS_DIR.mkdir(exist_ok=True)
# 时间范围
START_DATE = '20100101'
END_DATE = datetime.now().strftime('%Y%m%d')
# 请求间隔(秒)- tushare积分限制
REQUEST_INTERVAL = 2 # 减少间隔,用轻量存储补偿
# 批量合并阈值
MERGE_BATCH_SIZE = 50 # 每50只股票合并一次
def setup_tushare(token=None):
"""初始化tushare"""
if not token:
token = os.environ.get('TUSHARE_TOKEN', '')
if not token:
config_file = BASE_DIR / 'config.txt'
if config_file.exists():
token = config_file.read_text().strip()
if not token:
raise ValueError("缺少 Tushare Token")
ts.set_token(token)
return ts.pro_api()
def init_progress_db():
"""初始化SQLite进度数据库"""
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
# 创建进度表
cursor.execute('''
CREATE TABLE IF NOT EXISTS progress (
ts_code TEXT PRIMARY KEY,
status TEXT DEFAULT 'pending',
record_count INTEGER DEFAULT 0,
updated_at TEXT,
error_msg TEXT
)
''')
# 创建合并记录表
cursor.execute('''
CREATE TABLE IF NOT EXISTS merge_log (
batch_id INTEGER PRIMARY KEY AUTOINCREMENT,
stock_count INTEGER,
merged_at TEXT,
file_size INTEGER
)
''')
conn.commit()
conn.close()
print(f"进度数据库: {DB_FILE}")
def load_stock_list():
"""加载股票列表"""
df = pd.read_csv(STOCK_LIST_FILE)
df.columns = df.columns.str.strip()
print(f"加载股票列表: {len(df)} 只股票")
return df
def get_stock_codes_with_suffix(df):
"""将股票代码转换为tushare格式"""
codes = []
for code in df['code']:
code = str(code).zfill(6)
first_digit = code[0]
if first_digit == '6':
ts_code = f"{code}.SH"
elif first_digit in ('0', '3'):
ts_code = f"{code}.SZ"
elif first_digit in ('4', '8'):
ts_code = f"{code}.BJ"
else:
ts_code = f"{code}.SZ"
codes.append(ts_code)
return codes
def init_stock_progress(codes):
"""初始化所有股票的进度状态"""
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
# 批量插入(不存在则插入)
for ts_code in codes:
cursor.execute('''
INSERT OR IGNORE INTO progress (ts_code, status, updated_at)
VALUES (?, 'pending', ?)
''', (ts_code, datetime.now().isoformat()))
conn.commit()
# 统计状态
cursor.execute('SELECT status, COUNT(*) FROM progress GROUP BY status')
stats = cursor.fetchall()
conn.close()
print("\n当前进度状态:")
for status, count in stats:
print(f" {status}: {count}")
return stats
def get_pending_stocks():
"""获取待处理的股票列表"""
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute('''
SELECT ts_code FROM progress
WHERE status = 'pending' OR status = 'error'
ORDER BY ts_code
''')
pending = [row[0] for row in cursor.fetchall()]
conn.close()
return pending
def save_stock_temp(df, ts_code):
"""保存单只股票到临时文件(极快)"""
temp_file = TEMP_DIR / f"{ts_code.replace('.', '_')}.parquet"
df.to_parquet(temp_file, index=False, compression='snappy')
# 更新进度
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute('''
UPDATE progress
SET status = 'completed', record_count = ?, updated_at = ?
WHERE ts_code = ?
''', (len(df), datetime.now().isoformat(), ts_code))
conn.commit()
conn.close()
return temp_file.stat().st_size
def merge_batch_to_main():
"""将临时文件批量合并到主文件"""
temp_files = list(TEMP_DIR.glob('*.parquet'))
if not temp_files:
return 0
print(f"\n正在合并 {len(temp_files)} 个临时文件...")
# 读取所有临时文件
batch_data = []
for tf in temp_files:
try:
df = pd.read_parquet(tf)
batch_data.append(df)
except Exception as e:
print(f" 警告: 读取 {tf.name} 失败: {e}")
if not batch_data:
return 0
# 合并
new_data = pd.concat(batch_data, ignore_index=True)
# 读取主文件并合并
main_file = DATA_DIR / 'stock_daily_data.parquet'
if main_file.exists():
existing = pd.read_parquet(main_file)
# 获取已合并的股票代码
existing_codes = set(existing['ts_code'].unique())
new_codes = set(new_data['ts_code'].unique())
# 只合并新股票的数据
truly_new = new_data[~new_data['ts_code'].isin(existing_codes)]
if len(truly_new) > 0:
combined = pd.concat([existing, truly_new], ignore_index=True)
else:
combined = existing
else:
combined = new_data
# 排序并保存
combined = combined.sort_values(['ts_code', 'trade_date']).reset_index(drop=True)
combined.to_parquet(main_file, index=False, compression='snappy')
# 记录合并日志
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO merge_log (stock_count, merged_at, file_size)
VALUES (?, ?, ?)
''', (len(temp_files), datetime.now().isoformat(), main_file.stat().st_size))
conn.commit()
conn.close()
# 删除临时文件
for tf in temp_files:
tf.unlink()
print(f" 合并完成: {len(new_data)} 条新记录")
print(f" 主文件大小: {main_file.stat().st_size / 1024 / 1024:.2f} MB")
return len(temp_files)
def fetch_stock_data(pro, ts_code):
"""获取单只股票数据"""
try:
df = pro.daily(ts_code=ts_code, start_date=START_DATE, end_date=END_DATE)
if df is not None and len(df) > 0:
# 保存到临时文件
file_size = save_stock_temp(df, ts_code)
return True, len(df), file_size
else:
# 无数据,标记完成
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute('''
UPDATE progress
SET status = 'no_data', updated_at = ?
WHERE ts_code = ?
''', (datetime.now().isoformat(), ts_code))
conn.commit()
conn.close()
return True, 0, 0
except Exception as e:
# 记录错误
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute('''
UPDATE progress
SET status = 'error', error_msg = ?, updated_at = ?
WHERE ts_code = ?
''', (str(e)[:200], datetime.now().isoformat(), ts_code))
conn.commit()
conn.close()
return False, 0, 0
def fetch_all_stocks(pro, codes):
"""获取所有股票数据"""
total = len(codes)
batch_count = 0
print(f"\n开始获取数据...")
print(f"{total} 只股票")
print(f"{MERGE_BATCH_SIZE} 只合并一次")
print("-" * 50)
for i, ts_code in enumerate(codes):
success, records, size = fetch_stock_data(pro, ts_code)
status = "" if success else ""
print(f"[{i+1}/{total}] {ts_code} {status} {records}{size/1024:.1f}KB")
# 批量合并检查
batch_count += 1
if batch_count >= MERGE_BATCH_SIZE:
merge_batch_to_main()
batch_count = 0
# 请求间隔
if i < total - 1:
time.sleep(REQUEST_INTERVAL)
# 最后合并剩余的
if batch_count > 0:
merge_batch_to_main()
print("\n" + "=" * 50)
print("数据获取完成!")
def show_final_stats():
"""显示最终统计"""
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
# 状态统计
cursor.execute('SELECT status, COUNT(*) FROM progress GROUP BY status')
stats = cursor.fetchall()
# 记录数统计
cursor.execute('SELECT SUM(record_count) FROM progress WHERE status = "completed"')
total_records = cursor.fetchone()[0] or 0
# 合并历史
cursor.execute('SELECT COUNT(*), SUM(stock_count) FROM merge_log')
merge_stats = cursor.fetchone()
conn.close()
print("\n最终统计:")
print("-" * 30)
for status, count in stats:
print(f" {status}: {count}")
print(f"\n 总记录数: {total_records}")
print(f" 合并批次: {merge_stats[0]}")
# 文件大小
main_file = DATA_DIR / 'stock_daily_data.parquet'
if main_file.exists():
print(f" 主文件大小: {main_file.stat().st_size / 1024 / 1024:.2f} MB")
def main():
"""主函数"""
print("=" * 60)
print("A股历史数据获取系统 V2 - 性能优化版")
print("=" * 60)
print(f"数据时间范围: {START_DATE} ~ {END_DATE}")
print(f"数据保存目录: {DATA_DIR}")
print("=" * 60)
# 初始化
print("\n初始化 Tushare...")
pro = setup_tushare()
print("\n初始化进度数据库...")
init_progress_db()
# 加载股票列表
print("\n加载股票列表...")
stock_df = load_stock_list()
codes = get_stock_codes_with_suffix(stock_df)
# 初始化进度
init_stock_progress(codes)
# 获取待处理股票
pending = get_pending_stocks()
if not pending:
print("\n所有股票已完成!")
show_final_stats()
return
print(f"\n待处理: {len(pending)} 只股票")
print(f"预计耗时: {len(pending) * REQUEST_INTERVAL / 60:.1f} 分钟")
# 开始获取
fetch_all_stocks(pro, pending)
# 显示统计
show_final_stats()
if __name__ == '__main__':
main()