completed part 2 implementing gradient descent

This commit is contained in:
2019-07-20 23:21:32 +01:00
parent 1c8aec47c2
commit 08463c5d0b
15 changed files with 1685 additions and 411 deletions

View File

@@ -0,0 +1,48 @@
import numpy as np
def sigmoid(x):
"""
Calculate sigmoid
"""
return 1 / (1 + np.exp(-x))
x = np.array([0.5, 0.1, -0.2])
target = 0.6
learnrate = 0.5
weights_input_hidden = np.array([[0.5, -0.6],
[0.1, -0.2],
[0.1, 0.7]])
weights_hidden_output = np.array([0.1, -0.3])
# Forward pass
hidden_layer_input = np.dot(x, weights_input_hidden)
hidden_layer_output = sigmoid(hidden_layer_input)
output_layer_in = np.dot(hidden_layer_output, weights_hidden_output)
output = sigmoid(output_layer_in)
# Backwards pass
# TODO: Calculate output error
error = target - output
# TODO: Calculate error term for output layer
output_error_term = error * output * (1 - output)
# TODO: Calculate error term for hidden layer
hidden_error_term = np.dot(output_error_term, weights_hidden_output
* hidden_layer_output * (1 - hidden_layer_output))
# TODO: Calculate change in weights for hidden layer to output layer
delta_w_h_o = learnrate * output_error_term * hidden_layer_output
# TODO: Calculate change in weights for input layer to hidden layer
delta_w_i_h = learnrate * hidden_error_term * x[:, None]
print('Change in weights for hidden layer to output layer:')
print(delta_w_h_o)
print('Change in weights for input layer to hidden layer:')
print(delta_w_i_h)

View File

@@ -0,0 +1,78 @@
import numpy as np
from data_prep import features, targets, features_test, targets_test
np.random.seed(21)
def sigmoid(x):
"""
Calculate sigmoid
"""
return 1 / (1 + np.exp(-x))
# Hyperparameters
n_hidden = 2 # number of hidden units
epochs = 900
learnrate = 0.005
n_records, n_features = features.shape
last_loss = None
# Initialize weights
weights_input_hidden = np.random.normal(scale=1 / n_features ** .5,
size=(n_features, n_hidden))
weights_hidden_output = np.random.normal(scale=1 / n_features ** .5,
size=n_hidden)
for e in range(epochs):
del_w_input_hidden = np.zeros(weights_input_hidden.shape)
del_w_hidden_output = np.zeros(weights_hidden_output.shape)
for x, y in zip(features.values, targets):
## Forward pass ##
# TODO: Calculate the output
hidden_input = np.dot(x, weights_input_hidden)
hidden_output = sigmoid(hidden_input)
output = sigmoid(np.dot(hidden_output, weights_hidden_output))
## Backward pass ##
# TODO: Calculate the network's prediction error
error = y - output
# TODO: Calculate error term for the output unit
output_error_term = error * output * (1 - output)
# propagate errors to hidden layer
# TODO: Calculate the hidden layer's contribution to the error
hidden_error = np.dot(output_error_term, weights_hidden_output)
# TODO: Calculate the error term for the hidden layer
hidden_error_term = hidden_error * hidden_output * (1 - hidden_output)
# TODO: Update the change in weights
del_w_hidden_output += output_error_term * hidden_output
del_w_input_hidden += hidden_error_term * x[:, None]
# TODO: Update weights (don't forget to division by n_records or number of samples)
weights_input_hidden += learnrate * del_w_input_hidden / n_records
weights_hidden_output += learnrate * del_w_hidden_output / n_records
# Printing out the mean square error on the training set
if e % (epochs / 10) == 0:
hidden_output = sigmoid(np.dot(x, weights_input_hidden))
out = sigmoid(np.dot(hidden_output,
weights_hidden_output))
loss = np.mean((out - targets) ** 2)
if last_loss and last_loss < loss:
print("Train loss: ", loss, " WARNING - Loss Increasing")
else:
print("Train loss: ", loss)
last_loss = loss
# Calculate accuracy on test data
hidden = sigmoid(np.dot(features_test, weights_input_hidden))
out = sigmoid(np.dot(hidden, weights_hidden_output))
predictions = out > 0.5
accuracy = np.mean(predictions == targets_test)
print("Prediction accuracy: {:.3f}".format(accuracy))

View File

@@ -0,0 +1,401 @@
admit,gre,gpa,rank
0,380,3.61,3
1,660,3.67,3
1,800,4,1
1,640,3.19,4
0,520,2.93,4
1,760,3,2
1,560,2.98,1
0,400,3.08,2
1,540,3.39,3
0,700,3.92,2
0,800,4,4
0,440,3.22,1
1,760,4,1
0,700,3.08,2
1,700,4,1
0,480,3.44,3
0,780,3.87,4
0,360,2.56,3
0,800,3.75,2
1,540,3.81,1
0,500,3.17,3
1,660,3.63,2
0,600,2.82,4
0,680,3.19,4
1,760,3.35,2
1,800,3.66,1
1,620,3.61,1
1,520,3.74,4
1,780,3.22,2
0,520,3.29,1
0,540,3.78,4
0,760,3.35,3
0,600,3.4,3
1,800,4,3
0,360,3.14,1
0,400,3.05,2
0,580,3.25,1
0,520,2.9,3
1,500,3.13,2
1,520,2.68,3
0,560,2.42,2
1,580,3.32,2
1,600,3.15,2
0,500,3.31,3
0,700,2.94,2
1,460,3.45,3
1,580,3.46,2
0,500,2.97,4
0,440,2.48,4
0,400,3.35,3
0,640,3.86,3
0,440,3.13,4
0,740,3.37,4
1,680,3.27,2
0,660,3.34,3
1,740,4,3
0,560,3.19,3
0,380,2.94,3
0,400,3.65,2
0,600,2.82,4
1,620,3.18,2
0,560,3.32,4
0,640,3.67,3
1,680,3.85,3
0,580,4,3
0,600,3.59,2
0,740,3.62,4
0,620,3.3,1
0,580,3.69,1
0,800,3.73,1
0,640,4,3
0,300,2.92,4
0,480,3.39,4
0,580,4,2
0,720,3.45,4
0,720,4,3
0,560,3.36,3
1,800,4,3
0,540,3.12,1
1,620,4,1
0,700,2.9,4
0,620,3.07,2
0,500,2.71,2
0,380,2.91,4
1,500,3.6,3
0,520,2.98,2
0,600,3.32,2
0,600,3.48,2
0,700,3.28,1
1,660,4,2
0,700,3.83,2
1,720,3.64,1
0,800,3.9,2
0,580,2.93,2
1,660,3.44,2
0,660,3.33,2
0,640,3.52,4
0,480,3.57,2
0,700,2.88,2
0,400,3.31,3
0,340,3.15,3
0,580,3.57,3
0,380,3.33,4
0,540,3.94,3
1,660,3.95,2
1,740,2.97,2
1,700,3.56,1
0,480,3.13,2
0,400,2.93,3
0,480,3.45,2
0,680,3.08,4
0,420,3.41,4
0,360,3,3
0,600,3.22,1
0,720,3.84,3
0,620,3.99,3
1,440,3.45,2
0,700,3.72,2
1,800,3.7,1
0,340,2.92,3
1,520,3.74,2
1,480,2.67,2
0,520,2.85,3
0,500,2.98,3
0,720,3.88,3
0,540,3.38,4
1,600,3.54,1
0,740,3.74,4
0,540,3.19,2
0,460,3.15,4
1,620,3.17,2
0,640,2.79,2
0,580,3.4,2
0,500,3.08,3
0,560,2.95,2
0,500,3.57,3
0,560,3.33,4
0,700,4,3
0,620,3.4,2
1,600,3.58,1
0,640,3.93,2
1,700,3.52,4
0,620,3.94,4
0,580,3.4,3
0,580,3.4,4
0,380,3.43,3
0,480,3.4,2
0,560,2.71,3
1,480,2.91,1
0,740,3.31,1
1,800,3.74,1
0,400,3.38,2
1,640,3.94,2
0,580,3.46,3
0,620,3.69,3
1,580,2.86,4
0,560,2.52,2
1,480,3.58,1
0,660,3.49,2
0,700,3.82,3
0,600,3.13,2
0,640,3.5,2
1,700,3.56,2
0,520,2.73,2
0,580,3.3,2
0,700,4,1
0,440,3.24,4
0,720,3.77,3
0,500,4,3
0,600,3.62,3
0,400,3.51,3
0,540,2.81,3
0,680,3.48,3
1,800,3.43,2
0,500,3.53,4
1,620,3.37,2
0,520,2.62,2
1,620,3.23,3
0,620,3.33,3
0,300,3.01,3
0,620,3.78,3
0,500,3.88,4
0,700,4,2
1,540,3.84,2
0,500,2.79,4
0,800,3.6,2
0,560,3.61,3
0,580,2.88,2
0,560,3.07,2
0,500,3.35,2
1,640,2.94,2
0,800,3.54,3
0,640,3.76,3
0,380,3.59,4
1,600,3.47,2
0,560,3.59,2
0,660,3.07,3
1,400,3.23,4
0,600,3.63,3
0,580,3.77,4
0,800,3.31,3
1,580,3.2,2
1,700,4,1
0,420,3.92,4
1,600,3.89,1
1,780,3.8,3
0,740,3.54,1
1,640,3.63,1
0,540,3.16,3
0,580,3.5,2
0,740,3.34,4
0,580,3.02,2
0,460,2.87,2
0,640,3.38,3
1,600,3.56,2
1,660,2.91,3
0,340,2.9,1
1,460,3.64,1
0,460,2.98,1
1,560,3.59,2
0,540,3.28,3
0,680,3.99,3
1,480,3.02,1
0,800,3.47,3
0,800,2.9,2
1,720,3.5,3
0,620,3.58,2
0,540,3.02,4
0,480,3.43,2
1,720,3.42,2
0,580,3.29,4
0,600,3.28,3
0,380,3.38,2
0,420,2.67,3
1,800,3.53,1
0,620,3.05,2
1,660,3.49,2
0,480,4,2
0,500,2.86,4
0,700,3.45,3
0,440,2.76,2
1,520,3.81,1
1,680,2.96,3
0,620,3.22,2
0,540,3.04,1
0,800,3.91,3
0,680,3.34,2
0,440,3.17,2
0,680,3.64,3
0,640,3.73,3
0,660,3.31,4
0,620,3.21,4
1,520,4,2
1,540,3.55,4
1,740,3.52,4
0,640,3.35,3
1,520,3.3,2
1,620,3.95,3
0,520,3.51,2
0,640,3.81,2
0,680,3.11,2
0,440,3.15,2
1,520,3.19,3
1,620,3.95,3
1,520,3.9,3
0,380,3.34,3
0,560,3.24,4
1,600,3.64,3
1,680,3.46,2
0,500,2.81,3
1,640,3.95,2
0,540,3.33,3
1,680,3.67,2
0,660,3.32,1
0,520,3.12,2
1,600,2.98,2
0,460,3.77,3
1,580,3.58,1
1,680,3,4
1,660,3.14,2
0,660,3.94,2
0,360,3.27,3
0,660,3.45,4
0,520,3.1,4
1,440,3.39,2
0,600,3.31,4
1,800,3.22,1
1,660,3.7,4
0,800,3.15,4
0,420,2.26,4
1,620,3.45,2
0,800,2.78,2
0,680,3.7,2
0,800,3.97,1
0,480,2.55,1
0,520,3.25,3
0,560,3.16,1
0,460,3.07,2
0,540,3.5,2
0,720,3.4,3
0,640,3.3,2
1,660,3.6,3
1,400,3.15,2
1,680,3.98,2
0,220,2.83,3
0,580,3.46,4
1,540,3.17,1
0,580,3.51,2
0,540,3.13,2
0,440,2.98,3
0,560,4,3
0,660,3.67,2
0,660,3.77,3
1,520,3.65,4
0,540,3.46,4
1,300,2.84,2
1,340,3,2
1,780,3.63,4
1,480,3.71,4
0,540,3.28,1
0,460,3.14,3
0,460,3.58,2
0,500,3.01,4
0,420,2.69,2
0,520,2.7,3
0,680,3.9,1
0,680,3.31,2
1,560,3.48,2
0,580,3.34,2
0,500,2.93,4
0,740,4,3
0,660,3.59,3
0,420,2.96,1
0,560,3.43,3
1,460,3.64,3
1,620,3.71,1
0,520,3.15,3
0,620,3.09,4
0,540,3.2,1
1,660,3.47,3
0,500,3.23,4
1,560,2.65,3
0,500,3.95,4
0,580,3.06,2
0,520,3.35,3
0,500,3.03,3
0,600,3.35,2
0,580,3.8,2
0,400,3.36,2
0,620,2.85,2
1,780,4,2
0,620,3.43,3
1,580,3.12,3
0,700,3.52,2
1,540,3.78,2
1,760,2.81,1
0,700,3.27,2
0,720,3.31,1
1,560,3.69,3
0,720,3.94,3
1,520,4,1
1,540,3.49,1
0,680,3.14,2
0,460,3.44,2
1,560,3.36,1
0,480,2.78,3
0,460,2.93,3
0,620,3.63,3
0,580,4,1
0,800,3.89,2
1,540,3.77,2
1,680,3.76,3
1,680,2.42,1
1,620,3.37,1
0,560,3.78,2
0,560,3.49,4
0,620,3.63,2
1,800,4,2
0,640,3.12,3
0,540,2.7,2
0,700,3.65,2
1,540,3.49,2
0,540,3.51,2
0,660,4,1
1,480,2.62,2
0,420,3.02,1
1,740,3.86,2
0,580,3.36,2
0,640,3.17,2
0,640,3.51,2
1,800,3.05,2
1,660,3.88,2
1,600,3.38,3
1,620,3.75,2
1,460,3.99,3
0,620,4,2
0,560,3.04,3
0,460,2.63,2
0,700,3.65,2
0,600,3.89,3
1 admit gre gpa rank
2 0 380 3.61 3
3 1 660 3.67 3
4 1 800 4 1
5 1 640 3.19 4
6 0 520 2.93 4
7 1 760 3 2
8 1 560 2.98 1
9 0 400 3.08 2
10 1 540 3.39 3
11 0 700 3.92 2
12 0 800 4 4
13 0 440 3.22 1
14 1 760 4 1
15 0 700 3.08 2
16 1 700 4 1
17 0 480 3.44 3
18 0 780 3.87 4
19 0 360 2.56 3
20 0 800 3.75 2
21 1 540 3.81 1
22 0 500 3.17 3
23 1 660 3.63 2
24 0 600 2.82 4
25 0 680 3.19 4
26 1 760 3.35 2
27 1 800 3.66 1
28 1 620 3.61 1
29 1 520 3.74 4
30 1 780 3.22 2
31 0 520 3.29 1
32 0 540 3.78 4
33 0 760 3.35 3
34 0 600 3.4 3
35 1 800 4 3
36 0 360 3.14 1
37 0 400 3.05 2
38 0 580 3.25 1
39 0 520 2.9 3
40 1 500 3.13 2
41 1 520 2.68 3
42 0 560 2.42 2
43 1 580 3.32 2
44 1 600 3.15 2
45 0 500 3.31 3
46 0 700 2.94 2
47 1 460 3.45 3
48 1 580 3.46 2
49 0 500 2.97 4
50 0 440 2.48 4
51 0 400 3.35 3
52 0 640 3.86 3
53 0 440 3.13 4
54 0 740 3.37 4
55 1 680 3.27 2
56 0 660 3.34 3
57 1 740 4 3
58 0 560 3.19 3
59 0 380 2.94 3
60 0 400 3.65 2
61 0 600 2.82 4
62 1 620 3.18 2
63 0 560 3.32 4
64 0 640 3.67 3
65 1 680 3.85 3
66 0 580 4 3
67 0 600 3.59 2
68 0 740 3.62 4
69 0 620 3.3 1
70 0 580 3.69 1
71 0 800 3.73 1
72 0 640 4 3
73 0 300 2.92 4
74 0 480 3.39 4
75 0 580 4 2
76 0 720 3.45 4
77 0 720 4 3
78 0 560 3.36 3
79 1 800 4 3
80 0 540 3.12 1
81 1 620 4 1
82 0 700 2.9 4
83 0 620 3.07 2
84 0 500 2.71 2
85 0 380 2.91 4
86 1 500 3.6 3
87 0 520 2.98 2
88 0 600 3.32 2
89 0 600 3.48 2
90 0 700 3.28 1
91 1 660 4 2
92 0 700 3.83 2
93 1 720 3.64 1
94 0 800 3.9 2
95 0 580 2.93 2
96 1 660 3.44 2
97 0 660 3.33 2
98 0 640 3.52 4
99 0 480 3.57 2
100 0 700 2.88 2
101 0 400 3.31 3
102 0 340 3.15 3
103 0 580 3.57 3
104 0 380 3.33 4
105 0 540 3.94 3
106 1 660 3.95 2
107 1 740 2.97 2
108 1 700 3.56 1
109 0 480 3.13 2
110 0 400 2.93 3
111 0 480 3.45 2
112 0 680 3.08 4
113 0 420 3.41 4
114 0 360 3 3
115 0 600 3.22 1
116 0 720 3.84 3
117 0 620 3.99 3
118 1 440 3.45 2
119 0 700 3.72 2
120 1 800 3.7 1
121 0 340 2.92 3
122 1 520 3.74 2
123 1 480 2.67 2
124 0 520 2.85 3
125 0 500 2.98 3
126 0 720 3.88 3
127 0 540 3.38 4
128 1 600 3.54 1
129 0 740 3.74 4
130 0 540 3.19 2
131 0 460 3.15 4
132 1 620 3.17 2
133 0 640 2.79 2
134 0 580 3.4 2
135 0 500 3.08 3
136 0 560 2.95 2
137 0 500 3.57 3
138 0 560 3.33 4
139 0 700 4 3
140 0 620 3.4 2
141 1 600 3.58 1
142 0 640 3.93 2
143 1 700 3.52 4
144 0 620 3.94 4
145 0 580 3.4 3
146 0 580 3.4 4
147 0 380 3.43 3
148 0 480 3.4 2
149 0 560 2.71 3
150 1 480 2.91 1
151 0 740 3.31 1
152 1 800 3.74 1
153 0 400 3.38 2
154 1 640 3.94 2
155 0 580 3.46 3
156 0 620 3.69 3
157 1 580 2.86 4
158 0 560 2.52 2
159 1 480 3.58 1
160 0 660 3.49 2
161 0 700 3.82 3
162 0 600 3.13 2
163 0 640 3.5 2
164 1 700 3.56 2
165 0 520 2.73 2
166 0 580 3.3 2
167 0 700 4 1
168 0 440 3.24 4
169 0 720 3.77 3
170 0 500 4 3
171 0 600 3.62 3
172 0 400 3.51 3
173 0 540 2.81 3
174 0 680 3.48 3
175 1 800 3.43 2
176 0 500 3.53 4
177 1 620 3.37 2
178 0 520 2.62 2
179 1 620 3.23 3
180 0 620 3.33 3
181 0 300 3.01 3
182 0 620 3.78 3
183 0 500 3.88 4
184 0 700 4 2
185 1 540 3.84 2
186 0 500 2.79 4
187 0 800 3.6 2
188 0 560 3.61 3
189 0 580 2.88 2
190 0 560 3.07 2
191 0 500 3.35 2
192 1 640 2.94 2
193 0 800 3.54 3
194 0 640 3.76 3
195 0 380 3.59 4
196 1 600 3.47 2
197 0 560 3.59 2
198 0 660 3.07 3
199 1 400 3.23 4
200 0 600 3.63 3
201 0 580 3.77 4
202 0 800 3.31 3
203 1 580 3.2 2
204 1 700 4 1
205 0 420 3.92 4
206 1 600 3.89 1
207 1 780 3.8 3
208 0 740 3.54 1
209 1 640 3.63 1
210 0 540 3.16 3
211 0 580 3.5 2
212 0 740 3.34 4
213 0 580 3.02 2
214 0 460 2.87 2
215 0 640 3.38 3
216 1 600 3.56 2
217 1 660 2.91 3
218 0 340 2.9 1
219 1 460 3.64 1
220 0 460 2.98 1
221 1 560 3.59 2
222 0 540 3.28 3
223 0 680 3.99 3
224 1 480 3.02 1
225 0 800 3.47 3
226 0 800 2.9 2
227 1 720 3.5 3
228 0 620 3.58 2
229 0 540 3.02 4
230 0 480 3.43 2
231 1 720 3.42 2
232 0 580 3.29 4
233 0 600 3.28 3
234 0 380 3.38 2
235 0 420 2.67 3
236 1 800 3.53 1
237 0 620 3.05 2
238 1 660 3.49 2
239 0 480 4 2
240 0 500 2.86 4
241 0 700 3.45 3
242 0 440 2.76 2
243 1 520 3.81 1
244 1 680 2.96 3
245 0 620 3.22 2
246 0 540 3.04 1
247 0 800 3.91 3
248 0 680 3.34 2
249 0 440 3.17 2
250 0 680 3.64 3
251 0 640 3.73 3
252 0 660 3.31 4
253 0 620 3.21 4
254 1 520 4 2
255 1 540 3.55 4
256 1 740 3.52 4
257 0 640 3.35 3
258 1 520 3.3 2
259 1 620 3.95 3
260 0 520 3.51 2
261 0 640 3.81 2
262 0 680 3.11 2
263 0 440 3.15 2
264 1 520 3.19 3
265 1 620 3.95 3
266 1 520 3.9 3
267 0 380 3.34 3
268 0 560 3.24 4
269 1 600 3.64 3
270 1 680 3.46 2
271 0 500 2.81 3
272 1 640 3.95 2
273 0 540 3.33 3
274 1 680 3.67 2
275 0 660 3.32 1
276 0 520 3.12 2
277 1 600 2.98 2
278 0 460 3.77 3
279 1 580 3.58 1
280 1 680 3 4
281 1 660 3.14 2
282 0 660 3.94 2
283 0 360 3.27 3
284 0 660 3.45 4
285 0 520 3.1 4
286 1 440 3.39 2
287 0 600 3.31 4
288 1 800 3.22 1
289 1 660 3.7 4
290 0 800 3.15 4
291 0 420 2.26 4
292 1 620 3.45 2
293 0 800 2.78 2
294 0 680 3.7 2
295 0 800 3.97 1
296 0 480 2.55 1
297 0 520 3.25 3
298 0 560 3.16 1
299 0 460 3.07 2
300 0 540 3.5 2
301 0 720 3.4 3
302 0 640 3.3 2
303 1 660 3.6 3
304 1 400 3.15 2
305 1 680 3.98 2
306 0 220 2.83 3
307 0 580 3.46 4
308 1 540 3.17 1
309 0 580 3.51 2
310 0 540 3.13 2
311 0 440 2.98 3
312 0 560 4 3
313 0 660 3.67 2
314 0 660 3.77 3
315 1 520 3.65 4
316 0 540 3.46 4
317 1 300 2.84 2
318 1 340 3 2
319 1 780 3.63 4
320 1 480 3.71 4
321 0 540 3.28 1
322 0 460 3.14 3
323 0 460 3.58 2
324 0 500 3.01 4
325 0 420 2.69 2
326 0 520 2.7 3
327 0 680 3.9 1
328 0 680 3.31 2
329 1 560 3.48 2
330 0 580 3.34 2
331 0 500 2.93 4
332 0 740 4 3
333 0 660 3.59 3
334 0 420 2.96 1
335 0 560 3.43 3
336 1 460 3.64 3
337 1 620 3.71 1
338 0 520 3.15 3
339 0 620 3.09 4
340 0 540 3.2 1
341 1 660 3.47 3
342 0 500 3.23 4
343 1 560 2.65 3
344 0 500 3.95 4
345 0 580 3.06 2
346 0 520 3.35 3
347 0 500 3.03 3
348 0 600 3.35 2
349 0 580 3.8 2
350 0 400 3.36 2
351 0 620 2.85 2
352 1 780 4 2
353 0 620 3.43 3
354 1 580 3.12 3
355 0 700 3.52 2
356 1 540 3.78 2
357 1 760 2.81 1
358 0 700 3.27 2
359 0 720 3.31 1
360 1 560 3.69 3
361 0 720 3.94 3
362 1 520 4 1
363 1 540 3.49 1
364 0 680 3.14 2
365 0 460 3.44 2
366 1 560 3.36 1
367 0 480 2.78 3
368 0 460 2.93 3
369 0 620 3.63 3
370 0 580 4 1
371 0 800 3.89 2
372 1 540 3.77 2
373 1 680 3.76 3
374 1 680 2.42 1
375 1 620 3.37 1
376 0 560 3.78 2
377 0 560 3.49 4
378 0 620 3.63 2
379 1 800 4 2
380 0 640 3.12 3
381 0 540 2.7 2
382 0 700 3.65 2
383 1 540 3.49 2
384 0 540 3.51 2
385 0 660 4 1
386 1 480 2.62 2
387 0 420 3.02 1
388 1 740 3.86 2
389 0 580 3.36 2
390 0 640 3.17 2
391 0 640 3.51 2
392 1 800 3.05 2
393 1 660 3.88 2
394 1 600 3.38 3
395 1 620 3.75 2
396 1 460 3.99 3
397 0 620 4 2
398 0 560 3.04 3
399 0 460 2.63 2
400 0 700 3.65 2
401 0 600 3.89 3

View File

@@ -0,0 +1,22 @@
import numpy as np
import pandas as pd
admissions = pd.read_csv('binary.csv')
# Make dummy variables for rank
data = pd.concat([admissions, pd.get_dummies(admissions['rank'], prefix='rank')], axis=1)
data = data.drop('rank', axis=1)
# Standarize features
for field in ['gre', 'gpa']:
mean, std = data[field].mean(), data[field].std()
data.loc[:,field] = (data[field]-mean)/std
# Split off random 10% of the data for testing
np.random.seed(21)
sample = np.random.choice(data.index, size=int(len(data)*0.9), replace=False)
data, test_data = data.ix[sample], data.drop(sample)
# Split into features and targets
features, targets = data.drop('admit', axis=1), data['admit']
features_test, targets_test = test_data.drop('admit', axis=1), test_data['admit']