Python의 Tensorflow 코드를 R로 변환

"First Contact with Tensorflow"의 Regression Python 코드를 R 코드로 되도록이면 1:1로 변경해 보았습니다.

Python 코드
import numpy as np

num_points = 1000
vectors_set = []
for i in xrange(num_points):
  x1= np.random.normal(0.0, 0.55)
  y1= x1 * 0.1 + 0.3 + np.random.normal(0.0, 0.03)
  vectors_set.append([x1, y1])

x_data = [v[0] for v in vectors_set]
y_data = [v[1] for v in vectors_set]

import matplotlib.pyplot as plt

plt.plot(x_data, y_data, 'ro')
plt.legend()
plt.show()

import tensorflow as tf

W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b

loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

init = tf.initialize_all_variables()

sess = tf.Session()
sess.run(init)

for step in xrange(8):
  sess.run(train)
  print(step, sess.run(W), sess.run(b))
  print(step, sess.run(loss))

  plt.plot(x_data, y_data, 'ro')
  plt.plot(x_data, sess.run(W) * x_data + sess.run(b))
  plt.xlabel('x')
  plt.xlim(-2,2)
  plt.ylim(0.1,0.6)
  plt.ylabel('y')
  plt.legend()
  plt.show()

R 코드
library(tensorflow)

num_points <- 1000
vset <- data.frame()
for(i in 1:num_points) {
  x1 <- rnorm(1, 0.0, 0.55)
  y1 <- x1 * 0.1 + 0.3 + rnorm(1, 0.0, 0.03)
  vset <- rbind(vset, data.frame(x1, y1))
}

x_data = vset[,1]
y_data = vset[,2]
plot(x_data, y_data, col='red', xlim=c(-2,2), ylim=c(0.1, 0.6), xlab='x', ylab='y')

W = tf$Variable(tf$random_uniform(shape(1L), -1.0, 1.0))
b = tf$Variable(tf$zeros(shape(1L)))
y = W * x_data + b

loss = tf$reduce_mean(tf$square(y - y_data))
optimizer = tf$train$GradientDescentOptimizer(0.5)
train = optimizer$minimize(loss)

init = tf$initialize_all_variables()

sess = tf$Session()
sess$run(init)

opar <- par(mfrow=c(2,4))
for(step in 1:8) {
  sess$run(train)
  plot(x_data, y_data, col='red', xlim=c(-2,2), ylim=c(0.1, 0.6), xlab='x', ylab='y')
  lines(x_data, sess$run(W) * x_data + sess$run(b), col='blue')
}
par(opar)


by 윤석용

댓글

이 블로그의 인기 게시물

맥에서 여러 버전의 R을 사용하기

오즈와 오즈비, 왜 사용하지?

Raspberry PI에 R을 설치하여 빅데이터 분석