使用sklearn加载公共数据集、内存数据与CSV文件
本文介紹了如何加載各種數據源,以生成可以用于sklearn使用的數據集。主要包括以下幾類數據源:
- 預定義的公共數據源
- 內存中的數據
- csv文件
- 任意格式的數據文件
- 稀疏數據格式文件
sklearn使用的數據集一般為numpy ndarray,或者pandas dataframe。
import numpy as np import pandas as pd import sklearn import os import urllib import tarfile1、預定義的公共數據源
更多數據集請見:https://scikitlearn.com.cn/0.21.3/47/
minst數據集
以下示例用于判斷圖片是否數字5
from sklearn.datasets import fetch_openml mnist = fetch_openml('mnist_784', version=1) X,y = pd.DataFrame.to_numpy(mnist['data']), pd.DataFrame.to_numpy(mnist['target'])X_train, X_test = X[:6000], X[6000:] y_train, y_test = y[:6000].astype(np.uint8), y[6000:].astype(np.uint8) y_train_5 = (y_train == 5) y_test_5 = (y_test == 5)from sklearn.linear_model import SGDClassifier model = SGDClassifier(loss='hinge') model.fit(X_train, y_train_5) print(model.predict([X[0]])) [ True]iris數據集
這是一個非常著名的數據集,共有150朵鳶尾花,分別來自三個不同品種(山鳶尾、變色鳶尾和維吉尼亞鳶尾),數據里包含花的萼片以及花瓣的長度和寬度。
from sklearn import datasets iris = datasets.load_iris()我們看一下數據集。注意,sklearn的dataset都包含這些keys:
print(iris.keys()) print(iris['data'][:10], iris['target'][:], iris['frame'], iris['target_names'][:10],iris['DESCR'], iris['feature_names'][:10]) dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename']) [[5.1 3.5 1.4 0.2][4.9 3. 1.4 0.2][4.7 3.2 1.3 0.2][4.6 3.1 1.5 0.2][5. 3.6 1.4 0.2][5.4 3.9 1.7 0.4][4.6 3.4 1.4 0.3][5. 3.4 1.5 0.2][4.4 2.9 1.4 0.2][4.9 3.1 1.5 0.1]] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 00 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 11 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 22 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 22 2] None ['setosa' 'versicolor' 'virginica'] .. _iris_dataset:Iris plants dataset --------------------**Data Set Characteristics:**:Number of Instances: 150 (50 in each of three classes):Number of Attributes: 4 numeric, predictive attributes and the class:Attribute Information:- sepal length in cm- sepal width in cm- petal length in cm- petal width in cm- class:- Iris-Setosa- Iris-Versicolour- Iris-Virginica:Summary Statistics:============== ==== ==== ======= ===== ====================Min Max Mean SD Class Correlation============== ==== ==== ======= ===== ====================sepal length: 4.3 7.9 5.84 0.83 0.7826sepal width: 2.0 4.4 3.05 0.43 -0.4194petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)============== ==== ==== ======= ===== ====================:Missing Attribute Values: None:Class Distribution: 33.3% for each of 3 classes.:Creator: R.A. Fisher:Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov):Date: July, 1988The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken from Fisher's paper. Note that it's the same as in R, but not as in the UCI Machine Learning Repository, which has two wrong data points.This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field and is referenced frequently to this day. (See Duda & Hart, for example.) The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant. One class is linearly separable from the other 2; the latter are NOT linearly separable from each other... topic:: References- Fisher, R.A. "The use of multiple measurements in taxonomic problems"Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions toMathematical Statistics" (John Wiley, NY, 1950).- Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.(Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.- Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New SystemStructure and Classification Rule for Recognition in Partially ExposedEnvironments". IEEE Transactions on Pattern Analysis and MachineIntelligence, Vol. PAMI-2, No. 1, 67-71.- Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactionson Information Theory, May 1972, 431-433.- See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS IIconceptual clustering system finds 3 classes in the data.- Many, many more ... ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']2、內存中的數據
本示例,我們在內存中生成numpy ndarray,然后使用線性回歸擬合數據。
X = 2 * np.random.rand(100,1) y = 3 * X + 4 + np.random.rand(100,1)X = pd.DataFrame(X) y = pd.DataFrame(y)from sklearn.linear_model import LinearRegression model = LinearRegression() model.fit(X,y) print(model.intercept_, model.coef_) [4.45291269] [[2.99295562]]我們也可以使用pandas dataframe作為模型的輸入。
X = pd.DataFrame(2 * np.random.rand(100,1)) y = pd.DataFrame(3 * X + 4 + np.random.rand(100,1))from sklearn.linear_model import LinearRegression model = LinearRegression() model.fit(X,y) print(model.intercept_, model.coef_) [4.45003988] [[3.02825472]]下面使用csv文件中的數據時,大部分情況也是轉化為pandas.DataFrame。
3、csv文件中的數據
我們用housing數據做示例,使用線性回歸擬合一個地區的房價中位數。
由于我們沒有數據文件,先下載下來:
csv文件準備好了以后,我們使用pandas.read_csv()加載文件中的內容:
housing = pd.read_csv(os.path.join(HOUSING_PATH,'housing.csv')) # 簡單看幾行數據 housing.head()| -122.23 | 37.88 | 41.0 | 880.0 | 129.0 | 322.0 | 126.0 | 8.3252 | 452600.0 | NEAR BAY |
| -122.22 | 37.86 | 21.0 | 7099.0 | 1106.0 | 2401.0 | 1138.0 | 8.3014 | 358500.0 | NEAR BAY |
| -122.24 | 37.85 | 52.0 | 1467.0 | 190.0 | 496.0 | 177.0 | 7.2574 | 352100.0 | NEAR BAY |
| -122.25 | 37.85 | 52.0 | 1274.0 | 235.0 | 558.0 | 219.0 | 5.6431 | 341300.0 | NEAR BAY |
| -122.25 | 37.85 | 52.0 | 1627.0 | 280.0 | 565.0 | 259.0 | 3.8462 | 342200.0 | NEAR BAY |
由于housing中有缺失值,所以我們需要先填充數據。看一下缺失值的情況:
housing.info() <class 'pandas.core.frame.DataFrame'> RangeIndex: 20640 entries, 0 to 20639 Data columns (total 10 columns):# Column Non-Null Count Dtype --- ------ -------------- ----- 0 longitude 20640 non-null float641 latitude 20640 non-null float642 housing_median_age 20640 non-null float643 total_rooms 20640 non-null float644 total_bedrooms 20433 non-null float645 population 20640 non-null float646 households 20640 non-null float647 median_income 20640 non-null float648 median_house_value 20640 non-null float649 ocean_proximity 20640 non-null object dtypes: float64(9), object(1) memory usage: 1.6+ MB我們看到total_bedromms中有缺失值,我們使用均值來做填充。如果有很多字段都有缺失值,可以使用sklearn的Simpleimputer批量處理,詳見sklearn系列:數據預處理。
median = housing['total_bedrooms'].median() housing['total_bedrooms'].fillna(median,inplace=True)下面,我們分離label和feature。同時,先暫時忽略ocean_proximity這個非數值特征:
housing_label = housing['median_house_value'] housing_feature = housing.drop(['median_house_value','ocean_proximity'], axis=1) from sklearn.linear_model import LinearRegression model = LinearRegression() model.fit(housing_feature,housing_label) print(model.intercept_, model.coef_) -3570118.06149459 [-4.26104026e+04 -4.24754782e+04 1.14445085e+03 -6.62091740e+008.11609666e+01 -3.98732002e+01 7.93047225e+01 3.97522237e+04]完整代碼
housing = pd.read_csv(os.path.join(HOUSING_PATH,'housing.csv'))median = housing['total_bedrooms'].median() housing['total_bedrooms'].fillna(median,inplace=True)housing_label = housing['median_house_value'] housing_feature = housing.drop(['median_house_value','ocean_proximity'], axis=1)from sklearn.linear_model import LinearRegression model = LinearRegression() model.fit(housing_feature,housing_label) print(model.intercept_, model.coef_) -3570118.06149459 [-4.26104026e+04 -4.24754782e+04 1.14445085e+03 -6.62091740e+008.11609666e+01 -3.98732002e+01 7.93047225e+01 3.97522237e+04]總結
以上是生活随笔為你收集整理的使用sklearn加载公共数据集、内存数据与CSV文件的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 模型的评估与选择
- 下一篇: numpy和pandas的数据乱序