Skip to content

Instantly share code, notes, and snippets.

@BundleOfKent
Created December 17, 2020 14:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BundleOfKent/4f87869ae3e94eb6331fb9e213e9f343 to your computer and use it in GitHub Desktop.
Save BundleOfKent/4f87869ae3e94eb6331fb9e213e9f343 to your computer and use it in GitHub Desktop.
from matplotlib import animation
N=1000 # new sample size (N=1000)
# Define new cost function w.r.t. new weights of second hidden layer:
def costs_2(x,y,w_a,w_b, seed_):
np.random.seed(seed_)
w0=np.random.randn(hidden_0,784)
w1= np.random.randn(hidden_1,hidden_0)
w2=np.random.randn(10,hidden_1)
w1[5][5] = w_a # w5–5(1)
w1[5][6] = w_b # w5–6(1)
a0 = expit(w0 @ x.T)
a1= expit(w1@a0)
pred= expit(w2 @ a1)
return np.mean(np.sum((y.T-pred)**2,axis=0))
# Calculate z-values w.r.t. random seed with new cost function:
zs_158 = np.array([costs_2(X_train[0:N],y_train_oh[0:N]
,np.array([[mp1]]), np.array([[mp2]]),158)
for mp1, mp2 in zip(np.ravel(M1), np.ravel(M2))])
Z_158 = zs_158.reshape(M1.shape)
zs_20 = np.array([costs_2(X_train[0:N],y_train_oh[0:N]
,np.array([[mp1]]), np.array([[mp2]]),20)
for mp1, mp2 in zip(np.ravel(M1), np.ravel(M2))])
Z_20 = zs_20.reshape(M1.shape)
zs_41 = np.array([costs_2(X_train[0:N],y_train_oh[0:N]
,np.array([[mp1]]), np.array([[mp2]]),41)
for mp1, mp2 in zip(np.ravel(M1), np.ravel(M2))])
Z_41 = zs_41.reshape(M1.shape)
zs_106 = np.array([costs_2(X_train[0:N],y_train_oh[0:N]
,np.array([[mp1]]), np.array([[mp2]]),140)
for mp1, mp2 in zip(np.ravel(M1), np.ravel(M2))])
Z_106 = zs_106.reshape(M1.shape)
fontsize_=19 # axis label font size
titlefontsize_=16 # subplot title font size
# Add subplots to figure:
fig = plt.figure(figsize=(8.2,8.2))
ax0 = fig.add_subplot(2, 2, 1,projection='3d' )
ax1=fig.add_subplot(2, 2, 2,projection='3d')
ax2=fig.add_subplot(2, 2, 3,projection='3d')
ax3=fig.add_subplot(2, 2, 4,projection='3d')
# Customize subplots:
ax0.set_title('seed:158', fontsize=titlefontsize_)
ax0.set_xlabel(r'$w_a$', fontsize=fontsize_, labelpad=-4)
ax0.set_ylabel(r'$w_b$', fontsize=fontsize_, labelpad=-9)
ax0.set_zlabel("costs", fontsize=fontsize_, labelpad=-7)
ax0.set_xticklabels([]) # remove axis tick labels
ax0.set_yticklabels([])
ax0.set_zticklabels([])
ax1.set_title('seed:20', fontsize=titlefontsize_)
ax1.set_xlabel(r'$w_a$', fontsize=fontsize_, labelpad=-4)
ax1.set_ylabel(r'$w_b$', fontsize=fontsize_, labelpad=-9)
ax1.set_zlabel("costs", fontsize=fontsize_, labelpad=-7)
ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax1.set_zticklabels([])
ax2.set_title('seed:41', fontsize=titlefontsize_)
ax2.set_xlabel(r'$w_a$', fontsize=fontsize_, labelpad=-4)
ax2.set_ylabel(r'$w_b$', fontsize=fontsize_, labelpad=-9)
ax2.set_zlabel("costs", fontsize=fontsize_, labelpad=-7)
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax2.set_zticklabels([])
ax3.set_title('seed:106', fontsize=titlefontsize_)
ax3.set_xlabel(r'$w_a$', fontsize=fontsize_, labelpad=-4)
ax3.set_ylabel(r'$w_b$', fontsize=fontsize_, labelpad=-9)
ax3.set_zlabel("costs", fontsize=fontsize_, labelpad=-7)
ax3.set_xticklabels([])
ax3.set_yticklabels([])
ax3.set_zticklabels([])
# Rotate plots around the z-axis:
def rotate(angle):
ax0.view_init(elev=50,azim=angle)
ax1.view_init(elev=50,azim=angle)
ax2.view_init(elev=50,azim=angle)
ax3.view_init(elev=50,azim=angle)
# Create loss landscapes w.r.t. seed:
ax0.plot_surface(M1, M2, Z_158, cmap='terrain',
antialiased=True,cstride=1,rstride=1, alpha=0.99)
ax1.plot_surface(M1, M2, Z_20, cmap='terrain',
antialiased=True,cstride=1,rstride=1, alpha=0.99)
ax2.plot_surface(M1, M2, Z_41, cmap='terrain',
antialiased=True,cstride=1,rstride=1, alpha=0.99)
ax3.plot_surface(M1, M2, Z_106, cmap='terrain',
antialiased=True,cstride=1,rstride=1, alpha=0.99)
plt.tight_layout()
rot_animation = animation.FuncAnimation(fig, rotate, frames=np.arange(0,362,2),interval=100)
rot_animation.save('RotLoss_1000.gif', dpi=80, writer='imagemagick')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment